In [1]:
import torch
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from letters_dataset import LettersDataset
from text_encoder import TextEncoder
import torch.nn as nn
from train_collections import *
import numpy as np
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# autoreload notebook
%load_ext autoreload

In [22]:
embedding_dim = 64
n_epochs = 10
n_hidden = 265
batch_size = 64

In [23]:


dataset = LettersDataset(device=device)
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

# load val data
# da = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv')

w = 495


In [24]:
n_chars = dataset.get_input_vocab_size()
n_harakat = dataset.get_output_vocab_size()
n_harakat

17

In [31]:
class CharModel(nn.Module):
    def __init__(self):
        super().__init__()

        # embedding and LSTM layers
        self.embedding = nn.Embedding(n_chars, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=256,
                            num_layers=1, batch_first=True )
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, n_harakat)

    def forward(self, x):
        # pass thru embedding layer
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(self.dropout(x))
        return x

class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.embedding = nn.Embedding(n_chars, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=n_hidden, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(2*n_hidden, n_harakat)
        self.dropout = nn.Dropout(0.2)
        


    def forward(self, x):
        # pass thru embedding layer
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(self.dropout(x))
        return x

model = BiLSTM().to(device)

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.char_encoder.get_pad_id())
num_batches = len(loader)
print("Number of batches:", num_batches)
best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    for i, (X_batch,y_batch) in tqdm(enumerate(loader)):
        y_pred = ''
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        # print(y_pred.shape)
        # print(y_batch.shape)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print("Epoch %d, batch %d: Loss = %.4f" % (epoch, i, loss))
        
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for (X_batch,y_batch) in loader:
            y_pred = model(X_batch)
            y_pred = y_pred.transpose(1, 2) 
            
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))


Number of batches: 2590


3it [00:00,  6.24it/s]

Epoch 0, batch 0: Loss = 2.7657


103it [00:09, 11.20it/s]

Epoch 0, batch 100: Loss = 0.1537


203it [00:18, 10.68it/s]

Epoch 0, batch 200: Loss = 0.0918


303it [00:27, 10.82it/s]

Epoch 0, batch 300: Loss = 0.0745


403it [00:37, 10.93it/s]

Epoch 0, batch 400: Loss = 0.0676


503it [00:46, 10.89it/s]

Epoch 0, batch 500: Loss = 0.0545


603it [00:55, 10.62it/s]

Epoch 0, batch 600: Loss = 0.0596


703it [01:05, 10.33it/s]

Epoch 0, batch 700: Loss = 0.0354


803it [01:14, 11.03it/s]

Epoch 0, batch 800: Loss = 0.0809


903it [01:24, 10.88it/s]

Epoch 0, batch 900: Loss = 0.0492


1003it [01:33, 10.76it/s]

Epoch 0, batch 1000: Loss = 0.0300


1103it [01:42, 10.96it/s]

Epoch 0, batch 1100: Loss = 0.0397


1203it [01:51, 10.87it/s]

Epoch 0, batch 1200: Loss = 0.0358


1303it [02:00, 10.94it/s]

Epoch 0, batch 1300: Loss = 0.0328


1403it [02:10, 10.86it/s]

Epoch 0, batch 1400: Loss = 0.0528


1503it [02:19, 10.20it/s]

Epoch 0, batch 1500: Loss = 0.0277


1603it [02:28, 11.09it/s]

Epoch 0, batch 1600: Loss = 0.0459


1703it [02:37, 10.94it/s]

Epoch 0, batch 1700: Loss = 0.0217


1803it [02:47, 10.67it/s]

Epoch 0, batch 1800: Loss = 0.0326


1903it [02:57, 10.54it/s]

Epoch 0, batch 1900: Loss = 0.0355


2003it [03:06, 10.81it/s]

Epoch 0, batch 2000: Loss = 0.0287


2103it [03:15, 11.10it/s]

Epoch 0, batch 2100: Loss = 0.0242


2203it [03:25, 10.47it/s]

