In [1]:
# Cell 1: Imports
import numpy as np
import re
import os
import csv
import pickle
from tqdm import tqdm
from typing import List, Any, Dict, Tuple

# Skip-gram (Word2Vec)
from gensim.models import Word2Vec

# CRF
import sklearn_crfsuite
from sklearn_crfsuite import metrics

In [2]:
# Cell 2: Paths configuration
model_path = "../models/SkipgramCRFModel.pkl"
skipgram_model_path = "../models/SkipgramWordEmbeddings.model"
input_path = "../input/test_no_diacritics.txt"
output_path = "../output/output_crf.txt"

In [3]:
# Cell 3: Constants and hyperparameters

# Global registries
DATASET_REGISTRY: dict[str, Any] = {}
MODEL_REGISTRY: dict[str, Any] = {}

# Skip-gram hyperparameters
EMBEDDING_DIM = 100
WINDOW_SIZE = 5
MIN_COUNT = 1
SG = 1  # 1 for Skip-gram, 0 for CBOW
WORKERS = 4
SKIPGRAM_EPOCHS = 10

# CRF hyperparameters
CRF_ALGORITHM = 'lbfgs'
CRF_C1 = 0.1  # L1 regularization
CRF_C2 = 0.1  # L2 regularization
CRF_MAX_ITERATIONS = 100

# Context window for CRF features
CONTEXT_WINDOW = 2

# Data parameters
ARABIC_LETTERS = sorted(
    np.load('../data/utils/arabic_letters.pkl', allow_pickle=True))
DIACRITICS = sorted(np.load(
    '../data/utils/diacritics.pkl', allow_pickle=True))
PUNCTUATIONS = {".", "،", ":", "؛", "؟", "!", '"', "-"}

VALID_CHARS = set(ARABIC_LETTERS).union(
    set(DIACRITICS)).union(PUNCTUATIONS).union({" "})

CHAR2ID = {char: id for id, char in enumerate(ARABIC_LETTERS)}
CHAR2ID[" "] = len(ARABIC_LETTERS)
CHAR2ID["<PAD>"] = len(ARABIC_LETTERS) + 1
PAD = CHAR2ID["<PAD>"]
SPACE = CHAR2ID[" "]
ID2CHAR = {id: char for char, id in CHAR2ID.items()}

DIACRITIC2ID = np.load('../data/utils/diacritic2id.pkl', allow_pickle=True)
ID2DIACRITIC = {id: diacritic for diacritic, id in DIACRITIC2ID.items()}

In [4]:
# Cell 4: Registry functions

def register_dataset(name):
    def decorator(cls):
        DATASET_REGISTRY[name] = cls
        return cls
    return decorator


def generate_dataset(dataset_name: str, *args, **kwargs):
    try:
        dataset_cls = DATASET_REGISTRY[dataset_name]
    except KeyError:
        raise ValueError(f"Dataset '{dataset_name}' is not recognized.")
    return dataset_cls(*args, **kwargs)


def register_model(name):
    def decorator(cls):
        MODEL_REGISTRY[name] = cls
        return cls
    return decorator


def generate_model(model_name: str, *args, **kwargs):
    try:
        model_cls = MODEL_REGISTRY[model_name]
    except KeyError:
        raise ValueError(f"Model '{model_name}' is not recognized.")
    return model_cls(*args, **kwargs)

In [5]:
# Cell 5: Dataset class with Skip-gram feature extraction

