## Read data from train.txt and filter it from unwanted patterns


In [90]:
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from Preprocessing import utils, character_encoding
from Models import rnn_pytorch
# import config as conf

# config = conf.ConfigLoader().load_config()

In [91]:
VECTOR_SIZE = 10

NUM_TRAIN_LINES = 10
NUM_TEST_LINES = 3
NUM_EVAL_LINES = int(NUM_TRAIN_LINES * 0.2)
PADDING_SIZE = 150

MODEL = 'lstm'
TRAIN_MODEL = True # True if you want to train the model else load an existing model
NUM_EPOCHS = 200
HIDDEN_SIZE = 1000
LEARNING_RATE = 0.001
BATCH_SIZE = 1
MODEL_NAME = f'model_{MODEL}_{NUM_TRAIN_LINES}L_{NUM_EPOCHS}Epoch_{HIDDEN_SIZE}Hidden.pth'

## Preprocessing

Clean data and save it (uncomment the following lines if you need to re-clean the data)


In [134]:
import re
def filter_data(data: str) -> str:
    # data = re.sub(r"\( \d+ (/ \d+)? \)", "", data)
    # remove all numbers
    data = re.sub(r"\d+", "", data)
    # regex to remove all special characters
    data = re.sub(r"[][//,;()$:\-{}_*؛،:«»`–\"~]", "", data)
    # remove all english letters
    data = re.sub(r"[a-zA-Z]", "", data)
    # Substituting multiple spaces with single space
    data = re.sub(r"([^\S\n])+", " ", data, flags=re.I)
    return data

# split data into sentences
def split_data_to_sentences(data: str) -> list:
    # Split data into sentences using punctuation marks and newlines as delimiters
    # sentences = re.split(r"[.?!\n]", data)
    sentences = re.split(r'[\.\?؟!\n]', data)
    # Remove empty sentences
    sentences = [sentence for sentence in sentences if sentence.strip()]
    return sentences

In [135]:
dataset = utils.read_data("./dataset/train.txt")
filtered_dataset = filter_data(dataset)
sentences = split_data_to_sentences(filtered_dataset)
print(f"Number of sentences: {len(sentences)}")
# print first 10 sentences
for i in range(10):
    print(sentences[i])

Number of sentences: 53641
قَوْلُهُ أَوْ قَطَعَ الْأَوَّلُ يَدَهُ إلَخْ قَالَ الزَّرْكَشِيُّ 
ابْنُ عَرَفَةَ قَوْلُهُ بِلَفْظٍ يَقْتَضِيه كَإِنْكَارِ غَيْرِ حَدِيثٍ بِالْإِسْلَامِ وُجُوبَ مَا عُلِمَ وُجُوبُهُ مِنْ الدِّينِ ضَرُورَةً كَإِلْقَاءِ مُصْحَفٍ بِقَذَرٍ وَشَدِّ زُنَّارٍ ابْنُ عَرَفَةَ قَوْلُ ابْنِ شَاسٍ أَوْ بِفِعْلٍ يَتَضَمَّنُهُ هُوَ كَلُبْسِ الزُّنَّارِ وَإِلْقَاءِ الْمُصْحَفِ فِي صَرِيحِ النَّجَاسَةِ وَالسُّجُودِ لِلصَّنَمِ وَنَحْوِ ذَلِكَ وَسِحْرٍ مُحَمَّدٌ قَوْلُ مَالِكٍ وَأَصْحَابِهِ أَنَّ السَّاحِرَ كَافِرٌ بِاَللَّهِ تَعَالَى قَالَ مَالِكٌ هُوَ كَالزِّنْدِيقِ إذَا عَمِلَ السِّحْرَ بِنَفْسِهِ قُتِلَ وَلَمْ يُسْتَتَبْ 
 قَوْلُهُ لِعَدَمِ مَا تَتَعَلَّقُ إلَخْ أَيْ الْوَصِيَّةُ قَوْلُهُ مَا مَرَّ أَيْ قُبَيْلَ قَوْلِ الْمَتْنِ لَغَتْ وَلَوْ اقْتَصَرَ عَلَى أَوْصَيْت لَهُ بِشَاةٍ أَوْ أَعْطُوهُ شَاةً وَلَا غَنَمَ لَهُ عِنْدَ الْمَوْتِ هَلْ تَبْطُلُ الْوَصِيَّةُ أَوْ يُشْتَرَى لَهُ شَاةٌ وَيُؤْخَذُ مِنْ قَوْلِهِ الْآتِي كَمَا لَوْ لَمْ يَقُلْ مِنْ مَالِي وَلَا مِنْ غَنَمِي

In [92]:
# def save_data(path: str, data: str):
#     with open(path, "w", encoding="utf-8") as f:
#         f.write(data)

# dataset = utils.read_data("./dataset/val.txt")
# filtered_dataset = utils.filter_data(dataset)
# save_data("./dataset/val_filtered.txt", filtered_dataset)

# dataset = utils.read_data("./dataset/train.txt")
# filtered_dataset = utils.filter_data(dataset)
# save_data("./dataset/train_filtered.txt", filtered_dataset)

## Feature Extraction


Split training data to sentences and remove diacritics from each sentence


