In [1]:
# to reload modules automatically without having to restart the kernel
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import torch
import torch.optim as optim
import torch.utils.data as data
from letters_dataset import LettersDataset
import torch.nn as nn
from train_collections import *
import numpy as np
from tqdm import tqdm

In [2]:
# model and training parameters
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 10

In [3]:
# load train data
dataset = LettersDataset(device=device)
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)
n_chars = dataset.get_input_vocab_size()
n_harakat = dataset.get_output_vocab_size()
print("n_chars: ", n_chars)
print("n_harakat: ", n_harakat)

w = 417
n_chars:  41
n_harakat:  15


In [4]:
def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }, filename)


def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss']
    return epoch, loss

In [5]:
from accio import Accio

model = Accio(input_size=n_chars, output_size=n_harakat, device=device).to(device)
# model.load_state_dict(torch.load('./models/accio_test_epoch_19.pth'))
optimizer = optim.Adam(model.parameters(), lr=0.001)
load_checkpoint(model, optimizer, './models/accio_check_epoch_0.pth')
loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.char_encoder.get_pad_id())

In [6]:
num_batches = len(loader)
print("Number of batches:", num_batches)
best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    torch.cuda.empty_cache()  # Clear CUDA cache to avoid memory error
    model.train()
    for i, (X_batch, y_batch) in tqdm(enumerate(loader)):
        y_pred = ''
        y_pred = model(X_batch)
        # we transpose because the loss function expects the second dimension to be the classes
        # y_pred is now (batch_size, n_classes, seq_len)
        y_pred = y_pred.transpose(1, 2)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print("Epoch %d, batch %d: Loss = %.4f" % (epoch, i, loss))

    # save model after each epoch
    # torch.save(model.state_dict(), f'models/accio_all_epoch_{epoch + 4}.pth')
    save_checkpoint(model, optimizer, epoch, 0, f'models/accio_check_epoch_{epoch + 1}.pth')
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for (X_batch, y_batch) in loader:
            y_pred = model(X_batch)
            y_pred = y_pred.transpose(1, 2)
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))


Number of batches: 3466


1it [00:03,  3.44s/it]

Epoch 0, batch 0: Loss = 0.0030


101it [00:35,  3.16it/s]

Epoch 0, batch 100: Loss = 0.0025


201it [01:07,  3.10it/s]

Epoch 0, batch 200: Loss = 0.0034


301it [01:39,  3.18it/s]

Epoch 0, batch 300: Loss = 0.0026


401it [02:11,  3.14it/s]

Epoch 0, batch 400: Loss = 0.0023


501it [02:43,  3.10it/s]

Epoch 0, batch 500: Loss = 0.0024


601it [03:15,  3.11it/s]

Epoch 0, batch 600: Loss = 0.0029


701it [03:47,  3.20it/s]

Epoch 0, batch 700: Loss = 0.0039


801it [04:19,  3.09it/s]

Epoch 0, batch 800: Loss = 0.0034


901it [04:51,  3.09it/s]

Epoch 0, batch 900: Loss = 0.0041


1001it [05:24,  3.11it/s]

Epoch 0, batch 1000: Loss = 0.0045


1101it [05:56,  3.15it/s]

Epoch 0, batch 1100: Loss = 0.0035


1201it [06:28,  3.12it/s]

Epoch 0, batch 1200: Loss = 0.0044


1301it [07:01,  3.09it/s]

Epoch 0, batch 1300: Loss = 0.0031


1401it [07:34,  3.02it/s]

Epoch 0, batch 1400: Loss = 0.0092


1501it [08:06,  3.11it/s]

Epoch 0, batch 1500: Loss = 0.0050


1601it [08:39,  3.06it/s]

Epoch 0, batch 1600: Loss = 0.0044


1701it [09:11,  3.07it/s]

Epoch 0, batch 1700: Loss = 0.0017


1801it [09:44,  3.02it/s]

Epoch 0, batch 1800: Loss = 0.0050


1901it [10:17,  3.02it/s]

Epoch 0, batch 1900: Loss = 0.0041


2001it [10:51,  3.08it/s]