@register_dataset("ArabicSkipgramDataset")
class ArabicSkipgramDataset:
    def __init__(self, file_path: str, skipgram_model: Word2Vec = None):
        self.skipgram_model = skipgram_model
        self.sentences_with_diacritics = self.load_data(file_path)
        self.sentences_without_diacritics = self.extract_text_without_diacritics(
            self.sentences_with_diacritics)
        
        # Tokenize into words for Skip-gram
        self.tokenized_sentences = [sentence.split() for sentence in self.sentences_without_diacritics]
        
        # Extract diacritics per character
        self.diacritics_per_sentence = [
            self.extract_diacritics(sentence) 
            for sentence in self.sentences_with_diacritics
        ]

    def __len__(self):
        return len(self.sentences_without_diacritics)

    def __getitem__(self, idx):
        return self.tokenized_sentences[idx], self.diacritics_per_sentence[idx]

    def load_data(self, file_path: str):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    line = re.sub(
                        f'[^{re.escape("".join(VALID_CHARS))}]', '', line)
                    line = re.sub(r'\s+', ' ', line)
                    sentences = re.split(
                        f'[{re.escape("".join(PUNCTUATIONS))}]', line)
                    sentences = [s.strip() for s in sentences if s.strip()]
                    data.extend(sentences)
        return np.array(data)

    def extract_text_without_diacritics(self, dataY):
        dataX = dataY.copy()
        for diacritic, _ in DIACRITIC2ID.items():
            dataX = np.char.replace(dataX, diacritic, '')
        return dataX

    def extract_diacritics(self, sentence: str):
        """Extract diacritics for each character in the sentence."""
        result = []
        i = 0
        n = len(sentence)
        on_char = False

        while i < n:
            ch = sentence[i]
            if ch in DIACRITICS:
                on_char = False
                if i+1 < n and sentence[i+1] in DIACRITICS:
                    combined = ch + sentence[i+1]
                    if combined in DIACRITIC2ID:
                        result.append(str(DIACRITIC2ID[combined]))
                        i += 2
                        continue
                result.append(str(DIACRITIC2ID[ch]))
            elif ch in CHAR2ID:
                if on_char:
                    result.append(str(DIACRITIC2ID['']))
                on_char = True
            i += 1
        if on_char:
            result.append(str(DIACRITIC2ID['']))
        return result

    def get_corpus_for_skipgram(self):
        """Return tokenized sentences for Skip-gram training."""
        return self.tokenized_sentences

In [6]:
# Cell 6: CRF Model class and Skip-gram training

def train_skipgram_model(corpus: List[List[str]], save_path: str) -> Word2Vec:
    """Train a Skip-gram Word2Vec model on the corpus."""
    print("Training Skip-gram model...")
    model = Word2Vec(
        sentences=corpus,
        vector_size=EMBEDDING_DIM,
        window=WINDOW_SIZE,
        min_count=MIN_COUNT,
        sg=SG,
        workers=WORKERS,
        epochs=SKIPGRAM_EPOCHS
    )
    model.save(save_path)
    print(f"Skip-gram model saved to {save_path}")
    return model


def load_skipgram_model(path: str) -> Word2Vec:
    """Load a pre-trained Skip-gram model."""
    return Word2Vec.load(path)


def word_to_features(skipgram_model: Word2Vec, sentence: List[str], char_idx: int, word_idx: int, char_in_word: str, is_last_char: bool) -> Dict[str, Any]:
    """
    Extract features for a single character using Skip-gram word embeddings.
    
    Features include:
    - The word's embedding (if available)
    - Character identity
    - Position in word
    - Context word embeddings
    """
    features = {
        'bias': 1.0,
        'char': char_in_word,
        'char_idx_in_sentence': char_idx,
        'is_last_char_in_word': is_last_char,
    }
    
    # Current word embedding
    current_word = sentence[word_idx]
    if current_word in skipgram_model.wv:
        embedding = skipgram_model.wv[current_word]
        for i, val in enumerate(embedding):
            features[f'word_emb_{i}'] = float(val)
    else:
        for i in range(EMBEDDING_DIM):
            features[f'word_emb_{i}'] = 0.0
    
    # Previous word embedding (context)
    if word_idx > 0:
        prev_word = sentence[word_idx - 1]
        if prev_word in skipgram_model.wv:
            embedding = skipgram_model.wv[prev_word]
            for i, val in enumerate(embedding):
                features[f'prev_word_emb_{i}'] = float(val)
        else:
            for i in range(EMBEDDING_DIM):
                features[f'prev_word_emb_{i}'] = 0.0
    else:
        features['BOS'] = True  # Beginning of sentence
        for i in range(EMBEDDING_DIM):
            features[f'prev_word_emb_{i}'] = 0.0
    
    # Next word embedding (context)
    if word_idx < len(sentence) - 1:
        next_word = sentence[word_idx + 1]
        if next_word in skipgram_model.wv:
            embedding = skipgram_model.wv[next_word]
            for i, val in enumerate(embedding):
                features[f'next_word_emb_{i}'] = float(val)
        else:
            for i in range(EMBEDDING_DIM):
                features[f'next_word_emb_{i}'] = 0.0
    else:
        features['EOS'] = True  # End of sentence
        for i in range(EMBEDDING_DIM):
            features[f'next_word_emb_{i}'] = 0.0
    
    return features


