In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn
import pandas as pd
import torch.optim as optim
from tqdm import tqdm
from farasa.pos import FarasaPOSTagger
import torch.nn.init as init
from DataSetClass import Parallel_Data
from Preprocessing import Preprocessor
from model import Encoder, Decoder, Seq2Seq

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Preparation for training

In [None]:
train_data = Parallel_Data("./preprocessed_train_data.pkl","./arabic_tokens.json","./english_tokens.json")
val_data = Parallel_Data("./preprocessed_val_data.pkl","./arabic_tokens.json","./english_tokens.json")
test_data = Parallel_Data("./preprocessed_test_data.pkl","./arabic_tokens.json","./english_tokens.json")

## Weight initialization using xavier to put the model on a good starting point

In [None]:
def init_weights(module):
    # LSTM layers
    if isinstance(module, nn.LSTM):
        for name, param in module.named_parameters():
            if 'weight_ih' in name:
                init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)
                # Set forget-gate bias to 1
                n = param.size(0)
                start, end = n // 4, n // 2
                param.data[start:end].fill_(1)

    # Vanilla RNN/GRU layers
    elif isinstance(module, (nn.RNN, nn.GRU)):
        for name, param in module.named_parameters():
            if 'weight_ih' in name:
                init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)

    # Linear layers (used in attention and decoder output)
    elif isinstance(module, nn.Linear):
        init.xavier_uniform_(module.weight.data)
        if module.bias is not None:
            module.bias.data.fill_(0)


## model parameters and data loading functions

In [None]:
input_dim_arabic = len(train_data.arabic_tokens)
input_dim_postag = len(train_data.postags)
OUTPUT_DIM = len(train_data.english_tokens)

ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 256

N_LAYERS = 1


train_dataloader = DataLoader(train_data,32,shuffle=True)
val_dataloader = DataLoader(val_data,256,shuffle=False)
test_dataloader = DataLoader(train_data,256,shuffle=False)

enc_arabic = Encoder(input_dim_arabic, ENC_EMB_DIM, HID_DIM, N_LAYERS,device)
enc_postag = Encoder(input_dim_postag, 128, HID_DIM, N_LAYERS,device)

dec = Decoder(
    output_dim=OUTPUT_DIM,
    emb_dim=DEC_EMB_DIM,
    hid_dim=HID_DIM,
    n_layers=N_LAYERS,
    enc_hid_dim=HID_DIM * 4,  
    attn_dim=64
)

enc_arabic.apply(init_weights)
enc_postag.apply(init_weights)
dec.apply(init_weights)

model = Seq2Seq(enc_arabic, enc_postag, dec, device).to(device)


token_counts = torch.zeros(len(train_data.english_tokens), dtype=torch.long)


for _,trg_batch,_,_ in train_dataloader:
    trg = trg_batch.to("cpu") 
    token_counts += torch.bincount(
        trg.flatten(), minlength=len(train_data.english_tokens)
    )

token_counts[0] = 0

weights = 1.0 / torch.sqrt(token_counts.float() + 1e-5)
weights[0] = 0  
weights[-1] = 0
weights = weights / weights.mean()

criterion = nn.CrossEntropyLoss(
    weight=weights.to(device),
    ignore_index=0,      
    label_smoothing=0.1
)

optimizer = optim.Adam(model.parameters(),lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.5,patience=2)

# attention plotter to monitor heat map changes

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

def plot_attention(attention, source_tokens, target_tokens, epoch=None, filename=None):
    
    fig, ax = plt.subplots(figsize=(12, 10))
    cax = ax.matshow(attention, cmap='viridis')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticks(np.arange(len(source_tokens)))
    ax.set_yticks(np.arange(len(target_tokens)))
    ax.set_xticklabels(source_tokens, rotation=90)
    ax.set_yticklabels(target_tokens)

    # Force label every token
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.title(f"Attention Weights (Epoch {epoch})" if epoch else "Attention Weights")
    plt.tight_layout()
    
    if filename:
        plt.savefig(filename, bbox_inches='tight')
    # plt.show()

## Train and Evaluation Functions

