In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np 
from torch.utils.data import DataLoader, Dataset
import regex as re

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")

In [4]:
# import dataset using pd
dtypes = {
  'en': 'str',
  'fr': 'str'
}
df = pd.read_csv("small_en-fr.csv", dtype=dtypes) 
all_text = ""



In [25]:
for en, fr in zip(df["en"][:1000], df["fr"][:1000]):
  all_text += str(en) + " " + str(fr) + " "

In [26]:
encoded = list(map(int, all_text.encode("utf-8")))

In [27]:
# tokenizer: build a list of tokens in each language. BPE utilized here. 
def get_stats(ids):
  counts = {}
  for pair in zip(ids, ids[1:]):
    counts[pair] = counts.get(pair, 0) + 1
  return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

In [28]:
# bpe hyperparameters: 
num_merges = 100 
replacements = {}  # order of merges matters
for i in range(num_merges):
  stats = get_stats(encoded)
  best = max(stats, key=stats.get)
  encoded = merge(encoded, best, 256 + i)
  replacements[best] = 256 + i

In [35]:
def encode(text):
  ids = list(map(int, text.encode("utf-8")))
  for pair in list(replacements):
    merged = merge(ids, pair, replacements[pair])
    ids = merged
  return ids
def decode(ids):
    decoded_ids = ids[:]
    for merge_id in range(256 + num_merges - 1, 255, -1):  # Go backwards through merge ids
        original_pair = None
        for pair, replacement in replacements.items():
            if replacement == merge_id:
                original_pair = pair
                break
        if original_pair is not None:
            i = 0
            while i < len(decoded_ids):
                if decoded_ids[i] == merge_id:
                    decoded_ids = decoded_ids[:i] + list(original_pair) + decoded_ids[i+1:]
                i += 1

    byte_array = bytearray(decoded_ids)
    return byte_array.decode("utf-8")


In [27]:
# encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.gru = nn.GRU(emb_dim, hidden_dim)

    def forward(self, src):
        embedded = self.embedding(src)  # embedded: [src len, batch size, emb dim]
        outputs, hidden = self.gru(embedded)  # hidden: [1, batch size, hidden dim]
        return hidden
        

In [28]:
# decoder
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.gru = nn.GRU(emb_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, input, hidden):
        # input: [batch size]
        # hidden: [1, batch size, hidden dim]
        input = input.unsqueeze(0)  # input: [1, batch size]
        embedded = self.embedding(input)  # embedded: [1, batch size, emb dim]
        output, hidden = self.gru(embedded, hidden)
        prediction = self.fc(output.squeeze(0))  # prediction: [batch size, output dim]
        return prediction, hidden


In [29]:
# combined model from encoder and decoder
class Translator(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Translator, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src: [src len, batch size]
        # trg: [trg len, batch size]
        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        # tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        # last hidden state of the encoder is used as the initial hidden state of the decoder
        hidden = self.encoder(src)
        
        # first input to the decoder is the <sos> tokens
        input = trg[0,:]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1) 
            input = trg[t] if teacher_force else top1
        
        return outputs
