In [146]:
import torch
from torch.utils.data import Dataset
import re
import pickle as pkl
import numpy as np

class TashkeelDataset(Dataset):
    def __init__(self, path):
        with open(path, 'r', encoding='utf-8') as file:
            self.lines = lines = file.readlines()
        self._load_dicts()
        self.tokenized_lines = self._tokenize_lines()
        self.embedded_data = self._embedd_lines()

    def __len__(self):
        return len(self.embedded_data)

    def __getitem__(self, idx):
        x, y = self.embedded_data[idx]
        return torch.tensor(x), torch.tensor(y)
    
    def _remove_tashkeel(self,data):
        #double damma, double fatha, double kasera, damma, fatha, kasera, sukoon, shadd
        TASHKEEL_SET = {'ٌ', 'ً', 'ٍ', 'ُ', 'َ', 'ِ', 'ْ', 'ٌّ', 'ّ'}
        DIACRITICS_REGEX = re.compile('|'.join(TASHKEEL_SET))
        return re.sub(DIACRITICS_REGEX, '', data)
    
    def _one_hot_encode(self, indices, size):
        return [[1 if i == elem else 0 for i in range(size)] for elem in indices]

    def _chunk_text(self, text, chunk_size):
        chunks = []
        words = re.findall(r'\S+', text)

        current_chunk = ""
        for word in words:
            if len(current_chunk) + len(word) + 1 <= chunk_size:
                current_chunk += f"{word} "
            else:
                chunks.append(current_chunk.strip())
                current_chunk = f"{word} "

        if current_chunk:
            chunks.append(current_chunk.strip())

        return list(filter(None, chunks))
    
    def _tokenize_lines(self):
        # Define a pattern to match specific punctuation marks
        punctuation_pattern1 = r'([.,:;؛)\]}»،])'
        punctuation_pattern2 = r'([(\[{«])'
        tokenized_lines = []

        for line in self.lines:
            # Replace matched punctuation marks with the same followed by a line break
            splitted_line = re.sub(punctuation_pattern1, r'\1\n', line)
            splitted_line = re.sub(punctuation_pattern2, r'\n\1', splitted_line)

            # Further split the splitted line into substrings based on line breaks
            for sub_line in splitted_line.split('\n'):
                cleaned_sub_line = self._remove_tashkeel(sub_line).strip()
                if 0 < len(cleaned_sub_line) <= 500:
                    tokenized_lines.append(sub_line.strip())

                elif len(cleaned_sub_line) > 500:
                    tokenized_lines.extend(self._chunk_text(sub_line.strip(), 500))
    
        return tokenized_lines

    def _load_dicts(self):
        with open( '../utilities/pickle_files/LETTERS.pickle', 'rb') as file:
            self.LETTERS = pkl.load(file)
        with open( '../utilities/pickle_files/DIACRITICS.pickle', 'rb') as file:
            self.DIACRITICS = pkl.load(file)
        with open( '../utilities/pickle_files/CHAR_TO_ID.pickle', 'rb') as file:
            self.CHAR_TO_ID = pkl.load(file)
        with open( '../utilities/pickle_files/DIACRITIC_TO_ID.pickle', 'rb') as file:
            self.DIACRITIC_TO_ID = pkl.load(file)
        self.CHAR_TO_ID['<UNK>'] = len(self.CHAR_TO_ID)
        
    def _embedd_lines(self):
        inputs_embeddings=[]
        for line in self.tokenized_lines:
            x = [self.CHAR_TO_ID['<SOS>']]
            y = [self.DIACRITIC_TO_ID['<SOS>']]

            for index, char in enumerate(line):
                if char in self.CHAR_TO_ID:
                    x.append(self.CHAR_TO_ID[char])
                else: 
                    x.append(self.CHAR_TO_ID['<UNK>'])

                if char not in self.LETTERS:
                    y.append(self.DIACRITIC_TO_ID[''])
                else:
                    char_diac = ''
                    if index + 1 < len(line) and line[index + 1] in self.DIACRITICS:
                        char_diac = line[index + 1]
                        if index + 2 < len(line) and line[index + 2] in self.DIACRITICS and char_diac + line[index + 2] in self.DIACRITIC_TO_ID:
                            char_diac += line[index + 2]
                        elif index + 2 < len(line) and line[index + 2] in self.DIACRITICS and line[index + 2] + char_diac in self.DIACRITIC_TO_ID:
                            char_diac = line[index + 2] + char_diac
                    y.append(self.DIACRITIC_TO_ID[char_diac])

            x.append(self.CHAR_TO_ID['<EOS>'])
            y.append(self.DIACRITIC_TO_ID['<EOS>'])
            y = self._one_hot_encode(y, len(self.DIACRITIC_TO_ID))
            
            inputs_embeddings.append((x, y)) 
            
        return inputs_embeddings

In [147]:
# train_dataset = TashkeelDataset('../data/train.txt')

In [148]:
train_dataset = TashkeelDataset('../data/val.txt')
val_dataset = TashkeelDataset('../data/val.txt')

In [149]:
from torch.utils.data import DataLoader

import torch.nn.utils.rnn as rnn_utils

def collate_fn(batch):
    x_batch, y_batch = zip(*batch)
    x_padded = rnn_utils.pad_sequence(x_batch, batch_first=True, padding_value=train_dataset.CHAR_TO_ID['<PAD>'])
    y_padded = rnn_utils.pad_sequence(y_batch, batch_first=True, padding_value=train_dataset.DIACRITIC_TO_ID['<PAD>'])
    return x_padded, y_padded

# Create a DataLoader instance with collate_fn
dataloader_train = DataLoader(val_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)
dataloader_test = DataLoader(val_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)

In [158]:
import torch.nn as nn 

class MeshakkelatyModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size):
        super(MeshakkelatyModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.blstm1 = nn.LSTM(embedding_dim, hidden_size, bidirectional=True, batch_first=True)
        self.dropout1 = nn.Dropout(0.5)
        self.blstm2 = nn.LSTM(2 * hidden_size, hidden_size, bidirectional=True, batch_first=True)
        self.dropout2 = nn.Dropout(0.5)
        self.dense1 = nn.Linear(2 * hidden_size, 512)
        self.dense2 = nn.Linear(512, 512)
        self.output = nn.Linear(512, output_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.blstm1(x)
        x = self.dropout1(x)
        x, _ = self.blstm2(x)
        x = self.dropout2(x)
        
        # Add a Global Average Pooling layer 
        x = x.mean(dim=1)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.output(x)
        x = self.softmax(x)
        return x

In [159]:
meshakkelaty = MeshakkelatyModel(vocab_size=len(train_dataset.CHAR_TO_ID), embedding_dim=25, hidden_size=256, output_size=len(train_dataset.DIACRITIC_TO_ID))

In [160]:
import torch.optim as optim 
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(meshakkelaty.parameters())
epochs = 10

In [None]:
for epoch in range(epochs):
    meshakkelaty.train()
    for x_batch, y_batch in dataloader_train:
        optimizer.zero_grad()
        y_pred = meshakkelaty(x_batch)
        loss = criterion(y_pred.view(-1, len(train_dataset.DIACRITIC_TO_ID)), y_batch.argmax(dim=-1).view(-1))
        loss.backward()
        optimizer.step()

    meshakkelaty.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x_val, y_val in dataloader_test:
            y_val_pred = meshakkelaty(x_val)
            val_loss += criterion(y_val_pred.view(-1, len(train_dataset.DIACRITIC_TO_ID)), y_val.argmax(dim=-1).view(-1)).item()

    print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}, Val Loss: {val_loss / len(dataloader_test):.4f}')