In [1]:
# Cell 1: Imports
import numpy as np
import re
import os
import csv
import pickle
from tqdm import tqdm
from typing import List, Any, Dict, Tuple

# PyTorch for CNN and Bi-LSTM
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Skip-gram (Word2Vec)
from gensim.models import Word2Vec

# PyTorch CRF
from torchcrf import CRF

In [2]:
# Cell 2: Paths configuration
model_path = "../models/BiLSTMCRFModel.pth"
skipgram_model_path = "../models/SkipgramWordEmbeddings.model"
cnn_model_path = "../models/CNNCharEncoder.pth"
cnn_cache_path = "../models/cnn_word_embeddings_cache.pkl"
input_path = "../input/test_no_diacritics.txt"
output_path = "../output/output_bilstm_crf.txt"

In [3]:
# Cell 3: Constants and hyperparameters

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Global registries
DATASET_REGISTRY: dict[str, Any] = {}
MODEL_REGISTRY: dict[str, Any] = {}

# Skip-gram hyperparameters
SKIPGRAM_EMBEDDING_DIM = 100
WINDOW_SIZE = 5
MIN_COUNT = 1
SG = 1  # 1 for Skip-gram, 0 for CBOW
WORKERS = 4
SKIPGRAM_EPOCHS = 10

# CNN hyperparameters
CNN_CHAR_EMBEDDING_DIM = 30
CNN_NUM_FILTERS = 50
CNN_KERNEL_SIZES = [2, 3, 4]  # n-gram sizes
CNN_OUTPUT_DIM = CNN_NUM_FILTERS * len(CNN_KERNEL_SIZES)  # 150
CNN_BATCH_SIZE = 512  # Batch size for CNN inference

# Bi-LSTM hyperparameters
BILSTM_INPUT_DIM = SKIPGRAM_EMBEDDING_DIM + CNN_OUTPUT_DIM  # 100 + 150 = 250
BILSTM_HIDDEN_DIM = 256
BILSTM_NUM_LAYERS = 2
BILSTM_DROPOUT = 0.3

# Training hyperparameters
BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 0.001

# Data parameters
ARABIC_LETTERS = sorted(
    np.load('../data/utils/arabic_letters.pkl', allow_pickle=True))
DIACRITICS = sorted(np.load(
    '../data/utils/diacritics.pkl', allow_pickle=True))
PUNCTUATIONS = {".", "،", ":", "؛", "؟", "!", '"', "-"}

VALID_CHARS = set(ARABIC_LETTERS).union(
    set(DIACRITICS)).union(PUNCTUATIONS).union({" "})

CHAR2ID = {char: id for id, char in enumerate(ARABIC_LETTERS)}
CHAR2ID[" "] = len(ARABIC_LETTERS)
CHAR2ID["<PAD>"] = len(ARABIC_LETTERS) + 1
PAD = CHAR2ID["<PAD>"]
SPACE = CHAR2ID[" "]
ID2CHAR = {id: char for char, id in CHAR2ID.items()}
VOCAB_SIZE = len(CHAR2ID)

DIACRITIC2ID = np.load('../data/utils/diacritic2id.pkl', allow_pickle=True)
ID2DIACRITIC = {id: diacritic for diacritic, id in DIACRITIC2ID.items()}
NUM_TAGS = len(DIACRITIC2ID)
print(f"Number of diacritic classes: {NUM_TAGS}")

Using device: cuda
Number of diacritic classes: 15


In [4]:
# Cell 4: Registry functions

def register_dataset(name):
    def decorator(cls):
        DATASET_REGISTRY[name] = cls
        return cls
    return decorator


def generate_dataset(dataset_name: str, *args, **kwargs):
    try:
        dataset_cls = DATASET_REGISTRY[dataset_name]
    except KeyError:
        raise ValueError(f"Dataset '{dataset_name}' is not recognized.")
    return dataset_cls(*args, **kwargs)


def register_model(name):
    def decorator(cls):
        MODEL_REGISTRY[name] = cls
        return cls
    return decorator


