# Main generator


In [32]:
import random
from copy import deepcopy
from typing import Literal, Optional

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

In [33]:
# Without alignment
QWERTY_LOW_LAYOUT_: list[list[str]] = [
    ["`", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "-", "=", "<back>"],
    ["<tab>", "q", "w", "e", "r", "t", "y", "u", "i", "o", "p", "[", "]", "\\"],
    ["<caps>", "a", "s", "d", "f", "g", "h", "j", "k", "l", ";", "'", "<enter>"],
    ["<shift>", "z", "x", "c", "v", "b", "n", "m", ",", ".", "/", "<shift>"],
    [
        "<ctrl>",
        "<alt>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<alt>",
        "<ctrl>",
    ],
]

QWERTY_HIGH_LAYOUT_: list[list[str]] = [
    ["~", "!", "@", "#", "$", "%", "^", "&", "*", "(", ")", "_", "+", "<back>"],
    ["<tab>", "Q", "W", "E", "R", "T", "Y", "U", "I", "O", "P", "{", "}", "|"],
    ["<caps>", "A", "S", "D", "F", "G", "H", "J", "K", "L", ":", '"', "<enter>"],
    ["<shift>", "Z", "X", "C", "V", "B", "N", "M", "<", ">", "?", "<shift>"],
    [
        "<ctrl>",
        "<alt>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<alt>",
        "<ctrl>",
    ],
]

In [34]:
QWERTY_LOW_LAYOUT: list[list[str]] = [
    ["`", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "-", "=", "<back>"],
    ["<tab>", "q", "w", "e", "r", "t", "y", "u", "i", "o", "p", "[", "]", "\\"],
    [
        "<caps>",
        "a",
        "s",
        "d",
        "f",
        "g",
        "h",
        "j",
        "k",
        "l",
        ";",
        "'",
        "<enter>",
        "<enter>",
    ],
    [
        "<shift>",
        "<shift>",
        "z",
        "x",
        "c",
        "v",
        "b",
        "n",
        "m",
        ",",
        ".",
        "/",
        "<shift>",
        "<shift>",
    ],
    [
        "<ctrl>",
        "<alt>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<alt>",
        "<ctrl>",
    ],
]

QWERTY_HIGH_LAYOUT: list[list[str]] = [
    ["~", "!", "@", "#", "$", "%", "^", "&", "*", "(", ")", "_", "+", "<back>"],
    ["<tab>", "Q", "W", "E", "R", "T", "Y", "U", "I", "O", "P", "{", "}", "|"],
    [
        "<caps>",
        "A",
        "S",
        "D",
        "F",
        "G",
        "H",
        "J",
        "K",
        "L",
        ":",
        '"',
        "<enter>",
        "<enter>",
    ],
    [
        "<shift>",
        "<shift>",
        "Z",
        "X",
        "C",
        "V",
        "B",
        "N",
        "M",
        "<",
        ">",
        "?",
        "<shift>",
        "<shift>",
    ],
    [
        "<ctrl>",
        "<alt>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<space>",
        "<alt>",
        "<ctrl>",
    ],
]

In [35]:
(
    len(QWERTY_HIGH_LAYOUT),
    len(QWERTY_HIGH_LAYOUT[0]),
    len(QWERTY_HIGH_LAYOUT[1]),
    len(QWERTY_HIGH_LAYOUT[2]),
    len(QWERTY_HIGH_LAYOUT[3]),
    len(QWERTY_HIGH_LAYOUT[4]),
)

(5, 14, 14, 14, 14, 11)

In [36]:
def get_buttons_set(
    low_layout: list[list[str]], high_layout: list[list[str]]
) -> set[str]:
    buttons: set[str] = set()

    for layout in [low_layout, high_layout]:
        for i in range(len(layout)):
            for btn in layout[i]:
                buttons.add(btn)

    return buttons


def get_keyboard_shape(layout: list[list[str]]) -> tuple[int, ...]:
    shape = [len(row) for row in layout]

    return tuple(shape)


BUTTONS_SET = get_buttons_set(QWERTY_LOW_LAYOUT, QWERTY_HIGH_LAYOUT)
KEYBOARD_LAYOUT_SHAPE = get_keyboard_shape(QWERTY_LOW_LAYOUT)
KEYBOARD_LAYOUT_CUMSUM_SHAPE = np.cumsum(KEYBOARD_LAYOUT_SHAPE)
KEYS = sum(KEYBOARD_LAYOUT_SHAPE) * 2

In [37]:
def convert_int_to_cord(n: int) -> tuple[int, int, int]:
    shift = 0
    row = 0
    column = 0
    if n >= KEYS:
        return 2, 0, 0
    if n < 0:
        return 2, 0, 0
    if n >= KEYS // 2:
        n -= KEYS // 2
        shift = 1
    for i in range(len(KEYBOARD_LAYOUT_CUMSUM_SHAPE)):
        if n < KEYBOARD_LAYOUT_CUMSUM_SHAPE[i]:
            row = i
            break
    if row > 0:
        n -= KEYBOARD_LAYOUT_CUMSUM_SHAPE[row - 1]
    column = n
    return shift, row, column