In [93]:
class CustomDataset(Dataset):
    def __init__(self, test = False, eval = False, testdata = None):
        if test:
            dataset = testdata
            # self.filtered_dataset = utils.filter_data(dataset)
            self.data = utils.split_data_to_sentences(dataset)[0:NUM_TEST_LINES]
        elif eval:
            dataset = utils.read_data("./dataset/train_filtered.txt")
            # self.filtered_dataset = utils.filter_data(dataset)
            self.data = utils.split_data_to_sentences(dataset)[NUM_TRAIN_LINES:NUM_TRAIN_LINES + NUM_EVAL_LINES]
        else:
            dataset = utils.read_data("./dataset/train_filtered.txt")
            # self.filtered_dataset = utils.filter_data(dataset)
            self.data = utils.split_data_to_sentences(dataset)[0:NUM_TRAIN_LINES]
        self.max_length = PADDING_SIZE

    def __getitem__(self, index):
        sentence = self.data[index]
        # separate data (sentence) and label (diacritic of each character)
        sentence, diactritic = character_encoding.remove_diacritics(sentence, True)
        # get sentence vector
        sentence = character_encoding.getSentenceVector(sentence)
        # get diacritic vector
        diactritic = character_encoding.getDiacriticVector(diactritic)
        # add padding to sentence vector or clip it
        sentence,original_length = character_encoding.padding(sentence, len(character_encoding.ARABIC_ALPHABIT) +1,max_length=self.max_length)
        diactritic,_ = character_encoding.padding(diactritic, len(character_encoding.DIACRITICS),max_length=self.max_length)
        # convert to tensor
        sentence = torch.tensor(sentence, dtype=(torch.float32))
        diactritic = torch.tensor(diactritic, dtype=(torch.float32))
        return sentence, diactritic, original_length 

    def __len__(self):
        return len(self.data)

## Building The Model


#### connect to GPU if available


In [95]:
input_size = len(character_encoding.ARABIC_ALPHABIT) + 1
hidden_size = HIDDEN_SIZE
output_size = len(character_encoding.DIACRITICS)

In [None]:
# Create an instance of the RNN classifier
if MODEL == 'rnn':
    model = rnn_pytorch.RNNClassifier(input_size, hidden_size, output_size)

# Creare an instance of the LSTM classifier
elif MODEL == 'lstm':
    model = rnn_pytorch.LSTMClassifier(input_size, hidden_size, output_size)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device = " ,device)
print("Cuda : ",torch.cuda.is_available())
print("Number of Cuda devices :", torch.cuda.device_count())

In [97]:
batch_size = BATCH_SIZE
dataset = CustomDataset()
# Create a dataloader to handle batching and shuffling
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [98]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

### Training Model


In [99]:
num_epochs = NUM_EPOCHS


# Training loop
if TRAIN_MODEL:
    for epoch in range(num_epochs):
        correct_predictions = 0
        total_predictions = 0   
        for i, (sentences, labels, _) in enumerate(tqdm(train_dataloader)): 
            # Reshape input and labels to (batch_size, seq_length, input_size)
            sentences = sentences.view(batch_size, -1,input_size).to(device)
            labels = labels.view(batch_size, -1, output_size).to(device)
            # RNN
            if MODEL == 'rnn':
                hidden_state = model.init_hidden(batch_size=batch_size).to(device) # RNN has one hidden state
                optimizer.zero_grad()
                outputs = model(sentences, hidden_state)
            
            # LSTM
            elif MODEL == 'lstm':
                hidden_state, cell_state = model.init_hidden(batch_size=batch_size)  # LSTM has two hidden states
                hidden_state = hidden_state.to(device)
                cell_state = cell_state.to(device)
                optimizer.zero_grad()
                outputs = model(sentences, (hidden_state, cell_state))

            # Get the model's predictions
            _, predicted = torch.max(outputs.data, 1)
            
            # If your labels are one-hot encoded
            _, labels = torch.max(labels.data, 1)
            
            # Update total and correct predictions counters
            total_predictions += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            # calculate the loss
            loss = criterion(outputs, labels)
            # backward pass
            loss.backward()
            # update the weights
            optimizer.step()
            

        accuracy = correct_predictions / total_predictions
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Accuracy: {accuracy}')
    torch.save(model.state_dict(), f'./SavedModels/{MODEL_NAME}')
else:
    model.load_state_dict(torch.load( f'./SavedModels/{MODEL_NAME}'))
    print("Model loaded successfully")

100%|██████████| 10/10 [00:00<00:00, 14.81it/s]


Epoch 1/200, Loss: 4.21298360824585, Accuracy: 7.4


100%|██████████| 10/10 [00:00<00:00, 22.95it/s]


Epoch 2/200, Loss: 3.71956729888916, Accuracy: 8.3


100%|██████████| 10/10 [00:00<00:00, 21.59it/s]


Epoch 3/200, Loss: 3.493488073348999, Accuracy: 8.0


100%|██████████| 10/10 [00:00<00:00, 21.92it/s]


Epoch 4/200, Loss: 2.8995375633239746, Accuracy: 8.4


100%|██████████| 10/10 [00:00<00:00, 21.72it/s]


Epoch 5/200, Loss: 2.3798205852508545, Accuracy: 8.4


100%|██████████| 10/10 [00:00<00:00, 22.39it/s]


Epoch 6/200, Loss: 2.537262439727783, Accuracy: 8.2


100%|██████████| 10/10 [00:00<00:00, 22.68it/s]


Epoch 7/200, Loss: 1.8618651628494263, Accuracy: 8.0


100%|██████████| 10/10 [00:00<00:00, 21.96it/s]


Epoch 8/200, Loss: 2.5865564346313477, Accuracy: 8.2