In [None]:
def train(model, dataloader, optimizer, criterion, clip):
    model.train()  
    epoch_loss = 0
    for src, trg, src_length , postags in dataloader:
        src, trg = src.to(device), trg.to(device)
        postags = postags.to(device)
        src_length = src_length.to(device)
        optimizer.zero_grad()
    
        
        output, _ = model(src, trg,src_length, postags, teacher_forcing_ratio = 0.5)

        output_dim = output.shape[-1]

        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()  

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()

        epoch_loss += loss.item()
    print("done")
    return epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion,epoch, return_attention ,sample_index=0):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            src, trg, src_len, postags = batch
            src_len = src_len.to(device)

            src, trg = src.to(device), trg.to(device)
            postags = postags.to(device)
    
            output, attentions  = model(src, trg, src_len, postags, teacher_forcing_ratio = 0,return_attentions = return_attention)
            output_dim = output.shape[-1]

            output = output[:, 1:].reshape(-1, output_dim)
            trg_flat = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg_flat)
            epoch_loss += loss.item()
              

    if i == len(dataloader) - 1 and attentions:
        # Convert to numpy and select sample
        attn_matrix = torch.stack(attentions).squeeze(1)[:, sample_index, :]
        sample_attentions = attn_matrix.numpy()
             
    inv_src_vocab = {i: w for w, i in train_data.arabic_tokens.items() }
    inv_trg_vocab = {i: w for w, i in train_data.english_tokens.items() }

    src_tokens = [inv_src_vocab[idx] for idx in src[sample_index].cpu().numpy() 
                 if idx not in [0, train_data.arabic_tokens["<s>"], train_data.arabic_tokens["</s>"]]]  

    trg_tokens = [inv_trg_vocab[idx] for idx in trg[sample_index].cpu().numpy() 
                 if idx not in [0, train_data.english_tokens["<s>"], train_data.english_tokens["</s>"]]]
    
    if sample_attentions is not None:
        plot_filename = f"attention_epoch_{epoch}.png"
        plot_attention(
            sample_attentions, 
            source_tokens=src_tokens,
            target_tokens=trg_tokens,
            epoch=epoch,
            filename=plot_filename
        )
    model.train()

    return epoch_loss / len(dataloader)

# Evaluation and Inference functions

In [None]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
import torch

# Ensure required NLTK data is downloaded
nltk.download('punkt', quiet=True)

smooth_fn = SmoothingFunction().method4

## Inference

In [None]:
def generate(model,src, src_length, postags, max_length, start_index, end_index):
        batch_size = src.size(0)
        device = src.device
        
        # Encode source sequences
        enc_outs_arabic, hidden_arabic, cell_arabic = model.encoder_arabic(src, src_length)
        enc_outs_postag, hidden_postag, cell_postag = model.encoder_postag(postags, src_length)
        
        # Combine encoder outputs
        combined_enc_outs = torch.cat((enc_outs_arabic, enc_outs_postag), dim=2)
        max_src_len = combined_enc_outs.size(1)
        
        # Create mask from source lengths
        mask = model.create_mask(src_length, max_src_len)
        
        # Initialize decoder states
        hidden = model.enc2dec(torch.cat((hidden_arabic, hidden_postag), dim=2))
        cell = model.enc2dec(torch.cat((cell_arabic, cell_postag), dim=2))
        
        # Initialize output tensor with SOS tokens
        output_ids = torch.full((batch_size, max_length), eos_index, dtype=torch.long, device=device)
        output_ids[:, 0] = sos_index
        
        # Track finished sequences
        unfinished = torch.ones(batch_size, dtype=torch.bool, device=device)
        
        # Track which sequences are active in current step
        active_mask = torch.arange(batch_size, device=device)
        
        # Autoregressive generation
        for t in range(1, max_length):
            # Get last predicted tokens for active sequences
            input = output_ids[active_mask, t-1]  # [current_batch_size]
            
            # Run decoder for active sequences
            decoder_output, hidden_step, cell_step, _ = model.decoder(
                input=input,
                hidden=hidden[:, active_mask, :],
                cell=cell[:, active_mask, :],
                encoder_outputs=combined_enc_outs[active_mask],
                mask=mask[active_mask]
            )
            
            # Greedy token selection
            next_tokens = decoder_output.argmax(dim=-1)
            output_ids[active_mask, t] = next_tokens
            
            # Update states for active sequences
            hidden[:, active_mask, :] = hidden_step
            cell[:, active_mask, :] = cell_step
            
            # Update which sequences are still active
            unfinished[active_mask] = (next_tokens != eos_index)
            active_mask = torch.nonzero(unfinished, as_tuple=False).squeeze(-1)
            
            # Early termination if no active sequences
            if active_mask.nelement() == 0:
                break
        
        return output_ids

## BLEU accuracy metric