In [38]:
KEYBOARD_LAYOUT_SHAPE

(14, 14, 14, 14, 11)

In [39]:
def encode_decode_buttons(buttons: set[str]) -> tuple[dict[str, int], dict[int, str]]:
    letters_dict = {}
    for idx, letter in enumerate("abcdefghijklmnopqrsuvwxyz"):
        letters_dict[letter] = idx + 1
        letters_dict[letter.upper()] = -(idx + 1)

    encode_value = (len(letters_dict) // 2) + 1
    encode_dict = {}
    decode_dict = {}
    for btn in buttons:
        if btn in letters_dict:
            decode_dict[letters_dict[btn]] = btn
            encode_dict[btn] = letters_dict[btn]
        else:
            decode_dict[encode_value] = btn
            encode_dict[btn] = encode_value
            encode_value += 1
    return encode_dict, decode_dict


ENCODE_DICT, DECODE_DICT = encode_decode_buttons(BUTTONS_SET)
ENCODED_BUTTONS_SET = {ENCODE_DICT[btn] for btn in BUTTONS_SET}
SHIFT_CODE = ENCODE_DICT["<shift>"]

In [40]:
for btn in BUTTONS_SET:
    assert btn == DECODE_DICT[ENCODE_DICT[btn]]

In [41]:
Layout = list[list[int]]


def encode_layout(layout: list[list[str]]) -> Layout:
    return [[ENCODE_DICT[btn] for btn in layout[i]] for i in range(len(layout))]


def decode_layout(layout: Layout) -> list[list[str]]:
    return [[DECODE_DICT[btn] for btn in layout[i]] for i in range(len(layout))]


QWERTY_ENCODED_HIGH: Layout = encode_layout(QWERTY_HIGH_LAYOUT)
QWERTY_ENCODED_LOW: Layout = encode_layout(QWERTY_LOW_LAYOUT)

In [42]:
def get_all_buttons_encoded(high_layout: Layout, low_layout: Layout) -> list[int]:
    all_buttons = []

    for layout in (low_layout, high_layout):
        for row in layout:
            all_buttons.extend(row)
    return all_buttons


ALL_BUTTONS_ENCODED = get_all_buttons_encoded(QWERTY_ENCODED_HIGH, QWERTY_ENCODED_LOW)

In [43]:
LogType = Literal["basic"] | Literal["debug"] | Literal["error"]


class Logger:
    def __init__(self, verbose: bool = True, hide_types: list[LogType] = []) -> None:
        self.verbose = verbose
        self.hide_types = set(hide_types)

    def log(self, message: str, log_type: LogType = "basic") -> None:
        if self.verbose and log_type not in self.hide_types:
            print(message)


LOGGER = Logger()

In [44]:
Position = tuple[int, int]


class Finger:
    def __init__(
        self, initial_position: Position, name: str, logger: Logger = LOGGER
    ) -> None:
        self.name = name
        self.initial_position = initial_position

        self.logger = logger

        self.reset()

        # Constants

        self.wait_before_return = 4  # in ticks

        self.long_row_move_shift = 3
        self.long_row_move_penalty = 1

        self.row_penalty_coefficient = 1
        self.column_penalty_coefficient = 1.2

    def reset(self):
        self.current_position = self.initial_position
        self.ticks_before_return = 0  # if == 0, returns to the initial position
        self.typed_keys = 0

    def move(self, position: Position):
        self.current_position = position

        self.ticks_before_return = self.wait_before_return
        self.typed_keys += 1

    def tick(self) -> float:
        if self.ticks_before_return > 0:
            self.ticks_before_return -= 1

        if self.ticks_before_return == 0:
            score = self.get_score(self.initial_position)
            self.current_position = self.initial_position
            return score

        return 0

    def get_score(self, target_position: Position) -> float:
        x1, y1 = self.current_position
        x2, y2 = target_position

        row_distance = abs(x1 - x2) ** 2
        column_distance = abs(y1 - y2) ** 2

        penalty = 0
        if row_distance > self.long_row_move_shift:
            penalty = self.long_row_move_penalty
        return (
            row_distance * self.row_penalty_coefficient
            + column_distance * self.column_penalty_coefficient
            + penalty
        )

    def show_statistics(self):
        self.logger.log(
            f"Name: {self.name:22} \
            Typed keys: {self.typed_keys:5} \
            Ticks before return: {self.ticks_before_return:5} \
            Current position: {self.current_position}\t\
            Default position: {self.initial_position}"
        )

In [45]:
DEFAULT_FINGERS: list[Finger] = [
    Finger((2, 1), "левый мизинец"),
    Finger((2, 2), "левый безымянный"),
    Finger((2, 3), "левый средний"),
    Finger((2, 4), "левый указательный"),
    Finger((4, 3), "левый большой"),
    Finger((4, 6), "правый большой"),
    Finger((2, 7), "правый указательный"),
    Finger((2, 8), "правый средний"),
    Finger((2, 9), "правый безымянный"),
    Finger((2, 10), "правый мизинец"),
]

SwapType = Literal["low_layout"] | Literal["high_layout"] | Literal["between_layouts"]


class KeyboardLayout:
    @staticmethod
    def layout_to_dict(
        layout: Layout, unused_layout: Layout
    ) -> dict[int, list[Position]]:
        layout_dict: dict[int, list[Position]] = {}

        for i in range(len(layout)):
            for j in range(len(layout[i])):
                button = layout[i][j]
                if button in layout_dict:
                    layout_dict[button].append((i, j))
                else:
                    layout_dict[button] = [(i, j)]

        for i in range(len(unused_layout)):
            for j in range(len(unused_layout[i])):
                button = unused_layout[i][j]
                if button not in layout_dict:
                    layout_dict[button] = []

        return layout_dict

    def _finish_move(self):
        for finger in self.fingers:
            self.total_score += finger.tick()

    def __init__(self, low_layout: Layout, high_layout: Layout, logger: Logger = LOGGER):
        self.low_layout = deepcopy(low_layout)
        self.high_layout = deepcopy(high_layout)

        self.low_layout_dict = KeyboardLayout.layout_to_dict(
            self.low_layout, self.high_layout
        )
        self.high_layout_dict = KeyboardLayout.layout_to_dict(
            self.high_layout, self.low_layout
        )

        self.logger = logger

        self.fingers = deepcopy(DEFAULT_FINGERS)

        self.reset()

    def reset(self):
        self.total_score: float = 0
        self.typed_keys: int = 0
        for f in self.fingers:
            f.reset()

    def move_one_finger(
        self, positions: list[Position], busy_finger_id: Optional[int] = None
    ) -> tuple[tuple[int, Position], float]:
        best_finger_id: int = 0
        best_score = np.inf

        final_position: Position = (0, 0)

        for position in positions:
            scores = [
                finger.get_score(position) if i != busy_finger_id else np.inf
                for i, finger in enumerate(self.fingers)
            ]

            candidate_finger_id = int(np.argmin(scores))
            candidate_score = scores[candidate_finger_id]

            if candidate_score < best_score:
                best_score = candidate_score
                best_finger_id = candidate_finger_id
                final_position = position

        return (best_finger_id, final_position), best_score

    def move_two_fingers(
        self, positions: list[Position]
    ) -> tuple[tuple[int, Position], tuple[int, Position], float]:
        shift_positions = self.low_layout_dict[SHIFT_CODE]
        if len(shift_positions) == 0:
            print("ERROR SHIFT IS UNREACHABLE")
            return (0, (0, 0)), (0, (0, 0)), 9999

        # firstly reach SHIFT, then - positions
        finger_shift_info_1, shift_distance_1 = self.move_one_finger(shift_positions)
        finger_btn_info_1, d1_btn = self.move_one_finger(
            positions, finger_shift_info_1[0]
        )
        total_distance_1 = shift_distance_1 + d1_btn

        # firstly reach positions, then - SHIFT
        finger_btn_info_2, d1_btn = self.move_one_finger(positions)
        finger_shift_info_2, shift_distance_2 = self.move_one_finger(
            shift_positions, finger_btn_info_2[0]
        )
        total_distance_2 = shift_distance_2 + d1_btn

        if total_distance_1 < total_distance_2:
            return finger_btn_info_1, finger_shift_info_1, total_distance_1

        return finger_btn_info_2, finger_shift_info_2, total_distance_2

    def find_button(self, button: int):
        if button in self.low_layout_dict:
            (finger_id, finger_position), score = self.move_one_finger(
                self.low_layout_dict[button]
            )

            self.fingers[finger_id].move(finger_position)
            self.total_score += score
            self.typed_keys += 1

            self.logger.log(f"{button}:\t{self.fingers[finger_id].name}")

        elif button in self.high_layout_dict:
            (
                (finger_id_1, finger_position_1),
                (finger_id_2, finger_position_2),
                score,
            ) = self.move_two_fingers(self.high_layout_dict[button])

            self.fingers[finger_id_1].move(finger_position_1)
            self.fingers[finger_id_2].move(finger_position_2)
            self.total_score += score
            self.typed_keys += 2

            self.logger.log(
                f"{button}:\t{self.fingers[finger_id_1].name} + {self.fingers[finger_id_2].name}"
            )

        else:
            self.logger.log(f"NO SUCH KEY: {button}")

        self._finish_move()

    def type_text(self, text: list[str]) -> float:
        for button in text:
            self.find_button(ENCODE_DICT[button])

        return self.total_score

    def type_encoded_text(self, encoded_text: list[int]) -> float:
        for button in encoded_text:
            self.find_button(button)
        return self.total_score

    def swap_buttons(self, position1: Position, position2: Position, swap_type: SwapType):
        if swap_type == "high_layout":
            layout_from = layout_to = self.high_layout
            layout_from_dict = layout_to_dict = self.high_layout_dict
        elif swap_type == "low_layout":
            layout_from = layout_to = self.low_layout
            layout_from_dict = layout_to_dict = self.low_layout_dict
        else:  # swap_type == "between_layouts"
            layout_from = self.low_layout
            layout_to = self.high_layout
            layout_from_dict = self.low_layout_dict
            layout_to_dict = self.high_layout_dict

        x1, y1 = position1
        btn1 = layout_from[x1][y1]
        x2, y2 = position2
        btn2 = layout_to[x2][y2]

        layout_from[x1][y1], layout_to[x2][y2] = layout_to[x2][y2], layout_from[x1][y1]

        layout_from_dict[btn1].remove(position1)
        layout_to_dict[btn2].remove(position2)

        if btn2 in layout_from_dict:
            layout_from_dict[btn2].append(position1)
        else:
            layout_from_dict[btn2] = [position1]
        if btn1 in layout_to_dict:
            layout_to_dict[btn1].append(position2)
        else:
            layout_to_dict[btn1] = [position2]

    def decode_layouts(self) -> tuple[list[list[str]], list[list[str]]]:
        return (decode_layout(self.low_layout), decode_layout(self.high_layout))

    def get_sting_layouts(self) -> str:
        low_layout, high_layout = self.decode_layouts()
        result_string = "High layout:\n"
        for row in high_layout:
            for s in row:
                result_string += f"{s:8}"
            result_string += "\n"
        result_string += "\n"

        result_string += "\nLow layout:\n"
        for row in low_layout:
            for s in row:
                result_string += f"{s:8}"
            result_string += "\n"
        result_string += "\n"

        return result_string

    def show_statistics(self):
        self.logger.log("\nStatistics:")
        for f in self.fingers:
            f.show_statistics()

    def flatten(self):
        flatten = []
        for row in self.low_layout:
            flatten.extend(row)
        for row in self.high_layout:
            flatten.extend(row)
        return torch.as_tensor(flatten, dtype=torch.float32)

    def get_average_score(self) -> float:
        return self.total_score / self.typed_keys

In [46]:
def generate_random_layout(
    all_buttons_encoded: list[int],
    keyboard_shape: tuple[int, ...],
    seed: Optional[int] = None,
) -> tuple[Layout, Layout]:
    if seed is not None:
        random.seed(seed)

    all_buttons = all_buttons_encoded.copy()

    low_layout = [[0 for _ in range(row_length)] for row_length in keyboard_shape]
    high_layout = [[0 for _ in range(row_length)] for row_length in keyboard_shape]

    # Push single SHIFT to the low layout
    all_buttons.remove(SHIFT_CODE)
    random.shuffle(all_buttons)

    shift_position_low_layout = random.randint(0, len(all_buttons_encoded) // 2)
    all_buttons.insert(shift_position_low_layout, SHIFT_CODE)
    pointer = 0
    for layout in (low_layout, high_layout):
        for i, row_length in enumerate(keyboard_shape):
            for j in range(row_length):
                layout[i][j] = all_buttons[pointer]
                pointer += 1

    return low_layout, high_layout


sample_low_layout, sample_high_layout = generate_random_layout(
    ALL_BUTTONS_ENCODED, KEYBOARD_LAYOUT_SHAPE
)
sample_low_layout, sample_high_layout

([[35, -21, 25, 16, 62, 15, 72, 38, 75, -23, -7, 47, -1, -24],
  [53, 56, 32, 67, 27, 36, 67, -5, 60, 30, 67, 31, 67, 67],
  [-18, 53, -13, 53, 19, 67, 46, 74, 21, 18, -20, 45, 69, 49],
  [47, 67, 3, 17, 10, 47, 60, -9, 4, -10, -22, 55, 1, -8],
  [53, 77, 54, 39, 20, 44, 14, 53, 57, 57, -19]],
 [[61, -3, 22, 7, 67, 33, 67, 36, -25, 71, 9, 67, 46, 28],
  [53, -17, 41, 48, 36, 63, 76, 58, 60, -16, 42, 8, 26, 50],
  [40, -12, 37, 65, 11, 67, 2, 51, 53, 34, 26, 68, -14, 23],
  [36, 59, 67, 52, -11, 73, 29, 6, 53, 67, 67, -6, 60, 12],
  [64, -15, 66, 5, 43, 24, 70, -4, 13, 47, -2]])

### Test Keyboard

In [47]:
qwerty_keyboard = KeyboardLayout(QWERTY_ENCODED_LOW, QWERTY_ENCODED_HIGH, LOGGER)
print(qwerty_keyboard.get_sting_layouts())

High layout:
~       !       @       #       $       %       ^       &       *       (       )       _       +       <back>  
<tab>   Q       W       E       R       T       Y       U       I       O       P       {       }       |       
<caps>  A       S       D       F       G       H       J       K       L       :       "       <enter> <enter> 
<shift> <shift> Z       X       C       V       B       N       M       <       >       ?       <shift> <shift> 
<ctrl>  <alt>   <space> <space> <space> <space> <space> <space> <space> <alt>   <ctrl>  


Low layout:
`       1       2       3       4       5       6       7       8       9       0       -       =       <back>  
<tab>   q       w       e       r       t       y       u       i       o       p       [       ]       \       
<caps>  a       s       d       f       g       h       j       k       l       ;       '       <enter> <enter> 
<shift> <shift> z       x       c       v       b       n       m       ,       .       /    

In [48]:
print(qwerty_keyboard.low_layout_dict[ENCODE_DICT["s"]])
print(qwerty_keyboard.low_layout_dict[ENCODE_DICT["2"]])
print(qwerty_keyboard.high_layout_dict[ENCODE_DICT["S"]])
print(qwerty_keyboard.high_layout_dict[ENCODE_DICT["@"]])

[(2, 2)]
[(0, 2)]
[(2, 2)]
[(0, 2)]


In [49]:
qwerty_keyboard.swap_buttons((2, 2), (0, 2), "low_layout")
qwerty_keyboard.swap_buttons((2, 2), (0, 2), "high_layout")

print(qwerty_keyboard.low_layout_dict[ENCODE_DICT["s"]])
print(qwerty_keyboard.low_layout_dict[ENCODE_DICT["2"]])
print(qwerty_keyboard.high_layout_dict[ENCODE_DICT["S"]])
print(qwerty_keyboard.high_layout_dict[ENCODE_DICT["@"]])

[(0, 2)]
[(2, 2)]
[(0, 2)]
[(2, 2)]


In [50]:
print(qwerty_keyboard.low_layout[0][5])
print(qwerty_keyboard.high_layout[1][6])

45
-24


In [51]:
qwerty_keyboard.swap_buttons((0, 5), (1, 6), "between_layouts")
print(qwerty_keyboard.low_layout[0][5])
print(qwerty_keyboard.high_layout[1][6])

-24
45


In [52]:
qwerty_keyboard.show_statistics()


Statistics:
Name: левый мизинец                      Typed keys:     0             Ticks before return:     0             Current position: (2, 1)	            Default position: (2, 1)
Name: левый безымянный                   Typed keys:     0             Ticks before return:     0             Current position: (2, 2)	            Default position: (2, 2)
Name: левый средний                      Typed keys:     0             Ticks before return:     0             Current position: (2, 3)	            Default position: (2, 3)
Name: левый указательный                 Typed keys:     0             Ticks before return:     0             Current position: (2, 4)	            Default position: (2, 4)
Name: левый большой                      Typed keys:     0             Ticks before return:     0             Current position: (4, 3)	            Default position: (4, 3)
Name: правый большой                     Typed keys:     0             Ticks before return:     0             Current position:

In [53]:
qwerty_keyboard.type_text(list("Procrastination!"))

-16:	правый мизинец + левый мизинец
18:	левый указательный
15:	правый безымянный
3:	левый средний
18:	левый указательный
1:	левый мизинец
19:	левый безымянный
56:	левый указательный
9:	правый средний
14:	правый указательный
1:	левый мизинец
56:	левый указательный
9:	правый средний
15:	правый безымянный
14:	правый указательный
64:	левый мизинец + левый безымянный


44.400000000000006

In [54]:
qwerty_keyboard.show_statistics()


Statistics:
Name: левый мизинец                      Typed keys:     4             Ticks before return:     3             Current position: (0, 1)	            Default position: (2, 1)
Name: левый безымянный                   Typed keys:     2             Ticks before return:     3             Current position: (3, 1)	            Default position: (2, 2)
Name: левый средний                      Typed keys:     1             Ticks before return:     0             Current position: (2, 3)	            Default position: (2, 3)
Name: левый указательный                 Typed keys:     4             Ticks before return:     0             Current position: (2, 4)	            Default position: (2, 4)
Name: левый большой                      Typed keys:     0             Ticks before return:     0             Current position: (4, 3)	            Default position: (4, 3)
Name: правый большой                     Typed keys:     0             Ticks before return:     0             Current position:

As we can see, we successfully can swap different keys in given keyboard and calculate distance function fro given string.


## Data loading & preprocessing

In [55]:
data_frame = pd.read_csv("../data/raw/cpp_programs.csv", index_col=0)
data_frame.head()

Unnamed: 0,text
0,"point operator-(point p1,point p2)\r\n p1.x-=..."
1,"int main() {\r\n map<char, int> m;\r\n m..."
2,"int main()\r\n\tint n,s[1001],cnt=0;\r\n\tscan..."
3,int main() {\r\n\tcin >> n;\r\n\tfor (int i = ...
4,"int main()\r\n char a, b, c, d, kozir;\..."


In [56]:
text_lengths = data_frame["text"].apply(len)
print(f"Mean length: {text_lengths.mean()}")
print(f"Max length: {text_lengths.max()}")
print(f"Min length: {text_lengths.min()}")

Mean length: 484.69309686721283
Max length: 62514
Min length: 13


In [57]:
MAX_LENGTH = 400
PADDING_VALUE = ENCODE_DICT["<space>"]

SPECIAL_SYMBOLS = {
    "\t": "<tab>",
    "\n": "<enter>",
    " ": "<space>",
}


def preprocess_text(text: str) -> list[int]:
    splitted_text = [
        SPECIAL_SYMBOLS[s] if s in SPECIAL_SYMBOLS else s for s in list(text)
    ]
    encoded_text = [
        ENCODE_DICT[symbol]
        for symbol in list(filter(lambda s: s in BUTTONS_SET, splitted_text))[:MAX_LENGTH]
    ]
    return encoded_text + [PADDING_VALUE for _ in range(MAX_LENGTH - len(encoded_text))]


data_frame["encoded_text"] = data_frame["text"].apply(preprocess_text)

In [58]:
dataset = np.array([np.array(text) for text in data_frame["encoded_text"].to_list()])
dataset.shape

(119989, 400)

## Measure QWERTY

In [59]:
qwerty_keyboard = KeyboardLayout(
    QWERTY_ENCODED_LOW, QWERTY_ENCODED_HIGH, Logger(verbose=False)
)

random_low, random_high = generate_random_layout(
    ALL_BUTTONS_ENCODED, KEYBOARD_LAYOUT_SHAPE, seed=42
)
random_keyboard = KeyboardLayout(random_low, random_high, Logger(verbose=False))


loop = tqdm(dataset[:100])
for text in loop:
    qwerty_score_total = qwerty_keyboard.type_encoded_text(text)
    random_score_total = random_keyboard.type_encoded_text(text)
    loop.set_postfix({"qwerty": qwerty_score_total, "random": random_score_total})

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:05<00:00, 17.32it/s, qwerty=1e+5, random=1.23e+5]  


In [60]:
print(f"{qwerty_keyboard.total_score=:.0f}")
print(f"{random_keyboard.total_score=:.0f}")

qwerty_keyboard.total_score=100223
random_keyboard.total_score=123039


In [61]:
print(random_keyboard.get_sting_layouts())

High layout:
H       &       V       <tab>   8       p       v       l       <space> 0       <space> q       @       <space> 
5       g       X       x       c       =       n       N       <space> |       <ctrl>  t       G       z       
>       "       y       D       `       <space> ?       <space> <ctrl>  <shift> <shift> O       E       {       
[       %       M       _       <alt>   a       \       <shift> 3       4       <shift> (       -       #       
B       <space> Y       <back>  <caps>  j       <space> <space> $       6       <caps>  


Low layout:
]       m       F       b       <enter> u       1       :       d       ^       9       w       <alt>   '       
<shift> U       7       P       s       <space> .       <alt>   C       ,       f       K       <space> <ctrl>  
r       <alt>   Q       *       <shift> k       2       !       Z       S       <tab>   }       /       <shift> 
<space> R       W       h       J       <space> ;       <       <enter> A       <enter> <spac

## Score estimator 

In [111]:
def generate_random_layouts(n: int) -> list[KeyboardLayout]:
    layouts = []
    for i in range(n):
        random_low, random_high = generate_random_layout(
            ALL_BUTTONS_ENCODED, KEYBOARD_LAYOUT_SHAPE, seed=i
        )
        layouts.append(KeyboardLayout(random_low, random_high, Logger(verbose=False)))
    return layouts


def estimate_layouts(layouts: list[KeyboardLayout], dataset: np.ndarray) -> list[float]:
    scores = []
    loop = tqdm(dataset)
    for layout in layouts:
        layout.reset()
    for text in loop:
        for layout in layouts:
            layout.type_encoded_text(text)
    for layout in layouts:
        scores.append(layout.total_score)
        # scores.append(layout.get_average_score())
    return torch.as_tensor(scores)

In [63]:
random_layouts_100 = generate_random_layouts(500)
random_layouts_100_scores = estimate_layouts(random_layouts_100, dataset[:10])

100%|██████████| 10/10 [02:49<00:00, 16.94s/it]


In [64]:
# for i in range(100):
#     if random_layouts_100_scores[i] < qwerty_score_total:
#         print(i, random_layouts_100_scores[i])

In [65]:
print(random_layouts_100[81].get_sting_layouts())

High layout:
:       W       <tab>   p       &       <tab>   <space> V       <space> h       C       <shift> 7       <space> 
=       R       L       j       <alt>   l       x       <ctrl>  K       2       |       O       d       <space> 
E       a       m       J       <shift> {       G       o       %       *       ;       S       <ctrl>  \       
9       <shift> w       U       '       <caps>  c       }       k       <enter> N       `       _       <enter> 
D       <space> <space> t       y       4       <ctrl>  F       Y       M       <space> 


Low layout:
u       <shift> <enter> X       <space> 5       <alt>   <back>  <space> ,       "       3       <caps>  <space> 
<enter> <ctrl>  >       <alt>   f       e       <back>  b       v       ^       ~       0       $       <space> 
Z       <       <shift> P       /       r       .       )       s       g       1       z       I       <shift> 
<shift> B       6       <space> ]       A       !       @       Q       ?       i       T    

In [66]:
sum(KEYBOARD_LAYOUT_SHAPE) * 2

134

In [67]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [68]:
class Net(nn.Module):
    def __init__(self, in_size: int = KEYS):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(in_size, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

In [69]:
score_model = Net()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
score_model.to(device)
optimizer = torch.optim.Adam(score_model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

In [70]:
data_loader = list(zip(random_layouts_100, random_layouts_100_scores))

In [71]:
def train_epoch(model, loader, optimizer, loss_fn, epoch):
    model.train()
    train_loss = 0.0
    total = 0
    loop = loader
    # loop = tqdm(
    #     loader,
    #     total=len(loader),
    #     desc=f"Epoch {epoch}: train",
    #     leave=True,
    # )
    for data in loop:
        layout, score = data
        data = layout.flatten()
        values, targets = data.to(device), score.to(device)
        optimizer.zero_grad()
        outputs = model(values)
        # outputs.backward(my_loss)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        total += 1
        train_loss += loss.item()
    return train_loss / total
    # loop.set_postfix({"loss": train_loss / total})

In [72]:
loop = tqdm(range(400), desc="Train:")
for epoch in loop:
    loss = train_epoch(score_model, data_loader, optimizer, criterion, epoch)
    loop.set_description("Loss: {:.4f}".format(loss))

  return F.mse_loss(input, target, reduction=self.reduction)
Loss: 56362.1231: 100%|██████████| 400/400 [07:21<00:00,  1.10s/it] 


In [73]:
a, b = data_loader[13]
print(b)
a, b = a.flatten().to(device), b.to(device)

tensor(11684.5996)


In [74]:
score_model(a)

tensor([12101.0703], device='cuda:0', grad_fn=<ReluBackward0>)

In [139]:
class ActionNet(nn.Module):
    def __init__(self, in_size: int = KEYS):
        super(ActionNet, self).__init__()
        self.fc1 = nn.Linear(in_size, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, KEYS)
        self.fc4 = nn.Linear(64, KEYS)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x1 = F.softmax(self.fc3(x), dim=0)
        x2 = F.softmax(self.fc4(x), dim=0)
        return x1, x2

In [154]:
action_model = ActionNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
action_optimizer = torch.optim.Adam(action_model.parameters(), lr=1e-3)
action_model.to(device)

ActionNet(
  (fc1): Linear(in_features=134, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=134, bias=True)
  (fc4): Linear(in_features=64, out_features=134, bias=True)
)

In [155]:
new_keyboard = KeyboardLayout(
    QWERTY_ENCODED_LOW, QWERTY_ENCODED_HIGH, Logger(verbose=False)
)

new_keyboard_score = estimate_layouts([new_keyboard], dataset[:10])
last_score = new_keyboard_score.item()
best_score = new_keyboard_score.item()
best_score

100%|██████████| 10/10 [00:00<00:00, 35.71it/s]


10272.599609375

In [156]:
new_keyboard.get_average_score()

2.3330910742675557

In [157]:
steps = 20
best_keyboard = deepcopy(new_keyboard)

for i in tqdm(range(steps)):
    action_optimizer.zero_grad()

    cur_layout = new_keyboard.flatten().to(device)
    output_1, output_2 = action_model(cur_layout)
    n_1, n_2 = torch.argmax(output_1, dim=0).item(), torch.argmax(output_2, dim=0).item()
    cord_1, cord_2 = convert_int_to_cord(n_1), convert_int_to_cord(n_2)

    loss_1 = np.zeros(KEYS)
    loss_2 = np.zeros(KEYS)

    cur_loss = score_model(cur_layout)

    if (
        cord_1[0] == 2
        or cord_2[0] == 2
        or (cord_1[1] == cord_2[1] and cord_1[2] == cord_2[2])
    ):
        print(output_1[:10])
        loss = last_score
        for i in range(KEYS):
            loss_1[i] = loss
            loss_2[i] = loss
        loss_1[n_1] = 0
        loss_2[n_2] = 0
    else:
        print(output_1[:10])
        pos_1 = (cord_1[1], cord_1[2])
        pos_2 = (cord_2[1], cord_2[2])
        if cord_1[0] == 0 and cord_2[0] == 0:
            new_keyboard.swap_buttons(pos_1, pos_2, "low_layout")
        elif cord_1[0] == 1 and cord_2[0] == 1:
            new_keyboard.swap_buttons(pos_1, pos_2, "high_layout")
        else:
            new_keyboard.swap_buttons(pos_1, pos_2, "between_layouts")

        if i % 10 == 0:
            new_keyboard_score = estimate_layouts([new_keyboard], dataset[:10])
            score = new_keyboard_score.item()
        else:
            new_layout = new_keyboard.flatten().to(device)
            score = score_model(new_layout).item()
        # print(score)
        if score < best_score:
            best_score = score
            best_keyboard = deepcopy(new_keyboard)
        loss_1[n_1] = score
        loss_2[n_2] = score
        last_score = score

    print(loss_1[:10])
    back_score_1 = torch.as_tensor(loss_1)
    back_score_1 = back_score_1.to(device)
    back_score_2 = torch.as_tensor(loss_2)
    back_score_2 = back_score_2.to(device)
    output_1.backward(back_score_1, retain_graph=True)
    output_2.backward(back_score_2)

    action_optimizer.step()

  0%|          | 0/20 [00:00<?, ?it/s]

tensor([9.9978e-05, 4.0685e-05, 1.3357e-05, 6.5578e-05, 7.2768e-07, 2.3466e-03,
        2.1941e-04, 7.8379e-08, 4.8077e-04, 4.4775e-08], device='cuda:0',
       grad_fn=<SliceBackward0>)


100%|██████████| 10/10 [00:00<00:00, 32.89it/s]
 50%|█████     | 10/20 [00:00<00:00, 29.71it/s]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
tensor([3.9583e-05, 1.8214e-06, 4.4327e-08, 4.5765e-05, 4.6259e-08, 1.0615e-04,
        1.2514e-05, 5.0852e-10, 2.5919e-06, 2.8420e-11], device='cuda:0',
       grad_fn=<SliceBackward0>)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
tensor([6.3480e-05, 4.2534e-06, 1.0808e-08, 8.5929e-04, 1.0329e-07, 1.4147e-04,
        5.1622e-05, 1.6009e-09, 1.5702e-06, 3.1955e-12], device='cuda:0',
       grad_fn=<SliceBackward0>)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
tensor([4.6567e-07, 2.6730e-07, 1.4714e-10, 2.3469e-04, 3.1595e-09, 5.0389e-06,
        3.7720e-06, 3.9927e-10, 4.2528e-08, 5.0389e-15], device='cuda:0',
       grad_fn=<SliceBackward0>)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
tensor([1.0100e-08, 2.1407e-08, 5.5372e-12, 4.4931e-05, 7.6682e-11, 1.7576e-07,
        2.8006e-07, 7.2006e-11, 1.4178e-09, 2.0314e-17], device='cuda:0',
       grad_fn=<SliceBackward0>)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
tensor([1.0916e-09, 2.7020e-09, 1.0592e-12, 1.3144e-05, 5.4185e-12, 2.0912e-08,
        5.43

100%|██████████| 10/10 [00:00<00:00, 29.03it/s]
100%|██████████| 20/20 [00:00<00:00, 24.05it/s]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], device='cuda:0',
       grad_fn=<SliceBackward0>)
[ 0. inf inf inf inf inf inf inf inf inf]
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], device='cuda:0',
       grad_fn=<SliceBackward0>)
[ 0. inf inf inf inf inf inf inf inf inf]
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], device='cuda:0',
       grad_fn=<SliceBackward0>)