Epoch 0, batch 2200: Loss = 0.0181


2302it [03:34, 10.52it/s]

Epoch 0, batch 2300: Loss = 0.0219


2402it [03:43, 10.80it/s]

Epoch 0, batch 2400: Loss = 0.0270


2502it [03:53, 10.64it/s]

Epoch 0, batch 2500: Loss = 0.0304


2590it [04:01, 10.73it/s]


Epoch 0: Cross-entropy: 66.4396


2it [00:00, 10.93it/s]

Epoch 1, batch 0: Loss = 0.0281


102it [00:09, 10.95it/s]

Epoch 1, batch 100: Loss = 0.0235


202it [00:18, 11.22it/s]

Epoch 1, batch 200: Loss = 0.0284


302it [00:27, 11.12it/s]

Epoch 1, batch 300: Loss = 0.0265


402it [00:36, 10.72it/s]

Epoch 1, batch 400: Loss = 0.0313


502it [00:45, 10.69it/s]

Epoch 1, batch 500: Loss = 0.0282


602it [00:54, 11.19it/s]

Epoch 1, batch 600: Loss = 0.0203


702it [01:03, 11.12it/s]

Epoch 1, batch 700: Loss = 0.0393


803it [01:14, 10.71it/s]

Epoch 1, batch 800: Loss = 0.0262


903it [01:23, 11.05it/s]

Epoch 1, batch 900: Loss = 0.0321


1003it [01:32, 11.08it/s]

Epoch 1, batch 1000: Loss = 0.0271


1103it [01:41, 10.84it/s]

Epoch 1, batch 1100: Loss = 0.0208


1203it [01:51, 10.77it/s]

Epoch 1, batch 1200: Loss = 0.0269


1303it [02:00, 10.96it/s]

Epoch 1, batch 1300: Loss = 0.0309


1403it [02:09, 10.80it/s]

Epoch 1, batch 1400: Loss = 0.0272


1503it [02:19, 10.88it/s]

Epoch 1, batch 1500: Loss = 0.0201


1603it [02:28, 10.84it/s]

Epoch 1, batch 1600: Loss = 0.0239


1701it [02:37, 10.43it/s]

Epoch 1, batch 1700: Loss = 0.0220


1803it [02:47, 10.30it/s]

Epoch 1, batch 1800: Loss = 0.0204


1903it [02:56, 10.51it/s]

Epoch 1, batch 1900: Loss = 0.0203


2003it [03:06, 10.91it/s]

Epoch 1, batch 2000: Loss = 0.0199


2103it [03:15, 10.60it/s]

Epoch 1, batch 2100: Loss = 0.0126


2203it [03:25, 10.82it/s]

Epoch 1, batch 2200: Loss = 0.0211


2303it [03:34, 10.82it/s]

Epoch 1, batch 2300: Loss = 0.0274


2401it [03:43, 10.87it/s]

Epoch 1, batch 2400: Loss = 0.0237


2503it [03:53, 10.92it/s]

Epoch 1, batch 2500: Loss = 0.0252


2590it [04:01, 10.74it/s]


Epoch 1: Cross-entropy: 47.8175


2it [00:00, 10.81it/s]

Epoch 2, batch 0: Loss = 0.0188


102it [00:09, 10.62it/s]

Epoch 2, batch 100: Loss = 0.0285


202it [00:18, 10.61it/s]

Epoch 2, batch 200: Loss = 0.0185


302it [00:28, 10.48it/s]

Epoch 2, batch 300: Loss = 0.0194


402it [00:37, 10.73it/s]

Epoch 2, batch 400: Loss = 0.0175


502it [00:46, 10.81it/s]

Epoch 2, batch 500: Loss = 0.0293


602it [00:56, 10.88it/s]

Epoch 2, batch 600: Loss = 0.0194


702it [01:05, 10.68it/s]

Epoch 2, batch 700: Loss = 0.0223


802it [01:15, 10.79it/s]

Epoch 2, batch 800: Loss = 0.0187


902it [01:24, 10.52it/s]