100%|██████████| 10/10 [00:00<00:00, 22.24it/s]


Epoch 9/200, Loss: 2.2896182537078857, Accuracy: 8.5


100%|██████████| 10/10 [00:00<00:00, 22.66it/s]


Epoch 10/200, Loss: 2.0206024646759033, Accuracy: 8.6


100%|██████████| 10/10 [00:00<00:00, 21.94it/s]


Epoch 11/200, Loss: 1.9192785024642944, Accuracy: 8.3


100%|██████████| 10/10 [00:00<00:00, 22.25it/s]


Epoch 12/200, Loss: 1.3587666749954224, Accuracy: 8.7


100%|██████████| 10/10 [00:00<00:00, 22.65it/s]


Epoch 13/200, Loss: 2.023587942123413, Accuracy: 8.5


100%|██████████| 10/10 [00:00<00:00, 22.12it/s]


Epoch 14/200, Loss: 4.752702713012695, Accuracy: 8.8


100%|██████████| 10/10 [00:00<00:00, 22.03it/s]


Epoch 15/200, Loss: 2.254818916320801, Accuracy: 8.6


100%|██████████| 10/10 [00:00<00:00, 22.44it/s]


Epoch 16/200, Loss: 2.0638153553009033, Accuracy: 8.4


100%|██████████| 10/10 [00:00<00:00, 21.70it/s]


Epoch 17/200, Loss: 2.706198215484619, Accuracy: 8.4


100%|██████████| 10/10 [00:00<00:00, 21.81it/s]


Epoch 18/200, Loss: 3.709850311279297, Accuracy: 8.9


100%|██████████| 10/10 [00:00<00:00, 21.94it/s]


Epoch 19/200, Loss: 1.9530967473983765, Accuracy: 8.8


100%|██████████| 10/10 [00:00<00:00, 22.72it/s]


Epoch 20/200, Loss: 1.1978126764297485, Accuracy: 9.1


100%|██████████| 10/10 [00:00<00:00, 22.49it/s]


Epoch 21/200, Loss: 2.463221311569214, Accuracy: 9.3


100%|██████████| 10/10 [00:00<00:00, 22.32it/s]


Epoch 22/200, Loss: 1.7901307344436646, Accuracy: 9.3


100%|██████████| 10/10 [00:00<00:00, 22.95it/s]


Epoch 23/200, Loss: 2.395540952682495, Accuracy: 9.3


100%|██████████| 10/10 [00:00<00:00, 22.60it/s]


Epoch 24/200, Loss: 3.5281436443328857, Accuracy: 9.3


100%|██████████| 10/10 [00:00<00:00, 21.18it/s]


Epoch 25/200, Loss: 1.062536597251892, Accuracy: 9.3


100%|██████████| 10/10 [00:00<00:00, 22.43it/s]


Epoch 26/200, Loss: 1.504431128501892, Accuracy: 9.4


100%|██████████| 10/10 [00:00<00:00, 22.00it/s]


Epoch 27/200, Loss: 2.292299509048462, Accuracy: 9.6


100%|██████████| 10/10 [00:00<00:00, 22.61it/s]


Epoch 28/200, Loss: 1.4968101978302002, Accuracy: 9.7


100%|██████████| 10/10 [00:00<00:00, 22.20it/s]


Epoch 29/200, Loss: 2.1249194145202637, Accuracy: 9.8


100%|██████████| 10/10 [00:00<00:00, 21.98it/s]


Epoch 30/200, Loss: 1.5716270208358765, Accuracy: 9.5


100%|██████████| 10/10 [00:00<00:00, 21.77it/s]


Epoch 31/200, Loss: 1.726570725440979, Accuracy: 9.8


100%|██████████| 10/10 [00:00<00:00, 22.20it/s]


Epoch 32/200, Loss: 3.195695400238037, Accuracy: 9.8


100%|██████████| 10/10 [00:00<00:00, 22.23it/s]


Epoch 33/200, Loss: 1.3021489381790161, Accuracy: 9.6


100%|██████████| 10/10 [00:00<00:00, 22.30it/s]


Epoch 34/200, Loss: 1.630414366722107, Accuracy: 9.4


100%|██████████| 10/10 [00:00<00:00, 22.46it/s]


Epoch 35/200, Loss: 1.9971938133239746, Accuracy: 9.7


100%|██████████| 10/10 [00:00<00:00, 21.85it/s]


Epoch 36/200, Loss: 2.1519861221313477, Accuracy: 9.4


100%|██████████| 10/10 [00:00<00:00, 22.70it/s]


Epoch 37/200, Loss: 1.4618159532546997, Accuracy: 9.7


100%|██████████| 10/10 [00:00<00:00, 22.54it/s]


Epoch 38/200, Loss: 1.9522532224655151, Accuracy: 9.5


100%|██████████| 10/10 [00:00<00:00, 21.77it/s]


Epoch 39/200, Loss: 1.9487431049346924, Accuracy: 9.9


100%|██████████| 10/10 [00:00<00:00, 22.00it/s]


Epoch 40/200, Loss: 1.8759714365005493, Accuracy: 9.7


100%|██████████| 10/10 [00:00<00:00, 22.43it/s]


Epoch 41/200, Loss: 1.3051799535751343, Accuracy: 9.9


100%|██████████| 10/10 [00:00<00:00, 21.35it/s]


Epoch 42/200, Loss: 2.986445188522339, Accuracy: 9.9


100%|██████████| 10/10 [00:00<00:00, 22.61it/s]


