In [2]:
# to reload modules automatically without having to restart the kernel
%load_ext autoreload
%autoreload 2

import torch
import torch.optim as optim
import torch.utils.data as data
from letters_dataset import LettersDataset
import torch.nn as nn
from train_collections import *
import numpy as np
from tqdm import tqdm

In [13]:
# model and training parameters
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 5

In [5]:
# load train data
dataset = LettersDataset(device=device)
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

w = 415


In [6]:
n_chars = dataset.get_input_vocab_size()
n_harakat = dataset.get_output_vocab_size()
print("n_chars: ", n_chars)
print("n_harakat: ", n_harakat)

n_chars:  41
n_harakat:  15


In [14]:
from models.Accio import Accio
model = Accio(input_size=n_chars, output_size=n_harakat,device=device).to(device)
optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.char_encoder.get_pad_id())

In [15]:
num_batches = len(loader)
print("Number of batches:", num_batches)
best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    for i, (X_batch,y_batch) in tqdm(enumerate(loader)):
        y_pred = ''
        # y_pred = model(X_batch)['diacritics']
        y_pred = model(X_batch)
        # we transpose because the loss function expects the second dimension to be the classes
        # y_pred is now (batch_size, n_classes, seq_len)
        y_pred = y_pred.transpose(1, 2)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print("Epoch %d, batch %d: Loss = %.4f" % (epoch, i, loss))
        
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for (X_batch,y_batch) in loader:
            # y_pred = model(X_batch)['diacritics']
            y_pred = model(X_batch)
            y_pred = y_pred.transpose(1, 2) 
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))


Number of batches: 3150


4it [00:00, 10.56it/s]

Epoch 0, batch 0: Loss = 2.6814


103it [00:04, 26.25it/s]

Epoch 0, batch 100: Loss = 0.1166


205it [00:08, 25.60it/s]

Epoch 0, batch 200: Loss = 0.1604


304it [00:11, 26.00it/s]

Epoch 0, batch 300: Loss = 0.1534


406it [00:15, 25.76it/s]

Epoch 0, batch 400: Loss = 0.0936


505it [00:19, 25.03it/s]

Epoch 0, batch 500: Loss = 0.0842


604it [00:23, 24.00it/s]

Epoch 0, batch 600: Loss = 0.0518


703it [00:28, 24.02it/s]

Epoch 0, batch 700: Loss = 0.0568


805it [00:32, 23.88it/s]

Epoch 0, batch 800: Loss = 0.0432


904it [00:36, 23.70it/s]

Epoch 0, batch 900: Loss = 0.0531


1003it [00:40, 23.26it/s]

Epoch 0, batch 1000: Loss = 0.0372


1105it [00:45, 22.93it/s]

Epoch 0, batch 1100: Loss = 0.0479


1204it [00:49, 22.78it/s]

Epoch 0, batch 1200: Loss = 0.0427


1303it [00:53, 22.22it/s]

Epoch 0, batch 1300: Loss = 0.0317


1405it [00:58, 21.46it/s]

Epoch 0, batch 1400: Loss = 0.0284


1504it [01:03, 21.37it/s]

Epoch 0, batch 1500: Loss = 0.0508


1603it [01:07, 21.31it/s]

Epoch 0, batch 1600: Loss = 0.0295


1702it [01:12, 20.13it/s]

Epoch 0, batch 1700: Loss = 0.0313


1803it [01:17, 20.64it/s]

Epoch 0, batch 1800: Loss = 0.0271


1904it [01:22, 20.92it/s]

Epoch 0, batch 1900: Loss = 0.0397


2005it [01:27, 20.33it/s]

Epoch 0, batch 2000: Loss = 0.0262


2104it [01:32, 20.32it/s]

Epoch 0, batch 2100: Loss = 0.0260


2203it [01:37, 19.37it/s]

Epoch 0, batch 2200: Loss = 0.0320


2304it [01:42, 19.58it/s]

Epoch 0, batch 2300: Loss = 0.0273


2403it [01:47, 19.23it/s]

Epoch 0, batch 2400: Loss = 0.0243


2505it [01:52, 19.54it/s]

Epoch 0, batch 2500: Loss = 0.0277


2604it [01:57, 19.64it/s]

Epoch 0, batch 2600: Loss = 0.0146


