In [1]:
import torch
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from letters_dataset import LettersDataset
from text_encoder import TextEncoder
import torch.nn as nn
from train_collections import DS_ARABIC_LETTERS, DS_HARAKAT
import numpy as np
from tqdm import tqdm

# autoreload notebook
%load_ext autoreload

In [25]:



dim_vocab = len(DS_ARABIC_LETTERS)
dim_out = len(DS_HARAKAT) + 2
embedding_dim = 64
n_epochs = 10
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [26]:

dataset = LettersDataset(device=device)
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

# load val data
# da = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv')


w = 495


In [27]:




class CharModel(nn.Module):
    def __init__(self):
        super().__init__()

        # embedding and LSTM layers
        self.embedding = nn.Embedding(dim_vocab, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=256,
                            num_layers=1, batch_first=True )
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, dim_out)

    def forward(self, x):
        # pass thru embedding layer
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(self.dropout(x))
        return x



model = CharModel().to(device)

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
num_batches = len(loader)
print("Number of batches:", num_batches)
best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    for i, batch in enumerate(loader):
        X_batch = batch["input"]
        y_batch = batch["output"]
        y_pred = ''
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        # print(y_pred.shape)
        # print(y_batch.shape)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print("Epoch %d, batch %d: Loss = %.4f" % (epoch, i, loss))
        
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for batch in loader:
            X_batch = batch["input"]
            y_batch = batch["output"]
            y_pred = model(X_batch)
            y_pred = y_pred.transpose(1, 2) 
            
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))


Number of batches: 2590


1it [00:00,  6.07it/s]

Epoch 0, batch 0: Loss = 2.8615


106it [00:03, 27.23it/s]

Epoch 0, batch 100: Loss = 0.1884


205it [00:07, 29.02it/s]

Epoch 0, batch 200: Loss = 0.1615


307it [00:11, 28.36it/s]

Epoch 0, batch 300: Loss = 0.1526


406it [00:14, 27.35it/s]

Epoch 0, batch 400: Loss = 0.1144


505it [00:18, 27.50it/s]

Epoch 0, batch 500: Loss = 0.1063


604it [00:21, 26.94it/s]

Epoch 0, batch 600: Loss = 0.0903


706it [00:25, 28.02it/s]

Epoch 0, batch 700: Loss = 0.0776


805it [00:29, 27.70it/s]

Epoch 0, batch 800: Loss = 0.0988


904it [00:32, 28.30it/s]

Epoch 0, batch 900: Loss = 0.0607


1006it [00:36, 26.86it/s]

Epoch 0, batch 1000: Loss = 0.0844


1105it [00:40, 28.24it/s]

Epoch 0, batch 1100: Loss = 0.0884


1204it [00:43, 27.93it/s]

Epoch 0, batch 1200: Loss = 0.0798


1306it [00:47, 27.23it/s]

Epoch 0, batch 1300: Loss = 0.0648


1405it [00:51, 28.06it/s]

Epoch 0, batch 1400: Loss = 0.0723


1504it [00:54, 27.42it/s]

Epoch 0, batch 1500: Loss = 0.0925


1606it [00:58, 28.07it/s]

Epoch 0, batch 1600: Loss = 0.0778


1705it [01:02, 26.72it/s]

Epoch 0, batch 1700: Loss = 0.0727


1804it [01:05, 27.80it/s]

Epoch 0, batch 1800: Loss = 0.0744


1906it [01:09, 27.71it/s]

Epoch 0, batch 1900: Loss = 0.0773


2005it [01:13, 27.48it/s]

Epoch 0, batch 2000: Loss = 0.0419


2107it [01:16, 28.27it/s]

Epoch 0, batch 2100: Loss = 0.0625


2206it [01:20, 27.32it/s]

Epoch 0, batch 2200: Loss = 0.1002


2305it [01:24, 24.53it/s]

Epoch 0, batch 2300: Loss = 0.0819


2404it [01:28, 25.54it/s]

Epoch 0, batch 2400: Loss = 0.0669


2506it [01:31, 27.77it/s]

Epoch 0, batch 2500: Loss = 0.0593


2590it [01:35, 27.26it/s]