def generate_model(model_name: str, *args, **kwargs):
    try:
        model_cls = MODEL_REGISTRY[model_name]
    except KeyError:
        raise ValueError(f"Model '{model_name}' is not recognized.")
    return model_cls(*args, **kwargs)

In [5]:
# Cell 5: Dataset class

@register_dataset("ArabicBiLSTMDataset")
class ArabicBiLSTMDataset:
    def __init__(self, file_path: str, skipgram_model: Word2Vec = None):
        self.skipgram_model = skipgram_model
        self.sentences_with_diacritics = self.load_data(file_path)
        self.sentences_without_diacritics = self.extract_text_without_diacritics(
            self.sentences_with_diacritics)
        
        # Tokenize into words for Skip-gram
        self.tokenized_sentences = [sentence.split() for sentence in self.sentences_without_diacritics]
        
        # Extract diacritics per character
        self.diacritics_per_sentence = [
            self.extract_diacritics(sentence) 
            for sentence in self.sentences_with_diacritics
        ]
        
        # Collect all unique words for CNN batch processing
        self.unique_words = set()
        for sentence in self.tokenized_sentences:
            self.unique_words.update(sentence)
        self.unique_words = list(self.unique_words)

    def __len__(self):
        return len(self.sentences_without_diacritics)

    def __getitem__(self, idx):
        return self.tokenized_sentences[idx], self.diacritics_per_sentence[idx]

    def load_data(self, file_path: str):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    line = re.sub(
                        f'[^{re.escape("".join(VALID_CHARS))}]', '', line)
                    line = re.sub(r'\s+', ' ', line)
                    sentences = re.split(
                        f'[{re.escape("".join(PUNCTUATIONS))}]', line)
                    sentences = [s.strip() for s in sentences if s.strip()]
                    data.extend(sentences)
        return np.array(data)

    def extract_text_without_diacritics(self, dataY):
        dataX = dataY.copy()
        for diacritic, _ in DIACRITIC2ID.items():
            dataX = np.char.replace(dataX, diacritic, '')
        return dataX

    def extract_diacritics(self, sentence: str):
        """Extract diacritics for each character in the sentence."""
        result = []
        i = 0
        n = len(sentence)
        on_char = False

        while i < n:
            ch = sentence[i]
            if ch in DIACRITICS:
                on_char = False
                if i+1 < n and sentence[i+1] in DIACRITICS:
                    combined = ch + sentence[i+1]
                    if combined in DIACRITIC2ID:
                        result.append(DIACRITIC2ID[combined])
                        i += 2
                        continue
                result.append(DIACRITIC2ID[ch])
            elif ch in CHAR2ID:
                if on_char:
                    result.append(DIACRITIC2ID[''])
                on_char = True
            i += 1
        if on_char:
            result.append(DIACRITIC2ID[''])
        return result

    def get_corpus_for_skipgram(self):
        """Return tokenized sentences for Skip-gram training."""
        return self.tokenized_sentences

In [6]:
# Cell 6: CNN Character Encoder

class CNNCharEncoder(nn.Module):
    """
    CNN-based character encoder that produces a fixed-size embedding for each word
    based on its character sequence.
    """
    def __init__(self, vocab_size, char_embedding_dim, num_filters, kernel_sizes):
        super(CNNCharEncoder, self).__init__()
        
        self.char_embedding = nn.Embedding(vocab_size, char_embedding_dim, padding_idx=PAD)
        
        # Multiple CNN layers with different kernel sizes to capture different n-grams
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=char_embedding_dim, 
                     out_channels=num_filters, 
                     kernel_size=k,
                     padding=k//2)
            for k in kernel_sizes
        ])
        
        self.output_dim = num_filters * len(kernel_sizes)
    
    def forward(self, char_ids):
        """
        Args:
            char_ids: Tensor of shape (batch_size, max_word_len)
        Returns:
            word_embedding: Tensor of shape (batch_size, output_dim)
        """
        # Embed characters: (batch, max_word_len, char_embedding_dim)
        embedded = self.char_embedding(char_ids)
        
        # Transpose for Conv1d: (batch, char_embedding_dim, max_word_len)
        embedded = embedded.transpose(1, 2)
        
        # Apply each conv layer and max-pool
        conv_outputs = []
        for conv in self.convs:
            conv_out = F.relu(conv(embedded))  # (batch, num_filters, seq_len)
            pooled = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)  # (batch, num_filters)
            conv_outputs.append(pooled)
        
        # Concatenate all conv outputs
        word_embedding = torch.cat(conv_outputs, dim=1)  # (batch, output_dim)
        return word_embedding