Epoch 0, batch 2000: Loss = 0.0030


2101it [11:24,  3.12it/s]

Epoch 0, batch 2100: Loss = 0.0027


2201it [11:56,  3.11it/s]

Epoch 0, batch 2200: Loss = 0.0017


2301it [12:29,  3.05it/s]

Epoch 0, batch 2300: Loss = 0.0028


2401it [13:02,  3.13it/s]

Epoch 0, batch 2400: Loss = 0.0037


2501it [13:35,  2.99it/s]

Epoch 0, batch 2500: Loss = 0.0037


2601it [14:08,  3.04it/s]

Epoch 0, batch 2600: Loss = 0.0031


2701it [14:41,  3.04it/s]

Epoch 0, batch 2700: Loss = 0.0031


2801it [15:13,  2.99it/s]

Epoch 0, batch 2800: Loss = 0.0036


2901it [15:46,  3.01it/s]

Epoch 0, batch 2900: Loss = 0.0052


3001it [16:19,  2.89it/s]

Epoch 0, batch 3000: Loss = 0.0029


3101it [16:52,  3.13it/s]

Epoch 0, batch 3100: Loss = 0.0039


3201it [17:25,  3.08it/s]

Epoch 0, batch 3200: Loss = 0.0022


3301it [17:58,  3.08it/s]

Epoch 0, batch 3300: Loss = 0.0052


3401it [18:30,  3.10it/s]

Epoch 0, batch 3400: Loss = 0.0041


3466it [18:51,  3.06it/s]


Epoch 0: Cross-entropy: 11.0859


1it [00:00,  2.47it/s]

Epoch 1, batch 0: Loss = 0.0038


101it [00:33,  3.08it/s]

Epoch 1, batch 100: Loss = 0.0020


201it [01:05,  3.14it/s]

Epoch 1, batch 200: Loss = 0.0028


301it [01:38,  3.15it/s]

Epoch 1, batch 300: Loss = 0.0031


401it [02:10,  3.13it/s]

Epoch 1, batch 400: Loss = 0.0018


501it [02:42,  3.07it/s]

Epoch 1, batch 500: Loss = 0.0038


601it [03:15,  2.97it/s]

Epoch 1, batch 600: Loss = 0.0018


701it [03:48,  3.05it/s]

Epoch 1, batch 700: Loss = 0.0027


801it [04:21,  3.10it/s]

Epoch 1, batch 800: Loss = 0.0031


901it [04:54,  2.99it/s]

Epoch 1, batch 900: Loss = 0.0037


1001it [05:26,  3.05it/s]

Epoch 1, batch 1000: Loss = 0.0029


1101it [06:00,  2.78it/s]

Epoch 1, batch 1100: Loss = 0.0030


1201it [06:37,  3.06it/s]

Epoch 1, batch 1200: Loss = 0.0051


1301it [07:10,  3.02it/s]

Epoch 1, batch 1300: Loss = 0.0037


1401it [07:43,  2.90it/s]

Epoch 1, batch 1400: Loss = 0.0044


1501it [08:17,  2.64it/s]

Epoch 1, batch 1500: Loss = 0.0025


1601it [08:51,  2.96it/s]

Epoch 1, batch 1600: Loss = 0.0032


1701it [09:24,  3.07it/s]

Epoch 1, batch 1700: Loss = 0.0026


1801it [09:56,  2.92it/s]

Epoch 1, batch 1800: Loss = 0.0036


1901it [10:29,  3.00it/s]

Epoch 1, batch 1900: Loss = 0.0034


2001it [11:02,  2.97it/s]

Epoch 1, batch 2000: Loss = 0.0068


2101it [11:35,  3.11it/s]

Epoch 1, batch 2100: Loss = 0.0044


2201it [12:08,  2.98it/s]

Epoch 1, batch 2200: Loss = 0.0046


2301it [12:41,  3.09it/s]

Epoch 1, batch 2300: Loss = 0.0039


2401it [13:13,  3.08it/s]

Epoch 1, batch 2400: Loss = 0.0025


2501it [13:46,  2.99it/s]

