In [2]:
import torch
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from letters_dataset import LettersDataset
from text_encoder import TextEncoder
import torch.nn as nn
from train_collections import DS_ARABIC_LETTERS, DS_HARAKAT
import numpy as np
from tqdm import tqdm

# autoreload notebook
# %load_ext autoreload

In [3]:



dim_vocab = len(DS_ARABIC_LETTERS)
dim_out = len(DS_HARAKAT) + 2
embedding_dim = 64
n_epochs = 10
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:

dataset = LettersDataset(device=device)
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

# load val data
# da = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv')


w = 495


In [6]:
class CharModel(nn.Module):
    def __init__(self):
        super().__init__()

        # embedding and LSTM layers
        self.embedding = nn.Embedding(dim_vocab, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=256,
                            num_layers=1, batch_first=True )
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, dim_out)

    def forward(self, x):
        # pass thru embedding layer
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(self.dropout(x))
        return x



model = CharModel().to(device)

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
num_batches = len(loader)
print("Number of batches:", num_batches)
best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    for i, (X_batch,y_batch) in tqdm(enumerate(loader)):
        y_pred = ''
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        # print(y_pred.shape)
        # print(y_batch.shape)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print("Epoch %d, batch %d: Loss = %.4f" % (epoch, i, loss))
        
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for (X_batch,y_batch) in loader:
            y_pred = model(X_batch)
            y_pred = y_pred.transpose(1, 2) 
            
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))


Number of batches: 2590


4it [00:00, 10.09it/s]

Epoch 0, batch 0: Loss = 2.9327


106it [00:04, 28.60it/s]

Epoch 0, batch 100: Loss = 0.1666


205it [00:07, 26.16it/s]

Epoch 0, batch 200: Loss = 0.1237


304it [00:11, 27.13it/s]

Epoch 0, batch 300: Loss = 0.1140


403it [00:15, 26.34it/s]

Epoch 0, batch 400: Loss = 0.1006


505it [00:19, 25.62it/s]

Epoch 0, batch 500: Loss = 0.0971


604it [00:22, 27.30it/s]

Epoch 0, batch 600: Loss = 0.0749


706it [00:26, 27.26it/s]

Epoch 0, batch 700: Loss = 0.0707


805it [00:30, 24.99it/s]

Epoch 0, batch 800: Loss = 0.0786


907it [00:34, 26.85it/s]

Epoch 0, batch 900: Loss = 0.0845


1006it [00:38, 26.38it/s]

Epoch 0, batch 1000: Loss = 0.0939


1105it [00:41, 27.75it/s]

Epoch 0, batch 1100: Loss = 0.0651


1204it [00:45, 26.84it/s]

Epoch 0, batch 1200: Loss = 0.0695


1306it [00:49, 26.46it/s]

Epoch 0, batch 1300: Loss = 0.0799


1405it [00:53, 26.88it/s]

Epoch 0, batch 1400: Loss = 0.0779


1504it [00:56, 25.89it/s]

Epoch 0, batch 1500: Loss = 0.0679


1603it [01:00, 25.57it/s]

Epoch 0, batch 1600: Loss = 0.0705


1705it [01:04, 25.77it/s]

Epoch 0, batch 1700: Loss = 0.0454


1804it [01:08, 26.65it/s]

Epoch 0, batch 1800: Loss = 0.0384


1906it [01:12, 26.55it/s]

Epoch 0, batch 1900: Loss = 0.0957


2005it [01:15, 25.72it/s]

Epoch 0, batch 2000: Loss = 0.0717


2104it [01:19, 27.64it/s]

Epoch 0, batch 2100: Loss = 0.0748


2206it [01:23, 26.78it/s]

Epoch 0, batch 2200: Loss = 0.0573


2305it [01:27, 26.38it/s]

Epoch 0, batch 2300: Loss = 0.0518


2404it [01:30, 26.26it/s]

Epoch 0, batch 2400: Loss = 0.0677


2503it [01:34, 25.42it/s]

Epoch 0, batch 2500: Loss = 0.0599


2590it [01:37, 26.45it/s]


Epoch 0: Cross-entropy: 157.1270


3it [00:00, 23.49it/s]

