In [1]:
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f"Project root added to path: {project_root}")
print(f"Current working directory: {Path.cwd()}")

Project root added to path: /home/mohamed-ashraf/Desktop/projects/Arabic-Diacritization
Current working directory: /home/mohamed-ashraf/Desktop/projects/Arabic-Diacritization/models/mohamed_ashraf


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm
from utils.utils import create_data_pipeline
from models.mohamed_ashraf.bilstm3 import BiLSTM

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [4]:
with open(project_root / "utils/letter2idx.pickle", "rb") as file:
    letter2idx = pickle.load(file)

with open(project_root / "utils/diacritic2id.pickle", "rb") as file:
    diacritic2id = pickle.load(file)

idx2letter = {value: key for key, value in letter2idx.items()}
idx2diacritic = {value: key for key, value in diacritic2id.items()}

print(letter2idx)
print(idx2letter)
print(diacritic2id)
print(idx2diacritic)

{'ÿ∏': 0, 'Ÿä': 1, 'ÿ∫': 2, 'ŸÜ': 3, 'ŸÇ': 4, 'ÿ∞': 5, 'ÿØ': 6, 'ÿÆ': 7, 'ÿ±': 8, 'ÿ∑': 9, 'Ÿâ': 10, 'ŸÖ': 11, 'ŸÑ': 12, '<PAD>': 13, 'ÿ™': 14, 'ÿ¨': 15, 'ÿ¢': 16, 'ÿß': 17, 'ÿ≥': 18, 'ÿ¶': 19, 'ÿπ': 20, 'ŸÅ': 21, 'ÿµ': 22, 'Ÿá': 23, 'ÿ≤': 24, 'ŸÉ': 25, 'ÿ¥': 26, 'ÿ£': 27, 'Ÿà': 28, 'ÿ®': 29, 'ÿ§': 30, 'ÿ∂': 31, 'ÿ©': 32, 'ÿ´': 33, 'ÿ°': 34, 'ÿ≠': 35, 'ÿ•': 36, ' ': 37}
{0: 'ÿ∏', 1: 'Ÿä', 2: 'ÿ∫', 3: 'ŸÜ', 4: 'ŸÇ', 5: 'ÿ∞', 6: 'ÿØ', 7: 'ÿÆ', 8: 'ÿ±', 9: 'ÿ∑', 10: 'Ÿâ', 11: 'ŸÖ', 12: 'ŸÑ', 13: '<PAD>', 14: 'ÿ™', 15: 'ÿ¨', 16: 'ÿ¢', 17: 'ÿß', 18: 'ÿ≥', 19: 'ÿ¶', 20: 'ÿπ', 21: 'ŸÅ', 22: 'ÿµ', 23: 'Ÿá', 24: 'ÿ≤', 25: 'ŸÉ', 26: 'ÿ¥', 27: 'ÿ£', 28: 'Ÿà', 29: 'ÿ®', 30: 'ÿ§', 31: 'ÿ∂', 32: 'ÿ©', 33: 'ÿ´', 34: 'ÿ°', 35: 'ÿ≠', 36: 'ÿ•', 37: ' '}
{'Ÿé': 0, 'Ÿã': 1, 'Ÿè': 2, 'Ÿå': 3, 'Ÿê': 4, 'Ÿç': 5, 'Ÿí': 6, 'Ÿë': 7, 'ŸëŸé': 8, 'ŸëŸã': 9, 'ŸëŸè': 10, 'ŸëŸå': 11, 'ŸëŸê': 12, 'ŸëŸç': 13, '': 14, '<PAD>': 15}
{0: 'Ÿé', 1: 'Ÿã', 2: 'Ÿè', 3: 'Ÿå', 4: 'Ÿê', 5: 'Ÿç', 6: 'Ÿí', 7: 'Ÿë', 8: 'ŸëŸé', 9: 'Ÿë

In [5]:
vocab_size = len(letter2idx) 
num_classes = len(diacritic2id)
print("Vocab size:", vocab_size)
print("Num classes:", num_classes)

Vocab size: 38
Num classes: 16


In [6]:
def train_model(model, train_loader, val_loader, epochs=10, learning_rate=0.001):
    criterion = nn.CrossEntropyLoss(ignore_index=diacritic2id['<PAD>'])
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
    
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    best_val_loss = float('inf')
    best_model_path = 'best_bilstm_model.pth'

    for epoch in range(epochs):
        model.train()

        total_train_loss = 0
        total_train_correct = 0
        total_train_tokens = 0

        with tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]') as pbar:
            for batch_X, batch_y, _, lengths in pbar:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                
                optimizer.zero_grad()

                outputs = model(batch_X, lengths)

                B, T, C = outputs.shape
                loss = criterion(outputs.view(B*T, C), batch_y.view(B*T))

                preds = outputs.argmax(dim=-1)
                mask = (batch_y != diacritic2id['<PAD>'])

                correct = (preds[mask] == batch_y[mask]).sum().item()
                total = mask.sum().item()

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                total_train_loss += loss.item()
                total_train_correct += correct
                total_train_tokens += total

                acc = correct / total if total > 0 else 0.0

                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{acc:.4f}'
                })

        avg_train_loss = total_train_loss / len(train_loader)
        avg_train_acc = total_train_correct / total_train_tokens

        model.eval()
        total_val_loss = 0
        total_val_correct = 0
        total_val_tokens = 0

        with torch.no_grad():
            with tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]') as pbar:
                for batch_X, batch_y, _, lengths in pbar:
                    batch_X, batch_y = batch_X.to(device), batch_y.to(device)

                    outputs = model(batch_X, lengths)
                    B, T, C = outputs.shape

                    loss = criterion(outputs.view(B*T, C), batch_y.view(B*T))

                    preds = outputs.argmax(dim=-1)
                    mask = (batch_y != diacritic2id['<PAD>'])

                    correct = (preds[mask] == batch_y[mask]).sum().item()
                    total = mask.sum().item()

                    total_val_loss += loss.item()
                    total_val_correct += correct
                    total_val_tokens += total

                    acc = correct / total if total > 0 else 0.0

                    pbar.set_postfix({
                        'Loss': f'{loss.item():.4f}',
                        'Acc': f'{acc:.4f}'
                    })

        avg_val_loss = total_val_loss / len(val_loader)
        avg_val_acc = total_val_correct / total_val_tokens

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
                'val_accuracy': avg_val_acc,
                'train_loss': avg_train_loss,
                'train_accuracy': avg_train_acc
            }, best_model_path)
            print(f"  ‚Ü≥ Best model saved! (val_loss: {best_val_loss:.4f})")

        scheduler.step(avg_val_loss)

        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}')
        print(f'  LR: {optimizer.param_groups[0]["lr"]:.6f}')

        train_losses.append(avg_train_loss)
        train_accuracies.append(avg_train_acc)
        val_losses.append(avg_val_loss)
        val_accuracies.append(avg_val_acc)

    return {
        'train_loss': train_losses,
        'train_accuracy': train_accuracies,
        'val_loss': val_losses,
        'val_accuracy': val_accuracies
    }