2705it [02:03, 19.38it/s]

Epoch 0, batch 2700: Loss = 0.0239


2804it [02:08, 19.33it/s]

Epoch 0, batch 2800: Loss = 0.0190


2904it [02:13, 19.66it/s]

Epoch 0, batch 2900: Loss = 0.0184


3004it [02:18, 19.15it/s]

Epoch 0, batch 3000: Loss = 0.0200


3105it [02:23, 19.51it/s]

Epoch 0, batch 3100: Loss = 0.0204


3150it [02:26, 21.56it/s]


Epoch 0: Cross-entropy: 65.9903


2it [00:00, 15.82it/s]

Epoch 1, batch 0: Loss = 0.0232


105it [00:05, 19.97it/s]

Epoch 1, batch 100: Loss = 0.0165


204it [00:10, 19.01it/s]

Epoch 1, batch 200: Loss = 0.0196


303it [00:15, 18.71it/s]

Epoch 1, batch 300: Loss = 0.0257


403it [00:21, 18.65it/s]

Epoch 1, batch 400: Loss = 0.0186


503it [00:26, 18.91it/s]

Epoch 1, batch 500: Loss = 0.0167


603it [00:31, 18.69it/s]

Epoch 1, batch 600: Loss = 0.0254


705it [00:37, 18.94it/s]

Epoch 1, batch 700: Loss = 0.0210


805it [00:42, 19.03it/s]

Epoch 1, batch 800: Loss = 0.0189


903it [00:47, 18.28it/s]

Epoch 1, batch 900: Loss = 0.0114


1003it [00:53, 17.90it/s]

Epoch 1, batch 1000: Loss = 0.0242


1103it [00:58, 18.53it/s]

Epoch 1, batch 1100: Loss = 0.0222


1203it [01:03, 18.68it/s]

Epoch 1, batch 1200: Loss = 0.0262


1303it [01:09, 18.52it/s]

Epoch 1, batch 1300: Loss = 0.0166


1403it [01:14, 18.59it/s]

Epoch 1, batch 1400: Loss = 0.0154


1505it [01:20, 18.80it/s]

Epoch 1, batch 1500: Loss = 0.0171


1603it [01:25, 19.10it/s]

Epoch 1, batch 1600: Loss = 0.0186


1703it [01:30, 18.87it/s]

Epoch 1, batch 1700: Loss = 0.0146


1804it [01:36, 18.04it/s]

Epoch 1, batch 1800: Loss = 0.0206


1904it [01:41, 18.06it/s]

Epoch 1, batch 1900: Loss = 0.0166


2004it [01:47, 18.05it/s]

Epoch 1, batch 2000: Loss = 0.0115


2104it [01:53, 17.95it/s]

Epoch 1, batch 2100: Loss = 0.0142


2204it [01:58, 18.17it/s]

Epoch 1, batch 2200: Loss = 0.0227


2304it [02:03, 18.27it/s]

Epoch 1, batch 2300: Loss = 0.0240


2404it [02:09, 18.06it/s]

Epoch 1, batch 2400: Loss = 0.0175


2504it [02:15, 18.31it/s]

Epoch 1, batch 2500: Loss = 0.0156


2603it [02:20, 17.56it/s]

Epoch 1, batch 2600: Loss = 0.0109


2705it [02:26, 18.24it/s]

Epoch 1, batch 2700: Loss = 0.0152


2803it [02:31, 18.34it/s]

Epoch 1, batch 2800: Loss = 0.0163


2903it [02:36, 18.26it/s]

Epoch 1, batch 2900: Loss = 0.0136


3003it [02:42, 18.17it/s]

Epoch 1, batch 3000: Loss = 0.0182


3103it [02:47, 18.18it/s]

Epoch 1, batch 3100: Loss = 0.0132


3150it [02:50, 18.47it/s]


Epoch 1: Cross-entropy: 51.4951


2it [00:00, 16.95it/s]

Epoch 2, batch 0: Loss = 0.0114


105it [00:05, 18.56it/s]

Epoch 2, batch 100: Loss = 0.0154


203it [00:11, 18.68it/s]

Epoch 2, batch 200: Loss = 0.0201


303it [00:16, 18.39it/s]

Epoch 2, batch 300: Loss = 0.0144