Epoch 2, batch 900: Loss = 0.0125


1002it [01:33, 10.83it/s]

Epoch 2, batch 1000: Loss = 0.0153


1102it [01:43, 11.00it/s]

Epoch 2, batch 1100: Loss = 0.0166


1202it [01:52, 10.84it/s]

Epoch 2, batch 1200: Loss = 0.0206


1302it [02:01, 10.34it/s]

Epoch 2, batch 1300: Loss = 0.0119


1402it [02:11, 10.76it/s]

Epoch 2, batch 1400: Loss = 0.0142


1502it [02:20, 10.79it/s]

Epoch 2, batch 1500: Loss = 0.0237


1602it [02:30, 10.42it/s]

Epoch 2, batch 1600: Loss = 0.0198


1702it [02:39, 10.75it/s]

Epoch 2, batch 1700: Loss = 0.0177


1802it [02:48, 10.55it/s]

Epoch 2, batch 1800: Loss = 0.0172


1902it [02:58, 10.79it/s]

Epoch 2, batch 1900: Loss = 0.0218


2002it [03:07, 10.72it/s]

Epoch 2, batch 2000: Loss = 0.0155


2102it [03:16, 10.83it/s]

Epoch 2, batch 2100: Loss = 0.0152


2202it [03:26, 10.91it/s]

Epoch 2, batch 2200: Loss = 0.0182


2302it [03:35, 10.98it/s]

Epoch 2, batch 2300: Loss = 0.0216


2402it [03:44, 10.41it/s]

Epoch 2, batch 2400: Loss = 0.0229


2502it [03:54, 10.96it/s]

Epoch 2, batch 2500: Loss = 0.0187


2590it [04:02, 10.68it/s]


Epoch 2: Cross-entropy: 40.5782


1it [00:00,  9.95it/s]

Epoch 3, batch 0: Loss = 0.0163


103it [00:09, 10.81it/s]

Epoch 3, batch 100: Loss = 0.0158


203it [00:18, 10.55it/s]

Epoch 3, batch 200: Loss = 0.0166


303it [00:28, 10.11it/s]

Epoch 3, batch 300: Loss = 0.0191


402it [00:37, 10.76it/s]

Epoch 3, batch 400: Loss = 0.0166


503it [00:47,  9.99it/s]

Epoch 3, batch 500: Loss = 0.0181


602it [00:57,  9.31it/s]

Epoch 3, batch 600: Loss = 0.0228


702it [01:07, 10.86it/s]

Epoch 3, batch 700: Loss = 0.0219


802it [01:16, 10.86it/s]

Epoch 3, batch 800: Loss = 0.0230


902it [01:25, 10.92it/s]

Epoch 3, batch 900: Loss = 0.0143


1002it [01:35, 10.61it/s]

Epoch 3, batch 1000: Loss = 0.0255


1102it [01:44, 10.71it/s]

Epoch 3, batch 1100: Loss = 0.0173


1203it [01:54, 10.81it/s]

Epoch 3, batch 1200: Loss = 0.0188


1303it [02:03, 10.86it/s]

Epoch 3, batch 1300: Loss = 0.0165


1403it [02:12, 10.57it/s]

Epoch 3, batch 1400: Loss = 0.0147


1503it [02:21, 10.83it/s]

Epoch 3, batch 1500: Loss = 0.0157


1603it [02:30, 10.62it/s]

Epoch 3, batch 1600: Loss = 0.0205


1703it [02:40, 11.13it/s]

Epoch 3, batch 1700: Loss = 0.0186


1803it [02:49, 10.84it/s]

Epoch 3, batch 1800: Loss = 0.0171


1903it [02:58, 10.62it/s]

Epoch 3, batch 1900: Loss = 0.0157


2003it [03:07, 11.05it/s]

Epoch 3, batch 2000: Loss = 0.0195


2103it [03:16, 10.74it/s]

Epoch 3, batch 2100: Loss = 0.0195


2203it [03:25, 10.66it/s]