def sentence_to_features(skipgram_model: Word2Vec, tokenized_sentence: List[str], sentence_without_diacritics: str) -> List[Dict[str, Any]]:
    """
    Convert an entire sentence to CRF features.
    Each character gets features based on its word's Skip-gram embedding.
    """
    features = []
    char_idx = 0
    
    for word_idx, word in enumerate(tokenized_sentence):
        for i, char in enumerate(word):
            if char in CHAR2ID and char != ' ':
                is_last_char = (i == len(word) - 1)
                feat = word_to_features(
                    skipgram_model, tokenized_sentence, char_idx, word_idx, char, is_last_char
                )
                features.append(feat)
                char_idx += 1
    
    return features


@register_model("CRFArabicModel")
class CRFArabicModel:
    def __init__(self):
        self.crf = sklearn_crfsuite.CRF(
            algorithm=CRF_ALGORITHM,
            c1=CRF_C1,
            c2=CRF_C2,
            max_iterations=CRF_MAX_ITERATIONS,
            all_possible_transitions=True
        )
    
    def fit(self, X_train: List[List[Dict]], y_train: List[List[str]]):
        """Train the CRF model."""
        self.crf.fit(X_train, y_train)
    
    def predict(self, X: List[List[Dict]]) -> List[List[str]]:
        """Predict diacritics for input sequences."""
        return self.crf.predict(X)
    
    def save(self, path: str):
        """Save the CRF model to disk."""
        with open(path, 'wb') as f:
            pickle.dump(self.crf, f)
    
    def load(self, path: str):
        """Load a CRF model from disk."""
        with open(path, 'rb') as f:
            self.crf = pickle.load(f)

In [7]:
# Cell 7: Train function

def train(model: CRFArabicModel, train_dataset: ArabicSkipgramDataset, 
          skipgram_model: Word2Vec, model_path: str):
    """
    Train the CRF model on the training dataset.
    """
    print("Preparing training data for CRF...")
    X_train = []
    y_train = []
    
    for idx in tqdm(range(len(train_dataset)), desc="Extracting features"):
        tokenized_sentence, diacritics = train_dataset[idx]
        sentence_without_diacritics = train_dataset.sentences_without_diacritics[idx]
        
        features = sentence_to_features(
            skipgram_model, tokenized_sentence, sentence_without_diacritics
        )
        
        # Ensure features and diacritics have the same length
        if len(features) == len(diacritics):
            X_train.append(features)
            y_train.append(diacritics)
    
    print(f"Training CRF on {len(X_train)} sentences...")
    model.fit(X_train, y_train)
    model.save(model_path)
    print(f"CRF model saved to {model_path}")

In [8]:
# Cell 8: Evaluate function

def evaluate(model: CRFArabicModel, val_dataset: ArabicSkipgramDataset, 
             skipgram_model: Word2Vec):
    """
    Evaluate the CRF model on the validation dataset.
    Reports accuracy for:
    - Overall
    - Without last character (morphology)
    - Last character only (syntax/case endings)
    """
    print("Preparing validation data...")
    X_val = []
    y_val = []
    last_char_indices = []  # Track which indices are last characters in words
    
    for idx in tqdm(range(len(val_dataset)), desc="Extracting validation features"):
        tokenized_sentence, diacritics = val_dataset[idx]
        sentence_without_diacritics = val_dataset.sentences_without_diacritics[idx]
        
        features = sentence_to_features(
            skipgram_model, tokenized_sentence, sentence_without_diacritics
        )
        
        if len(features) == len(diacritics):
            X_val.append(features)
            y_val.append(diacritics)
            
            # Track last character positions
            sentence_last_chars = [feat.get('is_last_char_in_word', False) for feat in features]
            last_char_indices.append(sentence_last_chars)
    
    print("Running predictions...")
    y_pred = model.predict(X_val)
    
    # Calculate accuracies
    total_correct = 0
    total_tokens = 0
    total_correct_ending = 0
    total_tokens_ending = 0
    total_correct_without_ending = 0
    total_tokens_without_ending = 0
    
    for sent_idx in range(len(y_val)):
        for token_idx in range(len(y_val[sent_idx])):
            pred = y_pred[sent_idx][token_idx]
            true = y_val[sent_idx][token_idx]
            is_last_char = last_char_indices[sent_idx][token_idx]
            
            total_tokens += 1
            if pred == true:
                total_correct += 1
            
            if is_last_char:
                total_tokens_ending += 1
                if pred == true:
                    total_correct_ending += 1
            else:
                total_tokens_without_ending += 1
                if pred == true:
                    total_correct_without_ending += 1
    
    val_accuracy = (total_correct / total_tokens) * 100 if total_tokens > 0 else 0
    val_accuracy_ending = (total_correct_ending / total_tokens_ending) * 100 if total_tokens_ending > 0 else 0
    val_accuracy_without_ending = (total_correct_without_ending / total_tokens_without_ending) * 100 if total_tokens_without_ending > 0 else 0
    
    print(f"Validation Accuracy (Overall): {val_accuracy:.2f}%")
    print(f"Validation Accuracy (Without Last Character): {val_accuracy_without_ending:.2f}%")
    print(f"Validation Accuracy (Last Character): {val_accuracy_ending:.2f}%")