def word_to_char_ids(word: str, max_len: int = 20) -> List[int]:
    """Convert a word to a list of character IDs, padded to max_len."""
    char_ids = [CHAR2ID.get(c, PAD) for c in word if c in CHAR2ID]
    # Pad or truncate
    if len(char_ids) < max_len:
        char_ids = char_ids + [PAD] * (max_len - len(char_ids))
    else:
        char_ids = char_ids[:max_len]
    return char_ids

In [7]:
# Cell 7: Batch CNN embedding computation with caching

def compute_cnn_embeddings_batch(cnn_model: CNNCharEncoder, words: List[str], 
                                  batch_size: int = CNN_BATCH_SIZE) -> Dict[str, np.ndarray]:
    """
    Compute CNN embeddings for all words in batch on GPU.
    Returns a dictionary mapping word -> embedding.
    """
    print(f"Computing CNN embeddings for {len(words)} unique words on {DEVICE}...")
    cnn_model.eval()
    cnn_model.to(DEVICE)
    
    word_to_embedding = {}
    
    # Process in batches
    for i in tqdm(range(0, len(words), batch_size), desc="CNN batch processing"):
        batch_words = words[i:i+batch_size]
        
        # Convert words to char IDs
        batch_char_ids = [word_to_char_ids(word) for word in batch_words]
        batch_tensor = torch.tensor(batch_char_ids, dtype=torch.long).to(DEVICE)
        
        # Forward pass on GPU
        with torch.no_grad():
            embeddings = cnn_model(batch_tensor).cpu().numpy()
        
        # Store in cache
        for word, emb in zip(batch_words, embeddings):
            word_to_embedding[word] = emb
    
    print(f"Computed {len(word_to_embedding)} CNN word embeddings")
    return word_to_embedding


def save_cnn_cache(cache: Dict[str, np.ndarray], path: str):
    """Save CNN embedding cache to disk."""
    with open(path, 'wb') as f:
        pickle.dump(cache, f)
    print(f"CNN cache saved to {path}")


def load_cnn_cache(path: str) -> Dict[str, np.ndarray]:
    """Load CNN embedding cache from disk."""
    with open(path, 'rb') as f:
        cache = pickle.load(f)
    print(f"Loaded CNN cache with {len(cache)} words")
    return cache

In [8]:
# Cell 8: Skip-gram training and embedding functions

def train_skipgram_model(corpus: List[List[str]], save_path: str) -> Word2Vec:
    """Train a Skip-gram Word2Vec model on the corpus."""
    print("Training Skip-gram model...")
    model = Word2Vec(
        sentences=corpus,
        vector_size=SKIPGRAM_EMBEDDING_DIM,
        window=WINDOW_SIZE,
        min_count=MIN_COUNT,
        sg=SG,
        workers=WORKERS,
        epochs=SKIPGRAM_EPOCHS
    )
    model.save(save_path)
    print(f"Skip-gram model saved to {save_path}")
    return model


def load_skipgram_model(path: str) -> Word2Vec:
    """Load a pre-trained Skip-gram model."""
    return Word2Vec.load(path)


# Pre-cache Skip-gram lookups for speed
SKIPGRAM_CACHE: Dict[str, np.ndarray] = {}

def get_skipgram_embedding(skipgram_model: Word2Vec, word: str) -> np.ndarray:
    """Get Skip-gram embedding with caching."""
    if word not in SKIPGRAM_CACHE:
        if word in skipgram_model.wv:
            SKIPGRAM_CACHE[word] = skipgram_model.wv[word]
        else:
            SKIPGRAM_CACHE[word] = np.zeros(SKIPGRAM_EMBEDDING_DIM)
    return SKIPGRAM_CACHE[word]