Epoch 0: Cross-entropy: 164.6821


3it [00:00, 22.39it/s]

Epoch 1, batch 0: Loss = 0.0506


105it [00:03, 27.04it/s]

Epoch 1, batch 100: Loss = 0.0674


204it [00:07, 26.67it/s]

Epoch 1, batch 200: Loss = 0.0720


306it [00:11, 27.17it/s]

Epoch 1, batch 300: Loss = 0.0808


405it [00:15, 26.65it/s]

Epoch 1, batch 400: Loss = 0.0778


504it [00:18, 27.25it/s]

Epoch 1, batch 500: Loss = 0.0776


606it [00:22, 26.25it/s]

Epoch 1, batch 600: Loss = 0.0513


705it [00:26, 27.95it/s]

Epoch 1, batch 700: Loss = 0.0667


804it [00:29, 27.98it/s]

Epoch 1, batch 800: Loss = 0.0924


906it [00:33, 26.49it/s]

Epoch 1, batch 900: Loss = 0.0773


1005it [00:37, 28.39it/s]

Epoch 1, batch 1000: Loss = 0.0498


1104it [00:40, 26.13it/s]

Epoch 1, batch 1100: Loss = 0.0614


1203it [00:44, 26.93it/s]

Epoch 1, batch 1200: Loss = 0.0586


1305it [00:48, 27.79it/s]

Epoch 1, batch 1300: Loss = 0.0597


1404it [00:51, 27.02it/s]

Epoch 1, batch 1400: Loss = 0.0462


1506it [00:55, 26.56it/s]

Epoch 1, batch 1500: Loss = 0.0642


1605it [00:59, 27.56it/s]

Epoch 1, batch 1600: Loss = 0.0507


1704it [01:02, 26.42it/s]

Epoch 1, batch 1700: Loss = 0.0447


1806it [01:06, 27.43it/s]

Epoch 1, batch 1800: Loss = 0.0592


1905it [01:10, 26.72it/s]

Epoch 1, batch 1900: Loss = 0.0540


2004it [01:14, 26.78it/s]

Epoch 1, batch 2000: Loss = 0.0479


2106it [01:17, 27.28it/s]

Epoch 1, batch 2100: Loss = 0.0662


2205it [01:21, 27.56it/s]

Epoch 1, batch 2200: Loss = 0.0423


2304it [01:25, 26.67it/s]

Epoch 1, batch 2300: Loss = 0.0582


2406it [01:28, 26.96it/s]

Epoch 1, batch 2400: Loss = 0.0523


2505it [01:32, 26.68it/s]

Epoch 1, batch 2500: Loss = 0.0388


2590it [01:35, 27.00it/s]


Epoch 1: Cross-entropy: 140.3250


3it [00:00, 25.23it/s]

Epoch 2, batch 0: Loss = 0.0585


105it [00:03, 28.09it/s]

Epoch 2, batch 100: Loss = 0.0498


204it [00:07, 27.21it/s]

Epoch 2, batch 200: Loss = 0.0451


306it [00:11, 26.90it/s]

Epoch 2, batch 300: Loss = 0.0729


405it [00:14, 27.61it/s]

Epoch 2, batch 400: Loss = 0.0788


504it [00:18, 26.13it/s]

Epoch 2, batch 500: Loss = 0.0640


606it [00:22, 26.20it/s]

Epoch 2, batch 600: Loss = 0.0571


705it [00:25, 27.43it/s]

Epoch 2, batch 700: Loss = 0.0628


807it [00:29, 28.16it/s]

Epoch 2, batch 800: Loss = 0.0558


906it [00:33, 27.22it/s]

Epoch 2, batch 900: Loss = 0.0544


1005it [00:36, 27.63it/s]

Epoch 2, batch 1000: Loss = 0.0626


1104it [00:40, 26.99it/s]

Epoch 2, batch 1100: Loss = 0.0575


1206it [00:44, 27.65it/s]

Epoch 2, batch 1200: Loss = 0.0497


1305it [00:48, 27.98it/s]

Epoch 2, batch 1300: Loss = 0.0466


1404it [00:51, 26.66it/s]