In [9]:
# Cell 9: Predict function

def predict(model: CRFArabicModel, skipgram_model: Word2Vec, 
            sentence: str) -> List[str]:
    """
    Predict diacritics for a single sentence.
    """
    # Remove any existing diacritics
    clean_sentence = sentence
    for diacritic in DIACRITICS:
        clean_sentence = clean_sentence.replace(diacritic, '')
    
    tokenized = clean_sentence.split()
    features = sentence_to_features(skipgram_model, tokenized, clean_sentence)
    
    if not features:
        return []
    
    predictions = model.predict([features])[0]
    return predictions

In [10]:
# Cell 10: Infer function

def infer(model: CRFArabicModel, skipgram_model: Word2Vec,
          input_path: str, output_path: str):
    """
    Run inference on an input file and save results.
    """
    with open(input_path, 'r', encoding='utf-8') as f:
        input_data = f.readlines()
    
    output_list = []
    output_csv = [["ID", "Label"]]
    current_id = 0
    
    for sentence in tqdm(input_data, desc="Inference"):
        sentence = sentence.strip()
        if not sentence:
            output_list.append("")
            continue
        
        # Clean sentence
        clean_sentence = sentence
        for diacritic in DIACRITICS:
            clean_sentence = clean_sentence.replace(diacritic, '')
        
        predictions = predict(model, skipgram_model, clean_sentence)
        
        # Reconstruct diacritized sentence
        diacritized_sentence = ""
        pred_idx = 0
        for char in clean_sentence:
            diacritized_sentence += char
            if char in ARABIC_LETTERS and pred_idx < len(predictions):
                diacritic_id = int(predictions[pred_idx])
                diacritic = ID2DIACRITIC.get(diacritic_id, '')
                diacritized_sentence += diacritic
                output_csv.append([current_id, diacritic_id])
                current_id += 1
                pred_idx += 1
        
        output_list.append(diacritized_sentence)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for line in output_list:
            f.write(line + '\n')
    
    output_path_csv = os.path.splitext(output_path)[0] + ".csv"
    with open(output_path_csv, "w", newline="", encoding="utf-8") as file:
        writer = csv.writer(file)
        writer.writerows(output_csv)
    
    print(f"Output saved to {output_path} and {output_path_csv}")

In [11]:
# Cell 11: Load training dataset
train_dataset = generate_dataset("ArabicSkipgramDataset", "../data/train.txt")

In [12]:
# Cell 12: Train Skip-gram model on training corpus
corpus = train_dataset.get_corpus_for_skipgram()
skipgram_model = train_skipgram_model(corpus, skipgram_model_path)

Training Skip-gram model...
Skip-gram model saved to ../models/SkipgramWordEmbeddings.model


In [13]:
# Cell 13: Create CRF model
crf_model = generate_model("CRFArabicModel")

In [14]:
# Cell 14: Train CRF model
train(crf_model, train_dataset, skipgram_model, model_path)

Preparing training data for CRF...


Extracting features: 100%|██████████| 186315/186315 [05:21<00:00, 579.79it/s]


Training CRF on 20019 sentences...
CRF model saved to ../models/SkipgramCRFModel.pkl


In [15]:
# Cell 15: Load validation dataset
val_dataset = generate_dataset("ArabicSkipgramDataset", "../data/val.txt")

In [16]:
# Cell 16: Evaluate the model
evaluate(crf_model, val_dataset, skipgram_model)

Preparing validation data...


Extracting validation features: 100%|██████████| 9068/9068 [00:15<00:00, 592.02it/s]


Running predictions...
Validation Accuracy (Overall): 82.33%
Validation Accuracy (Without Last Character): 81.87%
Validation Accuracy (Last Character): 83.82%


In [17]:
# Cell 17: Run inference (optional)
# infer(crf_model, skipgram_model, input_path, output_path)