Epoch 3, batch 2200: Loss = 0.0150


2303it [03:35, 10.90it/s]

Epoch 3, batch 2300: Loss = 0.0196


2403it [03:44, 10.55it/s]

Epoch 3, batch 2400: Loss = 0.0171


2503it [03:53, 10.90it/s]

Epoch 3, batch 2500: Loss = 0.0161


2590it [04:01, 10.72it/s]


Epoch 3: Cross-entropy: 36.4875


2it [00:00, 10.64it/s]

Epoch 4, batch 0: Loss = 0.0152


102it [00:09, 10.74it/s]

Epoch 4, batch 100: Loss = 0.0164


202it [00:18, 10.73it/s]

Epoch 4, batch 200: Loss = 0.0137


302it [00:27, 10.97it/s]

Epoch 4, batch 300: Loss = 0.0186


402it [00:37, 10.82it/s]

Epoch 4, batch 400: Loss = 0.0168


502it [00:46, 11.00it/s]

Epoch 4, batch 500: Loss = 0.0101


602it [00:55, 10.72it/s]

Epoch 4, batch 600: Loss = 0.0161


702it [01:04, 10.72it/s]

Epoch 4, batch 700: Loss = 0.0168


802it [01:14, 10.60it/s]

Epoch 4, batch 800: Loss = 0.0116


902it [01:23, 11.02it/s]

Epoch 4, batch 900: Loss = 0.0127


1002it [01:33, 10.62it/s]

Epoch 4, batch 1000: Loss = 0.0252


1102it [01:42, 10.90it/s]

Epoch 4, batch 1100: Loss = 0.0130


1202it [01:51, 11.11it/s]

Epoch 4, batch 1200: Loss = 0.0090


1302it [02:00, 10.97it/s]

Epoch 4, batch 1300: Loss = 0.0110


1402it [02:10, 11.03it/s]

Epoch 4, batch 1400: Loss = 0.0149


1502it [02:19, 10.84it/s]

Epoch 4, batch 1500: Loss = 0.0191


1602it [02:28, 10.84it/s]

Epoch 4, batch 1600: Loss = 0.0201


1702it [02:37, 10.82it/s]

Epoch 4, batch 1700: Loss = 0.0098


1802it [02:47, 10.77it/s]

Epoch 4, batch 1800: Loss = 0.0125


1902it [02:56, 10.59it/s]

Epoch 4, batch 1900: Loss = 0.0156


2002it [03:05, 11.14it/s]

Epoch 4, batch 2000: Loss = 0.0228


2102it [03:14, 10.79it/s]

Epoch 4, batch 2100: Loss = 0.0137


2202it [03:24, 10.70it/s]

Epoch 4, batch 2200: Loss = 0.0143


2302it [03:33, 10.62it/s]

Epoch 4, batch 2300: Loss = 0.0098


2403it [03:43, 11.13it/s]

Epoch 4, batch 2400: Loss = 0.0148


2503it [03:52, 10.34it/s]

Epoch 4, batch 2500: Loss = 0.0145


2590it [04:00, 10.75it/s]


Epoch 4: Cross-entropy: 34.0030


In [34]:
val_dataset = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv', device=device)   

val_loader = data.DataLoader(val_dataset, shuffle=True, batch_size=batch_size)

# evaluaate accuracy on validation set


model.eval()
correct = 0
total = 0

with torch.no_grad():
    for (X_batch,y_batch) in val_loader:
        is_padding = X_batch == val_dataset.char_encoder.get_pad_id()
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        _, predicted = torch.max(y_pred.data, 1)
        # Count only non-padding characters
        total += torch.sum(~is_padding).item()
        
        # Count correct predictions
        correct += torch.sum((predicted == y_batch) & (~is_padding)).item()
print("Accuracy: %.2f%%" % (100 * correct / total))



w = 500
Accuracy: 96.30%


In [35]:
print('DER of the network on the validation set: %d %%' % (100 * (1 - correct / total)))


DER of the network on the validation set: 3 %