Epoch 43/200, Loss: 2.168454647064209, Accuracy: 10.0


100%|██████████| 10/10 [00:00<00:00, 22.43it/s]


Epoch 44/200, Loss: 0.7363598346710205, Accuracy: 9.9


100%|██████████| 10/10 [00:00<00:00, 22.57it/s]


Epoch 45/200, Loss: 2.0997703075408936, Accuracy: 9.9


100%|██████████| 10/10 [00:00<00:00, 22.03it/s]


Epoch 46/200, Loss: 1.7379558086395264, Accuracy: 10.0


100%|██████████| 10/10 [00:00<00:00, 22.34it/s]


Epoch 47/200, Loss: 1.0325998067855835, Accuracy: 10.0


100%|██████████| 10/10 [00:00<00:00, 22.68it/s]


Epoch 48/200, Loss: 1.4379255771636963, Accuracy: 10.0


100%|██████████| 10/10 [00:00<00:00, 22.16it/s]


Epoch 49/200, Loss: 1.9150443077087402, Accuracy: 9.9


100%|██████████| 10/10 [00:00<00:00, 23.11it/s]


Epoch 50/200, Loss: 1.9134526252746582, Accuracy: 9.9


100%|██████████| 10/10 [00:00<00:00, 22.36it/s]


Epoch 51/200, Loss: 1.2240911722183228, Accuracy: 10.1


100%|██████████| 10/10 [00:00<00:00, 22.71it/s]


Epoch 52/200, Loss: 1.0060890913009644, Accuracy: 10.0


100%|██████████| 10/10 [00:00<00:00, 22.52it/s]


Epoch 53/200, Loss: 2.0352158546447754, Accuracy: 10.2


100%|██████████| 10/10 [00:00<00:00, 22.10it/s]


Epoch 54/200, Loss: 1.5298584699630737, Accuracy: 10.3


100%|██████████| 10/10 [00:00<00:00, 22.51it/s]


Epoch 55/200, Loss: 1.5079989433288574, Accuracy: 10.7


100%|██████████| 10/10 [00:00<00:00, 22.67it/s]


Epoch 56/200, Loss: 1.4004788398742676, Accuracy: 10.7


100%|██████████| 10/10 [00:00<00:00, 22.38it/s]


Epoch 57/200, Loss: 2.381242036819458, Accuracy: 10.7


100%|██████████| 10/10 [00:00<00:00, 22.48it/s]


Epoch 58/200, Loss: 0.45981737971305847, Accuracy: 10.9


100%|██████████| 10/10 [00:00<00:00, 22.56it/s]


Epoch 59/200, Loss: 0.5840107202529907, Accuracy: 11.0


100%|██████████| 10/10 [00:00<00:00, 21.68it/s]


Epoch 60/200, Loss: 1.0948246717453003, Accuracy: 10.9


100%|██████████| 10/10 [00:00<00:00, 22.35it/s]


Epoch 61/200, Loss: 1.204560399055481, Accuracy: 10.9


100%|██████████| 10/10 [00:00<00:00, 22.56it/s]


Epoch 62/200, Loss: 1.6792856454849243, Accuracy: 10.9


100%|██████████| 10/10 [00:00<00:00, 22.53it/s]


Epoch 63/200, Loss: 0.638964831829071, Accuracy: 10.7


100%|██████████| 10/10 [00:00<00:00, 22.67it/s]


Epoch 64/200, Loss: 1.2253533601760864, Accuracy: 10.3


100%|██████████| 10/10 [00:00<00:00, 22.61it/s]


Epoch 65/200, Loss: 1.0360772609710693, Accuracy: 10.7


100%|██████████| 10/10 [00:00<00:00, 22.69it/s]


Epoch 66/200, Loss: 0.5826403498649597, Accuracy: 11.0


100%|██████████| 10/10 [00:00<00:00, 22.67it/s]


Epoch 67/200, Loss: 1.0647671222686768, Accuracy: 11.4


100%|██████████| 10/10 [00:00<00:00, 22.53it/s]


Epoch 68/200, Loss: 1.6515352725982666, Accuracy: 11.4


100%|██████████| 10/10 [00:00<00:00, 22.65it/s]


Epoch 69/200, Loss: 0.9276534914970398, Accuracy: 11.6


100%|██████████| 10/10 [00:00<00:00, 22.34it/s]


Epoch 70/200, Loss: 0.6607543230056763, Accuracy: 12.2


100%|██████████| 10/10 [00:00<00:00, 22.69it/s]


Epoch 71/200, Loss: 0.34777435660362244, Accuracy: 11.7


100%|██████████| 10/10 [00:00<00:00, 22.47it/s]


Epoch 72/200, Loss: 0.5694684386253357, Accuracy: 12.3


100%|██████████| 10/10 [00:00<00:00, 22.52it/s]


Epoch 73/200, Loss: 0.5088360905647278, Accuracy: 12.6


100%|██████████| 10/10 [00:00<00:00, 22.83it/s]


Epoch 74/200, Loss: 0.2310366928577423, Accuracy: 12.8


100%|██████████| 10/10 [00:00<00:00, 22.44it/s]


Epoch 75/200, Loss: 0.5144473314285278, Accuracy: 12.0


100%|██████████| 10/10 [00:00<00:00, 22.66it/s]


Epoch 76/200, Loss: 0.9694252610206604, Accuracy: 13.1


100%|██████████| 10/10 [00:00<00:00, 22.65it/s]