Epoch 2, batch 1400: Loss = 0.0531


1506it [00:55, 27.16it/s]

Epoch 2, batch 1500: Loss = 0.0483


1605it [00:59, 27.05it/s]

Epoch 2, batch 1600: Loss = 0.0560


1704it [01:02, 25.76it/s]

Epoch 2, batch 1700: Loss = 0.0504


1803it [01:06, 26.59it/s]

Epoch 2, batch 1800: Loss = 0.0467


1905it [01:10, 26.36it/s]

Epoch 2, batch 1900: Loss = 0.0519


2004it [01:14, 26.46it/s]

Epoch 2, batch 2000: Loss = 0.0445


2106it [01:18, 28.69it/s]

Epoch 2, batch 2100: Loss = 0.0529


2205it [01:21, 28.35it/s]

Epoch 2, batch 2200: Loss = 0.0407


2307it [01:25, 28.30it/s]

Epoch 2, batch 2300: Loss = 0.0665


2403it [01:29, 21.47it/s]

Epoch 2, batch 2400: Loss = 0.0594


2505it [01:33, 21.38it/s]

Epoch 2, batch 2500: Loss = 0.0473


2590it [01:37, 26.56it/s]


Epoch 2: Cross-entropy: 129.9703


3it [00:00, 22.76it/s]

Epoch 3, batch 0: Loss = 0.0376


105it [00:04, 22.49it/s]

Epoch 3, batch 100: Loss = 0.0570


204it [00:08, 23.33it/s]

Epoch 3, batch 200: Loss = 0.0407


303it [00:13, 24.14it/s]

Epoch 3, batch 300: Loss = 0.0563


405it [00:17, 22.61it/s]

Epoch 3, batch 400: Loss = 0.0538


504it [00:22, 22.98it/s]

Epoch 3, batch 500: Loss = 0.0343


603it [00:26, 23.98it/s]

Epoch 3, batch 600: Loss = 0.0628


705it [00:31, 23.83it/s]

Epoch 3, batch 700: Loss = 0.0456


804it [00:35, 24.00it/s]

Epoch 3, batch 800: Loss = 0.0477


906it [00:39, 26.78it/s]

Epoch 3, batch 900: Loss = 0.0678


1005it [00:43, 27.53it/s]

Epoch 3, batch 1000: Loss = 0.0514


1104it [00:46, 27.07it/s]

Epoch 3, batch 1100: Loss = 0.0489


1203it [00:50, 27.75it/s]

Epoch 3, batch 1200: Loss = 0.0620


1305it [00:54, 28.47it/s]

Epoch 3, batch 1300: Loss = 0.0433


1404it [00:57, 28.11it/s]

Epoch 3, batch 1400: Loss = 0.0488


1506it [01:01, 27.01it/s]

Epoch 3, batch 1500: Loss = 0.0483


1605it [01:04, 28.42it/s]

Epoch 3, batch 1600: Loss = 0.0635


1704it [01:08, 27.08it/s]

Epoch 3, batch 1700: Loss = 0.0518


1806it [01:12, 27.01it/s]

Epoch 3, batch 1800: Loss = 0.0442


1905it [01:16, 26.83it/s]

Epoch 3, batch 1900: Loss = 0.0502


2004it [01:19, 27.23it/s]

Epoch 3, batch 2000: Loss = 0.0452


2106it [01:23, 28.04it/s]

Epoch 3, batch 2100: Loss = 0.0506


2205it [01:26, 28.45it/s]

Epoch 3, batch 2200: Loss = 0.0584


2304it [01:30, 27.89it/s]

Epoch 3, batch 2300: Loss = 0.0381


2406it [01:34, 28.63it/s]

Epoch 3, batch 2400: Loss = 0.0424


2505it [01:37, 27.57it/s]

Epoch 3, batch 2500: Loss = 0.0470


2590it [01:40, 25.75it/s]


Epoch 3: Cross-entropy: 124.9516


3it [00:00, 25.68it/s]

Epoch 4, batch 0: Loss = 0.0530


105it [00:03, 27.81it/s]

Epoch 4, batch 100: Loss = 0.0420


204it [00:07, 27.63it/s]