In [7]:
def pad_collate_fn(batch):
    x_batch, y_batch, mask_batch = zip(*batch)
    lengths_x = [len(x) for x in x_batch]
    x_padded = torch.nn.utils.rnn.pad_sequence(x_batch, batch_first=True, padding_value=letter2idx['<PAD>'])
    y_padded = torch.nn.utils.rnn.pad_sequence(y_batch, batch_first=True, padding_value=diacritic2id['<PAD>'])
    mask_spadded = torch.nn.utils.rnn.pad_sequence(mask_batch, batch_first=True, padding_value=0)
    return x_padded, y_padded, mask_spadded, torch.tensor(lengths_x, dtype=torch.long)

In [8]:
train_dataset, train_loader = create_data_pipeline(
    corpus_path=str(project_root / 'data/train.txt'), 
    letter2idx=letter2idx, 
    diacritic2idx=diacritic2id, 
    collate_fn=pad_collate_fn,
    batch_size=32
)

val_dataset, val_loader = create_data_pipeline(
    corpus_path=str(project_root / 'data/val.txt'), 
    letter2idx=letter2idx, 
    diacritic2idx=diacritic2id,
    collate_fn=pad_collate_fn,
    train=False, 
    batch_size=32
)

In [9]:
def print_data_statistics(train_dataset, val_dataset, train_loader, val_loader, letter2idx, diacritic2id):
    print("="*80)
    print(" " * 25 + "üìä DATA STATISTICS")
    print("="*80)
    
    train_lengths = [len(x) for x, _, _ in train_dataset]
    val_lengths = [len(x) for x, _, _ in val_dataset]
    
    print(f"\n{'Dataset Sizes:':<30}")
    print(f"  {'Training samples:':<28} {len(train_dataset):>10,}")
    print(f"  {'Validation samples:':<28} {len(val_dataset):>10,}")
    print(f"  {'Total samples:':<28} {len(train_dataset) + len(val_dataset):>10,}")
    print(f"  {'Train/Val ratio:':<28} {len(train_dataset)/len(val_dataset):>10.2f}")
    
    print(f"\n{'Batch Information:':<30}")
    print(f"  {'Training batches:':<28} {len(train_loader):>10,}")
    print(f"  {'Validation batches:':<28} {len(val_loader):>10,}")
    print(f"  {'Batch size:':<28} {train_loader.batch_size:>10}")
    
    print(f"\n{'Training Sequence Lengths:':<30}")
    print(f"  {'Min length:':<28} {min(train_lengths):>10}")
    print(f"  {'Max length:':<28} {max(train_lengths):>10}")
    print(f"  {'Mean length:':<28} {np.mean(train_lengths):>10.2f}")
    print(f"  {'Median length:':<28} {np.median(train_lengths):>10.0f}")
    print(f"  {'Std deviation:':<28} {np.std(train_lengths):>10.2f}")
    print(f"  {'25th percentile:':<28} {np.percentile(train_lengths, 25):>10.0f}")
    print(f"  {'75th percentile:':<28} {np.percentile(train_lengths, 75):>10.0f}")
    print(f"  {'95th percentile:':<28} {np.percentile(train_lengths, 95):>10.0f}")
    print(f"  {'99th percentile:':<28} {np.percentile(train_lengths, 99):>10.0f}")
    
    print(f"\n{'Validation Sequence Lengths:':<30}")
    print(f"  {'Min length:':<28} {min(val_lengths):>10}")
    print(f"  {'Max length:':<28} {max(val_lengths):>10}")
    print(f"  {'Mean length:':<28} {np.mean(val_lengths):>10.2f}")
    print(f"  {'Median length:':<28} {np.median(val_lengths):>10.0f}")
    print(f"  {'Std deviation:':<28} {np.std(val_lengths):>10.2f}")
    
    print(f"\n{'Vocabulary Information:':<30}")
    print(f"  {'Vocabulary size:':<28} {len(letter2idx):>10}")
    print(f"  {'Number of diacritics:':<28} {len(diacritic2id):>10}")
    print(f"  {'Special tokens:':<28} {'<PAD>, <UNK>':>10}")
    
    total_train_chars = sum(train_lengths)
    total_val_chars = sum(val_lengths)
    print(f"\n{'Total Characters:':<30}")
    print(f"  {'Training characters:':<28} {total_train_chars:>10,}")
    print(f"  {'Validation characters:':<28} {total_val_chars:>10,}")
    print(f"  {'Total characters:':<28} {total_train_chars + total_val_chars:>10,}")
    
    print(f"\n{'Memory Estimates (approx):':<30}")
    avg_train_len = np.mean(train_lengths)
    avg_val_len = np.mean(val_lengths)
    train_batch_mem = train_loader.batch_size * max(train_lengths) * 4 / (1024**2)  # 4 bytes per int
    val_batch_mem = val_loader.batch_size * max(val_lengths) * 4 / (1024**2)
    print(f"  {'Max train batch (input):':<28} {train_batch_mem:>9.2f} MB")
    print(f"  {'Max val batch (input):':<28} {val_batch_mem:>9.2f} MB")
    
    print(f"\n{'Sequence Length Distribution:':<30}")
    bins = [0, 50, 100, 200, 300, 400, 512]
    print(f"  {'Range':<15} {'Train':<15} {'Val':<15}")
    for i in range(len(bins) - 1):
        train_count = sum(1 for l in train_lengths if bins[i] < l <= bins[i+1])
        val_count = sum(1 for l in val_lengths if bins[i] < l <= bins[i+1])
        train_pct = (train_count / len(train_lengths)) * 100
        val_pct = (val_count / len(val_lengths)) * 100
        print(f"  {f'{bins[i]}-{bins[i+1]}':<15} {f'{train_count:,} ({train_pct:.1f}%)':<15} {f'{val_count:,} ({val_pct:.1f}%)':<15}")
    
    print("\n" + "="*80)
    print("‚úì Data statistics computed successfully!")
    print("="*80 + "\n")

