In [1]:
import torch
from torch.utils.data import DataLoader
from torch import nn
import pandas as pd
import torch.optim as optim
from tqdm import tqdm
from farasa.pos import FarasaPOSTagger

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
from DataSetClass import Parallel_Data
from Preprocessing import Preprocessor
from model import Encoder, Decoder, Seq2Seq

In [9]:
train_data = Parallel_Data("./preprocessed_train_data.pkl","./arabic_tokens.json","./english_tokens.json")
val_data = Parallel_Data("./preprocessed_val_data.pkl","./arabic_tokens.json","./english_tokens.json")
test_data = Parallel_Data("./preprocessed_test_data.pkl","./arabic_tokens.json","./english_tokens.json")

In [31]:
input_dim_arabic = len(train_data.arabic_tokens)
input_dim_postag = len(train_data.postags)
OUTPUT_DIM = len(train_data.english_tokens)
ENC_EMB_DIM = 100
DEC_EMB_DIM = 100
HID_DIM = 1024
N_LAYERS = 2
train_dataloader = DataLoader(train_data,16,shuffle=True)
val_dataloader = DataLoader(val_data,256,shuffle=True)
#test_dataloader = DataLoader(train_data,256,shuffle=True)

enc_arabic = Encoder(input_dim_arabic, ENC_EMB_DIM, HID_DIM, N_LAYERS,device)
enc_postag = Encoder(input_dim_postag, ENC_EMB_DIM, HID_DIM, N_LAYERS,device)

dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS)
model = Seq2Seq(enc_arabic, enc_postag, dec, device).to(device)

# Initialize token frequency counts
token_counts = torch.zeros(len(train_data.english_tokens), dtype=torch.long)

# Accumulate token frequencies from target batches
for _,trg_batch,_,_ in train_dataloader:
    trg = trg_batch.to("cpu")  # Ensure on CPU for bincount
    token_counts += torch.bincount(
        trg.flatten(), minlength=len(train_data.english_tokens)
    )

# Avoid division by zero for padding (index 0)
token_counts[0] = 0

# Compute inverse square root frequency weights
weights = 1.0 / torch.sqrt(token_counts.float() + 1e-5)
weights[0] = 0  # Padding should not contribute to the loss
# UNK token weight
weights[-1] = 0
# Normalize weights to keep loss scale reasonable
weights = weights / weights.mean()

# Define loss function with weighting
criterion = nn.CrossEntropyLoss(
    weight=weights.to(device),
    ignore_index=0,       # Padding index
    label_smoothing=0.1
)

optimizer = optim.Adam(model.parameters(),lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,10)

In [32]:
def train(model, dataloader, optimizer, criterion, clip):
    model.train()  
    epoch_loss = 0
    for src, trg, src_length , postags in dataloader:
        src, trg = src.to(device), trg.to(device)
        postags = postags.to(device)
        optimizer.zero_grad()
        
        output = model(src, trg,src_length, postags, teacher_forcing_ratio = 0.6)

        output_dim = output.shape[-1]

        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()  
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()

        epoch_loss += loss.item()
    
    return epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg, src_len, postags in dataloader:
            src, trg = src.to(device), trg.to(device)
            postags = postags.to(device)

            output = model(src, trg, src_len, postags, 0)
            output_dim = output.shape[-1]

            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()
            
    model.train()
    return epoch_loss / len(dataloader)

In [None]:
for epoch in tqdm(range(40), desc="Epochs"):

    train_loss = train(model, val_dataloader, optimizer, criterion, clip=1)
    val_loss = evaluate(model, val_dataloader, criterion)

    print(f"Epoch {epoch + 1:02}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
    
    scheduler.step()

Epochs:   0%|          | 0/40 [00:00<?, ?it/s]

In [19]:
def translate_sentence(sentence, src_vocab, postag_vocab, trg_vocab, model, device, max_len=50):
    postagger = FarasaPOSTagger()
    sequence = postagger.tag_segments(sentence)
    tokens = [item.tokens[0] for item in sequence]
    tags = [item.tags[0] for item in sequence ]

    numericalized_tokens = (
        [src_vocab["<s>"]]
        + [src_vocab.get(token, src_vocab["<UNK>"]) for token in tokens]
        + [src_vocab["</s>"]]
    )
    numericalized_tags = (
        [postag_vocab["<s>"]]
        + [src_vocab.get(tag, postag_vocab["<UNK>"]) for tag in tags]
        + [postag_vocab["</s>"]]
    )
    tensor_tokens = torch.tensor(numericalized_tokens).unsqueeze(0).to(device)  # shape: [1, seq_len]
    tensor_tags   = torch.tensor(numericalized_tags).unsqueeze(0).to(device)  # shape: [1, seq_len]

    srclen = torch.tensor([len(numericalized_tokens)],dtype=torch.int64).to(device)

    # Encode the source sentence
    with torch.no_grad():
        hidden_tokens, cell_tokens = model.encoder_arabic(tensor_tokens,srclen)
        hidden_tags, cell_tags = model.encoder_postag(tensor_tags,srclen)
        hidden = model.enc2dec(torch.cat((hidden_tokens, hidden_tags), dim=2))
        cell = model.enc2dec(torch.cat((cell_tokens, cell_tags), dim=2))
        
    # Initialize the decoder input with <SOS>
    trg_indexes = [trg_vocab["<s>"]]
    
    for _ in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        with torch.no_grad():
            output, hidden, cell = model.decoder(trg_tensor, hidden, cell)

        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)
        if pred_token == trg_vocab["</s>"]:
            break

    return trg_indexes[1:-1]

In [28]:
out = translate_sentence("ما هو متوسط ساعات اليوم ؟",train_data.arabic_tokens,train_data.postags\
                         ,train_data.english_tokens,model,device)

In [30]:
inv_trg_vocab = {i: w for w, i in train_data.english_tokens.items() }
translated_tokens = [inv_trg_vocab[idx] for idx in out]
translated_tokens

['what',
 'are',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the',
 'of',
 'the']