Epoch 4, batch 200: Loss = 0.0666


306it [00:11, 27.44it/s]

Epoch 4, batch 300: Loss = 0.0583


405it [00:14, 27.04it/s]

Epoch 4, batch 400: Loss = 0.0528


504it [00:18, 28.52it/s]

Epoch 4, batch 500: Loss = 0.0454


606it [00:21, 27.28it/s]

Epoch 4, batch 600: Loss = 0.0456


705it [00:25, 28.51it/s]

Epoch 4, batch 700: Loss = 0.0483


804it [00:28, 28.70it/s]

Epoch 4, batch 800: Loss = 0.0476


906it [00:32, 28.21it/s]

Epoch 4, batch 900: Loss = 0.0508


1005it [00:36, 28.40it/s]

Epoch 4, batch 1000: Loss = 0.0437


1104it [00:39, 28.24it/s]

Epoch 4, batch 1100: Loss = 0.0522


1206it [00:43, 28.63it/s]

Epoch 4, batch 1200: Loss = 0.0508


1305it [00:46, 27.84it/s]

Epoch 4, batch 1300: Loss = 0.0512


1404it [00:50, 26.83it/s]

Epoch 4, batch 1400: Loss = 0.0517


1506it [00:54, 27.45it/s]

Epoch 4, batch 1500: Loss = 0.0453


1605it [00:57, 28.00it/s]

Epoch 4, batch 1600: Loss = 0.0368


1704it [01:01, 27.73it/s]

Epoch 4, batch 1700: Loss = 0.0465


1806it [01:05, 26.79it/s]

Epoch 4, batch 1800: Loss = 0.0602


1905it [01:09, 26.37it/s]

Epoch 4, batch 1900: Loss = 0.0681


2004it [01:12, 26.86it/s]

Epoch 4, batch 2000: Loss = 0.0458


2106it [01:16, 24.94it/s]

Epoch 4, batch 2100: Loss = 0.0516


2205it [01:20, 27.62it/s]

Epoch 4, batch 2200: Loss = 0.0400


2304it [01:23, 27.63it/s]

Epoch 4, batch 2300: Loss = 0.0374


2406it [01:27, 26.53it/s]

Epoch 4, batch 2400: Loss = 0.0620


2505it [01:31, 26.91it/s]

Epoch 4, batch 2500: Loss = 0.0439


2590it [01:34, 27.36it/s]


Epoch 4: Cross-entropy: 121.1431


3it [00:00, 24.15it/s]

Epoch 5, batch 0: Loss = 0.0427


105it [00:03, 27.75it/s]

Epoch 5, batch 100: Loss = 0.0417


204it [00:07, 27.70it/s]

Epoch 5, batch 200: Loss = 0.0393


306it [00:11, 27.47it/s]

Epoch 5, batch 300: Loss = 0.0588


405it [00:14, 28.38it/s]

Epoch 5, batch 400: Loss = 0.0505


504it [00:18, 27.66it/s]

Epoch 5, batch 500: Loss = 0.0473


606it [00:22, 28.25it/s]

Epoch 5, batch 600: Loss = 0.0437


705it [00:25, 28.43it/s]

Epoch 5, batch 700: Loss = 0.0440


804it [00:29, 28.17it/s]

Epoch 5, batch 800: Loss = 0.0468


906it [00:32, 27.70it/s]

Epoch 5, batch 900: Loss = 0.0575


1005it [00:36, 27.12it/s]

Epoch 5, batch 1000: Loss = 0.0587


1104it [00:39, 28.01it/s]

Epoch 5, batch 1100: Loss = 0.0538


1206it [00:43, 27.80it/s]

Epoch 5, batch 1200: Loss = 0.0439


1305it [00:47, 24.87it/s]

Epoch 5, batch 1300: Loss = 0.0550


1404it [00:51, 27.70it/s]

Epoch 5, batch 1400: Loss = 0.0614


1506it [00:55, 27.51it/s]

Epoch 5, batch 1500: Loss = 0.0575


1605it [00:58, 27.70it/s]

Epoch 5, batch 1600: Loss = 0.0374


1704it [01:02, 27.56it/s]