In [10]:
print_data_statistics(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    train_loader=train_loader,
    val_loader=val_loader,
    letter2idx=letter2idx,
    diacritic2id=diacritic2id
)

                         üìä DATA STATISTICS

Dataset Sizes:                
  Training samples:               183,466
  Validation samples:               8,921
  Total samples:                  192,387
  Train/Val ratio:                  20.57

Batch Information:            
  Training batches:                 5,734
  Validation batches:                 279
  Batch size:                          32

Training Sequence Lengths:    
  Min length:                           3
  Max length:                         512
  Mean length:                      54.89
  Median length:                       33
  Std deviation:                    69.46
  25th percentile:                     16
  75th percentile:                     65
  95th percentile:                    178
  99th percentile:                    385

Validation Sequence Lengths:  
  Min length:                           3
  Max length:                         512
  Mean length:                      55.03
  Median length:            

In [11]:
model = BiLSTM(vocab_size=vocab_size, num_classes=num_classes).to(device)

print("Model architecture:")
print(model)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Model architecture:
BiLSTM(
  (embedding): Embedding(38, 256, padding_idx=13)
  (bilstm): LSTM(256, 256, num_layers=3, batch_first=True, dropout=0.2, bidirectional=True)
  (emb_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (lstm_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc1_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (fc2): Linear(in_features=256, out_features=16, bias=True)
)

Total parameters: 4,353,808
Trainable parameters: 4,353,808


In [12]:
print("\\nModel Summary:")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

history = train_model(model, train_loader, val_loader, epochs=10)

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_accuracy'], label='Training Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

torch.save({
    'model_state_dict': model.state_dict(),
    'vocab_size': vocab_size,
    'num_classes': num_classes,
    'history': history
}, 'bilstm_model.pth')

print("Model saved successfully as 'bilstm_model.pth'!")

\nModel Summary:
Total parameters: 4,353,808


Epoch 1/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5734/5734 [04:00<00:00, 23.83it/s, Loss=0.0904, Acc=0.9683]
Epoch 1/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 279/279 [00:03<00:00, 77.40it/s, Loss=0.1203, Acc=0.9641]


  ‚Ü≥ Best model saved! (val_loss: 0.0908)
Epoch 1/10:
  Train Loss: 0.1495, Train Acc: 0.9497
  Val Loss: 0.0908, Val Acc: 0.9699
  LR: 0.001000


Epoch 2/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5734/5734 [04:03<00:00, 23.59it/s, Loss=0.0772, Acc=0.9685]
Epoch 2/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 279/279 [00:03<00:00, 76.76it/s, Loss=0.1325, Acc=0.9567]


  ‚Ü≥ Best model saved! (val_loss: 0.0770)
Epoch 2/10:
  Train Loss: 0.0871, Train Acc: 0.9709
  Val Loss: 0.0770, Val Acc: 0.9749
  LR: 0.001000


Epoch 3/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5734/5734 [03:59<00:00, 23.91it/s, Loss=0.1232, Acc=0.9582]
Epoch 3/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 279/279 [00:03<00:00, 79.30it/s, Loss=0.1103, Acc=0.9622]


  ‚Ü≥ Best model saved! (val_loss: 0.0718)
Epoch 3/10:
  Train Loss: 0.0760, Train Acc: 0.9746
  Val Loss: 0.0718, Val Acc: 0.9770
  LR: 0.001000


Epoch 4/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5734/5734 [03:59<00:00, 23.92it/s, Loss=0.0397, Acc=0.9917]
Epoch 4/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 279/279 [00:03<00:00, 76.93it/s, Loss=0.1097, Acc=0.9687]


  ‚Ü≥ Best model saved! (val_loss: 0.0690)
Epoch 4/10:
  Train Loss: 0.0699, Train Acc: 0.9766
  Val Loss: 0.0690, Val Acc: 0.9779
  LR: 0.001000


Epoch 5/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5734/5734 [04:03<00:00, 23.56it/s, Loss=0.0690, Acc=0.9737]
Epoch 5/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 279/279 [00:03<00:00, 74.32it/s, Loss=0.1039, Acc=0.9668]


  ‚Ü≥ Best model saved! (val_loss: 0.0669)
Epoch 5/10:
  Train Loss: 0.0661, Train Acc: 0.9779
  Val Loss: 0.0669, Val Acc: 0.9785
  LR: 0.001000


Epoch 6/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5734/5734 [04:02<00:00, 23.61it/s, Loss=0.0620, Acc=0.9851]
Epoch 6/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 279/279 [00:03<00:00, 78.17it/s, Loss=0.0975, Acc=0.9696]


  ‚Ü≥ Best model saved! (val_loss: 0.0664)
Epoch 6/10:
  Train Loss: 0.0632, Train Acc: 0.9788
  Val Loss: 0.0664, Val Acc: 0.9791
  LR: 0.001000


Epoch 7/10 [Train]:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 5224/5734 [03:42<00:21, 23.43it/s, Loss=0.0351, Acc=0.9869]


KeyboardInterrupt: 