In [15]:
# to reload modules automatically without having to restart the kernel
%load_ext autoreload
%autoreload 2

import torch
import torch.optim as optim
import torch.utils.data as data
from letters_dataset import LettersDataset
import torch.nn as nn
from train_collections import *
import numpy as np
from tqdm import tqdm
from nltk.stem.isri import ISRIStemmer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# model and training parameters
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 5

In [3]:
# load train data
dataset = LettersDataset(device=device)
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

w = 415


In [4]:
n_chars = dataset.get_input_vocab_size()
n_harakat = dataset.get_output_vocab_size()
print("n_chars: ", n_chars)
print("n_harakat: ", n_harakat)

n_chars:  41
n_harakat:  15


In [16]:
with open("clean_out/merged_unsplited.txt", "r", encoding="utf8") as f:
    text = f.read()

stemmer = ISRIStemmer()
# replace , and - with space
text = text.replace("،", "")
text = text.replace("-", "")
text = text.split("\n")

text = [sentence.split() for sentence in text]
lengths = [len(sentence) for sentence in text]
lengths = np.cumsum(lengths)
text = [[stemmer.stem(word) for word in sentence] for sentence in text]

vocab = set([word for sentence in text for word in sentence] + ["<S>", "</S>", "<UNK>"])
vocab_size = len(vocab)

In [17]:
word2idx = {word: i for i, word in enumerate(vocab)}
idx2word = {i: word for i, word in enumerate(vocab)}
embedding_weights = np.load("embedding/embedding_weights.npy")


def get_word_embedding(word):
    print(word)
    if word in word2idx:
        return embedding_weights[word2idx[word]]
    else:
        return embedding_weights[word2idx["<UNK>"]]


# Average word embedding
def get_sentence_embedding(sentence):
    return np.mean([get_word_embedding(stemmer.stem(word)) for word in sentence.split()], axis=0)

In [5]:
from models.Accio import Accio

model = Accio(input_size=n_chars, output_size=n_harakat, device=device).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.char_encoder.get_pad_id())

In [6]:
num_batches = len(loader)
print("Number of batches:", num_batches)
best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    for i, (X_batch, y_batch) in tqdm(enumerate(loader)):
        y_pred = ''
        # y_pred = model(X_batch)['diacritics']
        y_pred = model(X_batch)
        # we transpose because the loss function expects the second dimension to be the classes
        # y_pred is now (batch_size, n_classes, seq_len)
        y_pred = y_pred.transpose(1, 2)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print("Epoch %d, batch %d: Loss = %.4f" % (epoch, i, loss))

    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for (X_batch, y_batch) in loader:
            # y_pred = model(X_batch)['diacritics']
            y_pred = model(X_batch)
            y_pred = y_pred.transpose(1, 2)
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))


Number of batches: 3150


2it [00:00,  3.48it/s]

Epoch 0, batch 0: Loss = 2.6405


111it [00:02, 52.31it/s]

Epoch 0, batch 100: Loss = 0.2375


207it [00:04, 51.37it/s]

Epoch 0, batch 200: Loss = 0.2042


307it [00:06, 51.77it/s]

Epoch 0, batch 300: Loss = 0.1471


409it [00:08, 52.65it/s]

Epoch 0, batch 400: Loss = 0.1085


511it [00:10, 52.80it/s]

Epoch 0, batch 500: Loss = 0.0845


607it [00:12, 50.12it/s]

Epoch 0, batch 600: Loss = 0.1162


708it [00:14, 49.26it/s]

Epoch 0, batch 700: Loss = 0.0811


810it [00:16, 54.09it/s]

Epoch 0, batch 800: Loss = 0.0852


906it [00:18, 52.74it/s]

Epoch 0, batch 900: Loss = 0.0554


1007it [00:20, 50.65it/s]

Epoch 0, batch 1000: Loss = 0.0530


1106it [00:22, 50.47it/s]

Epoch 0, batch 1100: Loss = 0.0641


1205it [00:24, 48.61it/s]

Epoch 0, batch 1200: Loss = 0.0716


1307it [00:26, 50.46it/s]

Epoch 0, batch 1300: Loss = 0.0583


1408it [00:28, 51.21it/s]

Epoch 0, batch 1400: Loss = 0.0546


1510it [00:30, 51.20it/s]

Epoch 0, batch 1500: Loss = 0.0438


1606it [00:32, 50.55it/s]

Epoch 0, batch 1600: Loss = 0.0408


1711it [00:34, 48.76it/s]

