In [3]:
import pandas as pd
from Models.AutoModel import get_model
from Training.Trainer import Trainer
from Training.TrainingArguments import TrainingArguments
from Tokenizers.Tokenizers import Callable_tokenizer
from Models.ModelArgs import ModelArgs

from utils import MT_Dataset, MyCollate, compute_bleu, get_parameters_info
import torch

In [4]:
train_csv_path = "out/data/ar-en_train.csv"
valid_csv_path = "out/data/ar-en_valid.csv"
tokenizer_path = "out/tokenizers/ar-en_tokenizer.model"
model_config_path = "Configurations/model_config.json"
training_config_path = "Configurations/training_config.json"

In [9]:
print("---------------------Starting Tokenizer Loading...---------------------")
tokenizer = Callable_tokenizer(tokenizer_path)
vocab_size = len(tokenizer)
print(f"Tokenizer length {vocab_size}")
print("Tokenizer Loading Done.")

---------------------Starting Tokenizer Loading...---------------------
Tokenizer length 8192
Tokenizer Loading Done.


In [6]:
print("---------------------Starting Data Loading...---------------------")
train_df = pd.read_csv(train_csv_path)
valid_df = pd.read_csv(valid_csv_path)

train_ds = MT_Dataset(input_sentences_list=train_df['ar'].to_list(),
                        target_sentences_list=train_df['en'].to_list(),
                        callable_tokenizer=tokenizer)

valid_ds = MT_Dataset(input_sentences_list=valid_df['ar'].to_list(),
                        target_sentences_list=valid_df['en'].to_list(),
                        callable_tokenizer=tokenizer)

mycollate = MyCollate(batch_first=True,
                        pad_value=tokenizer.get_tokenId('<pad>'))

print(f"Training data length {len(train_ds)}, Validation data length {len(valid_ds)}")
print(f"Source tokens shape: {train_ds[0][0].shape}, Target_fwd tokens shape {train_ds[0][1].shape}, Target_loss tokens shape {train_ds[0][2].shape}")
print("Data Loading Done.")

---------------------Starting Data Loading...---------------------
Training data length 591452, Validation data length 73932
Source tokens shape: torch.Size([5]), Target_fwd tokens shape torch.Size([10]), Target_loss tokens shape torch.Size([10])
Data Loading Done.


In [8]:
print("---------------------Parsing Model arguments...---------------------")
model_args = ModelArgs(config_path=model_config_path)
print(model_args)
print("Parsing Done.")

---------------------Parsing Model arguments...---------------------
ModelArgs(
model_type=transformer,
dim_embed=512,
dim_model=512,
dim_feedforward=2048,
num_layers=6,
dropout=0.3,
maxlen=512,
flash_attention=False
)
Parsing Done.


In [11]:
print("---------------------Loading the model...---------------------")
model = get_model(model_args, vocab_size)
names, tr, nontr = get_parameters_info(model=model)
print(f"{'Module':<25}{'Trainable':>15}{'Non-Trainable':>15}")
for n, ttp, ntp in zip(names, tr, nontr):
    print(f"{n:<25}{ttp:>15,}{ntp:>15,}")
print("Model Loading Done.")

---------------------Loading the model...---------------------
Module                         Trainable  Non-Trainable
embed_shared_src_trg_cls       4,194,304              0
positonal_shared_src_trg         262,144              0
dropout                                0              0
transformer_encoder           18,914,304              0
transformer_decoder           25,224,192              0
classifier                     4,202,496              0
TotalParams                   52,797,440              0
Model Loading Done.


In [12]:
print("---------------------Parsing Training arguments...---------------------")
training_args = TrainingArguments(training_config_path)
print(training_args)
print("Parsing Done.")

---------------------Parsing Training arguments...---------------------
TrainingArguments(
  save_models_dir='./out/models',
  save_plots_dir='./out/plots',
  learning_rate=0.0001,
  max_steps=500,
  seed=123,
  precision='high',
  device='cpu',
  batch_size=64,
  cpu_num_workers=4,
  weight_decay=0.01,
  onnx=False,
  run_name='experiment_01',
  pin_memory=True,
  warmup_steps=100,
  save_steps=100,
  eval_steps=100,
  torch_compile=False
)
Parsing Done.


In [44]:
print("---------------------Start training...---------------------")
trainer = Trainer(args=training_args, model=model,
                    train_ds=train_ds, valid_ds=valid_ds,
                    collator=mycollate, compute_metrics_func=compute_bleu)

train_losses, valid_losses = trainer.train()
print("Training Done.")

---------------------Start training...---------------------
Training Done.


In [78]:
# source, target_forward, target_loss = next(iter(trainer.train_loader))
# source.shape, target_forward.shape, target_loss.shape

(torch.Size([64, 40]), torch.Size([64, 47]), torch.Size([64, 47]))

In [81]:
# logits, loss =model.forward(source=source,
#               target_forward=target_forward,
#               target_loss=target_loss,
#               src_pad_tokenId=0, 
#               trg_pad_tokenId=0)
# candidates = torch.argmax(logits, dim=-1)

In [82]:
# def compute_bleu(references:torch.Tensor, candidates:torch.Tensor):
#     batch_size = candidates.size(0)
#     total_bleu = 0
#     smoothing = SmoothingFunction().method2  # Use smoothing to handle zero n-gram overlaps
#     for i in range(batch_size):
#         mask_i = references[i]!=0
#         candidate = candidates[i][mask_i].tolist()
#         references_one = [references[i][mask_i].tolist()]
#         bleu_score = sentence_bleu(references_one, candidate, weights=[0.33,0.33,0.33,0.0], smoothing_function=smoothing)
#         print(round(bleu_score, 4))
#         total_bleu += bleu_score
    
#     return  round(total_bleu / batch_size, 4)

In [84]:
# # # Test Case 1: Perfect Match
# # references = torch.tensor([[1, 2, 3, 4, 5, 6], [5, 6, 7, 7, 0, 0]])
# # candidates = torch.tensor([[1, 2, 3, 4, 5, 6], [5, 6, 7, 8, 9, 9]])

# bleu_score = compute_bleu(target_loss, candidates)
# print(f"Test Case 1 - BLEU Score: {bleu_score}")  # Expected Output: 1.0

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0.0223
0
0
0
0
0
0
0
0
0
0
0
0
0
Test Case 1 - BLEU Score: 0.0003