Epoch 1, batch 0: Loss = 0.0517


105it [00:03, 27.29it/s]

Epoch 1, batch 100: Loss = 0.0647


204it [00:07, 27.70it/s]

Epoch 1, batch 200: Loss = 0.0533


303it [00:11, 25.36it/s]

Epoch 1, batch 300: Loss = 0.0678


405it [00:15, 25.28it/s]

Epoch 1, batch 400: Loss = 0.0614


504it [00:19, 26.90it/s]

Epoch 1, batch 500: Loss = 0.0672


603it [00:22, 24.65it/s]

Epoch 1, batch 600: Loss = 0.0645


705it [00:26, 26.71it/s]

Epoch 1, batch 700: Loss = 0.0630


804it [00:30, 25.77it/s]

Epoch 1, batch 800: Loss = 0.0714


906it [00:34, 25.54it/s]

Epoch 1, batch 900: Loss = 0.0441


1005it [00:38, 21.47it/s]

Epoch 1, batch 1000: Loss = 0.0778


1104it [00:42, 27.92it/s]

Epoch 1, batch 1100: Loss = 0.0846


1206it [00:46, 27.99it/s]

Epoch 1, batch 1200: Loss = 0.0638


1305it [00:50, 27.13it/s]

Epoch 1, batch 1300: Loss = 0.0592


1407it [00:54, 26.60it/s]

Epoch 1, batch 1400: Loss = 0.0585


1506it [00:57, 25.07it/s]

Epoch 1, batch 1500: Loss = 0.0602


1605it [01:02, 24.96it/s]

Epoch 1, batch 1600: Loss = 0.0555


1704it [01:05, 28.13it/s]

Epoch 1, batch 1700: Loss = 0.0509


1803it [01:09, 25.49it/s]

Epoch 1, batch 1800: Loss = 0.0608


1905it [01:13, 27.57it/s]

Epoch 1, batch 1900: Loss = 0.0546


2004it [01:16, 28.33it/s]

Epoch 1, batch 2000: Loss = 0.0484


2106it [01:20, 26.75it/s]

Epoch 1, batch 2100: Loss = 0.0610


2205it [01:24, 27.67it/s]

Epoch 1, batch 2200: Loss = 0.0683


2304it [01:27, 26.43it/s]

Epoch 1, batch 2300: Loss = 0.0642


2406it [01:31, 25.91it/s]

Epoch 1, batch 2400: Loss = 0.0471


2505it [01:35, 25.04it/s]

Epoch 1, batch 2500: Loss = 0.0707


2590it [01:39, 26.10it/s]


Epoch 1: Cross-entropy: 137.4701


3it [00:00, 23.60it/s]

Epoch 2, batch 0: Loss = 0.0356


105it [00:03, 27.06it/s]

Epoch 2, batch 100: Loss = 0.0448


204it [00:07, 26.61it/s]

Epoch 2, batch 200: Loss = 0.0569


303it [00:11, 27.79it/s]

Epoch 2, batch 300: Loss = 0.0507


405it [00:15, 25.81it/s]

Epoch 2, batch 400: Loss = 0.0433


504it [00:18, 25.11it/s]

Epoch 2, batch 500: Loss = 0.0458


606it [00:22, 27.92it/s]

Epoch 2, batch 600: Loss = 0.0501


705it [00:26, 26.25it/s]

Epoch 2, batch 700: Loss = 0.0601


804it [00:30, 27.03it/s]

Epoch 2, batch 800: Loss = 0.0568


906it [00:34, 27.19it/s]

Epoch 2, batch 900: Loss = 0.0460


1005it [00:37, 27.02it/s]

Epoch 2, batch 1000: Loss = 0.0543


1104it [00:41, 27.45it/s]

Epoch 2, batch 1100: Loss = 0.0567


1206it [00:45, 27.19it/s]

Epoch 2, batch 1200: Loss = 0.0447


1305it [00:49, 25.83it/s]

Epoch 2, batch 1300: Loss = 0.0577


1407it [00:53, 28.55it/s]

Epoch 2, batch 1400: Loss = 0.0547


1506it [00:56, 27.21it/s]

Epoch 2, batch 1500: Loss = 0.0409


