In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import csv
import numpy as np
import re
from typing import List, Any

# PyTorch CRF
from torch.cuda.amp import autocast, GradScaler
from torchcrf import CRF

In [2]:
model_path = f"../models/ArabicBiLSTMCRFModel.pth"
input_path = f"../input/test_no_diacritics.txt"
output_path = f"../output/output_bilstm_crf_main.txt"

In [None]:

# Configurations
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Global registries
DATASET_REGISTRY: dict[str, Any] = {}
MODEL_REGISTRY: dict[str, Any] = {}

# Model hyperparameters
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
BATCH_SIZE = 128
NUM_EPOCHS = 6
LEARNING_RATE = 0.001
NUM_LAYERS = 3
DROPOUT = 0.2

# Data parameters
ARABIC_LETTERS = sorted(
    np.load('../data/utils/arabic_letters.pkl', allow_pickle=True))
DIACRITICS = sorted(np.load(
    '../data/utils/diacritics.pkl', allow_pickle=True))
PUNCTUATIONS = {".", "،", ":", "؛", "؟", "!", '"', "-"}

VALID_CHARS = set(ARABIC_LETTERS).union(
    set(DIACRITICS)).union(PUNCTUATIONS).union({" "})

CHAR2ID = {char: id for id, char in enumerate(ARABIC_LETTERS)}
CHAR2ID[" "] = len(ARABIC_LETTERS)
CHAR2ID["<PAD>"] = len(ARABIC_LETTERS) + 1
PAD = CHAR2ID["<PAD>"]
SPACE = CHAR2ID[" "]
ID2CHAR = {id: char for char, id in CHAR2ID.items()}

DIACRITIC2ID = np.load('../data/utils/diacritic2id.pkl', allow_pickle=True)
ID2DIACRITIC = {id: diacritic for diacritic, id in DIACRITIC2ID.items()}

In [4]:

def register_dataset(name):
    def decorator(cls):
        DATASET_REGISTRY[name] = cls
        return cls
    return decorator


def generate_dataset(dataset_name: str, *args, **kwargs):
    try:
        dataset_cls = DATASET_REGISTRY[dataset_name]
    except KeyError:
        raise ValueError(f"Dataset '{dataset_name}' is not recognized.")
    return dataset_cls(*args, **kwargs)



def register_model(name):
    def decorator(cls):
        MODEL_REGISTRY[name] = cls
        return cls
    return decorator


def generate_model(model_name: str, *args, **kwargs):
    try:
        model_cls = MODEL_REGISTRY[model_name]
    except KeyError:
        raise ValueError(f"Model '{model_name}' is not recognized.")
    return model_cls(*args, **kwargs)

In [5]:

@register_dataset("ArabicDataset")
class ArabicDataset(Dataset):
    def __init__(self, file_path: str):
        self.data_X, self.data_Y = self.generate_tensor_data(file_path)

    def __len__(self):
        return len(self.data_X)

    def __getitem__(self, idx):
        return self.data_X[idx], self.data_Y[idx]

    def generate_tensor_data(self, data_path: str):
        data_Y = self.load_data(data_path)
        data_X = self.extract_text_without_diacritics(data_Y)

        encoded_data_X, encoded_data_Y = self.encode_data(data_X, data_Y)
        data_X = torch.tensor(
            encoded_data_X, dtype=torch.int64)
        data_Y = torch.tensor(
            encoded_data_Y, dtype=torch.int64)

        return data_X, data_Y

    def load_data(self, file_path: str):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    line = re.sub(
                        f'[^{re.escape("".join(VALID_CHARS))}]', '', line)
                    line = re.sub(r'\s+', ' ', line)
                    sentences = re.split(
                        f'[{re.escape("".join(PUNCTUATIONS))}]', line)
                    sentences = [s.strip() for s in sentences if s.strip()]
                    data.extend(sentences)

        return np.array(data)

    def extract_text_without_diacritics(self, dataY):
        dataX = dataY.copy()
        for diacritic, _ in DIACRITIC2ID.items():
            dataX = np.char.replace(
                dataX, diacritic, '')
        return dataX

    def encode_data(self, dataX: List[str], dataY: List[str]):
        encoded_data_X = []
        for sentence in dataX:
            encoded_data_X.append([CHAR2ID[char]
                                   for char in sentence if char in CHAR2ID])
        encoded_data_Y = []
        for sentence in dataY:
            encoded_data_Y.append(self.extract_diacritics(sentence))

        max_sentence_len = max(len(sentence) for sentence in encoded_data_X)
        padded_dataX = np.full(
            (len(encoded_data_X), max_sentence_len), PAD, dtype=np.int64)
        for i, seq in enumerate(encoded_data_X):
            padded_dataX[i, :len(seq)] = seq

        padded_dataY = np.full(
            (len(encoded_data_Y), max_sentence_len), 0, dtype=np.int64)  # Use 0 instead of PAD for CRF
        for i, seq in enumerate(encoded_data_Y):
            padded_dataY[i, :len(seq)] = seq

        return padded_dataX, padded_dataY

    def extract_diacritics(self, sentence: str):
        result = []
        i = 0
        n = len(sentence)
        on_char = False

        while i < n:
            ch = sentence[i]
            if ch in DIACRITICS:
                on_char = False
                # check if next char forms a stacked diacritic
                if i+1 < n and sentence[i+1] in DIACRITICS:
                    combined = ch + sentence[i+1]
                    if combined in DIACRITIC2ID:
                        result.append(DIACRITIC2ID[combined])
                        i += 2
                        continue
                result.append(DIACRITIC2ID[ch])
            elif ch in CHAR2ID:
                if on_char:
                    result.append(DIACRITIC2ID[''])
                on_char = True
            i += 1
        if on_char:
            result.append(DIACRITIC2ID[''])
        return result