In [None]:
def compute_dataset_bleu(model, dataloader, english_tokens, device, max_length=50, weights=(0.25, 0.25, 0.25, 0.25)):
    model.eval()
    references = []  
    hypotheses = []  

    # Get special token IDs
    start_id = english_tokens.get("<s>")
    end_id = english_tokens.get("</s>")
    
    # Create inverse vocabulary for decoding (with fallback for unknown tokens)
    inv_trg_vocab = {idx: token for token, idx in english_tokens.items()}
    
    # Smoothing function for BLEU
    smooth_fn = SmoothingFunction().method1

    with torch.no_grad():
        for src, trg, src_len, postags in dataloader:
            src, src_len = src.to(device), src_len.to(device)
            postags = postags.to(device)
            
            pred_ids = generate(
                model,
                src=src, 
                src_length=src_len, 
                postags=postags,
                max_length=max_length,
                start_index=start_id,
                end_index=end_id
            ) 

            # Process each example in the batch
            for i in range(pred_ids.size(0)):
                # Remove <s> and get tokens until </s> for reference
                ref_raw = trg[i].tolist()
                ref_tokens = []
                for tok_id in ref_raw:
                    if tok_id == start_id:
                        continue
                    if tok_id == end_id:
                        break
                    ref_tokens.append(tok_id)
                
                # Remove </s> and beyond for hypothesis
                hyp_raw = pred_ids[i].tolist()
                hyp_tokens = []
                for tok_id in hyp_raw:
                    if tok_id == end_id:
                        break
                    hyp_tokens.append(tok_id)
                
                # Convert token IDs to words
                ref_words = [inv_trg_vocab.get(idx, "<unk>") for idx in ref_tokens]
                hyp_words = [inv_trg_vocab.get(idx, "<unk>") for idx in hyp_tokens]
                
                references.append([ref_words])  # Wrap in list for corpus_bleu
                hypotheses.append(hyp_words)

    # Compute corpus-level BLEU
    return corpus_bleu(
        list_of_references=references,
        hypotheses=hypotheses,
        weights=weights,
        smoothing_function=smooth_fn
    )

# Tranining loop

In [None]:
for epoch in tqdm(range(40), desc="Epochs"):

    train_loss = train(model, train_dataloader, optimizer, criterion, clip=1)
    val_loss = evaluate(model, val_dataloader,criterion, epoch ,True)

    print(f"Epoch {epoch + 1:02}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
    bleu_score = compute_dataset_bleu(model,val_dataloader,train_data.english_tokens,device,val_data.max_length_english)
    print(bleu_score)
    # scheduler.step(bleu_score)

## testing sentences

In [None]:
def translate_sentence(sentence, src_vocab, postag_vocab, trg_vocab, model, device, max_len=50):
    # Tokenize and POS tag the sentence
    postagger = FarasaPOSTagger()
    sequence = postagger.tag_segments(sentence)
    tokens = [item.tokens[0] for item in sequence]
    tags = [item.tags[0] for item in sequence]

    # Numericalize tokens and tags
    numericalized_tokens = (
        [src_vocab["<s>"]]
        + [src_vocab.get(token, src_vocab["<UNK>"]) for token in tokens]
        + [src_vocab["</s>"]]
    )
    numericalized_tags = (
        [postag_vocab["<s>"]]
        + [postag_vocab.get(tag, postag_vocab["<UNK>"]) for tag in tags]
        + [postag_vocab["</s>"]]
    )
    
    # Convert to tensors
    tensor_tokens = torch.tensor(numericalized_tokens).unsqueeze(0).to(device)  # [1, seq_len]
    tensor_tags = torch.tensor(numericalized_tags).unsqueeze(0).to(device)      # [1, seq_len]
    src_len = torch.tensor([len(numericalized_tokens)]).to(device)
    
    # Create mask (all True since it's a single non-padded sequence)
    mask = torch.ones(1, len(numericalized_tokens), dtype=torch.bool).to(device)
    
    with torch.no_grad():
        # Get encoder outputs
        enc_outs_arabic, hidden_arabic, cell_arabic = model.encoder_arabic(tensor_tokens, src_len)
        enc_outs_postag, hidden_postag, cell_postag = model.encoder_postag(tensor_tags, src_len)
        
        # Combine encoder outputs
        combined_enc_outs = torch.cat((enc_outs_arabic, enc_outs_postag), dim=2)
        
        # Combine and project hidden states
        hidden = model.enc2dec(torch.cat((hidden_arabic, hidden_postag), dim=2))
        cell = model.enc2dec(torch.cat((cell_arabic, cell_postag), dim=2))
    
    # Initialize with "<s>" token
    trg_indexes = [trg_vocab["<s>"]]
    
    for _ in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        
        with torch.no_grad():
            # Use the updated decoder interface
            output, hidden, cell, _ = model.decoder(
                trg_tensor, 
                hidden, 
                cell,
                combined_enc_outs,
                mask
            )
        
        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)
        
        # Stop if EOS is generated
        if pred_token == trg_vocab["</s>"]:
            break
    

    return trg_indexes

In [None]:
out = translate_sentence("تعرف على الحزمة الأكثر استخدامًا في جميع ملفات مصدر <ENG>",train_data.arabic_tokens,train_data.postags\
                         ,train_data.english_tokens,model,device)

In [None]:
inv_trg_vocab = {i: w for w, i in train_data.english_tokens.items() }
translated_tokens = [inv_trg_vocab[idx] for idx in out]
translated_tokens

In [None]:
bleu_score = compute_dataset_bleu(model,test_dataloader,train_data.english_tokens,device,val_data.max_length_english)