In [7]:
import torch 
from torch.utils.data import Dataset
from tqdm import tqdm
import re
import pickle as pkl

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TashkeelDataset(Dataset):
    def __init__(self, name, path):
        self.name = name
        with open(path, 'r', encoding='utf-8') as file:
            self.lines = list(tqdm(file, f"Reading {self.name} Lines"))
        self._load_dicts()
        self.tokenized_lines = self._tokenize_lines()
        self.embedded_data = self._embedd_lines()

    def __len__(self):
        return len(self.embedded_data)

    def __getitem__(self, idx):
        x, y = self.embedded_data[idx]
        return torch.tensor(x).to(device), torch.tensor(y).to(device)
    
    def _remove_tashkeel(self,data):
        #double damma, double fatha, double kasera, damma, fatha, kasera, sukoon, shadd
        TASHKEEL_SET = {'ٌ', 'ً', 'ٍ', 'ُ', 'َ', 'ِ', 'ْ', 'ٌّ', 'ّ'}
        DIACRITICS_REGEX = re.compile('|'.join(TASHKEEL_SET))
        return re.sub(DIACRITICS_REGEX, '', data)
    
    def _one_hot_encode(self, indices, size):
        return [[1 if i == elem else 0 for i in range(size)] for elem in indices]
    
    def _chunk_text(self, text, chunk_size):
        chunks = []
        words = re.findall(r'\S+', text)

        current_chunk = ""
        for word in words:
            if len(current_chunk) + len(word) + 1 <= chunk_size:
                current_chunk += f"{word} "
            else:
                chunks.append(current_chunk.strip())
                current_chunk = f"{word} "

        if current_chunk:
            chunks.append(current_chunk.strip())

        return list(filter(None, chunks))
    
    def _tokenize_lines(self):
        # Define a pattern to match specific punctuation marks
        punctuation_pattern1 = r'([.,:;؛)\]}»،])'
        punctuation_pattern2 = r'([(\[{«])'
        tokenized_lines = []

        for line in tqdm(self.lines, f"Tokenizing {self.name} Lines"):
            # Replace matched punctuation marks with the same followed by a line break
            splitted_line = re.sub(punctuation_pattern1, r'\1\n', line)
            splitted_line = re.sub(punctuation_pattern2, r'\n\1', splitted_line)

            # Further split the splitted line into substrings based on line breaks
            for sub_line in splitted_line.split('\n'):
                cleaned_sub_line = self._remove_tashkeel(sub_line).strip()
                if 0 < len(cleaned_sub_line) <= 500:
                    tokenized_lines.append(sub_line.strip())

                elif len(cleaned_sub_line) > 500:
                    tokenized_lines.extend(self._chunk_text(sub_line.strip(), 500))
    
        return tokenized_lines

    def _load_dicts(self):
        with open( '../utilities/pickle_files/LETTERS.pickle', 'rb') as file:
            self.LETTERS = pkl.load(file)
        with open( '../utilities/pickle_files/DIACRITICS.pickle', 'rb') as file:
            self.DIACRITICS = pkl.load(file)
        with open( '../utilities/pickle_files/CHAR_TO_ID.pickle', 'rb') as file:
            self.CHAR_TO_ID = pkl.load(file)
        with open( '../utilities/pickle_files/DIACRITIC_TO_ID.pickle', 'rb') as file:
            self.DIACRITIC_TO_ID = pkl.load(file)
        
    def _embedd_lines(self):
        inputs_embeddings=[]
        for line in tqdm(self.tokenized_lines, f"Embedding {self.name} Lines"):
            x = [self.CHAR_TO_ID['<SOS>']]
            y = [self.DIACRITIC_TO_ID['<SOS>']]

            for index, char in enumerate(line):
                if char in self.DIACRITICS:
                    continue

                if char in self.CHAR_TO_ID:
                    x.append(self.CHAR_TO_ID[char])
                else:
                    x.append(self.CHAR_TO_ID['<UNK>'])

                if char not in self.LETTERS:
                    y.append(self.DIACRITIC_TO_ID[''])

                else:
                    char_diac = ''
                    if index + 1 < len(line) and line[index + 1] in self.DIACRITICS:
                        char_diac = line[index + 1]
                        if index + 2 < len(line) and line[index + 2] in self.DIACRITICS and char_diac + line[index + 2] in self.DIACRITIC_TO_ID:
                            char_diac += line[index + 2]
                        elif index + 2 < len(line) and line[index + 2] in self.DIACRITICS and line[index + 2] + char_diac in self.DIACRITIC_TO_ID:
                            char_diac = line[index + 2] + char_diac
                    y.append(self.DIACRITIC_TO_ID[char_diac])

            x.append(self.CHAR_TO_ID['<EOS>'])
            y.append(self.DIACRITIC_TO_ID['<EOS>'])
            y = self._one_hot_encode(y, len(self.DIACRITIC_TO_ID))
            
            inputs_embeddings.append((x, y)) 
            
        return inputs_embeddings

