In [4]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import re
import sys
import os
import pickle
sys.path.insert(0, os.path.abspath('../..'))
from preprocessing.preprocess import preprocess_data, remove_diactrics
import textwrap

In [5]:
def count_spaces(input_string):
  space_count = len(re.findall(r'\s', input_string))
  return space_count

In [22]:
dataset_path = '../../dataset'

limit = 2
# preprocess and save the data
preprocess_data(data_type='train', dataset_path=dataset_path)
preprocess_data(data_type='val', dataset_path=dataset_path)

max_len = 600
# load data
training_data = []
training_spaces = []
# load data
with open(f'{dataset_path}/cleaned_train_data_without_diacritics.txt', 'r', encoding='utf-8') as file:
    # read all lines into array of lines
    training_data_lines = file.readlines()
    for i in range(len(training_data_lines)):
        training_data_lines[i] = re.compile(r'[\n\r\t]').sub('', training_data_lines[i])
        training_data_lines[i] = re.compile(r'\s+').sub(' ', training_data_lines[i])
        training_data_lines[i] = training_data_lines[i].strip()

        dot_splitted_list = training_data_lines[i].split('.')
        
        # remove last string if empty
        if dot_splitted_list[-1] == '':
            dot_splitted_list = dot_splitted_list[:-1]
        
        for dot_splitted in dot_splitted_list:
            dot_splitted = dot_splitted.strip()
            # Split the line into sentences of max_len, without cutting words
            sentences = textwrap.wrap(dot_splitted, max_len)

            for sentence in sentences:
                training_data.append(sentence)
                training_spaces.append(count_spaces(sentence))
                
# validation
validation_data = []
validation_spaces = []
# load data
with open(f'{dataset_path}/cleaned_val_data_without_diacritics.txt', 'r', encoding='utf-8') as file:
    # read all lines into array of lines
    validation_data_lines = file.readlines()
    for i in range(len(validation_data_lines)):
        validation_data_lines[i] = re.compile(r'[\n\r\t]').sub('', validation_data_lines[i])
        validation_data_lines[i] = re.compile(r'\s+').sub(' ', validation_data_lines[i])
        validation_data_lines[i] = validation_data_lines[i].strip()

        dot_splitted_list = validation_data_lines[i].split('.')
        
        # remove last string if empty
        if dot_splitted_list[-1] == '':
            dot_splitted_list = dot_splitted_list[:-1]
        
        for dot_splitted in dot_splitted_list:
            dot_splitted = dot_splitted.strip()
            # Split the line into sentences of max_len, without cutting words
            sentences = textwrap.wrap(dot_splitted, max_len)

            for sentence in sentences:
                validation_data.append(sentence)
                validation_spaces.append(count_spaces(sentence))


training_data_with_diacritics = []
spaces_index = 0

with open(f'{dataset_path}/cleaned_train_data_with_diacritics.txt', 'r', encoding='utf-8') as file:
    training_data_with_diacritics_lines = file.readlines()
    for i in range(len(training_data_with_diacritics_lines)):
        training_data_with_diacritics_lines[i] = re.compile(r'[\n\r\t]').sub('', training_data_with_diacritics_lines[i])
        training_data_with_diacritics_lines[i] = re.compile(r'\s+').sub(' ', training_data_with_diacritics_lines[i])
        training_data_with_diacritics_lines[i] = training_data_with_diacritics_lines[i].strip()
        
        dot_splitted_list = training_data_with_diacritics_lines[i].split('.')
        
        # remove last string if empty
        if dot_splitted_list[-1] == '':
            dot_splitted_list = dot_splitted_list[:-1]
            
        for dot_splitted in dot_splitted_list:
            dot_splitted = dot_splitted.strip()
            remaining = dot_splitted
            remaining_length = len(remaining)
            while(remaining_length > 0):
                spaces_to_include = training_spaces[spaces_index]
                spaces_index += 1
                words = remaining.split()
                if len(words) <= spaces_to_include + 1:
                    training_data_with_diacritics.append(remaining.strip())
                    remaining_length = 0
                    break
                else:
                    sentence = ' '.join(words[:spaces_to_include + 1])
                    training_data_with_diacritics.append(sentence.strip())
                    remaining = ' '.join(words[spaces_to_include + 1:]).strip()
                    remaining_length = len(remaining)