In [6]:

@register_model("BiLSTMCRFArabicModel")
class BiLSTMCRFArabicModel(nn.Module):
    """
    Bi-LSTM + CRF Model for Arabic Diacritization.
    Same as LSTMArabicModel but with CRF layer for sequence-level constraints.
    """
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, PAD):
        super(BiLSTMCRFArabicModel, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.PAD = PAD
        
        # Character embedding layer
        self.embedding = nn.Embedding(
            vocab_size, embedding_dim, padding_idx=PAD)
        
        # Bi-LSTM layer
        self.lstm = nn.LSTM(embedding_dim, hidden_dim,
                            batch_first=True, bidirectional=True,
                            num_layers=NUM_LAYERS, dropout=DROPOUT)
        
        # Linear layer to project LSTM output to tag space
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        
        # CRF layer for sequence-level decoding
        self.crf = CRF(output_dim, batch_first=True)

    def _get_lstm_features(self, x):
        """
        Get emission scores from Bi-LSTM.
        Args:
            x: Input tensor (batch_size, seq_len)
        Returns:
            emissions: Tensor (batch_size, seq_len, output_dim)
        """
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        emissions = self.fc(lstm_out)
        return emissions

    def forward(self, x, tags=None, mask=None):
        """
        Forward pass.
        During training (tags provided): returns CRF loss
        During inference (tags=None): returns best tag sequence
        """
        emissions = self._get_lstm_features(x)
        
        if tags is not None:
            # Training mode: return negative log-likelihood loss
            if mask is None:
                mask = (x != self.PAD)
            loss = -self.crf(emissions, tags, mask=mask, reduction='mean')
            return loss
        else:
            # Inference mode: decode best sequence
            if mask is None:
                mask = (x != self.PAD)
            return self.crf.decode(emissions, mask=mask)

    def decode(self, x, mask=None):
        """Decode the best tag sequence using Viterbi algorithm."""
        emissions = self._get_lstm_features(x)
        if mask is None:
            mask = (x != self.PAD)
        return self.crf.decode(emissions, mask=mask)

In [7]:
def train(model: nn.Module, train_dataset: Dataset, val_dataset: Dataset, model_path: str, patience: int = 3):
    """
    Train the model with validation monitoring and early stopping.
    
    Args:
        model: The BiLSTM-CRF model
        train_dataset: Training dataset
        val_dataset: Validation dataset
        model_path: Path to save the best model
        patience: Number of epochs to wait for improvement before stopping (0 = disabled)
    """
    train_data_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True
    )
    
    val_data_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = GradScaler()  # For mixed precision training

    if torch.cuda.is_available():
        model = model.cuda()

    best_val_acc = 0.0
    epochs_without_improvement = 0
    
    for epoch in range(NUM_EPOCHS):
        # Training phase
        total_loss = 0
        num_batches = 0

        model.train()
        for train_X, train_Y in tqdm(train_data_loader, desc=f"Training Epoch {epoch + 1}"):
            train_X = train_X.to(DEVICE)
            train_Y = train_Y.to(DEVICE)
            mask = (train_X != PAD)

            optimizer.zero_grad()

            # Mixed precision forward pass
            with autocast():
                loss = model(train_X, tags=train_Y, mask=mask)

            # Mixed precision backward pass
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        print(f'Epoch {epoch + 1} | Train Loss: {avg_loss:.4f}')
        
        # Validation phase
        model.eval()
        total_correct = 0
        total_tokens = 0
        
        with torch.no_grad():
            for val_X, val_Y in tqdm(val_data_loader, desc=f"Validation Epoch {epoch + 1}"):
                val_X = val_X.to(DEVICE)
                val_Y = val_Y.to(DEVICE)
                mask = (val_X != PAD)

                # Decode using CRF
                predictions_list = model.decode(val_X, mask=mask)
                
                # Convert to tensor
                batch_size, seq_len = val_X.shape
                prediction = torch.full_like(val_Y, 0)
                for i, pred_seq in enumerate(predictions_list):
                    prediction[i, :len(pred_seq)] = torch.tensor(pred_seq, device=DEVICE)

                padding_mask = (val_X == PAD)
                everything_mask = ~padding_mask

                total_correct += ((prediction == val_Y) & everything_mask).sum().item()
                total_tokens += everything_mask.sum().item()

        val_accuracy = (total_correct / total_tokens) * 100
        print(f'Epoch {epoch + 1} | Validation Accuracy: {val_accuracy:.2f}%')
        
        # Save best model
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            epochs_without_improvement = 0
            torch.save(model.state_dict(), model_path)
            print(f'  → New best model saved! (Accuracy: {val_accuracy:.2f}%)')
        else:
            epochs_without_improvement += 1
            print(f'  → No improvement ({epochs_without_improvement}/{patience})')
        
        # Early stopping
        if patience > 0 and epochs_without_improvement >= patience:
            print(f'\nEarly stopping triggered after {epoch + 1} epochs')
            print(f'Best validation accuracy: {best_val_acc:.2f}%')
            break
    
    print(f'\nTraining completed!')
    print(f'Best validation accuracy: {best_val_acc:.2f}%')
    print(f'Model saved to {model_path}')