Epoch 0, batch 1700: Loss = 0.0546


1810it [00:36, 47.77it/s]

Epoch 0, batch 1800: Loss = 0.0451


1906it [00:38, 47.28it/s]

Epoch 0, batch 1900: Loss = 0.0468


2005it [00:40, 47.15it/s]

Epoch 0, batch 2000: Loss = 0.0366


2110it [00:42, 48.43it/s]

Epoch 0, batch 2100: Loss = 0.0691


2208it [00:44, 48.57it/s]

Epoch 0, batch 2200: Loss = 0.0385


2305it [00:46, 46.87it/s]

Epoch 0, batch 2300: Loss = 0.0369


2406it [00:49, 47.94it/s]

Epoch 0, batch 2400: Loss = 0.0460


2507it [00:51, 49.34it/s]

Epoch 0, batch 2500: Loss = 0.0336


2609it [00:53, 46.64it/s]

Epoch 0, batch 2600: Loss = 0.0375


2707it [00:55, 47.66it/s]

Epoch 0, batch 2700: Loss = 0.0343


2803it [00:57, 40.67it/s]

Epoch 0, batch 2800: Loss = 0.0399


2909it [00:59, 48.32it/s]

Epoch 0, batch 2900: Loss = 0.0367


3005it [01:01, 47.78it/s]

Epoch 0, batch 3000: Loss = 0.0295


3106it [01:03, 47.53it/s]

Epoch 0, batch 3100: Loss = 0.0349


3150it [01:04, 48.64it/s]


Epoch 0: Cross-entropy: 105.3544


4it [00:00, 37.74it/s]

Epoch 1, batch 0: Loss = 0.0244


108it [00:02, 46.57it/s]

Epoch 1, batch 100: Loss = 0.0384


208it [00:04, 45.29it/s]

Epoch 1, batch 200: Loss = 0.0484


308it [00:07, 45.41it/s]

Epoch 1, batch 300: Loss = 0.0323


406it [00:09, 40.29it/s]

Epoch 1, batch 400: Loss = 0.0224


505it [00:11, 38.36it/s]

Epoch 1, batch 500: Loss = 0.0299


610it [00:14, 45.76it/s]

Epoch 1, batch 600: Loss = 0.0287


710it [00:16, 46.53it/s]

Epoch 1, batch 700: Loss = 0.0314


805it [00:18, 44.99it/s]

Epoch 1, batch 800: Loss = 0.0295


905it [00:20, 46.64it/s]

Epoch 1, batch 900: Loss = 0.0232


1005it [00:22, 46.99it/s]

Epoch 1, batch 1000: Loss = 0.0339


1105it [00:25, 46.59it/s]

Epoch 1, batch 1100: Loss = 0.0275


1209it [00:27, 45.43it/s]

Epoch 1, batch 1200: Loss = 0.0287


1306it [00:29, 45.06it/s]

Epoch 1, batch 1300: Loss = 0.0203


1411it [00:32, 45.04it/s]

Epoch 1, batch 1400: Loss = 0.0203


1506it [00:34, 46.34it/s]

Epoch 1, batch 1500: Loss = 0.0232


1611it [00:36, 47.11it/s]

Epoch 1, batch 1600: Loss = 0.0319


1706it [00:38, 44.95it/s]

Epoch 1, batch 1700: Loss = 0.0303


1806it [00:40, 46.27it/s]

Epoch 1, batch 1800: Loss = 0.0215


1906it [00:43, 45.04it/s]

Epoch 1, batch 1900: Loss = 0.0210


2006it [00:45, 45.79it/s]

Epoch 1, batch 2000: Loss = 0.0226


2106it [00:47, 43.11it/s]

Epoch 1, batch 2100: Loss = 0.0316


2206it [00:50, 45.09it/s]

Epoch 1, batch 2200: Loss = 0.0236


2306it [00:52, 42.38it/s]

Epoch 1, batch 2300: Loss = 0.0440


2406it [00:54, 46.50it/s]

Epoch 1, batch 2400: Loss = 0.0200


2506it [00:56, 45.22it/s]

Epoch 1, batch 2500: Loss = 0.0281


2606it [00:58, 45.30it/s]

Epoch 1, batch 2600: Loss = 0.0234


2706it [01:01, 45.77it/s]

Epoch 1, batch 2700: Loss = 0.0190


2806it [01:03, 46.01it/s]

Epoch 1, batch 2800: Loss = 0.0211


2906it [01:05, 43.68it/s]

Epoch 1, batch 2900: Loss = 0.0258