Epoch 77/200, Loss: 0.2028830349445343, Accuracy: 13.1


100%|██████████| 10/10 [00:00<00:00, 22.67it/s]


Epoch 78/200, Loss: 0.7161510586738586, Accuracy: 13.0


100%|██████████| 10/10 [00:00<00:00, 22.16it/s]


Epoch 79/200, Loss: 1.1655596494674683, Accuracy: 13.2


100%|██████████| 10/10 [00:00<00:00, 23.07it/s]


Epoch 80/200, Loss: 0.32859867811203003, Accuracy: 13.8


100%|██████████| 10/10 [00:00<00:00, 22.81it/s]


Epoch 81/200, Loss: 0.34566181898117065, Accuracy: 13.6


100%|██████████| 10/10 [00:00<00:00, 21.83it/s]


Epoch 82/200, Loss: 0.18088622391223907, Accuracy: 14.3


100%|██████████| 10/10 [00:00<00:00, 22.69it/s]


Epoch 83/200, Loss: 0.3601594567298889, Accuracy: 14.3


100%|██████████| 10/10 [00:00<00:00, 22.85it/s]


Epoch 84/200, Loss: 0.3071140646934509, Accuracy: 14.4


100%|██████████| 10/10 [00:00<00:00, 22.46it/s]


Epoch 85/200, Loss: 0.09901653230190277, Accuracy: 14.3


100%|██████████| 10/10 [00:00<00:00, 22.47it/s]


Epoch 86/200, Loss: 0.06650716811418533, Accuracy: 14.4


100%|██████████| 10/10 [00:00<00:00, 22.51it/s]


Epoch 87/200, Loss: 0.07958663254976273, Accuracy: 14.5


100%|██████████| 10/10 [00:00<00:00, 22.59it/s]


Epoch 88/200, Loss: 0.19528713822364807, Accuracy: 14.7


100%|██████████| 10/10 [00:00<00:00, 21.75it/s]


Epoch 89/200, Loss: 0.21774829924106598, Accuracy: 14.7


100%|██████████| 10/10 [00:00<00:00, 21.82it/s]


Epoch 90/200, Loss: 0.12764742970466614, Accuracy: 14.7


100%|██████████| 10/10 [00:00<00:00, 23.42it/s]


Epoch 91/200, Loss: 0.12168438732624054, Accuracy: 14.9


100%|██████████| 10/10 [00:00<00:00, 22.56it/s]


Epoch 92/200, Loss: 0.09680188447237015, Accuracy: 14.9


100%|██████████| 10/10 [00:00<00:00, 22.02it/s]


Epoch 93/200, Loss: 0.16076292097568512, Accuracy: 14.9


100%|██████████| 10/10 [00:00<00:00, 21.89it/s]


Epoch 94/200, Loss: 0.2173587530851364, Accuracy: 14.9


100%|██████████| 10/10 [00:00<00:00, 22.54it/s]


Epoch 95/200, Loss: 0.07838869094848633, Accuracy: 14.9


100%|██████████| 10/10 [00:00<00:00, 22.63it/s]


Epoch 96/200, Loss: 0.06893880665302277, Accuracy: 14.9


100%|██████████| 10/10 [00:00<00:00, 22.45it/s]


Epoch 97/200, Loss: 0.16442899405956268, Accuracy: 14.9


100%|██████████| 10/10 [00:00<00:00, 22.73it/s]


Epoch 98/200, Loss: 0.04970838129520416, Accuracy: 14.9


100%|██████████| 10/10 [00:00<00:00, 22.63it/s]


Epoch 99/200, Loss: 0.10693736374378204, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.68it/s]


Epoch 100/200, Loss: 0.03810399770736694, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.65it/s]


Epoch 101/200, Loss: 0.06839499622583389, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 23.18it/s]


Epoch 102/200, Loss: 0.049293115735054016, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.67it/s]


Epoch 103/200, Loss: 0.08027070015668869, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.35it/s]


Epoch 104/200, Loss: 0.061376411467790604, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.81it/s]


Epoch 105/200, Loss: 0.03606867045164108, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.76it/s]


Epoch 106/200, Loss: 0.05325060710310936, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.41it/s]


Epoch 107/200, Loss: 0.03920093551278114, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.74it/s]


Epoch 108/200, Loss: 0.053625721484422684, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.55it/s]


Epoch 109/200, Loss: 0.05062159150838852, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.26it/s]


Epoch 110/200, Loss: 0.04874056950211525, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.06it/s]


Epoch 111/200, Loss: 0.0464148111641407, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 17.58it/s]


Epoch 112/200, Loss: 0.04580335691571236, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.41it/s]


Epoch 113/200, Loss: 0.06308296322822571, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.56it/s]


Epoch 114/200, Loss: 0.02595626190304756, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.87it/s]


Epoch 115/200, Loss: 0.02552586793899536, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.36it/s]


Epoch 116/200, Loss: 0.03915301710367203, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.35it/s]


Epoch 117/200, Loss: 0.022353999316692352, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.36it/s]


Epoch 118/200, Loss: 0.038062144070863724, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.66it/s]


Epoch 119/200, Loss: 0.036136284470558167, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.13it/s]


Epoch 120/200, Loss: 0.03507022187113762, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.95it/s]


Epoch 121/200, Loss: 0.04684089869260788, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.32it/s]


Epoch 122/200, Loss: 0.026628468185663223, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.03it/s]


Epoch 123/200, Loss: 0.034061018377542496, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.42it/s]