In [8]:
def evaluate(model: torch.nn.Module, val_dataset: torch.utils.data.Dataset):

    val_data_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    if torch.cuda.is_available():
        model = model.cuda()

    total_correct_without_ending = 0
    total_tokens_without_ending = 0
    total_correct_ending = 0
    total_tokens_ending = 0
    total_correct = 0
    total_tokens = 0
    model.eval()
    with torch.no_grad():

        for val_X, val_Y in tqdm(val_data_loader):
            val_X = val_X.to(DEVICE)
            val_Y = val_Y.to(DEVICE)
            mask = (val_X != PAD)

            # Decode using CRF
            predictions_list = model.decode(val_X, mask=mask)
            
            # Convert list of predictions to tensor for comparison
            batch_size, seq_len = val_X.shape
            prediction = torch.full_like(val_Y, 0)
            for i, pred_seq in enumerate(predictions_list):
                prediction[i, :len(pred_seq)] = torch.tensor(pred_seq, device=DEVICE)

            padding_mask = (val_X == PAD)  # Identify padding from inputs, not labels
            shifted = torch.roll(val_X, shifts=-1, dims=1)
            end_of_word_mask = (shifted == SPACE) | (shifted == PAD)

            last_char_mask = end_of_word_mask & (~padding_mask)
            rest_of_word_mask = (~end_of_word_mask) & (~padding_mask)
            everything_mask = ~padding_mask

            total_correct_ending += ((prediction == val_Y)
                                     & last_char_mask).sum().item()
            total_tokens_ending += last_char_mask.sum().item()

            total_correct_without_ending += ((prediction == val_Y) &
                                             rest_of_word_mask).sum().item()
            total_tokens_without_ending += rest_of_word_mask.sum().item()

            total_correct += ((prediction == val_Y) &
                              everything_mask).sum().item()
            total_tokens += everything_mask.sum().item()

        val_accuracy = (total_correct / total_tokens) * 100
        val_accuracy_without_ending = (total_correct_without_ending /
                                       total_tokens_without_ending) * 100
        val_accuracy_ending = (total_correct_ending /
                               total_tokens_ending) * 100
        print(
            f"Validation Accuracy (Overall): {val_accuracy:.2f}%\n" +
            f"Validation Accuracy (Without Last Character): {val_accuracy_without_ending:.2f}%\n" +
            f"Validation Accuracy (Last Character): {val_accuracy_ending:.2f}%\n")

In [9]:
def predict(model, encoded_sentence):
    input_tensor = torch.tensor(
        [encoded_sentence], dtype=torch.int64).to(DEVICE)
    mask = (input_tensor != PAD)
    with torch.no_grad():
        predictions_list = model.decode(input_tensor, mask=mask)
    return np.array(predictions_list[0])

