In [9]:
# to reload modules automatically without having to restart the kernel
%load_ext autoreload
%autoreload 2

import torch
import torch.optim as optim
import torch.utils.data as data
from letters_dataset import LettersDataset
import torch.nn as nn
from train_collections import *
from tqdm import tqdm
import pandas as pd
import numpy as np
from nltk.stem.isri import ISRIStemmer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
# model and training parameters
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 20

In [11]:
# load train data
dataset = LettersDataset(device=device)
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

sample = next(iter(loader))
print(sample)
n_chars = dataset.get_input_vocab_size()
n_harakat = dataset.get_output_vocab_size()
print("n_chars: ", n_chars)
print("n_harakat: ", n_harakat)

w = 417
[tensor([[40, 31, 27,  ..., 39, 39, 39],
        [40, 31,  5,  ..., 39, 39, 39],
        [40, 21, 31,  ..., 39, 39, 39],
        ...,
        [40,  7, 27,  ..., 39, 39, 39],
        [40, 31, 27,  ..., 39, 39, 39],
        [20, 26, 22,  ..., 39, 39, 39]], device='cuda:0'), tensor([[14,  0,  6,  ..., 14, 14, 14],
        [14,  0,  0,  ..., 14, 14, 14],
        [14,  0,  0,  ..., 14, 14, 14],
        ...,
        [14,  0,  6,  ..., 14, 14, 14],
        [14,  0,  6,  ..., 14, 14, 14],
        [ 0,  6,  2,  ..., 14, 14, 14]], device='cuda:0')]
n_chars:  41
n_harakat:  15


In [12]:
def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }, filename)


def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss']
    return epoch, loss

In [13]:
from accio import Accio

model = Accio(input_size=n_chars, output_size=n_harakat, device=device).to(device)
optimizer = optim.Adam(model.parameters())

# model.load_state_dict(torch.load("models/accio_check_epoch_0.pth"))
_ = load_checkpoint(model, optimizer, "models/accio_check_epoch_0.pth")

In [14]:
test_dataset = LettersDataset('clean_out/X_test_no_diacritics.csv', 'clean_out/Y_test_no_diacritics.csv', val_mode=True, device=device)
val_loader = data.DataLoader(test_dataset, batch_size=batch_size)
print(test_dataset.char_encoder.word2idx)
# evaluaate accuracy on validation set

model.eval()
letter_haraka = []
with torch.no_grad():
    for (X_batch, y_batch) in val_loader:
        torch.cuda.empty_cache()  # Clear CUDA cache to avoid memory error
        # y_pred = model(X_batch)['diacritics']
        y_pred = model(X_batch)
        # we transpose because the loss function expects the second dimension to be the classes
        # y_pred is now (batch_size, n_classes, seq_len)
        y_pred = y_pred.transpose(1, 2)
        _, predicted = torch.max(y_pred.data, 1)
        # Count only non-padding characters
        for x, y in zip(X_batch, predicted):
            for xx, yy in zip(x, y):
                # we reached the end of the sentence
                # print(xx.item())
                # print(test_dataset.char_encoder.get_pad_id())
                # print(test_dataset.char_encoder.get_id_by_token(UNK_TOKEN))
                if xx.item() == test_dataset.char_encoder.get_pad_id():
                    break
                ll = test_dataset.char_encoder.is_arabic_letter(xx.item())
                if ll:
                    letter_haraka.append(yy.item())

# save ID,Label pairs in a csv file
import pandas as pd

df = pd.DataFrame(letter_haraka, columns=['label'])
df.to_csv('./results/letter_diacritic.csv', index=True, index_label='ID')



w = 1184
{'ا': 0, 'ب': 1, 'ت': 2, 'ث': 3, 'ج': 4, 'ح': 5, 'خ': 6, 'د': 7, 'ذ': 8, 'ر': 9, 'ز': 10, 'س': 11, 'ش': 12, 'ص': 13, 'ض': 14, 'ط': 15, 'ظ': 16, 'ع': 17, 'غ': 18, 'ف': 19, 'ق': 20, 'ك': 21, 'ل': 22, 'م': 23, 'ن': 24, 'ه': 25, 'و': 26, 'ي': 27, 'ى': 28, 'ة': 29, 'آ': 30, 'أ': 31, 'إ': 32, 'ء': 33, 'ؤ': 34, 'ئ': 35, ' ': 36, '،': 37, '-': 38, '<pad>': 39, '<unk>': 40}
