In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader



CSV_PATH = "translit_mistakes.csv"
MODEL_SAVE_PATH = "translit_model.pt"

BATCH_SIZE = 64
MAX_EPOCHS = 5
MAX_LEN = 32

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


PAD_TOKEN = "<pad>"
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"



df = pd.read_csv(CSV_PATH)
df.dropna(inplace=True)


df["MISTAKE"] = df["MISTAKE"].astype(str)
df["CORRECT"] = df["CORRECT"].astype(str)


all_text = list(df["MISTAKE"]) + list(df["CORRECT"])
unique_chars = set()
for word in all_text:
    for ch in word:
        unique_chars.add(ch)


vocab_list = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN] + sorted(list(unique_chars))
vocab2idx = {v: i for i, v in enumerate(vocab_list)}
idx2vocab = {i: v for i, v in enumerate(vocab_list)}

PAD_IDX = vocab2idx[PAD_TOKEN]
BOS_IDX = vocab2idx[BOS_TOKEN]
EOS_IDX = vocab2idx[EOS_TOKEN]
VOCAB_SIZE = len(vocab_list)

print(f"The dictionary is formed. Size: {VOCAB_SIZE} symbols (including special tokens).")



def text_to_tensor(text: str, max_len=MAX_LEN) -> torch.Tensor:
    tokens = [BOS_IDX]
    for ch in text:
        if ch in vocab2idx:
            tokens.append(vocab2idx[ch])
        else:
            tokens.append(PAD_IDX)
    tokens.append(EOS_IDX)

    if len(tokens) < max_len:
        tokens += [PAD_IDX] * (max_len - len(tokens))
    else:
        tokens = tokens[:max_len]

    return torch.tensor(tokens, dtype=torch.long)

def tensor_to_text(tensor: torch.Tensor) -> str:
    chars = []
    for idx in tensor:
        idx_val = idx.item()
        if idx_val == BOS_IDX:
            continue
        if idx_val == EOS_IDX or idx_val == PAD_IDX:
            break
        chars.append(idx2vocab[idx_val])
    return "".join(chars)



class TranslitDataset(Dataset):
    def __init__(self, dataframe, max_len=MAX_LEN):
        self.df = dataframe
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        mistake_word = row["MISTAKE"]
        correct_word = row["CORRECT"]

        src_tensor = text_to_tensor(mistake_word, self.max_len)
        tgt_tensor = text_to_tensor(correct_word, self.max_len)
        return src_tensor, tgt_tensor

dataset = TranslitDataset(df, max_len=MAX_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.W1 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W2 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, decoder_hidden, encoder_outputs):
        decoder_hidden = decoder_hidden.unsqueeze(1)

        energy = self.v(
            torch.tanh(
                self.W1(decoder_hidden) + self.W2(encoder_outputs)
            )
        )

        attention_weights = torch.softmax(energy.squeeze(2), dim=1)

        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)

        context = context.squeeze(1)
        return context, attention_weights

class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, pad_idx):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=pad_idx)
        self.lstm = nn.LSTM(
            input_size=embed_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True
        )

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, (h, c) = self.lstm(embedded)
        return outputs, (h, c)


class DecoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, pad_idx):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=pad_idx)
        self.lstm = nn.LSTM(
            input_size=embed_size + hidden_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True
        )
        self.attention = Attention(hidden_size)
        self.fc_out = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_token, hidden, cell, encoder_outputs):
        embedded = self.embedding(input_token)
        embedded = embedded.unsqueeze(1)

        context, attn_weights = self.attention(hidden, encoder_outputs)
        context = context.unsqueeze(1)

        rnn_input = torch.cat((embedded, context), dim=2)

        output, (h_new, c_new) = self.lstm(rnn_input, (hidden.unsqueeze(0), cell.unsqueeze(0)))

        h_new = h_new.squeeze(0)
        c_new = c_new.squeeze(0)

        logits = self.fc_out(output.squeeze(1))

        return logits, h_new, c_new, attn_weights


class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, pad_idx):
        super().__init__()
        self.encoder = EncoderRNN(vocab_size, embed_size, hidden_size, pad_idx)
        self.decoder = DecoderRNN(vocab_size, embed_size, hidden_size, pad_idx)
        self.hidden_size = hidden_size

    def forward(self, src, tgt):
        batch_size, tgt_len = tgt.shape


        encoder_outputs, (h, c) = self.encoder(src)
        h = h.squeeze(0)
        c = c.squeeze(0)

        outputs = []
        input_token = tgt[:, 0]
        for t in range(1, tgt_len):
            logits, h, c, _ = self.decoder(input_token, h, c, encoder_outputs)
            outputs.append(logits.unsqueeze(1))

            input_token = tgt[:, t]

        logits_seq = torch.cat(outputs, dim=1)
        return logits_seq


EMBED_SIZE = 128
HIDDEN_SIZE = 256

model = Seq2Seq(
    vocab_size=VOCAB_SIZE,
    embed_size=EMBED_SIZE,
    hidden_size=HIDDEN_SIZE,
    pad_idx=PAD_IDX
).to(DEVICE)

loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


def train_one_epoch(model, dataloader, optimizer, loss_fn, device=DEVICE):
    model.train()
    total_loss = 0

    for src, tgt in dataloader:
        src = src.to(device)
        tgt = tgt.to(device)

        logits_seq = model(src, tgt)

        tgt_y = tgt[:, 1:]

        logits_seq = logits_seq.reshape(-1, VOCAB_SIZE)
        tgt_y = tgt_y.reshape(-1)

        loss = loss_fn(logits_seq, tgt_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


for epoch in range(MAX_EPOCHS):
    avg_loss = train_one_epoch(model, dataloader, optimizer, loss_fn, device=DEVICE)
    print(f"[Epoch {epoch+1}/{MAX_EPOCHS}] loss = {avg_loss:.4f}")


save_data = {
    "model_state_dict": model.state_dict(),
    "vocab2idx": vocab2idx,
    "idx2vocab": idx2vocab,
    "pad_idx": PAD_IDX,
    "bos_idx": BOS_IDX,
    "eos_idx": EOS_IDX,
    "embed_size": EMBED_SIZE,
    "hidden_size": HIDDEN_SIZE,
    "max_len": MAX_LEN
}
torch.save(save_data, MODEL_SAVE_PATH)
print(f"The model and dictionary are saved in {MODEL_SAVE_PATH}")
