In [None]:
import torch
import torch.nn as nn
import math
import os
from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader
from datasets import load_from_disk
import sys

In [11]:
# Dirs
cwd= os.getcwd()
data_dir= os.path.join(cwd, "../data/")
artifacts_dir= os.path.join(cwd, "../artifacts/")
src_dir= os.path.join(cwd, "../src/")
en_path= data_dir + 'UNPC.ar-en.en'
ar_path= data_dir + 'UNPC.ar-en.ar'
en_path= os.path.abspath(en_path)
ar_path= os.path.abspath(ar_path)

# Add src to path to import your modules
sys.path.append(src_dir)

from model import Transformer

In [5]:
class TranslationDataset(Dataset):
    def __init__(self, dataset_path, en_tokenizer_path, ar_tokenizer_path, max_seq_len= 512):
        """
        Args:
            dataset_path: Path to saved Hugging Face dataset
            en_tokenizer_path: Path to English tokenizer
            ar_tokenizer_path: Path to Arabic tokenizer
            max_seq_length: Maximum sequence length
        """
        # Load dataset
        self.dataset= load_from_disk(dataset_path)

        # Load tokenizers
        self.en_tokenizer= Tokenizer.from_file(en_tokenizer_path)
        self.ar_tokenizer= Tokenizer.from_file(ar_tokenizer_path)

        # Get special tokens
        self.pad_id= self.en_tokenizer.token_to_id("<PAD>")
        self.sos_id= self.en_tokenizer.token_to_id("<SOS>")
        self.eos_id= self.en_tokenizer.token_to_id("<EOS>")

        self.max_seq_length= max_seq_len

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item= self.dataset[idx]
        en_text= item['en']
        ar_text= item['ar']

        # Tokenize English (source)
        en_encoding= self.en_tokenizer.encode(en_text)
        en_ids= en_encoding.ids[:self.max_seq_length]

        # Tokenize Arabic (target) - add special tokens
        ar_encoding= self.ar_tokenizer.encode(ar_text)
        ar_ids= [self.sos_id] + ar_encoding.ids[:self.max_seq_length-2] + [self.sos_id]

        return {
            'en_ids': torch.tensor(en_ids, dtype=torch.long),
            'ar_ids': torch.tensor(ar_ids, dtype= torch.long),
            'en_text': en_text,
            'ar_text': ar_text
        }

def collate_fn(batch, pad_id):
    """Custom collate function to pad sequences and create masks"""
    en_ids= [item['en_ids'] for item in batch]
    ar_ids= [item['ar_ids'] for item in batch]

    # Pad Sequences
    en_ids_padded= torch.nn.utils.rnn.pad_sequence(
        en_ids, batch_first= True, padding_value= pad_id
    )
    ar_ids_padded= torch.nn.utils.rnn.pad_sequence(
        ar_ids, batch_first= True, padding_value= pad_id
    )

    # Create masks
    en_mask= (en_ids_padded != pad_id).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, src_len)
    ar_mask= (ar_ids_padded != pad_id).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, tgt_len)

    # Create causal mask for decoder
    tgt_len= ar_ids_padded.size(1)
    causal_mask= torch.tril(torch.ones(tgt_len, tgt_len)).bool().unsqueeze(0).unsqueeze(0) # (1, 1, tgt_len, tgt_len)

    return {
        'en_ids': en_ids_padded,
        'ar_ids': ar_ids_padded,
        'en_mask': en_mask,
        'ar_mask': ar_mask,
        'causal_mask': causal_mask,
        'en_texts': [item['en_text'] for item in batch],
        'ar_texts': [item['ar_text'] for item in batch]
    }

def create_data_loaders(data_dir_path, tokenizer_path, batch_size= 32, max_seq_length= 512):
    """
    Create Train and validation data loaders
    """
    # Dataset paths
    train_ds_path= os.path.join(data_dir_path + 'train_ds')
    val_ds_path= os.path.join(data_dir_path + 'val_ds') 
    # Tokenizer paths
    en_tokenizer_path= os.path.join(tokenizer_path + 'bpe_tokenizer_en.json')
    ar_tokenizer_path= os.path.join(tokenizer_path + 'bpe_tokenizer_ar.json')
    
    # Create datasets
    train_dataset= TranslationDataset(train_ds_path, en_tokenizer_path, ar_tokenizer_path, max_seq_length)
    val_dataset= TranslationDataset(val_ds_path, en_tokenizer_path, ar_tokenizer_path, max_seq_length)

    # Get pad_id from tokenizer
    en_tokenizer= Tokenizer.from_file(en_tokenizer_path)
    pad_id= en_tokenizer.token_to_id("<PAD>")

    # Create data loaders
    train_loader= DataLoader(
        train_dataset,
        batch_size= batch_size,
        shuffle= True,
        collate_fn= lambda batch: collate_fn(batch, pad_id),
        num_workers= 4,
        pin_memory= True
    )
    val_loader= DataLoader(
        val_dataset,
        batch_size= batch_size,
        shuffle= False,
        collate_fn= lambda batch: collate_fn(batch, pad_id),
        num_workers= 4,
        pin_memory= True
    )
    return train_loader, val_loader
    

In [None]:

# Create data loaders
train_loader, val_loader = create_data_loaders(data_dir, artifacts_dir, batch_size=32, max_seq_length=512)

# Test the data loader
batch = next(iter(train_loader))
print(f"Batch keys: {batch.keys()}")
print(f"English IDs shape: {batch['en_ids'].shape}")
print(f"Arabic IDs shape: {batch['ar_ids'].shape}")
print(f"English mask shape: {batch['en_mask'].shape}")
print(f"Causal mask shape: {batch['causal_mask'].shape}")

# Sample output
sample_idx = 0
print(f"\nSample English text: {batch['en_texts'][sample_idx]}")
print(f"Sample Arabic text: {batch['ar_texts'][sample_idx]}")
print(f"Sample English IDs: {batch['en_ids'][sample_idx]}")
print(f"Sample Arabic IDs: {batch['ar_ids'][sample_idx]}")

Batch keys: dict_keys(['en_ids', 'ar_ids', 'en_mask', 'ar_mask', 'causal_mask', 'en_texts', 'ar_texts'])
English IDs shape: torch.Size([32, 87])
Arabic IDs shape: torch.Size([32, 89])
English mask shape: torch.Size([32, 1, 1, 87])
Causal mask shape: torch.Size([1, 1, 89, 89])

Sample English text: 5. The independent expert also met with Mariano Fernández, the Special Representative of the Secretary-General and Head of the United Nations Stabilization Mission in Haiti (MINUSTAH), and with his deputies, Kevin Kennedy and Nigel Fischer. He wishes to thank all the members of their team who provided him with effective support.
Sample Arabic text: 5- واجتمع الخبير المستقبل أيضاً بالممثل الخاص للأمين العام ورئيس بعثة الأمم المتحدة لتحقيق الاستقرار في هايتي (بعثة الأمم المتحدة في هايتي)، ماريانو فرناندز، ونائبيه، كيفين كيندي، ونايجل فيشر، ويود أن يتوجه بالشكر إلى جميع أعضاء فريقهما لما قدموه من دعم فعال.
Sample English IDs: tensor([   22,    15,   828,  3125,  3507,  1035,  2981,   843, 21796,