validation_data_with_diacritics = []
spaces_index = 0

with open(f'{dataset_path}/cleaned_val_data_with_diacritics.txt', 'r', encoding='utf-8') as file:
    validation_data_with_diacritics_lines = file.readlines()
    for i in range(len(validation_data_with_diacritics_lines)):
        validation_data_with_diacritics_lines[i] = re.compile(r'[\n\r\t]').sub('', validation_data_with_diacritics_lines[i])
        validation_data_with_diacritics_lines[i] = re.compile(r'\s+').sub(' ', validation_data_with_diacritics_lines[i])
        validation_data_with_diacritics_lines[i] = validation_data_with_diacritics_lines[i].strip()
        
        dot_splitted_list = validation_data_with_diacritics_lines[i].split('.')
        
        # remove last string if empty
        if dot_splitted_list[-1] == '':
            dot_splitted_list = dot_splitted_list[:-1]
            
        for dot_splitted in dot_splitted_list:
            dot_splitted = dot_splitted.strip()
            remaining = dot_splitted
            remaining_length = len(remaining)
            while(remaining_length > 0):
                spaces_to_include = validation_spaces[spaces_index]
                spaces_index += 1
                words = remaining.split()
                if len(words) <= spaces_to_include + 1:
                    validation_data_with_diacritics.append(remaining.strip())
                    remaining_length = 0
                    break
                else:
                    sentence = ' '.join(words[:spaces_to_include + 1])
                    validation_data_with_diacritics.append(sentence.strip())
                    remaining = ' '.join(words[spaces_to_include + 1:]).strip()
                    remaining_length = len(remaining)

# test equality
for i in range(len(training_data)):
    if(training_data[i] != remove_diactrics([training_data_with_diacritics[i]])[0])  or len(training_data[i]) > max_len:
        print('Diacritized text != cleaned text!')

for i in range(len(validation_data)):
    if validation_data[i] != remove_diactrics([validation_data_with_diacritics[i]])[0] or len(validation_data[i]) > max_len :
        print('Diacritized text != cleaned text!')

print(len(training_data))
print(len(validation_data))
print(len(training_data_with_diacritics))
print(len(validation_data_with_diacritics))

print(training_data[0])
print(training_data_with_diacritics[0])

## Tokenize the text into sequences at the character level
char_to_index = {'ؤ': 1, 'ا': 2, 'ك': 3, 'ب': 4, 'ف': 5, 'و': 6, 'ث': 7, 'خ': 8, 'ظ': 9, 'أ': 10, 'ض': 11, 'س': 12, 'ج': 13, 'ى': 14, 'ص': 15, 'ع': 16, 'ذ': 17, 'ت': 18, ' ': 19, 'د': 20, 'ح': 21, 'ز': 22, 'ش': 23, 'إ': 24, 'ه': 25, 'ء': 26, 'م': 27, 'ق': 28, 'ة': 29, 'ن': 30, 'آ': 31, 'غ': 32, 'ي': 33, 'ر': 34, '!': 35, 'ل': 36, 'ئ': 37, 'ط': 38}
index_to_char = {1: 'ؤ', 2: 'ا', 3: 'ك', 4: 'ب', 5: 'ف', 6: 'و', 7: 'ث', 8: 'خ', 9: 'ظ', 10: 'أ', 11: 'ض', 12: 'س', 13: 'ج', 14: 'ى', 15: 'ص', 16: 'ع', 17: 'ذ', 18: 'ت', 19: ' ', 20: 'د', 21: 'ح', 22: 'ز', 23: 'ش', 24: 'إ', 25: 'ه', 26: 'ء', 27: 'م', 28: 'ق', 29: 'ة', 30: 'ن', 31: 'آ', 32: 'غ', 33: 'ي', 34: 'ر', 35: '!', 36: 'ل', 37: 'ئ', 38: 'ط'}