403it [00:22, 18.34it/s]

Epoch 2, batch 400: Loss = 0.0116


503it [00:27, 18.44it/s]

Epoch 2, batch 500: Loss = 0.0180


603it [00:32, 18.52it/s]

Epoch 2, batch 600: Loss = 0.0148


703it [00:38, 18.43it/s]

Epoch 2, batch 700: Loss = 0.0150


803it [00:43, 17.91it/s]

Epoch 2, batch 800: Loss = 0.0176


903it [00:49, 18.39it/s]

Epoch 2, batch 900: Loss = 0.0150


1003it [00:54, 18.40it/s]

Epoch 2, batch 1000: Loss = 0.0149


1103it [01:00, 18.32it/s]

Epoch 2, batch 1100: Loss = 0.0156


1203it [01:05, 18.18it/s]

Epoch 2, batch 1200: Loss = 0.0115


1303it [01:11, 18.48it/s]

Epoch 2, batch 1300: Loss = 0.0162


1403it [01:16, 18.49it/s]

Epoch 2, batch 1400: Loss = 0.0159


1503it [01:21, 18.46it/s]

Epoch 2, batch 1500: Loss = 0.0184


1603it [01:27, 18.51it/s]

Epoch 2, batch 1600: Loss = 0.0142


1703it [01:32, 18.43it/s]

Epoch 2, batch 1700: Loss = 0.0086


1803it [01:38, 18.35it/s]

Epoch 2, batch 1800: Loss = 0.0146


1903it [01:43, 18.28it/s]

Epoch 2, batch 1900: Loss = 0.0116


2003it [01:49, 18.26it/s]

Epoch 2, batch 2000: Loss = 0.0120


2103it [01:54, 18.29it/s]

Epoch 2, batch 2100: Loss = 0.0115


2203it [02:00, 18.47it/s]

Epoch 2, batch 2200: Loss = 0.0122


2303it [02:05, 18.30it/s]

Epoch 2, batch 2300: Loss = 0.0231


2403it [02:11, 18.52it/s]

Epoch 2, batch 2400: Loss = 0.0153


2503it [02:16, 18.35it/s]

Epoch 2, batch 2500: Loss = 0.0121


2603it [02:22, 18.29it/s]

Epoch 2, batch 2600: Loss = 0.0121


2703it [02:27, 18.56it/s]

Epoch 2, batch 2700: Loss = 0.0205


2803it [02:32, 18.47it/s]

Epoch 2, batch 2800: Loss = 0.0218


2903it [02:38, 18.34it/s]

Epoch 2, batch 2900: Loss = 0.0106


3003it [02:43, 18.43it/s]

Epoch 2, batch 3000: Loss = 0.0184


3103it [02:49, 18.45it/s]

Epoch 2, batch 3100: Loss = 0.0176


3150it [02:51, 18.33it/s]


Epoch 2: Cross-entropy: 42.7601


2it [00:00, 16.99it/s]

Epoch 3, batch 0: Loss = 0.0224


103it [00:05, 17.82it/s]

Epoch 3, batch 100: Loss = 0.0101


204it [00:11, 19.05it/s]

Epoch 3, batch 200: Loss = 0.0097


304it [00:16, 18.13it/s]

Epoch 3, batch 300: Loss = 0.0110


404it [00:22, 18.06it/s]

Epoch 3, batch 400: Loss = 0.0155


504it [00:27, 18.12it/s]

Epoch 3, batch 500: Loss = 0.0122


604it [00:33, 18.13it/s]

Epoch 3, batch 600: Loss = 0.0106


704it [00:38, 18.08it/s]

Epoch 3, batch 700: Loss = 0.0126


804it [00:44, 18.12it/s]

Epoch 3, batch 800: Loss = 0.0169


903it [00:49, 18.26it/s]

Epoch 3, batch 900: Loss = 0.0099


1003it [00:55, 18.03it/s]

Epoch 3, batch 1000: Loss = 0.0109


1103it [01:00, 17.93it/s]

Epoch 3, batch 1100: Loss = 0.0131


1203it [01:06, 18.30it/s]

Epoch 3, batch 1200: Loss = 0.0069


1303it [01:12, 17.66it/s]

Epoch 3, batch 1300: Loss = 0.0128


