In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import torch.nn.utils.rnn as rnn_utils
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Hyperparmeters

In [None]:
learning_rate = 5e-3
batch_size = 512
n_epochs=150
test_size = 0.01 #(1% of the data, around 700 samples)
seed = 42

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset

In [None]:
class TransliterationDataset(Dataset):
    def __init__(self, DATA_HUB='atlasia/ATAM'):
        # Load the dataset
        self.data = load_dataset(DATA_HUB)['train'].to_pandas().values.tolist()
        # Create a set of all unique characters in the source and target languages
        self.arabizi_chars = set(''.join([d[0] for d in self.data]))
        self.arabic_chars = set(''.join([d[1] for d in self.data]))
        # Create a dictionary mapping each character to a unique index
        self.char2idx_ary = {char: idx for idx, char in enumerate(self.arabizi_chars)}
        self.char2idx_ar = {char: idx for idx, char in enumerate(self.arabic_chars)}
        # Calculate the size of the vocabulary
        self.vocab_size_src = len(self.char2idx_ary)
        self.vocab_size_tgt = len(self.char2idx_ar)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        darija, darija_ar = self.data[idx]
        input_seq = [self.char2idx_ary[char] for char in darija]
        target_seq = [self.char2idx_ar[char] for char in darija_ar]
        input = torch.LongTensor(input_seq)
        target = torch.LongTensor(target_seq)
        return input, target

def collate_function(batch):
    # Extract inputs and targets from the batch
    inputs, targets = zip(*batch)
    
    # Concatenate all sequences in the batch to find the maximum length
    all_sequences = inputs + targets
    
    # Find the maximum length of sequences in the batch
    max_seq_length = max(len(seq) for seq in all_sequences)
    
    # Pad all sequences to the maximum length
    padded_inputs = rnn_utils.pad_sequence(inputs, batch_first=True, padding_value=0)
    padded_targets = rnn_utils.pad_sequence(targets, batch_first=True, padding_value=0)
    
    # Pad sequences to max_seq_length
    padded_inputs = torch.cat([padded_inputs, torch.zeros(padded_inputs.size(0), max_seq_length - padded_inputs.size(1), dtype=torch.long)], dim=1)
    padded_targets = torch.cat([padded_targets, torch.zeros(padded_targets.size(0), max_seq_length - padded_targets.size(1), dtype=torch.long)], dim=1)
    
    return padded_inputs, padded_targets

In [None]:
dataset = TransliterationDataset()
train_data, val_data = train_test_split(dataset, test_size=test_size, random_state=seed)

In [None]:
print(f'The training dataset has {len(train_data)} samples.')
print(f'The validation dataset has {len(val_data)} samples.')
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_function)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, collate_fn=collate_function)

# Model

In [None]:
class TransliterationModel(nn.Module):
    def __init__(self, vocab_size_src, vocab_size_tgt, d_model=128, nhead=2, num_encoder_layers=2, num_decoder_layers=2):
        super(TransliterationModel, self).__init__()
        self.embedding_src = nn.Embedding(vocab_size_src, d_model)
        self.embedding_tgt = nn.Embedding(vocab_size_tgt, d_model)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
        self.fc = nn.Linear(d_model, vocab_size_tgt)

    def forward(self, src, tgt):
        src = self.embedding_src(src)
        tgt = self.embedding_tgt(tgt)
        output = self.transformer(src, tgt)
        output = self.fc(output)
        return output

In [None]:
model = TransliterationModel(vocab_size_src=dataset.vocab_size_src, vocab_size_tgt=dataset.vocab_size_tgt).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def train(model, train_loader, criterion, optimizer, n_epochs=10):
    model.train()
    for epoch in tqdm(range(n_epochs)):
        running_loss = 0.0
        for inputs, targets in train_loader:
            optimizer.zero_grad()

            # Ensure that inputs and targets have the same length after padding
            max_seq_length = max(inputs.size(1), targets.size(1))
            inputs = torch.cat([inputs, torch.zeros(inputs.size(0), max_seq_length - inputs.size(1), dtype=torch.long)], dim=1)
            targets = torch.cat([targets, torch.zeros(targets.size(0), max_seq_length - targets.size(1) + 1, dtype=torch.long)], dim=1) # we add 1 to the target length to account for the shift in the decoder input
            
            # Adjust the slicing operation to ensure that the batch size remains the same
            outputs = model(inputs[:, :max_seq_length], targets[:, :-1])  # Exclude the last token from targets as input to the decoder

            # Reshape outputs and targets to (batch_size * seq_len, vocab_size_tgt) for loss calculation
            outputs = outputs.view(-1, dataset.vocab_size_tgt)
            targets = targets[:, 1:].contiguous().view(-1)  # Exclude the first token from targets for loss calculation

            loss = criterion(outputs, targets)
            
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            
            print(f"[INFO] Epoch {epoch+1}/{n_epochs}, Loss: {loss.item():.4f}, Running Loss: {running_loss:.4f}")
            print(f'-'*10)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss:.4f}")
        print(f'-'*10)

    torch.save(model.state_dict(), 'transliteration_transformer.pth')


In [None]:
# Start training
train(model, train_loader, criterion, optimizer, n_epochs)