Epoch 1, batch 2500: Loss = 0.0023


2601it [14:21,  2.98it/s]

Epoch 1, batch 2600: Loss = 0.0032


2701it [14:53,  3.14it/s]

Epoch 1, batch 2700: Loss = 0.0032


2801it [15:26,  3.06it/s]

Epoch 1, batch 2800: Loss = 0.0029


2901it [15:59,  3.13it/s]

Epoch 1, batch 2900: Loss = 0.0049


3001it [16:32,  3.04it/s]

Epoch 1, batch 3000: Loss = 0.0068


3101it [17:05,  3.01it/s]

Epoch 1, batch 3100: Loss = 0.0048


3201it [17:39,  2.81it/s]

Epoch 1, batch 3200: Loss = 0.0035


3301it [18:12,  3.05it/s]

Epoch 1, batch 3300: Loss = 0.0036


3401it [18:45,  3.10it/s]

Epoch 1, batch 3400: Loss = 0.0031


3466it [19:06,  3.02it/s]


Epoch 1: Cross-entropy: 10.6893


1it [00:00,  2.14it/s]

Epoch 2, batch 0: Loss = 0.0035


101it [00:33,  3.11it/s]

Epoch 2, batch 100: Loss = 0.0022


201it [01:05,  3.17it/s]

Epoch 2, batch 200: Loss = 0.0027


301it [01:37,  3.18it/s]

Epoch 2, batch 300: Loss = 0.0052


401it [02:09,  3.13it/s]

Epoch 2, batch 400: Loss = 0.0030


501it [02:42,  2.74it/s]

Epoch 2, batch 500: Loss = 0.0024


601it [03:16,  2.92it/s]

Epoch 2, batch 600: Loss = 0.0032


701it [03:51,  3.07it/s]

Epoch 2, batch 700: Loss = 0.0039


801it [04:24,  3.11it/s]

Epoch 2, batch 800: Loss = 0.0031


901it [04:59,  2.94it/s]

Epoch 2, batch 900: Loss = 0.0035


1001it [05:32,  3.01it/s]

Epoch 2, batch 1000: Loss = 0.0029


1101it [06:05,  2.97it/s]

Epoch 2, batch 1100: Loss = 0.0035


1201it [06:39,  2.96it/s]

Epoch 2, batch 1200: Loss = 0.0033


1301it [07:11,  3.16it/s]

Epoch 2, batch 1300: Loss = 0.0039


1401it [07:43,  3.07it/s]

Epoch 2, batch 1400: Loss = 0.0035


1501it [08:17,  3.11it/s]

Epoch 2, batch 1500: Loss = 0.0049


1601it [08:49,  2.91it/s]

Epoch 2, batch 1600: Loss = 0.0014


1701it [09:22,  3.08it/s]

Epoch 2, batch 1700: Loss = 0.0027


1801it [09:55,  2.99it/s]

Epoch 2, batch 1800: Loss = 0.0043


1901it [10:27,  3.02it/s]

Epoch 2, batch 1900: Loss = 0.0045


2001it [11:00,  2.99it/s]

Epoch 2, batch 2000: Loss = 0.0024


2101it [11:32,  3.10it/s]

Epoch 2, batch 2100: Loss = 0.0038


2201it [12:05,  3.06it/s]

Epoch 2, batch 2200: Loss = 0.0027


2301it [12:37,  3.16it/s]

Epoch 2, batch 2300: Loss = 0.0044


2401it [13:10,  3.11it/s]

Epoch 2, batch 2400: Loss = 0.0049


2501it [13:42,  3.09it/s]

Epoch 2, batch 2500: Loss = 0.0018


2601it [14:15,  3.05it/s]

Epoch 2, batch 2600: Loss = 0.0059


2701it [14:47,  3.14it/s]

Epoch 2, batch 2700: Loss = 0.0049


2801it [15:19,  3.18it/s]

Epoch 2, batch 2800: Loss = 0.0026


2901it [15:52,  3.16it/s]

Epoch 2, batch 2900: Loss = 0.0049


3001it [16:24,  3.02it/s]

Epoch 2, batch 3000: Loss = 0.0047