3006it [01:07, 46.52it/s]

Epoch 1, batch 3000: Loss = 0.0184


3105it [01:10, 43.48it/s]

Epoch 1, batch 3100: Loss = 0.0275


3150it [01:11, 44.19it/s]


Epoch 1: Cross-entropy: 73.4921


3it [00:00, 30.00it/s]

Epoch 2, batch 0: Loss = 0.0228


105it [00:02, 37.68it/s]

Epoch 2, batch 100: Loss = 0.0197


209it [00:05, 44.33it/s]

Epoch 2, batch 200: Loss = 0.0169


309it [00:07, 44.29it/s]

Epoch 2, batch 300: Loss = 0.0302


409it [00:09, 42.83it/s]

Epoch 2, batch 400: Loss = 0.0268


507it [00:12, 43.19it/s]

Epoch 2, batch 500: Loss = 0.0257


606it [00:14, 44.08it/s]

Epoch 2, batch 600: Loss = 0.0172


706it [00:16, 40.72it/s]

Epoch 2, batch 700: Loss = 0.0253


806it [00:19, 45.16it/s]

Epoch 2, batch 800: Loss = 0.0287


906it [00:21, 46.46it/s]

Epoch 2, batch 900: Loss = 0.0251


1006it [00:23, 45.55it/s]

Epoch 2, batch 1000: Loss = 0.0206


1106it [00:25, 45.44it/s]

Epoch 2, batch 1100: Loss = 0.0302


1206it [00:28, 42.97it/s]

Epoch 2, batch 1200: Loss = 0.0260


1306it [00:30, 45.01it/s]

Epoch 2, batch 1300: Loss = 0.0189


1406it [00:32, 45.70it/s]

Epoch 2, batch 1400: Loss = 0.0184


1506it [00:34, 44.95it/s]

Epoch 2, batch 1500: Loss = 0.0218


1606it [00:37, 36.77it/s]

Epoch 2, batch 1600: Loss = 0.0220


1703it [00:39, 41.88it/s]

Epoch 2, batch 1700: Loss = 0.0231


1806it [00:42, 44.42it/s]

Epoch 2, batch 1800: Loss = 0.0245


1906it [00:44, 42.42it/s]

Epoch 2, batch 1900: Loss = 0.0125


2006it [00:46, 42.95it/s]

Epoch 2, batch 2000: Loss = 0.0194


2106it [00:49, 44.12it/s]

Epoch 2, batch 2100: Loss = 0.0270


2208it [00:51, 43.03it/s]

Epoch 2, batch 2200: Loss = 0.0246


2308it [00:54, 42.41it/s]

Epoch 2, batch 2300: Loss = 0.0139


2408it [00:56, 37.18it/s]

Epoch 2, batch 2400: Loss = 0.0191


2507it [00:59, 35.60it/s]

Epoch 2, batch 2500: Loss = 0.0171


2608it [01:02, 44.05it/s]

Epoch 2, batch 2600: Loss = 0.0202


2708it [01:04, 42.11it/s]

Epoch 2, batch 2700: Loss = 0.0168


2808it [01:06, 43.58it/s]

Epoch 2, batch 2800: Loss = 0.0150


2908it [01:09, 44.20it/s]

Epoch 2, batch 2900: Loss = 0.0202


3005it [01:11, 32.88it/s]

Epoch 2, batch 3000: Loss = 0.0208


3106it [01:14, 44.30it/s]

Epoch 2, batch 3100: Loss = 0.0266


3150it [01:15, 41.82it/s]


Epoch 2: Cross-entropy: 62.4418


2it [00:00, 15.87it/s]

Epoch 3, batch 0: Loss = 0.0257


104it [00:03, 36.37it/s]

Epoch 3, batch 100: Loss = 0.0162


208it [00:06, 37.99it/s]

Epoch 3, batch 200: Loss = 0.0161


307it [00:09, 42.59it/s]

Epoch 3, batch 300: Loss = 0.0159


407it [00:11, 40.81it/s]

Epoch 3, batch 400: Loss = 0.0234


507it [00:13, 44.24it/s]

Epoch 3, batch 500: Loss = 0.0134


607it [00:16, 42.40it/s]

Epoch 3, batch 600: Loss = 0.0248


707it [00:18, 41.06it/s]

Epoch 3, batch 700: Loss = 0.0181


807it [00:21, 40.15it/s]

Epoch 3, batch 800: Loss = 0.0201


907it [00:23, 43.07it/s]

Epoch 3, batch 900: Loss = 0.0225