1604it [01:00, 19.58it/s]

Epoch 2, batch 1600: Loss = 0.0828


1706it [01:05, 27.42it/s]

Epoch 2, batch 1700: Loss = 0.0676


1805it [01:08, 27.08it/s]

Epoch 2, batch 1800: Loss = 0.0623


1904it [01:12, 26.05it/s]

Epoch 2, batch 1900: Loss = 0.0455


2006it [01:16, 27.84it/s]

Epoch 2, batch 2000: Loss = 0.0473


2105it [01:19, 26.65it/s]

Epoch 2, batch 2100: Loss = 0.0707


2204it [01:23, 27.30it/s]

Epoch 2, batch 2200: Loss = 0.0517


2306it [01:27, 27.61it/s]

Epoch 2, batch 2300: Loss = 0.0512


2405it [01:30, 27.53it/s]

Epoch 2, batch 2400: Loss = 0.0646


2504it [01:34, 26.28it/s]

Epoch 2, batch 2500: Loss = 0.0511


2590it [01:37, 26.48it/s]


Epoch 2: Cross-entropy: 128.8836


3it [00:00, 24.51it/s]

Epoch 3, batch 0: Loss = 0.0463


105it [00:03, 27.87it/s]

Epoch 3, batch 100: Loss = 0.0682


204it [00:07, 27.03it/s]

Epoch 3, batch 200: Loss = 0.0632


306it [00:11, 28.07it/s]

Epoch 3, batch 300: Loss = 0.0394


405it [00:14, 28.03it/s]

Epoch 3, batch 400: Loss = 0.0372


504it [00:18, 28.15it/s]

Epoch 3, batch 500: Loss = 0.0389


606it [00:21, 27.36it/s]

Epoch 3, batch 600: Loss = 0.0399


705it [00:25, 28.23it/s]

Epoch 3, batch 700: Loss = 0.0560


804it [00:29, 27.28it/s]

Epoch 3, batch 800: Loss = 0.0435


906it [00:32, 26.45it/s]

Epoch 3, batch 900: Loss = 0.0566


1005it [00:36, 26.06it/s]

Epoch 3, batch 1000: Loss = 0.0371


1107it [00:40, 27.29it/s]

Epoch 3, batch 1100: Loss = 0.0422


1203it [00:44, 26.89it/s]

Epoch 3, batch 1200: Loss = 0.0569


1305it [00:47, 27.55it/s]

Epoch 3, batch 1300: Loss = 0.0502


1404it [00:51, 27.38it/s]

Epoch 3, batch 1400: Loss = 0.0662


1506it [00:55, 26.76it/s]

Epoch 3, batch 1500: Loss = 0.0521


1605it [00:58, 27.40it/s]

Epoch 3, batch 1600: Loss = 0.0630


1704it [01:02, 28.04it/s]

Epoch 3, batch 1700: Loss = 0.0574


1806it [01:06, 27.79it/s]

Epoch 3, batch 1800: Loss = 0.0637


1905it [01:10, 27.87it/s]

Epoch 3, batch 1900: Loss = 0.0446


2004it [01:13, 27.19it/s]

Epoch 3, batch 2000: Loss = 0.0760


2106it [01:17, 27.19it/s]

Epoch 3, batch 2100: Loss = 0.0438


2205it [01:21, 27.95it/s]

Epoch 3, batch 2200: Loss = 0.0503


2307it [01:24, 28.00it/s]

Epoch 3, batch 2300: Loss = 0.0535


2406it [01:28, 27.26it/s]

Epoch 3, batch 2400: Loss = 0.0509


2505it [01:32, 27.18it/s]

Epoch 3, batch 2500: Loss = 0.0461


2590it [01:35, 27.18it/s]


Epoch 3: Cross-entropy: 123.8252


3it [00:00, 23.54it/s]

Epoch 4, batch 0: Loss = 0.0546


105it [00:04, 26.35it/s]

Epoch 4, batch 100: Loss = 0.0481


204it [00:07, 25.81it/s]

Epoch 4, batch 200: Loss = 0.0511


306it [00:11, 27.95it/s]

Epoch 4, batch 300: Loss = 0.0352


405it [00:15, 28.36it/s]