3101it [16:57,  3.06it/s]

Epoch 2, batch 3100: Loss = 0.0028


3201it [17:30,  3.10it/s]

Epoch 2, batch 3200: Loss = 0.0052


3301it [18:02,  3.09it/s]

Epoch 2, batch 3300: Loss = 0.0049


3401it [18:35,  3.15it/s]

Epoch 2, batch 3400: Loss = 0.0045


3466it [18:56,  3.05it/s]


Epoch 2: Cross-entropy: 10.4807


1it [00:00,  2.59it/s]

Epoch 3, batch 0: Loss = 0.0020


101it [00:32,  3.17it/s]

Epoch 3, batch 100: Loss = 0.0027


201it [01:03,  3.16it/s]

Epoch 3, batch 200: Loss = 0.0020


301it [01:36,  3.18it/s]

Epoch 3, batch 300: Loss = 0.0017


401it [02:08,  3.12it/s]

Epoch 3, batch 400: Loss = 0.0038


501it [02:41,  3.05it/s]

Epoch 3, batch 500: Loss = 0.0032


601it [03:15,  2.91it/s]

Epoch 3, batch 600: Loss = 0.0030


701it [03:47,  3.05it/s]

Epoch 3, batch 700: Loss = 0.0035


801it [04:21,  3.05it/s]

Epoch 3, batch 800: Loss = 0.0026


901it [04:53,  3.04it/s]

Epoch 3, batch 900: Loss = 0.0029


1001it [05:26,  2.91it/s]

Epoch 3, batch 1000: Loss = 0.0044


1101it [06:00,  3.00it/s]

Epoch 3, batch 1100: Loss = 0.0023


1201it [06:33,  2.97it/s]

Epoch 3, batch 1200: Loss = 0.0033


1301it [07:06,  2.92it/s]

Epoch 3, batch 1300: Loss = 0.0052


1401it [07:39,  3.12it/s]

Epoch 3, batch 1400: Loss = 0.0030


1501it [08:12,  3.04it/s]

Epoch 3, batch 1500: Loss = 0.0047


1601it [08:45,  3.06it/s]

Epoch 3, batch 1600: Loss = 0.0039


1701it [09:18,  3.04it/s]

Epoch 3, batch 1700: Loss = 0.0028


1801it [09:51,  2.81it/s]

Epoch 3, batch 1800: Loss = 0.0025


1901it [10:24,  3.08it/s]

Epoch 3, batch 1900: Loss = 0.0019


2001it [10:57,  3.04it/s]

Epoch 3, batch 2000: Loss = 0.0034


2101it [11:30,  3.08it/s]

Epoch 3, batch 2100: Loss = 0.0026


2201it [12:03,  3.04it/s]

Epoch 3, batch 2200: Loss = 0.0032


2301it [12:36,  3.02it/s]

Epoch 3, batch 2300: Loss = 0.0035


2401it [13:13,  2.74it/s]

Epoch 3, batch 2400: Loss = 0.0054


2501it [13:49,  3.07it/s]

Epoch 3, batch 2500: Loss = 0.0033


2601it [14:25,  1.93it/s]

Epoch 3, batch 2600: Loss = 0.0022


2701it [15:00,  2.94it/s]

Epoch 3, batch 2700: Loss = 0.0032


2801it [15:33,  3.06it/s]

Epoch 3, batch 2800: Loss = 0.0052


2901it [16:07,  2.94it/s]

Epoch 3, batch 2900: Loss = 0.0033


3001it [16:41,  2.98it/s]

Epoch 3, batch 3000: Loss = 0.0036


3101it [17:16,  2.90it/s]

Epoch 3, batch 3100: Loss = 0.0057


3201it [17:49,  3.05it/s]

Epoch 3, batch 3200: Loss = 0.0028


3301it [18:22,  2.86it/s]

Epoch 3, batch 3300: Loss = 0.0030


3401it [18:56,  2.95it/s]

Epoch 3, batch 3400: Loss = 0.0031


3466it [19:17,  2.99it/s]


Epoch 3: Cross-entropy: 10.6800


1it [00:00,  2.50it/s]