Epoch 124/200, Loss: 0.025550486519932747, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.61it/s]


Epoch 125/200, Loss: 0.02016339637339115, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 23.42it/s]


Epoch 126/200, Loss: 0.03365060314536095, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.47it/s]


Epoch 127/200, Loss: 0.019879642874002457, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.61it/s]


Epoch 128/200, Loss: 0.02761462889611721, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 20.98it/s]


Epoch 129/200, Loss: 0.03585266321897507, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.30it/s]


Epoch 130/200, Loss: 0.018034987151622772, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.82it/s]


Epoch 131/200, Loss: 0.026142466813325882, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.50it/s]


Epoch 132/200, Loss: 0.025612840428948402, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.05it/s]


Epoch 133/200, Loss: 0.014480914920568466, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.55it/s]


Epoch 134/200, Loss: 0.01422110479325056, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.23it/s]


Epoch 135/200, Loss: 0.013937229290604591, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.11it/s]


Epoch 136/200, Loss: 0.02322077564895153, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 20.77it/s]


Epoch 137/200, Loss: 0.024615531787276268, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.47it/s]


Epoch 138/200, Loss: 0.01875891536474228, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.40it/s]


Epoch 139/200, Loss: 0.028323763981461525, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.41it/s]


Epoch 140/200, Loss: 0.0126134492456913, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.35it/s]


Epoch 141/200, Loss: 0.020784636959433556, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.49it/s]


Epoch 142/200, Loss: 0.026625512167811394, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.44it/s]


Epoch 143/200, Loss: 0.013881579972803593, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 23.02it/s]


Epoch 144/200, Loss: 0.019901273772120476, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 19.18it/s]


Epoch 145/200, Loss: 0.017997991293668747, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 20.89it/s]


Epoch 146/200, Loss: 0.01780480146408081, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.52it/s]


Epoch 147/200, Loss: 0.01310577429831028, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 23.50it/s]


Epoch 148/200, Loss: 0.03909536078572273, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.56it/s]


Epoch 149/200, Loss: 0.032294370234012604, Accuracy: 14.8


100%|██████████| 10/10 [00:00<00:00, 22.95it/s]


Epoch 150/200, Loss: 1.322967290878296, Accuracy: 11.6


100%|██████████| 10/10 [00:00<00:00, 22.29it/s]


Epoch 151/200, Loss: 1.365296483039856, Accuracy: 10.1


100%|██████████| 10/10 [00:00<00:00, 20.43it/s]


Epoch 152/200, Loss: 1.3590443134307861, Accuracy: 10.1


100%|██████████| 10/10 [00:00<00:00, 20.61it/s]


Epoch 153/200, Loss: 1.0742533206939697, Accuracy: 10.2


100%|██████████| 10/10 [00:00<00:00, 21.89it/s]


Epoch 154/200, Loss: 0.9480804204940796, Accuracy: 10.9


100%|██████████| 10/10 [00:00<00:00, 22.16it/s]


Epoch 155/200, Loss: 1.0592083930969238, Accuracy: 11.2


100%|██████████| 10/10 [00:00<00:00, 21.95it/s]


Epoch 156/200, Loss: 0.9356400966644287, Accuracy: 11.3


100%|██████████| 10/10 [00:00<00:00, 21.67it/s]


Epoch 157/200, Loss: 0.6368989944458008, Accuracy: 11.4


100%|██████████| 10/10 [00:00<00:00, 22.64it/s]


Epoch 158/200, Loss: 0.7842395901679993, Accuracy: 11.7


100%|██████████| 10/10 [00:00<00:00, 21.72it/s]


Epoch 159/200, Loss: 0.8753691911697388, Accuracy: 11.5


100%|██████████| 10/10 [00:00<00:00, 22.02it/s]


Epoch 160/200, Loss: 0.9289971590042114, Accuracy: 11.6


100%|██████████| 10/10 [00:00<00:00, 22.05it/s]


Epoch 161/200, Loss: 1.2953850030899048, Accuracy: 12.1


100%|██████████| 10/10 [00:00<00:00, 21.27it/s]


Epoch 162/200, Loss: 0.44998157024383545, Accuracy: 12.8


100%|██████████| 10/10 [00:00<00:00, 22.80it/s]


Epoch 163/200, Loss: 0.16438211500644684, Accuracy: 12.7


100%|██████████| 10/10 [00:00<00:00, 21.28it/s]


Epoch 164/200, Loss: 0.2409161627292633, Accuracy: 13.0


100%|██████████| 10/10 [00:00<00:00, 23.28it/s]


Epoch 165/200, Loss: 0.7127290964126587, Accuracy: 12.4


100%|██████████| 10/10 [00:00<00:00, 21.93it/s]


Epoch 166/200, Loss: 0.310350239276886, Accuracy: 13.3


100%|██████████| 10/10 [00:00<00:00, 20.68it/s]


Epoch 167/200, Loss: 0.20879653096199036, Accuracy: 13.3


100%|██████████| 10/10 [00:00<00:00, 22.69it/s]


Epoch 168/200, Loss: 0.16487380862236023, Accuracy: 14.1


100%|██████████| 10/10 [00:00<00:00, 22.38it/s]


Epoch 169/200, Loss: 0.2448234111070633, Accuracy: 14.1


100%|██████████| 10/10 [00:00<00:00, 22.12it/s]


Epoch 170/200, Loss: 0.4213153123855591, Accuracy: 14.5


100%|██████████| 10/10 [00:00<00:00, 22.99it/s]