Epoch 4, batch 400: Loss = 0.0461


504it [00:18, 27.27it/s]

Epoch 4, batch 500: Loss = 0.0372


606it [00:22, 27.77it/s]

Epoch 4, batch 600: Loss = 0.0524


705it [00:25, 27.21it/s]

Epoch 4, batch 700: Loss = 0.0438


807it [00:29, 28.66it/s]

Epoch 4, batch 800: Loss = 0.0657


906it [00:33, 28.16it/s]

Epoch 4, batch 900: Loss = 0.0530


1005it [00:36, 27.06it/s]

Epoch 4, batch 1000: Loss = 0.0597


1104it [00:40, 28.02it/s]

Epoch 4, batch 1100: Loss = 0.0497


1206it [00:44, 27.08it/s]

Epoch 4, batch 1200: Loss = 0.0407


1305it [00:47, 27.40it/s]

Epoch 4, batch 1300: Loss = 0.0567


1404it [00:51, 26.59it/s]

Epoch 4, batch 1400: Loss = 0.0574


1506it [00:55, 26.86it/s]

Epoch 4, batch 1500: Loss = 0.0521


1605it [00:58, 27.26it/s]

Epoch 4, batch 1600: Loss = 0.0541


1704it [01:02, 28.26it/s]

Epoch 4, batch 1700: Loss = 0.0504


1806it [01:05, 29.11it/s]

Epoch 4, batch 1800: Loss = 0.0428


1905it [01:09, 28.80it/s]

Epoch 4, batch 1900: Loss = 0.0433


2004it [01:12, 27.84it/s]

Epoch 4, batch 2000: Loss = 0.0416


2106it [01:16, 29.08it/s]

Epoch 4, batch 2100: Loss = 0.0465


2205it [01:19, 28.59it/s]

Epoch 4, batch 2200: Loss = 0.0473


2307it [01:23, 29.01it/s]

Epoch 4, batch 2300: Loss = 0.0516


2406it [01:27, 26.41it/s]

Epoch 4, batch 2400: Loss = 0.0382


2505it [01:30, 26.55it/s]

Epoch 4, batch 2500: Loss = 0.0482


2590it [01:33, 27.55it/s]


Epoch 4: Cross-entropy: 120.2747


3it [00:00, 24.06it/s]

Epoch 5, batch 0: Loss = 0.0526


105it [00:03, 27.21it/s]

Epoch 5, batch 100: Loss = 0.0603


204it [00:07, 27.51it/s]

Epoch 5, batch 200: Loss = 0.0447


303it [00:11, 27.04it/s]

Epoch 5, batch 300: Loss = 0.0485


405it [00:15, 27.26it/s]

Epoch 5, batch 400: Loss = 0.0451


507it [00:18, 27.28it/s]

Epoch 5, batch 500: Loss = 0.0598


602it [00:22, 19.57it/s]

Epoch 5, batch 600: Loss = 0.0357


703it [00:27, 21.45it/s]

Epoch 5, batch 700: Loss = 0.0555


804it [00:32, 24.93it/s]

Epoch 5, batch 800: Loss = 0.0505


906it [00:36, 25.03it/s]

Epoch 5, batch 900: Loss = 0.0461


1005it [00:40, 26.23it/s]

Epoch 5, batch 1000: Loss = 0.0446


1104it [00:44, 26.43it/s]

Epoch 5, batch 1100: Loss = 0.0515


1206it [00:47, 26.10it/s]

Epoch 5, batch 1200: Loss = 0.0407


1305it [00:51, 26.65it/s]

Epoch 5, batch 1300: Loss = 0.0264


1404it [00:55, 26.11it/s]

Epoch 5, batch 1400: Loss = 0.0468


1506it [00:59, 26.90it/s]

Epoch 5, batch 1500: Loss = 0.0501


1605it [01:02, 26.61it/s]

Epoch 5, batch 1600: Loss = 0.0506


1704it [01:06, 27.11it/s]

Epoch 5, batch 1700: Loss = 0.0408


1806it [01:10, 27.42it/s]

Epoch 5, batch 1800: Loss = 0.0454


1905it [01:13, 27.44it/s]

Epoch 5, batch 1900: Loss = 0.0302


