In [None]:
import torch
from torch.utils.data import Dataset
import numpy as np
import re
import pickle
from typing import List, Any


In [None]:
model_path = f"../models/ArabicHMMModel.pkl"
input_path = f"../input/dataset_no_diacritics.txt"
output_path = f"../output/output.txt"

In [None]:

# Configurations
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Global registries
DATASET_REGISTRY: dict[str, Any] = {}
MODEL_REGISTRY: dict[str, Any] = {}

# Data parameters
ARABIC_LETTERS = sorted(
    np.load('../data/utils/arabic_letters.pkl', allow_pickle=True))
DIACRITICS = sorted(np.load(
    '../data/utils/diacritics.pkl', allow_pickle=True))
PUNCTUATIONS = {".", "،", ":", "؛", "؟", "!", '"', "-"}

VALID_CHARS = set(ARABIC_LETTERS).union(
    set(DIACRITICS)).union(PUNCTUATIONS).union({" "})

CHAR2ID = {char: id for id, char in enumerate(ARABIC_LETTERS)}
CHAR2ID[" "] = len(ARABIC_LETTERS)
CHAR2ID["<PAD>"] = len(ARABIC_LETTERS) + 1
PAD = CHAR2ID["<PAD>"]
SPACE = CHAR2ID[" "]
ID2CHAR = {id: char for char, id in CHAR2ID.items()}

DIACRITIC2ID = np.load('../data/utils/diacritic2id.pkl', allow_pickle=True)
ID2DIACRITIC = {id: diacritic for diacritic, id in DIACRITIC2ID.items()}


In [4]:

def register_dataset(name):
    def decorator(cls):
        DATASET_REGISTRY[name] = cls
        return cls
    return decorator


def generate_dataset(dataset_name: str, *args, **kwargs):
    try:
        dataset_cls = DATASET_REGISTRY[dataset_name]
    except KeyError:
        raise ValueError(f"Dataset '{dataset_name}' is not recognized.")
    return dataset_cls(*args, **kwargs)



def register_model(name):
    def decorator(cls):
        MODEL_REGISTRY[name] = cls
        return cls
    return decorator


def generate_model(model_name: str, *args, **kwargs):
    try:
        model_cls = MODEL_REGISTRY[model_name]
    except KeyError:
        raise ValueError(f"Model '{model_name}' is not recognized.")
    return model_cls(*args, **kwargs)


In [5]:

@register_dataset("ArabicDataset")
class ArabicDataset(Dataset):
    def __init__(self, file_path: str):
        self.data_X, self.data_Y = self.generate_tensor_data(file_path)

    def __len__(self):
        return len(self.data_X)

    def __getitem__(self, idx):
        return self.data_X[idx], self.data_Y[idx]

    def generate_tensor_data(self, data_path: str):
        data_Y = self.load_data(data_path)
        data_X = self.extract_text_without_diacritics(data_Y)

        encoded_data_X, encoded_data_Y = self.encode_data(data_X, data_Y)
        data_X = torch.tensor(
            encoded_data_X, dtype=torch.int64)
        data_Y = torch.tensor(
            encoded_data_Y, dtype=torch.int64)

        return data_X, data_Y

    def load_data(self, file_path: str):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    line = re.sub(
                        f'[^{re.escape("".join(VALID_CHARS))}]', '', line)
                    line = re.sub(r'\s+', ' ', line)
                    sentences = re.split(
                        f'[{re.escape("".join(PUNCTUATIONS))}]', line)
                    sentences = [s.strip() for s in sentences if s.strip()]
                    data.extend(sentences)

        return np.array(data)

    def extract_text_without_diacritics(self, dataY):
        dataX = dataY.copy()
        for diacritic, _ in DIACRITIC2ID.items():
            dataX = np.char.replace(
                dataX, diacritic, '')
        return dataX

    def encode_data(self, dataX: List[str], dataY: List[str]):
        encoded_data_X = []
        for sentence in dataX:
            encoded_data_X.append([CHAR2ID[char]
                                   for char in sentence if char in CHAR2ID])
        encoded_data_Y = []
        for sentence in dataY:
            encoded_data_Y.append(self.extract_diacritics(sentence))

        max_sentence_len = max(len(sentence) for sentence in encoded_data_X)
        padded_dataX = np.full(
            (len(encoded_data_X), max_sentence_len), PAD, dtype=np.int64)
        for i, seq in enumerate(encoded_data_X):
            padded_dataX[i, :len(seq)] = seq

        padded_dataY = np.full(
            (len(encoded_data_Y), max_sentence_len), PAD, dtype=np.int64)
        for i, seq in enumerate(encoded_data_Y):
            padded_dataY[i, :len(seq)] = seq

        return padded_dataX, padded_dataY

    def extract_diacritics(self, sentence: str):
        result = []
        i = 0
        n = len(sentence)
        on_char = False

        while i < n:
            ch = sentence[i]
            if ch in DIACRITICS:
                on_char = False
                # check if next char forms a stacked diacritic
                if i+1 < n and sentence[i+1] in DIACRITICS:
                    combined = ch + sentence[i+1]
                    if combined in DIACRITIC2ID:
                        result.append(DIACRITIC2ID[combined])
                        i += 2
                        continue
                result.append(DIACRITIC2ID[ch])
            elif ch in CHAR2ID:
                if on_char:
                    result.append(DIACRITIC2ID[''])
                on_char = True
            i += 1
        if on_char:
            result.append(DIACRITIC2ID[''])
        return result