Epoch 171/200, Loss: 0.08025755733251572, Accuracy: 14.6


100%|██████████| 10/10 [00:00<00:00, 22.03it/s]


Epoch 172/200, Loss: 0.28631046414375305, Accuracy: 14.5


100%|██████████| 10/10 [00:00<00:00, 20.42it/s]


Epoch 173/200, Loss: 0.18284514546394348, Accuracy: 14.5


100%|██████████| 10/10 [00:00<00:00, 23.11it/s]


Epoch 174/200, Loss: 0.19585581123828888, Accuracy: 14.7


100%|██████████| 10/10 [00:00<00:00, 20.99it/s]


Epoch 175/200, Loss: 0.10405723750591278, Accuracy: 14.7


100%|██████████| 10/10 [00:00<00:00, 20.71it/s]


Epoch 176/200, Loss: 0.06632588058710098, Accuracy: 14.7


100%|██████████| 10/10 [00:00<00:00, 22.66it/s]


Epoch 177/200, Loss: 0.11911473423242569, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.17it/s]


Epoch 178/200, Loss: 0.03368233144283295, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 23.16it/s]


Epoch 179/200, Loss: 0.17982591688632965, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.02it/s]


Epoch 180/200, Loss: 0.15529794991016388, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.46it/s]


Epoch 181/200, Loss: 0.07326646149158478, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.00it/s]


Epoch 182/200, Loss: 0.06461438536643982, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.39it/s]


Epoch 183/200, Loss: 0.02422991953790188, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.81it/s]


Epoch 184/200, Loss: 0.017892587929964066, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.72it/s]


Epoch 185/200, Loss: 0.030798494815826416, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.93it/s]


Epoch 186/200, Loss: 0.05190489441156387, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.49it/s]


Epoch 187/200, Loss: 0.07408594340085983, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 23.61it/s]


Epoch 188/200, Loss: 0.02108892984688282, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.75it/s]


Epoch 189/200, Loss: 0.036867182701826096, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.75it/s]


Epoch 190/200, Loss: 0.010280094109475613, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 22.77it/s]


Epoch 191/200, Loss: 0.028868356719613075, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 20.79it/s]


Epoch 192/200, Loss: 0.026945296674966812, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.23it/s]


Epoch 193/200, Loss: 0.012547415681183338, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 19.00it/s]


Epoch 194/200, Loss: 0.02915782481431961, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.36it/s]


Epoch 195/200, Loss: 0.015106111764907837, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.87it/s]


Epoch 196/200, Loss: 0.021989209577441216, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.91it/s]


Epoch 197/200, Loss: 0.015618045814335346, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.82it/s]


Epoch 198/200, Loss: 0.017759811133146286, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 23.35it/s]


Epoch 199/200, Loss: 0.02816978469491005, Accuracy: 15.0


100%|██████████| 10/10 [00:00<00:00, 21.79it/s]

Epoch 200/200, Loss: 0.016239430755376816, Accuracy: 15.0





## Model Evaluation


Preparing Validation data to be passed into the `model.evaluate()`


In [100]:
# Assuming you have a test dataset prepared in the same format as your training dataset
eval_dataset = CustomDataset(eval=True)  # You'll need to modify your CustomDataset class to accept this parameter and load the eval data
eval_dataloader = DataLoader(eval_dataset, batch_size = batch_size)

# Switch the model to evaluation mode
model.eval()

# Initialize the test loss
test_loss = 0

# We don't need to compute gradients during evaluation, so we wrap this in torch.no_grad()
with torch.no_grad():
    for sentences, labels, _ in eval_dataloader:
        sentences = sentences.view(batch_size, -1, input_size).to(device)
        labels = labels.view(batch_size, -1, output_size).to(device)
        # RNN
        if MODEL == 'rnn':
            hidden_state = model.init_hidden(batch_size=batch_size).to(device) # RNN has one hidden state
            outputs = model(sentences, hidden_state)
        
        # LSTM
        elif MODEL == 'lstm':
            hidden_state, cell_state = model.init_hidden(batch_size=batch_size)  # LSTM has two hidden states
            hidden_state = hidden_state.to(device)
            cell_state = cell_state.to(device)
            outputs = model(sentences, (hidden_state, cell_state))

        # Compute the loss
        loss = criterion(outputs, labels)

        # Accumulate the test loss
        test_loss += loss.item()

# Compute the average test loss
avg_test_loss = test_loss / len(eval_dataloader)

print(f'Average Evaluation Loss: {avg_test_loss}')


Average Evaluation Loss: 36.29100036621094


# Testing


testing on a given sentence


In [101]:
test_set = utils.read_data(f"./dataset/train_filtered.txt")
# test_set = utils.read_data(f"./dataset/val_filtered.txt")
# filtered_training_set = utils.filter_data(training_set)
# test_sentences = utils.split_data_to_sentences(filtered_training_set)[0:1]

In [102]:
# Assuming you have a test dataset prepared in the same format as your training dataset
test_dataset = CustomDataset(test=True, testdata = test_set) 
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

# Switch the model to evaluation mode
model.eval()