#for char, index in char_to_index.items():
#    print(f'{char} => {index}')

54936
2766
54936
2766
قوله : أو قطع الأول يده إلخ قال الزركشي
قَوْلُهُ : أَوْ قَطَعَ الْأَوَّلُ يَدَهُ إلَخْ قَالَ الزَّرْكَشِيُّ


In [3]:
# define the diacritics unicode and their corresponding labels classes indices
# note that index 0 is reserved for no diacritic
labels = {
    # no diacritic
    0: 0,
    # fath
    1614: 1,
    # damm
    1615: 2,
    # kasr
    1616: 3,
    # shadd
    1617: 4,
    # sukun
    1618: 5,
    # tanween bel fath
    1611: 6,
    # tanween bel damm
    1612: 7,
    # tanween bel kasr
    1613: 8,
    # shadd and fath
    (1617, 1614): 9,
    # shadd and damm
    (1617, 1615): 10,
    # shadd and kasr
    (1617, 1616): 11,
    # shadd and tanween bel fath
    (1617, 1611): 12,
    # shadd and tanween bel damm
    (1617, 1612): 13,
    # shadd and tanween bel kasr
    (1617, 1613): 14
}

indicies_to_labels = {
    # no diacritic
    0: 0,
    # fath
    1: 1614,
    # damm
    2: 1615,
    # kasr
    3: 1616,
    # shadd
    4: 1617,
    # sukun
    5: 1618,
    # tanween bel fath
    6: 1611,
    # tanween bel damm
    7: 1612,
    # tanween bel kasr
    8: 1613,
    # shadd and fath
    9: (1617, 1614),
    # shadd and damm
    10: (1617, 1615),
    # shadd and kasr
    11: (1617, 1616),
    # shadd and tanween bel fath
    12: (1617, 1611),
    # shadd and tanween bel damm
    13: (1617, 1612),
    # shadd and tanween bel kasr
    14: (1617, 1613)
}

print(labels)
print(indicies_to_labels)

{0: 0, 1614: 1, 1615: 2, 1616: 3, 1617: 4, 1618: 5, 1611: 6, 1612: 7, 1613: 8, (1617, 1614): 9, (1617, 1615): 10, (1617, 1616): 11, (1617, 1611): 12, (1617, 1612): 13, (1617, 1613): 14}
{0: 0, 1: 1614, 2: 1615, 3: 1616, 4: 1617, 5: 1618, 6: 1611, 7: 1612, 8: 1613, 9: (1617, 1614), 10: (1617, 1615), 11: (1617, 1616), 12: (1617, 1611), 13: (1617, 1612), 14: (1617, 1613)}


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(device)

cuda


In [5]:
# build one array that holds all sequences of training data
training_data_sequences = [[char_to_index[char] for char in sequence] for sequence in training_data]

# build one array that holds all sequences of validation data
validation_data_sequences = [[char_to_index[char] for char in sequence] for sequence in validation_data]

# Find the maximum sequence length
max_len = max(max(len(sequence) for sequence in training_data_sequences), max(len(sequence) for sequence in validation_data_sequences))

# Pad sequences to the maximum length
training_data_sequences = [sequence + [0] * (max_len - len(sequence)) for sequence in training_data_sequences]

# Pad sequences to the maximum length
validation_data_sequences = [sequence + [0] * (max_len - len(sequence)) for sequence in validation_data_sequences]

print(training_data_sequences[0][:10])
print(training_data[0][:10])
print(validation_data_sequences[0][:10])
print(validation_data[0][:10])

training_data_sequences = torch.tensor(training_data_sequences).to(device)


validation_data_sequences = torch.tensor(validation_data_sequences).to(device)

[40, 27, 11, 9, 24, 37, 24, 1, 27, 24]
قوله : أو 
[24, 40, 27, 11, 9, 24, 37, 24, 27, 11]
 قوله : ول