In [8]:
train_dataset = TashkeelDataset('train dataset','../data/val.txt')
val_dataset = TashkeelDataset('validation dataset','../data/val.txt')

Reading train dataset Lines: 2500it [00:00, 208385.70it/s]


Tokenizing train dataset Lines: 100%|██████████| 2500/2500 [00:00<00:00, 12685.87it/s]
Embedding train dataset Lines: 100%|██████████| 15362/15362 [00:01<00:00, 8374.72it/s]
Reading validation dataset Lines: 2500it [00:00, 250083.71it/s]
Tokenizing validation dataset Lines: 100%|██████████| 2500/2500 [00:00<00:00, 11955.14it/s]
Embedding validation dataset Lines: 100%|██████████| 15362/15362 [00:02<00:00, 6805.47it/s]


In [9]:
from torch.utils.data import DataLoader

import torch.nn.utils.rnn as rnn_utils

def collate_fn(batch):
    x_batch, y_batch = zip(*batch)
    x_padded = rnn_utils.pad_sequence(x_batch, batch_first=True, padding_value=train_dataset.CHAR_TO_ID['<PAD>'])
    y_padded = rnn_utils.pad_sequence(y_batch, batch_first=True, padding_value=train_dataset.DIACRITIC_TO_ID['<PAD>'])
    return x_padded, y_padded

# Create a DataLoader instance with collate_fn
dataloader_train = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
dataloader_test = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

In [10]:
import torch.nn as nn 

class MeshakkelatyModel(nn.Module):
    def __init__(self, char_to_id, diacritic_to_id):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=len(char_to_id),
            embedding_dim=25,
            padding_idx=char_to_id['<PAD>']  
        )
        self.lstm1 = nn.LSTM(
            input_size=25,
            hidden_size=256,
            num_layers=2,
            bidirectional=True,
            dropout=0.5,
            batch_first=True  
        )
        self.linear1 = nn.Linear(2*256, 512)
        self.linear2 = nn.Linear(512, len(diacritic_to_id))

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm1(x)
        x = nn.functional.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [11]:
import torch.optim as optim 
from torchmetrics import Accuracy

meshakkelaty = MeshakkelatyModel(train_dataset.CHAR_TO_ID, train_dataset.DIACRITIC_TO_ID).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(meshakkelaty.parameters())
epochs = 10
metric = Accuracy(task="multiclass", num_classes=len(train_dataset.DIACRITIC_TO_ID)).to(device)

In [None]:
for epoch in range(epochs):
    meshakkelaty.train()

    # Initialize variables to accumulate correct and total predictions
    total_correct = 0
    total_samples = 0

    epoch_progress = tqdm(dataloader_train, desc=f"Epoch {epoch + 1}/{epochs}")

    for x_batch, y_batch in epoch_progress:
    
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        y_pred = meshakkelaty(x_batch)
        loss = criterion(y_pred, y_batch.float())
        loss.backward()
        optimizer.step()

        # Convert one-hot encoded predictions and targets to class indices
        y_pred_class = y_pred.argmax(dim=-1)
        y_batch_class = y_batch.argmax(dim=-1)
        train_acc = metric(y_pred_class, y_batch_class)

        # Update accumulated values
        total_correct += torch.sum(y_pred_class == y_batch_class).item()

        total_samples += y_batch.size(0) * y_batch.size(1)

        # Calculate accuracy for the current batch
        batch_acc = total_correct / total_samples

        # Update the progress bar description with the current accuracy
        epoch_progress.set_description(f"Epoch {epoch + 1}/{epochs}, Train Accuracy: {metric.compute()*100:.4f}%", refresh=True)
        # total_correct = 0 
    # Print a newline to move to the next line after the epoch is finished
    print(f'Epoch {epoch + 1}/{epochs}, Train Accuracy: {metric.compute()*100:.4f}%')
    metric.reset()