Epoch 4, batch 0: Loss = 0.0029


101it [00:33,  2.88it/s]

Epoch 4, batch 100: Loss = 0.0053


201it [01:06,  3.06it/s]

Epoch 4, batch 200: Loss = 0.0028


301it [01:40,  2.99it/s]

Epoch 4, batch 300: Loss = 0.0022


401it [02:14,  3.00it/s]

Epoch 4, batch 400: Loss = 0.0034


501it [02:48,  3.00it/s]

Epoch 4, batch 500: Loss = 0.0044


601it [03:21,  2.93it/s]

Epoch 4, batch 600: Loss = 0.0034


701it [03:55,  2.99it/s]

Epoch 4, batch 700: Loss = 0.0023


801it [04:30,  2.79it/s]

Epoch 4, batch 800: Loss = 0.0040


901it [05:03,  3.01it/s]

Epoch 4, batch 900: Loss = 0.0013


1001it [05:37,  2.92it/s]

Epoch 4, batch 1000: Loss = 0.0025


1101it [06:11,  2.90it/s]

Epoch 4, batch 1100: Loss = 0.0039


1201it [06:44,  3.00it/s]

Epoch 4, batch 1200: Loss = 0.0030


1301it [07:17,  3.03it/s]

Epoch 4, batch 1300: Loss = 0.0027


1401it [07:51,  2.88it/s]

Epoch 4, batch 1400: Loss = 0.0035


1501it [08:24,  3.07it/s]

Epoch 4, batch 1500: Loss = 0.0045


1601it [08:57,  3.08it/s]

Epoch 4, batch 1600: Loss = 0.0039


1701it [09:31,  3.05it/s]

Epoch 4, batch 1700: Loss = 0.0038


1801it [10:04,  3.03it/s]

Epoch 4, batch 1800: Loss = 0.0042


1901it [10:36,  3.08it/s]

Epoch 4, batch 1900: Loss = 0.0017


2001it [11:09,  3.07it/s]

Epoch 4, batch 2000: Loss = 0.0041


2101it [11:43,  2.77it/s]

Epoch 4, batch 2100: Loss = 0.0032


2201it [12:17,  2.82it/s]

Epoch 4, batch 2200: Loss = 0.0033


2301it [12:50,  3.08it/s]

Epoch 4, batch 2300: Loss = 0.0043


2401it [13:23,  2.95it/s]

Epoch 4, batch 2400: Loss = 0.0059


2501it [13:56,  3.09it/s]

Epoch 4, batch 2500: Loss = 0.0026


2601it [14:29,  3.03it/s]

Epoch 4, batch 2600: Loss = 0.0056


2701it [15:02,  3.03it/s]

Epoch 4, batch 2700: Loss = 0.0044


2801it [15:35,  2.99it/s]

Epoch 4, batch 2800: Loss = 0.0044


2901it [16:09,  3.05it/s]

Epoch 4, batch 2900: Loss = 0.0055


3001it [16:42,  3.04it/s]

Epoch 4, batch 3000: Loss = 0.0027


3101it [17:15,  3.04it/s]

Epoch 4, batch 3100: Loss = 0.0041


3201it [17:48,  2.94it/s]

Epoch 4, batch 3200: Loss = 0.0055


3301it [18:21,  3.07it/s]

Epoch 4, batch 3300: Loss = 0.0057


3401it [18:54,  3.06it/s]

Epoch 4, batch 3400: Loss = 0.0029


3466it [19:29,  2.96it/s]


Epoch 4: Cross-entropy: 10.4819


1it [00:00,  2.51it/s]

Epoch 5, batch 0: Loss = 0.0034


101it [00:33,  3.02it/s]

Epoch 5, batch 100: Loss = 0.0025


201it [01:06,  2.97it/s]

Epoch 5, batch 200: Loss = 0.0053


301it [01:39,  3.07it/s]

Epoch 5, batch 300: Loss = 0.0032


401it [02:12,  3.12it/s]

Epoch 5, batch 400: Loss = 0.0029


501it [02:44,  3.01it/s]

Epoch 5, batch 500: Loss = 0.0045