In [None]:
@register_model("HMMArabicModel")
class HMMArabicModel:
    def __init__(self, num_states, num_observations, pad_state_id=None):
        self.num_states = num_states
        self.num_observations = num_observations
        self.pad_state_id = pad_state_id
        
        self.log_pi = None             # (num_states,) initial state log probabilities
        self.log_transition = None     # (num_states, num_states)
        self.log_emission = None       # (num_states, num_observations)

    def fit(self, seq_obs, seq_states, laplace=1.0):
        """
        seq_obs: list of observation sequences (char ids)
        seq_states: list of corresponding state sequences (diac ids)
        laplace: add-k smoothing constant
        """
        pi_counts = np.zeros(self.num_states, dtype=np.float64)
        transition_counts = np.zeros((self.num_states, self.num_states), dtype=np.float64)
        emission_counts = np.zeros((self.num_states, self.num_observations), dtype=np.float64)

        total_sequences = 0
        for obs, states in zip(seq_obs, seq_states):
            #uniform length check
            if len(obs) == 0 or len(obs) != len(states):
                continue
            total_sequences += 1

            # initial state
            s0 = states[0]
            if self.is_pad(s0):
                continue
            pi_counts[s0] += 1.0
            
            # emissions and transitions
            for i in range(len(obs)):
                s = states[i]
                o = obs[i]
                if self.is_pad(s) or o is None:
                    continue
                emission_counts[s, o] += 1.0
                if i + 1 < len(obs):
                    s_next = states[i + 1]
                    if not self.is_pad(s_next):
                        transition_counts[s, s_next] += 1.0
        # apply Laplace smoothing then normalize
        pi_sm = pi_counts + laplace
        self.log_pi = np.log(pi_sm / pi_sm.sum())

        log_transition = transition_counts + laplace
        transition_row_sums = log_transition.sum(axis=1, keepdims=True)
        # avoid divide by zero
        transition_row_sums[transition_row_sums == 0] = 1.0
        self.log_transition = np.log(log_transition / transition_row_sums)

        log_emission = emission_counts + laplace
        emission_row_sums = log_emission.sum(axis=1, keepdims=True)
        emission_row_sums[emission_row_sums == 0] = 1.0
        self.log_emission = np.log(log_emission / emission_row_sums)

        return self
    
    def is_pad(self, state_id):
        return self.pad_state_id is not None and state_id == self.pad_state_id

    def viterbi(self, obs_seq):
        """
        obs_seq: list of observation ids (ints)
        returns: best state sequence (list of state ids)
        """
        T = len(obs_seq)
        N = self.num_states
        if T == 0:
            return []

        # Use -inf for impossible
        neginf = -1e300

        # delta[t, i] = max log prob of a path ending in state i at time t
        delta = np.full((T, N), neginf, dtype=np.float64)
        psi = np.zeros((T, N), dtype=np.int32)

        # init
        o0 = obs_seq[0]
        # if observation index out of range, treat emission log-prob as neginf
        emit0 = self.log_emission[:, o0] if 0 <= o0 < self.num_observations else np.full(N, neginf)
        delta[0, :] = self.log_pi + emit0
        psi[0, :] = 0

        for t in range(1, T):
            ot = obs_seq[t]
            emit_t = self.log_emission[:, ot] if 0 <= ot < self.num_observations else np.full(N, neginf)
            for j in range(N):
                # compute delta[t-1, i] + logA[i,j] for all i
                scores = delta[t-1, :] + self.log_transition[:, j]
                i_max = np.argmax(scores)
                delta[t, j] = scores[i_max] + emit_t[j]
                psi[t, j] = i_max

        # backtrack
        states = [0] * T
        states[T-1] = int(np.argmax(delta[T-1, :]))
        for t in range(T-2, -1, -1):
            states[t] = int(psi[t+1, states[t+1]])
        return states