[ 0. inf inf inf inf inf inf inf inf inf]
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], device='cuda:0',
       grad_fn=<SliceBackward0>)
[ 0. inf inf inf inf inf inf inf inf inf]
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], device='cuda:0',
       grad_fn=<SliceBackward0>)
[ 0. inf inf inf inf inf inf inf inf inf]
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], device='cuda:0',
       grad_fn=<SliceBackward0>)
[ 0. inf inf inf inf inf inf inf inf inf]
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dev




In [80]:
print(new_keyboard.get_sting_layouts())

High layout:
~       N       @       #       $       %       )       &       *       (       <shift> _       +       <back>  
<tab>   Q       W       E       R       T       Y       U       I       O       P       {       }       X       
<caps>  A       !       D       F       G       H       <       K       L       :       "       <enter> <enter> 
<shift> <shift> Z       |       C       V       B       S       M       J       >       ?       <shift> <enter> 
<ctrl>  <alt>   <space> <space> <space> <space> <space> <space> <space> <alt>   <ctrl>  


Low layout:
`       1       2       3       4       5       w       7       8       9       0       -       =       <back>  
<tab>   q       6       e       r       t       y       u       i       o       p       [       ]       \       
<caps>  a       s       d       f       ^       h       j       k       l       ;       '       g       <enter> 
<shift> <shift> z       x       c       v       b       n       m       ,       .       /    

In [92]:
best_score

10272.599609375