Epoch 5, batch 1700: Loss = 0.0582


1806it [01:06, 28.76it/s]

Epoch 5, batch 1800: Loss = 0.0389


1905it [01:09, 29.07it/s]

Epoch 5, batch 1900: Loss = 0.0579


2004it [01:12, 28.71it/s]

Epoch 5, batch 2000: Loss = 0.0370


2106it [01:16, 28.92it/s]

Epoch 5, batch 2100: Loss = 0.0454


2205it [01:19, 28.78it/s]

Epoch 5, batch 2200: Loss = 0.0450


2304it [01:23, 28.73it/s]

Epoch 5, batch 2300: Loss = 0.0422


2406it [01:26, 28.86it/s]

Epoch 5, batch 2400: Loss = 0.0550


2505it [01:30, 28.82it/s]

Epoch 5, batch 2500: Loss = 0.0610


2590it [01:33, 27.73it/s]


Epoch 5: Cross-entropy: 118.4245


3it [00:00, 25.19it/s]

Epoch 6, batch 0: Loss = 0.0433


105it [00:03, 27.29it/s]

Epoch 6, batch 100: Loss = 0.0710


204it [00:07, 27.88it/s]

Epoch 6, batch 200: Loss = 0.0449


306it [00:11, 26.92it/s]

Epoch 6, batch 300: Loss = 0.0321


405it [00:14, 27.38it/s]

Epoch 6, batch 400: Loss = 0.0471


507it [00:18, 28.32it/s]

Epoch 6, batch 500: Loss = 0.0508


606it [00:21, 28.09it/s]

Epoch 6, batch 600: Loss = 0.0362


705it [00:25, 26.07it/s]

Epoch 6, batch 700: Loss = 0.0443


804it [00:29, 25.71it/s]

Epoch 6, batch 800: Loss = 0.0375


906it [00:33, 27.42it/s]

Epoch 6, batch 900: Loss = 0.0628


1005it [00:36, 28.02it/s]

Epoch 6, batch 1000: Loss = 0.0518


1104it [00:40, 27.31it/s]

Epoch 6, batch 1100: Loss = 0.0518


1206it [00:44, 27.87it/s]

Epoch 6, batch 1200: Loss = 0.0530


1305it [00:47, 25.96it/s]

Epoch 6, batch 1300: Loss = 0.0387


1404it [00:51, 26.57it/s]

Epoch 6, batch 1400: Loss = 0.0447


1506it [00:55, 27.67it/s]

Epoch 6, batch 1500: Loss = 0.0566


1605it [00:58, 27.58it/s]

Epoch 6, batch 1600: Loss = 0.0624


1704it [01:02, 27.81it/s]

Epoch 6, batch 1700: Loss = 0.0462


1806it [01:06, 26.87it/s]

Epoch 6, batch 1800: Loss = 0.0405


1905it [01:09, 26.35it/s]

Epoch 6, batch 1900: Loss = 0.0407


2004it [01:13, 27.69it/s]

Epoch 6, batch 2000: Loss = 0.0503


2106it [01:16, 26.85it/s]

Epoch 6, batch 2100: Loss = 0.0623


2205it [01:20, 26.85it/s]

Epoch 6, batch 2200: Loss = 0.0499


2304it [01:24, 27.54it/s]

Epoch 6, batch 2300: Loss = 0.0291


2406it [01:28, 26.45it/s]

Epoch 6, batch 2400: Loss = 0.0576


2505it [01:31, 28.09it/s]

Epoch 6, batch 2500: Loss = 0.0465


2590it [01:34, 27.31it/s]


Epoch 6: Cross-entropy: 116.4141


3it [00:00, 25.95it/s]

Epoch 7, batch 0: Loss = 0.0406


105it [00:03, 27.78it/s]

Epoch 7, batch 100: Loss = 0.0552


204it [00:07, 28.30it/s]

Epoch 7, batch 200: Loss = 0.0500


306it [00:11, 27.06it/s]

Epoch 7, batch 300: Loss = 0.0372


405it [00:14, 28.04it/s]

Epoch 7, batch 400: Loss = 0.0627


504it [00:18, 28.00it/s]