In [10]:
def infer(model, model_path, input_path, output_path, text_path=None):
    """
    Run inference on input data and generate diacritized output.

    Args:
        model: The model to use for inference
        model_path: Path to model weights
        input_path: Path to input file (CSV or TXT)
        output_path: Path to output file
        text_path: Path to text file (required for CSV input to get full context)
    """
    model_state_dict = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(model_state_dict)

    # Check if input is CSV format (submission format) or text format
    is_csv_input = input_path.endswith(".csv")

    if is_csv_input:
        # Read CSV with id,line_number,letter format (may have case_ending column)
        with open(input_path, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            rows = list(reader)

        # Build a set of IDs we need to output predictions for
        target_ids = {int(row["id"]) for row in rows}

        # Read the text file to get full sentences for context
        if text_path is None:
            text_path = os.path.join(
                os.path.dirname(input_path), "dataset_no_diacritics.txt"
            )

        with open(text_path, "r", encoding="utf-8") as f:
            input_lines = f.readlines()

        # Store predictions with their IDs
        output_csv = [["ID", "label"]]
        output_list = []
        current_id = 0  # Global character ID counter

        model.eval()
        for sentence in input_lines:
            encoded_sentence = [CHAR2ID[char] for char in sentence if char in CHAR2ID]

            if len(encoded_sentence) == 0:
                output_list.append("")
                continue

            predictions = predict(model, encoded_sentence)

            diacritized_sentence = ""
            pred_idx = 0
            for char in sentence:
                if char in CHAR2ID:
                    diacritic_id = int(predictions[pred_idx])
                    diacritic = ID2DIACRITIC[diacritic_id]
                    pred_idx += 1

                    # Only output for Arabic letters that are in target IDs
                    if char in ARABIC_LETTERS:
                        if current_id in target_ids:
                            output_csv.append([current_id, diacritic_id])
                        current_id += 1

                    diacritized_sentence += char + diacritic
                else:
                    diacritized_sentence += char

            output_list.append(diacritized_sentence.strip())
    else:
        # Original text file format
        with open(input_path, "r", encoding="utf-8") as f:
            input_data = f.readlines()

        output_list = []
        output_csv = [["ID", "label"]]
        current_id = 0

        model.eval()
        for sentence in input_data:
            encoded_sentence = [CHAR2ID[char] for char in sentence if char in CHAR2ID]

            if len(encoded_sentence) == 0:
                output_list.append("")
                continue

            predictions = predict(model, encoded_sentence)

            diacritized_sentence = ""
            pred_idx = 0
            for char in sentence:
                if char in CHAR2ID:
                    diacritic_id = int(predictions[pred_idx])
                    diacritic = ID2DIACRITIC[diacritic_id]
                    pred_idx += 1
                    if char in ARABIC_LETTERS:
                        output_csv.append([current_id, diacritic_id])
                        current_id += 1
                    diacritized_sentence += char + diacritic
                else:
                    diacritized_sentence += char

            output_list.append(diacritized_sentence.strip())

    # Write diacritized text output
    with open(output_path, "w", encoding="utf-8") as f:
        for line in output_list:
            f.write(line + "\n")

    # Write CSV submission output
    output_path_csv = os.path.splitext(output_path)[0] + ".csv"
    with open(output_path_csv, "w", newline="", encoding="utf-8") as file:
        writer = csv.writer(file)
        writer.writerows(output_csv)

In [11]:
train_dataset = generate_dataset("ArabicDataset", "../data/train.txt")

In [12]:
val_dataset = generate_dataset("ArabicDataset", "../data/val.txt")

In [13]:
model = generate_model(
    model_name="BiLSTMCRFArabicModel",
    vocab_size=len(CHAR2ID),
    embedding_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    output_dim=len(DIACRITIC2ID),
    PAD=PAD
)

In [14]:
# Uncomment to load pre-trained model
model_state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [15]:
# Train with validation monitoring
# patience=3: Stop if no improvement for 3 epochs (set to 0 to disable)
# patience=0: Train for all NUM_EPOCHS without early stopping
train(model, train_dataset, val_dataset, model_path, patience=3)

  scaler = GradScaler()  # For mixed precision training
  with autocast():
Training Epoch 1: 100%|██████████| 1456/1456 [17:15<00:00,  1.41it/s]


Epoch 1 | Train Loss: 3.0516


Validation Epoch 1: 100%|██████████| 71/71 [00:14<00:00,  5.02it/s]


Epoch 1 | Validation Accuracy: 97.80%
  → New best model saved! (Accuracy: 97.80%)

Training completed!
Best validation accuracy: 97.80%
Model saved to ../models/ArabicBiLSTMCRFModel.pth


In [16]:
evaluate(model, val_dataset)

100%|██████████| 71/71 [00:14<00:00,  5.05it/s]

Validation Accuracy (Overall): 97.80%
Validation Accuracy (Without Last Character): 98.33%
Validation Accuracy (Last Character): 95.75%






In [18]:
infer(model, model_path, "../input/test_no_diacritics.csv", output_path)