1007it [00:25, 45.01it/s]

Epoch 3, batch 1000: Loss = 0.0214


1107it [00:28, 44.66it/s]

Epoch 3, batch 1100: Loss = 0.0193


1207it [00:30, 41.74it/s]

Epoch 3, batch 1200: Loss = 0.0160


1309it [00:32, 40.19it/s]

Epoch 3, batch 1300: Loss = 0.0261


1404it [00:35, 41.94it/s]

Epoch 3, batch 1400: Loss = 0.0162


1508it [00:37, 42.77it/s]

Epoch 3, batch 1500: Loss = 0.0212


1607it [00:40, 40.72it/s]

Epoch 3, batch 1600: Loss = 0.0126


1707it [00:42, 44.58it/s]

Epoch 3, batch 1700: Loss = 0.0156


1807it [00:45, 44.32it/s]

Epoch 3, batch 1800: Loss = 0.0175


1907it [00:47, 42.98it/s]

Epoch 3, batch 1900: Loss = 0.0192


2007it [00:49, 42.46it/s]

Epoch 3, batch 2000: Loss = 0.0213


2107it [00:52, 44.70it/s]

Epoch 3, batch 2100: Loss = 0.0167


2207it [00:54, 41.25it/s]

Epoch 3, batch 2200: Loss = 0.0189


2308it [00:57, 39.35it/s]

Epoch 3, batch 2300: Loss = 0.0171


2405it [01:00, 33.70it/s]

Epoch 3, batch 2400: Loss = 0.0220


2508it [01:03, 33.82it/s]

Epoch 3, batch 2500: Loss = 0.0221


2606it [01:05, 39.04it/s]

Epoch 3, batch 2600: Loss = 0.0379


2705it [01:08, 31.54it/s]

Epoch 3, batch 2700: Loss = 0.0145


2806it [01:11, 40.52it/s]

Epoch 3, batch 2800: Loss = 0.0216


2906it [01:13, 39.33it/s]

Epoch 3, batch 2900: Loss = 0.0249


3009it [01:16, 38.95it/s]

Epoch 3, batch 3000: Loss = 0.0142


3105it [01:19, 39.98it/s]

Epoch 3, batch 3100: Loss = 0.0137


3150it [01:20, 39.03it/s]


Epoch 3: Cross-entropy: 56.1514


3it [00:00, 21.90it/s]

Epoch 4, batch 0: Loss = 0.0139


106it [00:02, 34.96it/s]

Epoch 4, batch 100: Loss = 0.0221


207it [00:05, 34.15it/s]

Epoch 4, batch 200: Loss = 0.0180


306it [00:08, 38.94it/s]

Epoch 4, batch 300: Loss = 0.0206


407it [00:11, 39.13it/s]

Epoch 4, batch 400: Loss = 0.0169


506it [00:13, 39.78it/s]

Epoch 4, batch 500: Loss = 0.0234


608it [00:16, 41.46it/s]

Epoch 4, batch 600: Loss = 0.0213


706it [00:19, 40.85it/s]

Epoch 4, batch 700: Loss = 0.0176


809it [00:21, 38.58it/s]

Epoch 4, batch 800: Loss = 0.0235


907it [00:24, 41.86it/s]

Epoch 4, batch 900: Loss = 0.0147


1005it [00:26, 35.40it/s]

Epoch 4, batch 1000: Loss = 0.0118


1106it [00:29, 33.40it/s]

Epoch 4, batch 1100: Loss = 0.0173


1207it [00:32, 38.16it/s]

Epoch 4, batch 1200: Loss = 0.0144


1305it [00:35, 40.47it/s]

Epoch 4, batch 1300: Loss = 0.0169


1407it [00:38, 41.07it/s]

Epoch 4, batch 1400: Loss = 0.0154


1504it [00:40, 39.95it/s]

Epoch 4, batch 1500: Loss = 0.0168


1606it [00:43, 33.21it/s]

Epoch 4, batch 1600: Loss = 0.0197


1705it [00:46, 41.97it/s]

Epoch 4, batch 1700: Loss = 0.0227


1809it [00:48, 42.16it/s]

Epoch 4, batch 1800: Loss = 0.0088


1909it [00:51, 42.53it/s]

Epoch 4, batch 1900: Loss = 0.0195


2008it [00:53, 41.52it/s]

Epoch 4, batch 2000: Loss = 0.0198


2108it [00:56, 40.04it/s]

Epoch 4, batch 2100: Loss = 0.0154