Epoch 7, batch 500: Loss = 0.0436


606it [00:21, 28.74it/s]

Epoch 7, batch 600: Loss = 0.0533


705it [00:25, 28.62it/s]

Epoch 7, batch 700: Loss = 0.0323


804it [00:28, 27.74it/s]

Epoch 7, batch 800: Loss = 0.0406


906it [00:32, 28.11it/s]

Epoch 7, batch 900: Loss = 0.0336


1005it [00:35, 28.37it/s]

Epoch 7, batch 1000: Loss = 0.0346


1104it [00:39, 27.77it/s]

Epoch 7, batch 1100: Loss = 0.0443


1206it [00:43, 28.23it/s]

Epoch 7, batch 1200: Loss = 0.0435


1305it [00:46, 27.92it/s]

Epoch 7, batch 1300: Loss = 0.0423


1407it [00:50, 28.01it/s]

Epoch 7, batch 1400: Loss = 0.0335


1506it [00:54, 27.65it/s]

Epoch 7, batch 1500: Loss = 0.0494


1605it [00:57, 26.93it/s]

Epoch 7, batch 1600: Loss = 0.0458


1707it [01:01, 28.23it/s]

Epoch 7, batch 1700: Loss = 0.0498


1806it [01:04, 27.72it/s]

Epoch 7, batch 1800: Loss = 0.0570


1905it [01:08, 27.61it/s]

Epoch 7, batch 1900: Loss = 0.0476


2007it [01:12, 27.23it/s]

Epoch 7, batch 2000: Loss = 0.0431


2106it [01:15, 26.47it/s]

Epoch 7, batch 2100: Loss = 0.0413


2205it [01:19, 28.24it/s]

Epoch 7, batch 2200: Loss = 0.0492


2304it [01:22, 27.50it/s]

Epoch 7, batch 2300: Loss = 0.0492


2403it [01:26, 27.61it/s]

Epoch 7, batch 2400: Loss = 0.0382


2505it [01:30, 26.95it/s]

Epoch 7, batch 2500: Loss = 0.0532


2590it [01:33, 27.78it/s]


Epoch 7: Cross-entropy: 114.9561


3it [00:00, 25.57it/s]

Epoch 8, batch 0: Loss = 0.0430


105it [00:03, 27.59it/s]

Epoch 8, batch 100: Loss = 0.0411


207it [00:07, 28.32it/s]

Epoch 8, batch 200: Loss = 0.0544


306it [00:11, 27.48it/s]

Epoch 8, batch 300: Loss = 0.0338


405it [00:14, 27.69it/s]

Epoch 8, batch 400: Loss = 0.0401


504it [00:18, 27.70it/s]

Epoch 8, batch 500: Loss = 0.0526


606it [00:22, 26.53it/s]

Epoch 8, batch 600: Loss = 0.0628


705it [00:25, 27.41it/s]

Epoch 8, batch 700: Loss = 0.0369


804it [00:29, 27.56it/s]

Epoch 8, batch 800: Loss = 0.0650


906it [00:32, 27.29it/s]

Epoch 8, batch 900: Loss = 0.0471


1005it [00:36, 27.65it/s]

Epoch 8, batch 1000: Loss = 0.0502


1104it [00:40, 27.38it/s]

Epoch 8, batch 1100: Loss = 0.0490


1206it [00:43, 28.16it/s]

Epoch 8, batch 1200: Loss = 0.0570


1305it [00:47, 27.69it/s]

Epoch 8, batch 1300: Loss = 0.0357


1407it [00:51, 27.65it/s]

Epoch 8, batch 1400: Loss = 0.0441


1506it [00:54, 28.48it/s]

Epoch 8, batch 1500: Loss = 0.0521


1605it [00:58, 27.43it/s]

Epoch 8, batch 1600: Loss = 0.0464


1704it [01:01, 28.29it/s]

Epoch 8, batch 1700: Loss = 0.0648


1806it [01:05, 28.39it/s]

Epoch 8, batch 1800: Loss = 0.0454


1905it [01:09, 28.23it/s]

Epoch 8, batch 1900: Loss = 0.0433


2004it [01:12, 26.79it/s]