In [9]:
# Cell 9: Bi-LSTM + CRF Model

@register_model("BiLSTMCRFModel")
class BiLSTMCRFModel(nn.Module):
    """
    Bi-LSTM + CRF Model for Arabic Diacritization.
    
    Architecture:
    1. Input: Pre-computed embeddings (Skip-gram word + CNN char) per character
    2. Bi-LSTM: Captures bidirectional context across the sequence
    3. Linear: Projects LSTM hidden states to tag space
    4. CRF: Models transition probabilities between diacritic tags
    
    The Bi-LSTM makes predictions, and the CRF layer validates them by:
    - Learning which tag transitions are valid/likely
    - Enforcing global sequence-level constraints during decoding
    """
    def __init__(self, input_dim, hidden_dim, num_tags, num_layers=2, dropout=0.3):
        super(BiLSTMCRFModel, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.num_tags = num_tags
        
        # Optional: Project input embeddings
        self.input_projection = nn.Linear(input_dim, hidden_dim)
        
        # Bi-LSTM layer
        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # Linear layer to project LSTM output to tag space
        # Bi-LSTM outputs hidden_dim * 2 (forward + backward)
        self.hidden2tag = nn.Linear(hidden_dim * 2, num_tags)
        
        # CRF layer for sequence-level tag decoding
        self.crf = CRF(num_tags, batch_first=True)
        
        self.dropout = nn.Dropout(dropout)
    
    def _get_lstm_features(self, x):
        """
        Get emission scores from Bi-LSTM.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
        Returns:
            emissions: Tensor of shape (batch_size, seq_len, num_tags)
        """
        # Project input
        x = self.input_projection(x)
        x = self.dropout(x)
        
        # Bi-LSTM
        lstm_out, _ = self.lstm(x)
        lstm_out = self.dropout(lstm_out)
        
        # Project to tag space
        emissions = self.hidden2tag(lstm_out)
        
        return emissions
    
    def forward(self, x, tags, mask=None):
        """
        Compute CRF loss for training.
        
        Args:
            x: Input features (batch_size, seq_len, input_dim)
            tags: True tags (batch_size, seq_len)
            mask: Mask tensor (batch_size, seq_len), 1 for valid positions
        Returns:
            loss: Negative log-likelihood loss
        """
        emissions = self._get_lstm_features(x)
        
        # CRF computes negative log-likelihood
        # We negate it because CRF returns log-likelihood, and we want to minimize loss
        loss = -self.crf(emissions, tags, mask=mask, reduction='mean')
        
        return loss
    
    def decode(self, x, mask=None):
        """
        Decode the best tag sequence using Viterbi algorithm.
        
        Args:
            x: Input features (batch_size, seq_len, input_dim)
            mask: Mask tensor (batch_size, seq_len)
        Returns:
            best_tags: List of lists containing best tag sequences
        """
        emissions = self._get_lstm_features(x)
        
        # Viterbi decode
        best_tags = self.crf.decode(emissions, mask=mask)
        
        return best_tags

In [10]:
# Cell 10: Feature extraction for Bi-LSTM

def extract_sentence_features(skipgram_model: Word2Vec, cnn_cache: Dict[str, np.ndarray],
                               tokenized_sentence: List[str]) -> np.ndarray:
    """
    Extract combined Skip-gram + CNN features for each character in a sentence.
    
    For each character, we use the embedding of its parent word.
    Returns array of shape (num_chars, input_dim)
    """
    features = []
    
    for word in tokenized_sentence:
        # Get word-level embeddings
        sg_emb = get_skipgram_embedding(skipgram_model, word)
        cnn_emb = cnn_cache.get(word, np.zeros(CNN_OUTPUT_DIM))
        
        # Concatenate Skip-gram and CNN embeddings
        word_emb = np.concatenate([sg_emb, cnn_emb])
        
        # Assign same embedding to each character in the word
        for char in word:
            if char in CHAR2ID and char != ' ':
                features.append(word_emb)
    
    return np.array(features) if features else np.array([]).reshape(0, BILSTM_INPUT_DIM)

In [11]:
# Cell 11: Batch Collation Function

def collate_batch(batch_data: List[Tuple[np.ndarray, List[int]]]):
    """
    Collate function for DataLoader.
    Pads sequences to the same length within a batch.
    
    Args:
        batch_data: List of (features, tags) tuples
    Returns:
        features_padded: Tensor of shape (batch_size, max_seq_len, input_dim)
        tags_padded: Tensor of shape (batch_size, max_seq_len)
        mask: Tensor of shape (batch_size, max_seq_len)
    """
    features_list = [item[0] for item in batch_data]
    tags_list = [item[1] for item in batch_data]
    
    # Find max sequence length in batch
    max_len = max(len(f) for f in features_list)
    
    batch_size = len(features_list)
    
    # Pad features and tags
    features_padded = np.zeros((batch_size, max_len, BILSTM_INPUT_DIM))
    tags_padded = np.zeros((batch_size, max_len), dtype=np.int64)
    mask = np.zeros((batch_size, max_len), dtype=np.uint8)
    
    for i, (feat, tags) in enumerate(zip(features_list, tags_list)):
        seq_len = len(feat)
        features_padded[i, :seq_len, :] = feat
        tags_padded[i, :seq_len] = tags
        mask[i, :seq_len] = 1
    
    return (
        torch.tensor(features_padded, dtype=torch.float32),
        torch.tensor(tags_padded, dtype=torch.long),
        torch.tensor(mask, dtype=torch.bool)
    )

In [12]:
# Cell 12: Train function

def train(model: BiLSTMCRFModel, train_dataset, skipgram_model: Word2Vec,
          cnn_cache: Dict[str, np.ndarray], model_path: str):
    """
    Train the Bi-LSTM + CRF model.
    """
    print("Preparing training data...")
    
    # Extract features for all sentences
    training_data = []
    for idx in tqdm(range(len(train_dataset)), desc="Extracting features"):
        tokenized_sentence, diacritics = train_dataset[idx]
        features = extract_sentence_features(skipgram_model, cnn_cache, tokenized_sentence)
        
        # Ensure features and diacritics have the same length
        if len(features) == len(diacritics) and len(features) > 0:
            training_data.append((features, diacritics))
    
    print(f"Prepared {len(training_data)} valid training samples")
    
    # Create DataLoader
    train_loader = DataLoader(
        training_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_batch
    )
    
    # Setup optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    model.to(DEVICE)
    model.train()
    
    for epoch in range(NUM_EPOCHS):
        total_loss = 0
        num_batches = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for features, tags, mask in progress_bar:
            features = features.to(DEVICE)
            tags = tags.to(DEVICE)
            mask = mask.to(DEVICE)
            
            # Forward pass
            loss = model(features, tags, mask)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Average Loss: {avg_loss:.4f}")
    
    # Save model
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

In [13]:
# Cell 13: Evaluate function

def evaluate(model: BiLSTMCRFModel, val_dataset, skipgram_model: Word2Vec,
             cnn_cache: Dict[str, np.ndarray]):
    """
    Evaluate the Bi-LSTM + CRF model on validation data.
    """
    print("Preparing validation data...")
    
    # Extract features for all sentences
    val_data = []
    last_char_info = []  # Track which positions are last char of word
    
    for idx in tqdm(range(len(val_dataset)), desc="Extracting validation features"):
        tokenized_sentence, diacritics = val_dataset[idx]
        features = extract_sentence_features(skipgram_model, cnn_cache, tokenized_sentence)
        
        if len(features) == len(diacritics) and len(features) > 0:
            val_data.append((features, diacritics))
            
            # Track last character positions
            last_chars = []
            for word in tokenized_sentence:
                word_chars = [c for c in word if c in CHAR2ID and c != ' ']
                for i, _ in enumerate(word_chars):
                    last_chars.append(i == len(word_chars) - 1)
            last_char_info.append(last_chars)
    
    print(f"Prepared {len(val_data)} valid validation samples")
    
    # Create DataLoader
    val_loader = DataLoader(
        val_data,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_batch
    )
    
    model.to(DEVICE)
    model.eval()
    
    all_predictions = []
    all_targets = []
    all_masks = []
    
    print("Running predictions...")
    with torch.no_grad():
        for features, tags, mask in tqdm(val_loader, desc="Evaluating"):
            features = features.to(DEVICE)
            mask = mask.to(DEVICE)
            
            # Decode using CRF
            predictions = model.decode(features, mask)
            
            # Store results
            all_predictions.extend(predictions)
            all_targets.extend(tags.numpy().tolist())
            all_masks.extend(mask.cpu().numpy().tolist())
    
    # Calculate accuracies
    total_correct = 0
    total_tokens = 0
    total_correct_ending = 0
    total_tokens_ending = 0
    total_correct_without_ending = 0
    total_tokens_without_ending = 0
    
    sample_idx = 0
    for preds, targets, mask in zip(all_predictions, all_targets, all_masks):
        last_chars = last_char_info[sample_idx] if sample_idx < len(last_char_info) else []
        
        for i, (pred, target, m) in enumerate(zip(preds, targets, mask)):
            if m:  # Only count valid positions
                total_tokens += 1
                if pred == target:
                    total_correct += 1
                
                is_last_char = last_chars[i] if i < len(last_chars) else False
                if is_last_char:
                    total_tokens_ending += 1
                    if pred == target:
                        total_correct_ending += 1
                else:
                    total_tokens_without_ending += 1
                    if pred == target:
                        total_correct_without_ending += 1
        
        sample_idx += 1
    
    val_accuracy = (total_correct / total_tokens) * 100 if total_tokens > 0 else 0
    val_accuracy_ending = (total_correct_ending / total_tokens_ending) * 100 if total_tokens_ending > 0 else 0
    val_accuracy_without_ending = (total_correct_without_ending / total_tokens_without_ending) * 100 if total_tokens_without_ending > 0 else 0
    
    print(f"Validation Accuracy (Overall): {val_accuracy:.2f}%")
    print(f"Validation Accuracy (Without Last Character): {val_accuracy_without_ending:.2f}%")
    print(f"Validation Accuracy (Last Character): {val_accuracy_ending:.2f}%")

In [14]:
# Cell 14: Predict function

def predict(model: BiLSTMCRFModel, skipgram_model: Word2Vec,
            cnn_cache: Dict[str, np.ndarray], sentence: str) -> List[int]:
    """
    Predict diacritics for a single sentence.
    """
    # Clean sentence
    clean_sentence = sentence
    for diacritic in DIACRITICS:
        clean_sentence = clean_sentence.replace(diacritic, '')
    
    tokenized = clean_sentence.split()
    
    # Handle new words not in CNN cache
    for word in tokenized:
        if word not in cnn_cache:
            cnn_cache[word] = np.zeros(CNN_OUTPUT_DIM)
    
    features = extract_sentence_features(skipgram_model, cnn_cache, tokenized)
    
    if len(features) == 0:
        return []
    
    # Prepare input
    features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    mask = torch.ones(1, len(features), dtype=torch.bool).to(DEVICE)
    
    model.eval()
    with torch.no_grad():
        predictions = model.decode(features_tensor, mask)[0]
    
    return predictions

In [15]:
# Cell 15: Infer function

def infer(model: BiLSTMCRFModel, skipgram_model: Word2Vec,
          cnn_cache: Dict[str, np.ndarray], input_path: str, output_path: str):
    """
    Run inference on an input file and save results.
    """
    with open(input_path, 'r', encoding='utf-8') as f:
        input_data = f.readlines()
    
    output_list = []
    output_csv = [["ID", "Label"]]
    current_id = 0
    
    for sentence in tqdm(input_data, desc="Inference"):
        sentence = sentence.strip()
        if not sentence:
            output_list.append("")
            continue
        
        clean_sentence = sentence
        for diacritic in DIACRITICS:
            clean_sentence = clean_sentence.replace(diacritic, '')
        
        predictions = predict(model, skipgram_model, cnn_cache, clean_sentence)
        
        diacritized_sentence = ""
        pred_idx = 0
        for char in clean_sentence:
            diacritized_sentence += char
            if char in ARABIC_LETTERS and pred_idx < len(predictions):
                diacritic_id = predictions[pred_idx]
                diacritic = ID2DIACRITIC.get(diacritic_id, '')
                diacritized_sentence += diacritic
                output_csv.append([current_id, diacritic_id])
                current_id += 1
                pred_idx += 1
        
        output_list.append(diacritized_sentence)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for line in output_list:
            f.write(line + '\n')
    
    output_path_csv = os.path.splitext(output_path)[0] + ".csv"
    with open(output_path_csv, "w", newline="", encoding="utf-8") as file:
        writer = csv.writer(file)
        writer.writerows(output_csv)
    
    print(f"Output saved to {output_path} and {output_path_csv}")

In [16]:
# Cell 16: Load training dataset
train_dataset = generate_dataset("ArabicBiLSTMDataset", "../data/train.txt")
print(f"Loaded {len(train_dataset)} training sentences")
print(f"Found {len(train_dataset.unique_words)} unique words")

Loaded 186315 training sentences
Found 105795 unique words


In [17]:
# Cell 17: Train or load Skip-gram model
if os.path.exists(skipgram_model_path):
    print("Loading existing Skip-gram model...")
    skipgram_model = load_skipgram_model(skipgram_model_path)
else:
    corpus = train_dataset.get_corpus_for_skipgram()
    skipgram_model = train_skipgram_model(corpus, skipgram_model_path)

Training Skip-gram model...
Skip-gram model saved to ../models/SkipgramWordEmbeddings.model


In [18]:
# Cell 18: Initialize CNN Character Encoder
cnn_model = CNNCharEncoder(
    vocab_size=VOCAB_SIZE,
    char_embedding_dim=CNN_CHAR_EMBEDDING_DIM,
    num_filters=CNN_NUM_FILTERS,
    kernel_sizes=CNN_KERNEL_SIZES
).to(DEVICE)

print(f"CNN Character Encoder initialized on {DEVICE}")
print(f"CNN output dimension: {cnn_model.output_dim}")

CNN Character Encoder initialized on cuda
CNN output dimension: 150


In [19]:
# Cell 19: Pre-compute CNN embeddings for all unique words

if os.path.exists(cnn_cache_path):
    print("Loading existing CNN cache...")
    cnn_cache = load_cnn_cache(cnn_cache_path)
else:
    cnn_cache = compute_cnn_embeddings_batch(cnn_model, train_dataset.unique_words)
    save_cnn_cache(cnn_cache, cnn_cache_path)

# Save CNN model weights
torch.save(cnn_model.state_dict(), cnn_model_path)
print(f"CNN model saved to {cnn_model_path}")

Computing CNN embeddings for 105795 unique words on cuda...


CNN batch processing: 100%|██████████| 207/207 [00:00<00:00, 451.26it/s]


Computed 105795 CNN word embeddings
CNN cache saved to ../models/cnn_word_embeddings_cache.pkl
CNN model saved to ../models/CNNCharEncoder.pth


In [20]:
# Cell 20: Create Bi-LSTM + CRF model
bilstm_crf_model = generate_model(
    "BiLSTMCRFModel",
    input_dim=BILSTM_INPUT_DIM,
    hidden_dim=BILSTM_HIDDEN_DIM,
    num_tags=NUM_TAGS,
    num_layers=BILSTM_NUM_LAYERS,
    dropout=BILSTM_DROPOUT
)

print(f"Bi-LSTM + CRF Model created")
print(f"Input dimension: {BILSTM_INPUT_DIM}")
print(f"Hidden dimension: {BILSTM_HIDDEN_DIM}")
print(f"Number of tags: {NUM_TAGS}")
print(bilstm_crf_model)

Bi-LSTM + CRF Model created
Input dimension: 250
Hidden dimension: 256
Number of tags: 15
BiLSTMCRFModel(
  (input_projection): Linear(in_features=250, out_features=256, bias=True)
  (lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (hidden2tag): Linear(in_features=512, out_features=15, bias=True)
  (crf): CRF(num_tags=15)
  (dropout): Dropout(p=0.3, inplace=False)
)


In [21]:
# Cell 21: Train the model
train(bilstm_crf_model, train_dataset, skipgram_model, cnn_cache, model_path)

Preparing training data...


Extracting features: 100%|██████████| 186315/186315 [00:04<00:00, 39388.45it/s]


Prepared 20019 valid training samples


Epoch 1/10: 100%|██████████| 626/626 [00:04<00:00, 128.68it/s, loss=1.4810]


Epoch 1/10, Average Loss: 2.7939


Epoch 2/10: 100%|██████████| 626/626 [00:04<00:00, 143.43it/s, loss=1.5066]


Epoch 2/10, Average Loss: 1.6933


Epoch 3/10: 100%|██████████| 626/626 [00:04<00:00, 147.37it/s, loss=1.3519]


Epoch 3/10, Average Loss: 1.4253


Epoch 4/10: 100%|██████████| 626/626 [00:04<00:00, 146.42it/s, loss=0.5862]


Epoch 4/10, Average Loss: 1.2667


Epoch 5/10: 100%|██████████| 626/626 [00:04<00:00, 143.03it/s, loss=1.4435]


Epoch 5/10, Average Loss: 1.1443


Epoch 6/10: 100%|██████████| 626/626 [00:04<00:00, 143.13it/s, loss=0.7545]


Epoch 6/10, Average Loss: 1.0600


Epoch 7/10: 100%|██████████| 626/626 [00:04<00:00, 145.30it/s, loss=0.2166]


Epoch 7/10, Average Loss: 0.9869


Epoch 8/10: 100%|██████████| 626/626 [00:04<00:00, 138.62it/s, loss=0.4995]


Epoch 8/10, Average Loss: 0.9036


Epoch 9/10: 100%|██████████| 626/626 [00:04<00:00, 145.76it/s, loss=0.1705]


Epoch 9/10, Average Loss: 0.8622


Epoch 10/10: 100%|██████████| 626/626 [00:04<00:00, 141.55it/s, loss=0.5806]


Epoch 10/10, Average Loss: 0.8079
Model saved to ../models/BiLSTMCRFModel.pth


In [22]:
# Cell 22: Load validation dataset and update CNN cache
val_dataset = generate_dataset("ArabicBiLSTMDataset", "../data/val.txt")

# Add validation words to cache if not present
new_words = [w for w in val_dataset.unique_words if w not in cnn_cache]
if new_words:
    print(f"Computing CNN embeddings for {len(new_words)} new validation words...")
    new_cache = compute_cnn_embeddings_batch(cnn_model, new_words)
    cnn_cache.update(new_cache)

Computing CNN embeddings for 2319 new validation words...
Computing CNN embeddings for 2319 unique words on cuda...


CNN batch processing: 100%|██████████| 5/5 [00:00<00:00, 737.42it/s]

Computed 2319 CNN word embeddings





In [23]:
# Cell 23: Evaluate the model
evaluate(bilstm_crf_model, val_dataset, skipgram_model, cnn_cache)

Preparing validation data...


Extracting validation features: 100%|██████████| 9068/9068 [00:00<00:00, 33428.53it/s]


Prepared 964 valid validation samples
Running predictions...


Evaluating: 100%|██████████| 31/31 [00:00<00:00, 319.58it/s]

Validation Accuracy (Overall): 90.12%
Validation Accuracy (Without Last Character): 90.39%
Validation Accuracy (Last Character): 89.21%





In [25]:
# Cell 24: Run inference (optional)
infer(bilstm_crf_model, skipgram_model, cnn_cache, input_path, output_path)

Inference: 100%|██████████| 199/199 [00:02<00:00, 93.31it/s] 

Output saved to ../output/output_bilstm_crf.txt and ../output/output_bilstm_crf.csv