2208it [00:58, 39.63it/s]

Epoch 4, batch 2200: Loss = 0.0178


2305it [01:01, 40.18it/s]

Epoch 4, batch 2300: Loss = 0.0142


2408it [01:03, 38.84it/s]

Epoch 4, batch 2400: Loss = 0.0198


2505it [01:06, 38.28it/s]

Epoch 4, batch 2500: Loss = 0.0111


2605it [01:09, 34.16it/s]

Epoch 4, batch 2600: Loss = 0.0134


2705it [01:12, 34.41it/s]

Epoch 4, batch 2700: Loss = 0.0188


2808it [01:15, 41.06it/s]

Epoch 4, batch 2800: Loss = 0.0136


2907it [01:17, 40.97it/s]

Epoch 4, batch 2900: Loss = 0.0146


3004it [01:20, 39.19it/s]

Epoch 4, batch 3000: Loss = 0.0171


3104it [01:24, 21.97it/s]

Epoch 4, batch 3100: Loss = 0.0171


3150it [01:26, 36.54it/s]


Epoch 4: Cross-entropy: 53.1279


In [7]:
val_dataset = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv', val_mode=True, device=device)

val_loader = data.DataLoader(val_dataset, batch_size=batch_size)
print(val_dataset.char_encoder.word2idx)

w = 1129
{'ا': 0, 'ب': 1, 'ت': 2, 'ث': 3, 'ج': 4, 'ح': 5, 'خ': 6, 'د': 7, 'ذ': 8, 'ر': 9, 'ز': 10, 'س': 11, 'ش': 12, 'ص': 13, 'ض': 14, 'ط': 15, 'ظ': 16, 'ع': 17, 'غ': 18, 'ف': 19, 'ق': 20, 'ك': 21, 'ل': 22, 'م': 23, 'ن': 24, 'ه': 25, 'و': 26, 'ي': 27, 'ى': 28, 'ة': 29, 'آ': 30, 'أ': 31, 'إ': 32, 'ء': 33, 'ؤ': 34, 'ئ': 35, ' ': 36, '،': 37, '-': 38, '<pad>': 39, '<unk>': 40}


In [8]:
# evaluaate accuracy on validation set
model.eval()
letter_haraka = []
with torch.no_grad():
    for (X_batch, y_batch) in val_loader:
        # y_pred = model(X_batch)['diacritics']
        y_pred = model(X_batch)
        # we transpose because the loss function expects the second dimension to be the classes
        # y_pred is now (batch_size, n_classes, seq_len)
        y_pred = y_pred.transpose(1, 2)
        _, predicted = torch.max(y_pred.data, 1)
        # Count only non-padding characters
        for x, y in zip(X_batch, predicted):
            for xx, yy in zip(x, y):
                # we reached the end of the sentence
                # print(xx.item())
                # print(val_dataset.char_encoder.get_pad_id())
                # print(val_dataset.char_encoder.get_id_by_token(UNK_TOKEN))
                if xx.item() == val_dataset.char_encoder.get_pad_id():
                    break
                ll = val_dataset.char_encoder.is_arabic_letter(xx.item())
                if ll:
                    letter_haraka.append([ll, yy.item()])

# save ID,Label pairs in a csv file
import pandas as pd

df = pd.DataFrame(letter_haraka, columns=['letter', 'label'])
df.to_csv('./results/letter_haraka.csv', index=True, index_label='ID')



In [9]:
gold_val = pd.read_csv('clean_out/val_gold.csv', index_col=0)
sys_val = pd.read_csv('results/letter_haraka.csv', index_col=0)
# Accuracy per letter
# print(gold_val.head())
# print(sys_val.head())   
# print(gold_val.iloc[0]['label'])

correct = 0
total = len(gold_val)
for i in range(total):
    # print(gold_val[i][0], sys_val[i][0])
    correct += (gold_val.iloc[i]['label'] == sys_val.iloc[i]['label'])

print("Accuracy: %.2f%%" % (100.0 * correct / total))

Accuracy: 94.07%


In [10]:
# save model 
# torch.save(model, 'models/lstm.pth')
# save model state dict
torch.save(model.state_dict(), 'models/Accio_4.pth')
# load model state dict
# model = BiLSTM()
# model.load_state_dict(torch.load('models/bilstm.pth'))
# load model
# model = torch.load('models/___.pth')

In [11]:
# load model and test again

In [12]:
print('DER of the network on the validation set: %d %%' % (100 * (1 - correct / total)))


DER of the network on the validation set: 5 %
