In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np 
from torch.utils.data import DataLoader, Dataset
import regex as re
import random

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")

In [2]:
# import dataset using pd
dtypes = {
  'en': 'str',
  'fr': 'str'
}
df = pd.read_csv("small_en-fr.csv", dtype=dtypes) 
all_text = ""



In [3]:
for en, fr in zip(df["en"][:1000], df["fr"][:1000]):
  all_text += str(en) + " " + str(fr) + " "

In [4]:
encoded = list(map(int, all_text.encode("utf-8")))

In [5]:
# tokenizer: build a list of tokens in each language. BPE utilized here. 
def get_stats(ids):
  counts = {}
  for pair in zip(ids, ids[1:]):
    counts[pair] = counts.get(pair, 0) + 1
  return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

In [6]:
# bpe hyperparameters: 
num_merges = 100 
replacements = {}  # order of merges matters
for i in range(num_merges):
  stats = get_stats(encoded)
  best = max(stats, key=stats.get)
  encoded = merge(encoded, best, 256 + i)
  replacements[best] = 256 + i

In [7]:
def encode(text):
  ids = list(map(int, text.encode("utf-8")))
  for pair in list(replacements):
    merged = merge(ids, pair, replacements[pair])
    ids = merged
  return ids
def decode(ids):
    decoded_ids = ids[:]
    for merge_id in range(256 + num_merges - 1, 255, -1):  # Go backwards through merge ids
        original_pair = None
        for pair, replacement in replacements.items():
            if replacement == merge_id:
                original_pair = pair
                break
        if original_pair is not None:
            i = 0
            while i < len(decoded_ids):
                if decoded_ids[i] == merge_id:
                    decoded_ids = decoded_ids[:i] + list(original_pair) + decoded_ids[i+1:]
                i += 1

    byte_array = bytearray(decoded_ids)
    return byte_array.decode("utf-8")


In [8]:
# encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.gru = nn.GRU(emb_dim, hidden_dim)

    def forward(self, src):
        embedded = self.embedding(src)  # embedded: [src len, batch size, emb dim]
        outputs, hidden = self.gru(embedded)  # hidden: [1, batch size, hidden dim]
        return hidden
        

In [17]:
# decoder
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.gru = nn.GRU(emb_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.output_dim = output_dim

    def forward(self, input, hidden):
        # input: [batch size]
        # hidden: [1, batch size, hidden dim]
        input = input.unsqueeze(0)  # input: [1, batch size]
        embedded = self.embedding(input)  # embedded: [1, batch size, emb dim]
        output, hidden = self.gru(embedded, hidden)
        prediction = self.fc(output.squeeze(0))  # prediction: [batch size, output dim]
        return prediction, hidden


In [18]:
# combined model from encoder and decoder
class Translator(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Translator, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src: [src len, batch size]
        # trg: [trg len, batch size]
        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        # tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        # last hidden state of the encoder is used as the initial hidden state of the decoder
        hidden = self.encoder(src)
        
        # first input to the decoder is the <sos> tokens
        input = trg[0,:]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1) 
            input = trg[t] if teacher_force else top1
        
        return outputs


In [19]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    source_batch, target_batch = zip(*batch)
    padding = num_merges + 255 + 1
    
    # Pad sequences with the padding token ID, here assumed to be 0
    padded_source_batch = pad_sequence(source_batch, batch_first=True, padding_value=padding)
    padded_target_batch = pad_sequence(target_batch, batch_first=True, padding_value=padding)
    
    return padded_source_batch, padded_target_batch

class TranslationDataset(Dataset):
    def __init__(self, dataframe, source_column, target_column, encode_fn):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame containing the translation pairs.
            source_column (str): Name of the column containing the source text.
            target_column (str): Name of the column containing the target text.
            encode_fn (callable): Function to encode text to numerical format.
        """
        self.dataframe = dataframe
        self.source_column = source_column
        self.target_column = target_column
        self.encode_fn = encode_fn

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        source_text = self.dataframe.iloc[idx][self.source_column]
        target_text = self.dataframe.iloc[idx][self.target_column]
        
        # Encode the source and target texts
        source_encoded = self.encode_fn(source_text)
        target_encoded = self.encode_fn(target_text)
        
        # Convert to PyTorch tensors
        source_tensor = torch.tensor(source_encoded, dtype=torch.long)
        target_tensor = torch.tensor(target_encoded, dtype=torch.long)
        
        return source_tensor, target_tensor


In [21]:
# training loop and hyperparameters
# Assume input_dim and output_dim are defined based on your vocabulary size
input_dim = 256 + num_merges  # Adjust based on your BPE merges
output_dim = 256 + num_merges  # Assuming the same vocabulary for src and trg for simplicity
emb_dim = 256
hidden_dim = 512
encoder = Encoder(emb_dim, emb_dim, hidden_dim)
decoder = Decoder(hidden_dim, emb_dim, hidden_dim)
model = Translator(encoder, decoder, device).to(device)

# Hyperparameters
learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=255+num_merges+1)  

dataset = TranslationDataset(dataframe=df, source_column='en', target_column='fr', encode_fn=encode)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for src, trg in dataloader:  # Assuming your dataloader yields batches of src and trg sentences
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        
        output = model(src, trg)
        
        # Reshape output to [batch size * trg len, output dim]
        # Reshape trg to [batch size * trg len]
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        
        loss = criterion(output, trg)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f'Epoch: {epoch+1}, Loss: {epoch_loss / len(dataloader)}')


RuntimeError: Expected hidden size (1, 330, 512), got [1, 238, 512]