In [None]:
def evaluate(model, X_val, Y_val):
    total_correct = 0
    total_tokens = 0
    total_correct_ending = 0
    total_tokens_ending = 0
    total_correct_without_ending = 0
    total_tokens_without_ending = 0

    for obs_seq, true_seq in zip(X_val, Y_val):
        pred_seq = model.viterbi(obs_seq)

        # Determine end-of-word positions
        end_of_word_mask = []
        for i, o in enumerate(obs_seq):
            if i + 1 < len(obs_seq):
                end_of_word_mask.append(obs_seq[i + 1] == SPACE)
            else:
                end_of_word_mask.append(True)
        end_of_word_mask = [bool(x) for x in end_of_word_mask]

        for i, (pred, true) in enumerate(zip(pred_seq, true_seq)):
            if true == PAD:
                continue

            total_tokens += 1
            if pred == true:
                total_correct += 1

            if end_of_word_mask[i]:
                total_tokens_ending += 1
                if pred == true:
                    total_correct_ending += 1
            else:
                total_tokens_without_ending += 1
                if pred == true:
                    total_correct_without_ending += 1

    val_accuracy = total_correct / total_tokens * 100 if total_tokens else 0
    val_accuracy_ending = total_correct_ending / total_tokens_ending * 100 if total_tokens_ending else 0
    val_accuracy_without_ending = total_correct_without_ending / total_tokens_without_ending * 100 if total_tokens_without_ending else 0

    print(
        f"Validation Accuracy (Overall): {val_accuracy:.2f}%\n" +
        f"Validation Accuracy (Without Last Character): {val_accuracy_without_ending:.2f}%\n" +
        f"Validation Accuracy (Last Character): {val_accuracy_ending:.2f}%\n"
    )

In [None]:
def tensor_to_sequences(tensor):
    sequences = []
    for row in tensor:
        seq = [int(x) for x in row if int(x) != PAD]
        sequences.append(seq)
    return sequences

In [None]:
train_dataset = generate_dataset("ArabicDataset", "../data/train.txt")
X_train = tensor_to_sequences(train_dataset.data_X)
Y_train = tensor_to_sequences(train_dataset.data_Y)

In [None]:
val_dataset = generate_dataset("ArabicDataset", "../data/val.txt")
X_val = tensor_to_sequences(val_dataset.data_X)
Y_val = tensor_to_sequences(val_dataset.data_Y)

In [None]:
# Initialize HMM
model = generate_model(
    model_name="HMMArabicModel",
    num_states=len(DIACRITIC2ID),
    num_observations=len(CHAR2ID),
    pad_state_id=PAD
)

In [None]:
with open(model_path, "wb") as f:
    pickle.dump(model, f)
print(f"HMM model saved to {model_path}")

In [None]:
with open(model_path, "rb") as f:
    model = pickle.load(f)
print(f"HMM model loaded from {model_path}")

In [None]:
evaluate(model, X_val, Y_val)