2004it [01:17, 26.05it/s]

Epoch 5, batch 2000: Loss = 0.0467


2106it [01:21, 28.22it/s]

Epoch 5, batch 2100: Loss = 0.0448


2205it [01:25, 26.51it/s]

Epoch 5, batch 2200: Loss = 0.0443


2304it [01:29, 26.48it/s]

Epoch 5, batch 2300: Loss = 0.0545


2406it [01:32, 25.79it/s]

Epoch 5, batch 2400: Loss = 0.0522


2505it [01:36, 25.97it/s]

Epoch 5, batch 2500: Loss = 0.0467


2590it [01:40, 25.89it/s]


Epoch 5: Cross-entropy: 118.4189


3it [00:00, 24.58it/s]

Epoch 6, batch 0: Loss = 0.0491


105it [00:04, 26.04it/s]

Epoch 6, batch 100: Loss = 0.0443


204it [00:07, 27.42it/s]

Epoch 6, batch 200: Loss = 0.0659


303it [00:11, 26.38it/s]

Epoch 6, batch 300: Loss = 0.0428


405it [00:15, 26.91it/s]

Epoch 6, batch 400: Loss = 0.0600


504it [00:19, 25.64it/s]

Epoch 6, batch 500: Loss = 0.0698


606it [00:23, 26.79it/s]

Epoch 6, batch 600: Loss = 0.0485


705it [00:26, 25.72it/s]

Epoch 6, batch 700: Loss = 0.0333


804it [00:30, 26.19it/s]

Epoch 6, batch 800: Loss = 0.0393


906it [00:34, 26.66it/s]

Epoch 6, batch 900: Loss = 0.0434


1005it [00:38, 24.64it/s]

Epoch 6, batch 1000: Loss = 0.0566


1104it [00:42, 24.53it/s]

Epoch 6, batch 1100: Loss = 0.0464


1206it [00:46, 26.15it/s]

Epoch 6, batch 1200: Loss = 0.0464


1305it [00:50, 26.12it/s]

Epoch 6, batch 1300: Loss = 0.0420


1404it [00:53, 26.64it/s]

Epoch 6, batch 1400: Loss = 0.0356


1506it [00:57, 26.46it/s]

Epoch 6, batch 1500: Loss = 0.0494


1605it [01:01, 24.98it/s]

Epoch 6, batch 1600: Loss = 0.0569


1704it [01:05, 26.04it/s]

Epoch 6, batch 1700: Loss = 0.0352


1803it [01:09, 25.86it/s]

Epoch 6, batch 1800: Loss = 0.0612


1905it [01:13, 25.57it/s]

Epoch 6, batch 1900: Loss = 0.0406


2004it [01:17, 27.50it/s]

Epoch 6, batch 2000: Loss = 0.0317


2106it [01:20, 26.39it/s]

Epoch 6, batch 2100: Loss = 0.0523


2205it [01:24, 25.56it/s]

Epoch 6, batch 2200: Loss = 0.0611


2304it [01:28, 24.57it/s]

Epoch 6, batch 2300: Loss = 0.0480


2406it [01:32, 26.06it/s]

Epoch 6, batch 2400: Loss = 0.0491


2505it [01:36, 26.42it/s]

Epoch 6, batch 2500: Loss = 0.0310


2590it [01:39, 25.98it/s]


In [None]:
val_dataset = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv', device=device)   

val_loader = data.DataLoader(val_dataset, shuffle=True, batch_size=batch_size)

# evaluaate accuracy on validation set


model.eval()
correct = 0
total = 0

with torch.no_grad():
    for (X_batch,y_batch) in val_loader:
        is_padding = X_batch == val_dataset.char_encoder.get_pad_token()
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        _, predicted = torch.max(y_pred.data, 1)
        # Count only non-padding characters
        total += torch.sum(~is_padding).item()
        
        # Count correct predictions
        correct += torch.sum((predicted == y_batch) & (~is_padding)).item()
print("Accuracy: %.2f%%" % (100 * correct / total))



w = 500
Accuracy: 83.73%


In [None]:
print('DER of the network on the validation set: %d %%' % (100 * (1 - correct / total)))


DER of the network on the validation set: 16 %