# We don't need to compute gradients during evaluation, so we wrap this in torch.no_grad()
sentences_diacritics_prediction = []
sentences_without_diacritics = []
original_diacritics = []
original_sentences_len = []
with torch.no_grad():
    for sentences, labels, sentence_length in test_dataloader:
        original_sentences_len.append(sentence_length)
        sentences_without_diacritics.extend(sentences)
        original_diacritics.extend(labels)
        sentences = sentences.view(batch_size, -1, input_size).to(device)
        labels = labels.view(batch_size, -1, output_size).to(device)
        # RNN
        if MODEL == 'rnn':
            hidden_state = model.init_hidden(batch_size=batch_size).to(device) # RNN has one hidden state
            outputs = model(sentences, hidden_state)
        
        
        # LSTM
        elif MODEL == 'lstm':
            hidden_state, cell_state = model.init_hidden(batch_size=batch_size)  # LSTM has two hidden states
            hidden_state = hidden_state.to(device)
            cell_state = cell_state.to(device)
            outputs = model(sentences, (hidden_state, cell_state))

        sentences_diacritics_prediction.extend(outputs.argmax(dim=2).cpu()) 

sentences_diacritics_prediction = np.array(sentences_diacritics_prediction)
sentences_without_diacritics = np.array(sentences_without_diacritics)
original_diacritics = np.array(original_diacritics)

print("Sentences diacritics prediction : ",sentences_diacritics_prediction.shape)
print("Sentences Without diacritics    : ",sentences_without_diacritics.shape)
print("Original diacritics             : ",original_diacritics.shape)

Sentences diacritics prediction :  (3, 150)
Sentences Without diacritics    :  (3, 150, 38)
Original diacritics             :  (3, 150, 15)


In [103]:
def AverageDER(sentences_diacritics_prediction,sentences_without_diacritics,original_sentences_len):
    diacritic_error_rate = 0
    number_of_mis_classified = 0
    number_of_chars_to_classify = 0
    for i, p in enumerate(sentences_diacritics_prediction):
        pred = character_encoding.index_to_char(p)
        s = character_encoding.oneHot_to_sentence(sentences_without_diacritics[i][0:original_sentences_len[i]]) # sentence without diacritics
        d = character_encoding.oneHot_to_diacritic(original_diacritics[i][0:original_sentences_len[i]])         # original diacritics of the sentence
        original_text = character_encoding.restore_diacritics(s, d)
        restored_text = character_encoding.restore_diacritics(s, pred)
        diac, miss = character_encoding.diacritics_error_rate(d, pred)
        # print("Original Sentence : ", original_text)
        # print("Restored Sentence : ", restored_text)
        # print(f"DER sentence [{i}] = {diac} %")
        diacritic_error_rate += diac
        number_of_mis_classified += miss
        number_of_chars_to_classify += int(original_sentences_len[i][0]) if len(pred) > original_sentences_len[i] else len(pred)

    # diacritic_error_rate /= len(sentences_diacritics_prediction)
    diacritic_error_rate = number_of_mis_classified / number_of_chars_to_classify * 100
    print("Diacritic Error Rate = ", diacritic_error_rate, "%")
    print("Diacritic Correct Rate = ", 100 - diacritic_error_rate, "%")
    print("Number of Misclassified = ", number_of_mis_classified, "out of", number_of_chars_to_classify)
    return diacritic_error_rate

In [104]:
avg_der = AverageDER(sentences_diacritics_prediction,sentences_without_diacritics,original_sentences_len)

Diacritic Error Rate =  81.65680473372781 %
Diacritic Correct Rate =  18.34319526627219 %
Number of Misclassified =  276 out of 338


In [105]:
# # Switch the model to evaluation mode
# model.eval()
# # Assume 'input_sentence' is your input sentence
# input_sentence = test_sentences[0]
# print("Input sentence : ",input_sentence)

# # Process the input_sentence in the same way as you did for your training data
# sentence_without_diacritics, original_diacritics = character_encoding.remove_diacritics(input_sentence, True)
# sentence = character_encoding.getSentenceVector(sentence_without_diacritics)
# sentence,_ = character_encoding.padding(sentence, len(character_encoding.ARABIC_ALPHABIT) + 2, max_length=PADDING_SIZE)
# diacritic = character_encoding.getDiacriticVector(original_diacritics)
# diacritic,_ = character_encoding.padding(diacritic, len(character_encoding.DIACRITICS), max_length=PADDING_SIZE)
# sentence = torch.tensor(sentence, dtype=(torch.float32)).unsqueeze(0).to(device)  # Add an extra dimension for batch and move to device

# # We don't need to compute gradients during evaluation, so we wrap this in torch.no_grad()
# with torch.no_grad():
#     hidden = model.init_hidden(batch_size=1).to(device)  # Batch size is 1 for inference
#     # Forward pass
#     output = model(sentence, hidden)
# print(sentence.shape)

# # The output is the model's prediction, you might want to post-process this output to convert it back into a readable format
# prediction = output.argmax(dim=2)  # This gives you the index of the highest value in the output tensor


In [106]:
# predicted_diacritics = character_encoding.index_to_char(prediction[0])
# der, miss = diacritics_error_rate(original_diacritics, predicted_diacritics)
# print("Diacritics error rate : ", der, "%")
# print("Correct diacritics rate : ", 100 - der, "%")
# print("Number of miss : ", miss, "out of ", len(original_diacritics))
# print("Original Sentence : ", input_sentence)
# restored_sentence = character_encoding.restore_diacritics(sentence_without_diacritics,predicted_diacritics)
# print("Restored Sentence : ", restored_sentence)
# print("Original diacritics : ", character_encoding.map_text_to_diacritic(original_diacritics))
# print("Predicted diacritics : ",character_encoding.map_text_to_diacritic( predicted_diacritics))