1403it [01:18, 17.38it/s]

Epoch 3, batch 1400: Loss = 0.0158


1503it [01:23, 17.72it/s]

Epoch 3, batch 1500: Loss = 0.0119


1603it [01:29, 17.01it/s]

Epoch 3, batch 1600: Loss = 0.0181


1703it [01:34, 17.69it/s]

Epoch 3, batch 1700: Loss = 0.0168


1802it [01:40, 17.16it/s]

Epoch 3, batch 1800: Loss = 0.0095


1904it [01:46, 18.07it/s]

Epoch 3, batch 1900: Loss = 0.0166


2003it [01:52, 15.08it/s]

Epoch 3, batch 2000: Loss = 0.0080


2103it [01:58, 17.25it/s]

Epoch 3, batch 2100: Loss = 0.0095


2204it [02:04, 17.71it/s]

Epoch 3, batch 2200: Loss = 0.0111


2304it [02:09, 16.58it/s]

Epoch 3, batch 2300: Loss = 0.0128


2404it [02:15, 17.96it/s]

Epoch 3, batch 2400: Loss = 0.0112


2504it [02:20, 18.06it/s]

Epoch 3, batch 2500: Loss = 0.0152


2604it [02:26, 17.75it/s]

Epoch 3, batch 2600: Loss = 0.0134


2704it [02:32, 17.68it/s]

Epoch 3, batch 2700: Loss = 0.0165


2804it [02:37, 17.77it/s]

Epoch 3, batch 2800: Loss = 0.0140


2904it [02:43, 18.06it/s]

Epoch 3, batch 2900: Loss = 0.0148


3004it [02:49, 17.86it/s]

Epoch 3, batch 3000: Loss = 0.0125


3104it [02:54, 17.90it/s]

Epoch 3, batch 3100: Loss = 0.0136


3150it [02:57, 17.78it/s]


Epoch 3: Cross-entropy: 37.7168


2it [00:00, 15.50it/s]

Epoch 4, batch 0: Loss = 0.0121


104it [00:05, 17.89it/s]

Epoch 4, batch 100: Loss = 0.0135


204it [00:11, 17.96it/s]

Epoch 4, batch 200: Loss = 0.0137


304it [00:17, 17.95it/s]

Epoch 4, batch 300: Loss = 0.0085


404it [00:22, 18.18it/s]

Epoch 4, batch 400: Loss = 0.0100


504it [00:28, 17.39it/s]

Epoch 4, batch 500: Loss = 0.0135


603it [00:34, 17.80it/s]

Epoch 4, batch 600: Loss = 0.0083


703it [00:39, 18.47it/s]

Epoch 4, batch 700: Loss = 0.0104


803it [00:45, 18.45it/s]

Epoch 4, batch 800: Loss = 0.0120


903it [00:50, 18.03it/s]

Epoch 4, batch 900: Loss = 0.0126


1005it [00:56, 17.64it/s]

Epoch 4, batch 1000: Loss = 0.0108


1103it [01:02, 18.26it/s]

Epoch 4, batch 1100: Loss = 0.0155


1203it [01:08, 18.60it/s]

Epoch 4, batch 1200: Loss = 0.0093


1303it [01:14, 17.73it/s]

Epoch 4, batch 1300: Loss = 0.0123


1403it [01:19, 17.87it/s]

Epoch 4, batch 1400: Loss = 0.0265


1503it [01:25, 17.86it/s]

Epoch 4, batch 1500: Loss = 0.0110


1603it [01:30, 17.90it/s]

Epoch 4, batch 1600: Loss = 0.0130


1703it [01:36, 17.82it/s]

Epoch 4, batch 1700: Loss = 0.0134


1803it [01:42, 18.04it/s]

Epoch 4, batch 1800: Loss = 0.0170


1903it [01:47, 17.81it/s]

Epoch 4, batch 1900: Loss = 0.0107


2003it [01:53, 17.72it/s]

Epoch 4, batch 2000: Loss = 0.0127


2103it [01:59, 17.65it/s]

Epoch 4, batch 2100: Loss = 0.0156


2203it [02:04, 17.88it/s]

Epoch 4, batch 2200: Loss = 0.0120


2303it [02:10, 17.31it/s]

Epoch 4, batch 2300: Loss = 0.0095