In [6]:
training_data_labels = []
training_size = len(training_data_with_diacritics)
for sentence_index in range(training_size):
  sentence_labels = []
  sentence_size = len(training_data_with_diacritics[sentence_index])
  index = 0
  while index < sentence_size:
      if ord(training_data_with_diacritics[sentence_index][index]) not in labels:
          # char is not a diacritic
          if (index + 1) < sentence_size and ord(training_data_with_diacritics[sentence_index][index + 1]) in labels:
              # char has a diacritic
              if ord(training_data_with_diacritics[sentence_index][index + 1]) == 1617:
                  # char has a shadd diacritic
                  if (index + 2) < sentence_size and ord(training_data_with_diacritics[sentence_index][index + 2]) in labels:
                      # char has a shadd and another diacritic
                      sentence_labels.append(labels[(1617, ord(training_data_with_diacritics[sentence_index][index + 2]))])
                      # skip next 2 diacritics chars
                      index += 3  # increment by 3 to skip two diacritic chars
                      continue
                  else:
                      # char has a shadd and no other diacritic
                      sentence_labels.append(labels[1617])
                      # skip next diacritic char
                      index += 2
                      continue
              # char has a diacritic other than shadd
              sentence_labels.append(labels[ord(training_data_with_diacritics[sentence_index][index + 1])])
              # skip next diacritic char
              index += 2  # increment by 2 to skip one diacritic char
              continue
          else:
              # char has no diacritic
              sentence_labels.append(0)
      index += 1  # increment by 1 for normal iteration

  training_data_labels.append(sentence_labels)

print(len(training_data_labels))
print(training_data_labels[0][:10])
print(training_data[0][:10])

# Pad sequences to the maximum length
training_data_labels = [sequence + [0] * (max_len - len(sequence)) for sequence in training_data_labels]

print(len(training_data_labels) * len(training_data_labels[0]))
print(len(training_data_sequences) * len(training_data_sequences[0]))

training_data_labels = torch.tensor(training_data_labels).to(device)

10
[1, 5, 2, 2, 0, 0, 0, 1, 5, 0]
قوله : أو 
6420
6420


In [7]:
validation_data_labels = []
validation_size = len(validation_data_with_diacritics)
for sentence_index in range(validation_size):
  sentence_labels = []
  sentence_size = len(validation_data_with_diacritics[sentence_index])
  index = 0
  while index < sentence_size:
      if ord(validation_data_with_diacritics[sentence_index][index]) not in labels:
          # char is not a diacritic
          if (index + 1) < sentence_size and ord(validation_data_with_diacritics[sentence_index][index + 1]) in labels:
              # char has a diacritic
              if ord(validation_data_with_diacritics[sentence_index][index + 1]) == 1617:
                  # char has a shadd diacritic
                  if (index + 2) < sentence_size and ord(validation_data_with_diacritics[sentence_index][index + 2]) in labels:
                      # char has a shadd and another diacritic
                      sentence_labels.append(labels[(1617, ord(validation_data_with_diacritics[sentence_index][index + 2]))])
                      # skip next 2 diacritics chars
                      index += 3  # increment by 3 to skip two diacritic chars
                      continue
                  else:
                      # char has a shadd and no other diacritic
                      sentence_labels.append(labels[1617])
                      # skip next diacritic char
                      index += 2
                      continue
              # char has a diacritic other than shadd
              sentence_labels.append(labels[ord(validation_data_with_diacritics[sentence_index][index + 1])])
              # skip next diacritic char
              index += 2  # increment by 2 to skip one diacritic char
              continue
          else:
              # char has no diacritic
              sentence_labels.append(0)
      index += 1  # increment by 1 for normal iteration

  validation_data_labels.append(sentence_labels)

print(len(validation_data_labels))
print(validation_data_labels[0][:10])
print(validation_data[0][:10])

# Pad sequences to the maximum length
validation_data_labels = [sequence + [0] * (max_len - len(sequence)) for sequence in validation_data_labels]

print(len(validation_data_labels) * len(validation_data_labels[0]))
print(len(validation_data_sequences) * len(validation_data_sequences[0]))

validation_data_labels = torch.tensor(validation_data_labels).to(device)

10
[0, 1, 5, 2, 2, 0, 0, 0, 1, 1]
 قوله : ول