Epoch 8, batch 2000: Loss = 0.0475


2106it [01:16, 27.85it/s]

Epoch 8, batch 2100: Loss = 0.0380


2205it [01:19, 27.76it/s]

Epoch 8, batch 2200: Loss = 0.0332


2304it [01:23, 27.18it/s]

Epoch 8, batch 2300: Loss = 0.0466


2406it [01:27, 27.20it/s]

Epoch 8, batch 2400: Loss = 0.0381


2505it [01:30, 27.55it/s]

Epoch 8, batch 2500: Loss = 0.0577


2590it [01:33, 27.62it/s]


Epoch 8: Cross-entropy: 113.7928


3it [00:00, 24.93it/s]

Epoch 9, batch 0: Loss = 0.0467


105it [00:03, 27.20it/s]

Epoch 9, batch 100: Loss = 0.0397


204it [00:07, 26.97it/s]

Epoch 9, batch 200: Loss = 0.0436


306it [00:11, 25.92it/s]

Epoch 9, batch 300: Loss = 0.0551


405it [00:15, 25.96it/s]

Epoch 9, batch 400: Loss = 0.0456


504it [00:19, 25.98it/s]

Epoch 9, batch 500: Loss = 0.0412


606it [00:23, 27.43it/s]

Epoch 9, batch 600: Loss = 0.0431


705it [00:26, 26.99it/s]

Epoch 9, batch 700: Loss = 0.0432


804it [00:30, 26.83it/s]

Epoch 9, batch 800: Loss = 0.0555


906it [00:34, 27.38it/s]

Epoch 9, batch 900: Loss = 0.0366


1005it [00:37, 27.51it/s]

Epoch 9, batch 1000: Loss = 0.0713


1104it [00:41, 27.35it/s]

Epoch 9, batch 1100: Loss = 0.0578


1206it [00:45, 27.38it/s]

Epoch 9, batch 1200: Loss = 0.0343


1305it [00:48, 27.84it/s]

Epoch 9, batch 1300: Loss = 0.0593


1407it [00:52, 27.95it/s]

Epoch 9, batch 1400: Loss = 0.0354


1506it [00:56, 28.41it/s]

Epoch 9, batch 1500: Loss = 0.0342


1605it [00:59, 28.20it/s]

Epoch 9, batch 1600: Loss = 0.0474


1707it [01:03, 27.73it/s]

Epoch 9, batch 1700: Loss = 0.0618


1806it [01:07, 28.38it/s]

Epoch 9, batch 1800: Loss = 0.0518


1905it [01:10, 28.52it/s]

Epoch 9, batch 1900: Loss = 0.0458


2004it [01:14, 25.80it/s]

Epoch 9, batch 2000: Loss = 0.0557


2106it [01:17, 27.62it/s]

Epoch 9, batch 2100: Loss = 0.0537


2205it [01:21, 28.01it/s]

Epoch 9, batch 2200: Loss = 0.0382


2304it [01:24, 28.62it/s]

Epoch 9, batch 2300: Loss = 0.0560


2406it [01:28, 28.88it/s]

Epoch 9, batch 2400: Loss = 0.0456


2505it [01:31, 27.80it/s]

Epoch 9, batch 2500: Loss = 0.0485


2590it [01:34, 27.28it/s]


Epoch 9: Cross-entropy: 112.9521


In [28]:
val_dataset = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv', device=device)   

val_loader = data.DataLoader(val_dataset, shuffle=True, batch_size=batch_size)

# evaluaate accuracy on validation set


model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in val_loader:
        X_batch = batch["input"]
        y_batch = batch["output"]
        is_padding = X_batch == val_dataset.char_encoder.get_pad_token()
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        _, predicted = torch.max(y_pred.data, 1)
        # Count only non-padding characters
        total += torch.sum(~is_padding).item()
        
        # Count correct predictions
        correct += torch.sum((predicted == y_batch) & (~is_padding)).item()
print("Accuracy: %.2f%%" % (100 * correct / total))



w = 500
Accuracy: 86.74%


In [29]:
print('DER of the network on the validation set: %d %%' % (100 * (1 - correct / total)))


DER of the network on the validation set: 13 %