601it [03:19,  3.05it/s]

Epoch 5, batch 600: Loss = 0.0037


701it [03:52,  2.87it/s]

Epoch 5, batch 700: Loss = 0.0022


801it [04:25,  3.01it/s]

Epoch 5, batch 800: Loss = 0.0025


901it [04:58,  2.95it/s]

Epoch 5, batch 900: Loss = 0.0027


1001it [05:31,  3.04it/s]

Epoch 5, batch 1000: Loss = 0.0023


1101it [06:04,  2.98it/s]

Epoch 5, batch 1100: Loss = 0.0027


1201it [06:36,  2.97it/s]

Epoch 5, batch 1200: Loss = 0.0034


1301it [07:10,  2.95it/s]

Epoch 5, batch 1300: Loss = 0.0041


1401it [07:43,  3.09it/s]

Epoch 5, batch 1400: Loss = 0.0022


1501it [08:16,  2.97it/s]

Epoch 5, batch 1500: Loss = 0.0041


1601it [08:49,  2.94it/s]

Epoch 5, batch 1600: Loss = 0.0054


1701it [09:22,  3.05it/s]

Epoch 5, batch 1700: Loss = 0.0020


1801it [09:55,  3.03it/s]

Epoch 5, batch 1800: Loss = 0.0040


1901it [10:27,  3.02it/s]

Epoch 5, batch 1900: Loss = 0.0042


2001it [11:01,  2.98it/s]

Epoch 5, batch 2000: Loss = 0.0058


2101it [11:34,  3.01it/s]

Epoch 5, batch 2100: Loss = 0.0020


2201it [12:07,  3.04it/s]

Epoch 5, batch 2200: Loss = 0.0034


2301it [12:39,  3.03it/s]

Epoch 5, batch 2300: Loss = 0.0037


2401it [13:12,  3.07it/s]

Epoch 5, batch 2400: Loss = 0.0039


2501it [13:45,  2.95it/s]

Epoch 5, batch 2500: Loss = 0.0024


2601it [14:18,  3.06it/s]

Epoch 5, batch 2600: Loss = 0.0035


2701it [14:51,  3.00it/s]

Epoch 5, batch 2700: Loss = 0.0067


2801it [15:24,  3.07it/s]

Epoch 5, batch 2800: Loss = 0.0024


2901it [15:57,  3.03it/s]

Epoch 5, batch 2900: Loss = 0.0034


3001it [16:30,  3.03it/s]

Epoch 5, batch 3000: Loss = 0.0043


3101it [17:03,  3.01it/s]

Epoch 5, batch 3100: Loss = 0.0045


3201it [17:36,  3.08it/s]

Epoch 5, batch 3200: Loss = 0.0033


3301it [18:09,  3.02it/s]

Epoch 5, batch 3300: Loss = 0.0043


3401it [18:42,  3.00it/s]

Epoch 5, batch 3400: Loss = 0.0042


3466it [19:04,  3.03it/s]


Epoch 5: Cross-entropy: 10.3855


1it [00:00,  2.51it/s]

Epoch 6, batch 0: Loss = 0.0028


101it [00:33,  3.07it/s]

Epoch 6, batch 100: Loss = 0.0026


201it [01:06,  2.98it/s]

Epoch 6, batch 200: Loss = 0.0026


301it [01:39,  3.06it/s]

Epoch 6, batch 300: Loss = 0.0023


401it [02:12,  3.05it/s]

Epoch 6, batch 400: Loss = 0.0031


501it [02:45,  3.04it/s]

Epoch 6, batch 500: Loss = 0.0042


601it [03:17,  3.09it/s]

Epoch 6, batch 600: Loss = 0.0023


701it [03:51,  2.98it/s]

Epoch 6, batch 700: Loss = 0.0029


801it [04:24,  2.97it/s]

Epoch 6, batch 800: Loss = 0.0024


901it [04:57,  2.97it/s]

Epoch 6, batch 900: Loss = 0.0035


1001it [05:31,  2.78it/s]

Epoch 6, batch 1000: Loss = 0.0033


1101it [06:07,  2.94it/s]