6420
6420


In [8]:
training_dataset = TensorDataset(training_data_sequences, training_data_labels)

batch_size = 1
training_dataloader = DataLoader(training_dataset, batch_size=batch_size)

validation_dataset = TensorDataset(validation_data_sequences, validation_data_labels)

batch_size = 1
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size)

In [9]:
class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, output_size, num_layers=1):
        super(CharLSTM, self).__init__()
        # chars embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_size)

        # LSTM layers
        # batch_first: it means that the input tensor has its first dimension representing the batch size
        self.lstm = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)

        # output layer
        self.output = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x):
        embedded = self.embedding(x) # batch_size * seq_length * embedding_size
        lstm_out, _ = self.lstm(embedded) # batch_size * seq_length * hidden_size
        output = self.output(lstm_out)  # batch_size * seq_length * output_size
        output_softmax = F.softmax(output, dim=1)  # Apply softmax to the output
        return output_softmax

num_layers = 1
vocab_size = len(char_to_index) + 1 # +1 for the 0 padding
embedding_size = 2
output_size = len(labels)
hidden_size = 2
lr=0.001
num_epochs = 5

model = CharLSTM(vocab_size, embedding_size,  hidden_size, output_size, num_layers).to(device)

print(model)

CharLSTM(
  (embedding): Embedding(43, 2)
  (lstm): LSTM(2, 2, batch_first=True, bidirectional=True)
  (output): Linear(in_features=4, out_features=15, bias=True)
)


In [22]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# before training
correct_predictions = 0
total_predictions = 0
for validation_batch_sequences, validation_batch_labels in validation_dataloader:
    outputs = model(validation_batch_sequences).float() # batch_size * seq_length * output_size

    # Calculate accuracy using mask
    mask = (validation_batch_sequences != 0).float()

    # Calculate accuracy
    predicted_labels = outputs.argmax(dim=2)  # Get the index with the maximum probability

    # Only consider non-padded elements in accuracy calculation
    correct_predictions += ((predicted_labels == validation_batch_labels) * mask).sum().item()
    total_predictions += mask.sum().item()


accuracy = correct_predictions / total_predictions
print(f'Epoch 0/{num_epochs}, Accuracy: {accuracy * 100:.2f}%')

for epoch in range(num_epochs):
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    for batch_sequences, batch_labels in training_dataloader:
        optimizer.zero_grad()
        outputs = model(batch_sequences).float() # batch_size * seq_length * output_size
        # convert batch_labels to one hot encoding
        batch_labels_one_hot = F.one_hot(batch_labels, num_classes=output_size).float() # batch_size * seq_length * output_size

        # Calculate the mask
        mask = (batch_sequences != 0).float()

        # Apply the mask to both the outputs and the labels
        masked_outputs = outputs * mask.unsqueeze(2)
        masked_labels = batch_labels_one_hot * mask.unsqueeze(2)

        loss = criterion(masked_outputs, masked_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()


    for validation_batch_sequences, validation_batch_labels in validation_dataloader:
        outputs = model(validation_batch_sequences).float() # batch_size * seq_length * output_size
        # Calculate accuracy using mask
        mask = (validation_batch_sequences != 0).float()

        # Calculate accuracy
        predicted_labels = outputs.argmax(dim=2)  # Get the index with the maximum probability

        # Only consider non-padded elements in accuracy calculation
        correct_predictions += ((predicted_labels == validation_batch_labels) * mask).sum().item()
        total_predictions += mask.sum().item()

    accuracy = correct_predictions / total_predictions
    
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}, Accuracy: {accuracy * 100:.2f}%')

Epoch 0/5, Accuracy: 16.63%
Epoch 1/5, Loss: 866.4723424911499, Accuracy: 17.15%
Epoch 2/5, Loss: 866.4669589996338, Accuracy: 17.45%
Epoch 3/5, Loss: 866.461630821228, Accuracy: 17.79%
Epoch 4/5, Loss: 866.4562888145447, Accuracy: 17.75%
Epoch 5/5, Loss: 866.4509482383728, Accuracy: 17.96%