2403it [02:16, 17.80it/s]

Epoch 4, batch 2400: Loss = 0.0127


2503it [02:21, 17.69it/s]

Epoch 4, batch 2500: Loss = 0.0106


2603it [02:27, 17.62it/s]

Epoch 4, batch 2600: Loss = 0.0076


2703it [02:32, 17.45it/s]

Epoch 4, batch 2700: Loss = 0.0092


2803it [02:38, 17.16it/s]

Epoch 4, batch 2800: Loss = 0.0152


2903it [02:44, 17.62it/s]

Epoch 4, batch 2900: Loss = 0.0126


3003it [02:49, 17.82it/s]

Epoch 4, batch 3000: Loss = 0.0132


3103it [02:55, 17.40it/s]

Epoch 4, batch 3100: Loss = 0.0148


3150it [02:58, 17.69it/s]


Epoch 4: Cross-entropy: 34.9279


In [16]:
val_dataset = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv',val_mode=True, device=device)   

val_loader = data.DataLoader(val_dataset,  batch_size=batch_size)
print(val_dataset.char_encoder.word2idx)
# evaluaate accuracy on validation set

model.eval()
letter_haraka = []
with torch.no_grad():
    for (X_batch,y_batch) in val_loader:
        # y_pred = model(X_batch)['diacritics']
        y_pred = model(X_batch)
        # we transpose because the loss function expects the second dimension to be the classes
        # y_pred is now (batch_size, n_classes, seq_len)
        y_pred = y_pred.transpose(1, 2) 
        _, predicted = torch.max(y_pred.data, 1)
        # Count only non-padding characters
        for x,y in zip(X_batch,predicted):
            for xx,yy in zip(x,y):
                # we reached the end of the sentence
                # print(xx.item())
                # print(val_dataset.char_encoder.get_pad_id())
                # print(val_dataset.char_encoder.get_id_by_token(UNK_TOKEN))
                if xx.item() == val_dataset.char_encoder.get_pad_id():
                    break
                ll = val_dataset.char_encoder.is_arabic_letter(xx.item())
                if ll:
                    letter_haraka.append([ll,yy.item()])

# save ID,Label pairs in a csv file
import pandas as pd
df = pd.DataFrame(letter_haraka, columns=['letter','label'])
df.to_csv('./results/letter_haraka.csv', index=True, index_label='ID')



w = 1129
{'ا': 0, 'ب': 1, 'ت': 2, 'ث': 3, 'ج': 4, 'ح': 5, 'خ': 6, 'د': 7, 'ذ': 8, 'ر': 9, 'ز': 10, 'س': 11, 'ش': 12, 'ص': 13, 'ض': 14, 'ط': 15, 'ظ': 16, 'ع': 17, 'غ': 18, 'ف': 19, 'ق': 20, 'ك': 21, 'ل': 22, 'م': 23, 'ن': 24, 'ه': 25, 'و': 26, 'ي': 27, 'ى': 28, 'ة': 29, 'آ': 30, 'أ': 31, 'إ': 32, 'ء': 33, 'ؤ': 34, 'ئ': 35, ' ': 36, '،': 37, '-': 38, '<pad>': 39, '<unk>': 40}


In [17]:
gold_val = pd.read_csv('clean_out/val_gold.csv',index_col=0)
sys_val = pd.read_csv('results/letter_haraka.csv',index_col=0)
# Accuracy per letter
# print(gold_val.head())
# print(sys_val.head())   
# print(gold_val.iloc[0]['label'])

correct = 0
total = len(gold_val)
for i in range(total):
    # print(gold_val[i][0], sys_val[i][0])
    correct +=( gold_val.iloc[i]['label'] == sys_val.iloc[i]['label'])
    
print("Accuracy: %.2f%%" % (100.0 * correct / total))

Accuracy: 95.81%


In [18]:
# save model 
# torch.save(model, 'models/lstm.pth')
# save model state dict
torch.save(model.state_dict(), 'models/Accio.pth')
# load model state dict
# model = BiLSTM()
# model.load_state_dict(torch.load('models/bilstm.pth'))
# load model
# model = torch.load('models/___.pth')

In [19]:
print('DER of the network on the validation set: %d %%' % (100 * (1 - correct / total)))


DER of the network on the validation set: 4 %