Epoch 6, batch 1100: Loss = 0.0024


1201it [06:42,  2.98it/s]

Epoch 6, batch 1200: Loss = 0.0039


1301it [07:15,  3.02it/s]

Epoch 6, batch 1300: Loss = 0.0033


1401it [07:49,  3.00it/s]

Epoch 6, batch 1400: Loss = 0.0031


1501it [08:22,  3.05it/s]

Epoch 6, batch 1500: Loss = 0.0056


1601it [08:55,  3.08it/s]

Epoch 6, batch 1600: Loss = 0.0031


1701it [09:28,  2.45it/s]

Epoch 6, batch 1700: Loss = 0.0045


1801it [10:03,  3.05it/s]

Epoch 6, batch 1800: Loss = 0.0027


1901it [10:36,  2.95it/s]

Epoch 6, batch 1900: Loss = 0.0029


2001it [11:10,  3.09it/s]

Epoch 6, batch 2000: Loss = 0.0044


2101it [11:44,  3.07it/s]

Epoch 6, batch 2100: Loss = 0.0032


2201it [12:16,  3.02it/s]

Epoch 6, batch 2200: Loss = 0.0039


2301it [12:50,  3.07it/s]

Epoch 6, batch 2300: Loss = 0.0025


2401it [13:23,  3.08it/s]

Epoch 6, batch 2400: Loss = 0.0034


2501it [13:57,  2.79it/s]

Epoch 6, batch 2500: Loss = 0.0037


2601it [14:32,  2.96it/s]

Epoch 6, batch 2600: Loss = 0.0028


2701it [15:05,  2.95it/s]

Epoch 6, batch 2700: Loss = 0.0033


2801it [15:38,  3.03it/s]

Epoch 6, batch 2800: Loss = 0.0049


2901it [16:11,  3.05it/s]

Epoch 6, batch 2900: Loss = 0.0038


3001it [16:44,  3.03it/s]

Epoch 6, batch 3000: Loss = 0.0027


3101it [17:17,  3.06it/s]

Epoch 6, batch 3100: Loss = 0.0083


3201it [17:50,  2.94it/s]

Epoch 6, batch 3200: Loss = 0.0036


3301it [18:23,  3.04it/s]

Epoch 6, batch 3300: Loss = 0.0028


3401it [18:56,  2.99it/s]

Epoch 6, batch 3400: Loss = 0.0034


3466it [19:18,  2.99it/s]


Epoch 6: Cross-entropy: 10.1721


1it [00:00,  2.52it/s]

Epoch 7, batch 0: Loss = 0.0031


101it [00:32,  3.08it/s]

Epoch 7, batch 100: Loss = 0.0027


201it [01:05,  3.05it/s]

Epoch 7, batch 200: Loss = 0.0022


301it [01:37,  3.06it/s]

Epoch 7, batch 300: Loss = 0.0024


401it [02:10,  3.11it/s]

Epoch 7, batch 400: Loss = 0.0039


501it [02:42,  3.11it/s]

Epoch 7, batch 500: Loss = 0.0031


601it [03:14,  3.09it/s]

Epoch 7, batch 600: Loss = 0.0036


701it [03:47,  3.04it/s]

Epoch 7, batch 700: Loss = 0.0021


801it [04:19,  3.10it/s]

Epoch 7, batch 800: Loss = 0.0026


901it [04:51,  3.09it/s]

Epoch 7, batch 900: Loss = 0.0034


1001it [05:24,  3.12it/s]

Epoch 7, batch 1000: Loss = 0.0028


1101it [05:56,  3.07it/s]

Epoch 7, batch 1100: Loss = 0.0047


1201it [06:28,  3.08it/s]

Epoch 7, batch 1200: Loss = 0.0029


1301it [07:01,  3.07it/s]

Epoch 7, batch 1300: Loss = 0.0030


1401it [07:33,  3.05it/s]

Epoch 7, batch 1400: Loss = 0.0031


1501it [08:05,  3.03it/s]

Epoch 7, batch 1500: Loss = 0.0030


1601it [08:37,  3.08it/s]

Epoch 7, batch 1600: Loss = 0.0030


1701it [09:10,  3.10it/s]

Epoch 7, batch 1700: Loss = 0.0020


1801it [09:42,  3.08it/s]

Epoch 7, batch 1800: Loss = 0.0032


1901it [10:14,  3.04it/s]

Epoch 7, batch 1900: Loss = 0.0038


2001it [10:47,  3.09it/s]

Epoch 7, batch 2000: Loss = 0.0039


2101it [11:19,  3.06it/s]

Epoch 7, batch 2100: Loss = 0.0040


2201it [11:51,  3.02it/s]

Epoch 7, batch 2200: Loss = 0.0034


2301it [12:24,  3.09it/s]

Epoch 7, batch 2300: Loss = 0.0044


2401it [12:56,  3.03it/s]

Epoch 7, batch 2400: Loss = 0.0021


2501it [13:28,  3.10it/s]

Epoch 7, batch 2500: Loss = 0.0031


2601it [14:01,  3.10it/s]

Epoch 7, batch 2600: Loss = 0.0027


2701it [14:33,  3.09it/s]

Epoch 7, batch 2700: Loss = 0.0046


2801it [15:06,  3.06it/s]

Epoch 7, batch 2800: Loss = 0.0032


2901it [15:38,  3.12it/s]

Epoch 7, batch 2900: Loss = 0.0057


3001it [16:11,  2.93it/s]

Epoch 7, batch 3000: Loss = 0.0055


3101it [16:44,  3.12it/s]

Epoch 7, batch 3100: Loss = 0.0028


3201it [17:17,  3.11it/s]

Epoch 7, batch 3200: Loss = 0.0037


3301it [17:50,  3.01it/s]

Epoch 7, batch 3300: Loss = 0.0029


3401it [18:23,  3.01it/s]

Epoch 7, batch 3400: Loss = 0.0030


3466it [18:45,  3.08it/s]


Epoch 7: Cross-entropy: 10.4841


1it [00:00,  2.45it/s]

Epoch 8, batch 0: Loss = 0.0023


101it [00:32,  3.08it/s]

Epoch 8, batch 100: Loss = 0.0041


201it [01:06,  2.82it/s]

Epoch 8, batch 200: Loss = 0.0031


209it [01:09,  2.83it/s]

In [None]:
# load validation data
val_dataset = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv', val_mode=True, device=device)
val_loader = data.DataLoader(val_dataset, batch_size=batch_size)
print(val_dataset.char_encoder.word2idx)

In [None]:
# evaluaate accuracy on validation set
model.eval()
letter_haraka = []
with torch.no_grad():
    for (X_batch, y_batch) in val_loader:
        y_pred = model(X_batch)
        # we transpose because the loss function expects the second dimension to be the classes
        # y_pred is now (batch_size, n_classes, seq_len)
        y_pred = y_pred.transpose(1, 2)
        _, predicted = torch.max(y_pred.data, 1)
        # Count only non-padding characters
        for x, y in zip(X_batch, predicted):
            for xx, yy in zip(x, y):
                # we reached the end of the sentence
                if xx.item() == val_dataset.char_encoder.get_pad_id():
                    break
                ll = val_dataset.char_encoder.is_arabic_letter(xx.item())
                if ll:
                    letter_haraka.append([ll, yy.item()])

# save ID,Label pairs in a csv file
import pandas as pd

df = pd.DataFrame(letter_haraka, columns=['letter', 'label'])
df.to_csv('./results/letter_haraka.csv', index=True, index_label='ID')

In [None]:
gold_val = pd.read_csv('clean_out/val_gold.csv', index_col=0)
sys_val = pd.read_csv('results/letter_haraka.csv', index_col=0)
# Accuracy per letter
correct = 0
total = len(gold_val)
for i in range(total):
    # print(gold_val[i][0], sys_val[i][0])
    correct += (gold_val.iloc[i]['label'] == sys_val.iloc[i]['label'])

print("Accuracy: %.2f%%" % (100.0 * correct / total))

In [None]:
print('DER of the network on the validation set: %d %%' % (100 * (1 - correct / total)))