In [1]:
import torch
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from letters_dataset import LettersDataset
from text_encoder import TextEncoder
import torch.nn as nn
from train_collections import *
import numpy as np
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# autoreload notebook
%load_ext autoreload

In [6]:
embedding_dim = 64
n_epochs = 2
n_hidden = 128
batch_size = 64

In [2]:


dataset = LettersDataset(device=device)
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

# load val data
# da = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv')

w = 495


In [3]:
n_chars = dataset.get_input_vocab_size()
n_harakat = dataset.get_output_vocab_size()
n_harakat

17

In [7]:
class CharModel(nn.Module):
    def __init__(self):
        super().__init__()

        # embedding and LSTM layers
        self.embedding = nn.Embedding(n_chars, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=256,
                            num_layers=1, batch_first=True )
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, n_harakat)

    def forward(self, x):
        # pass thru embedding layer
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(self.dropout(x))
        return x

class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.embedding = nn.Embedding(n_chars, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=n_hidden, bidirectional=True)
        self.linear = nn.Linear(2*n_hidden, n_harakat)
        self.dropout = nn.Dropout(0.2)
        


    def forward(self, x):
        # pass thru embedding layer
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(self.dropout(x))
        return x

model = BiLSTM().to(device)

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.char_encoder.get_pad_id())
num_batches = len(loader)
print("Number of batches:", num_batches)
best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    for i, (X_batch,y_batch) in tqdm(enumerate(loader)):
        y_pred = ''
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        # print(y_pred.shape)
        # print(y_batch.shape)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print("Epoch %d, batch %d: Loss = %.4f" % (epoch, i, loss))
        
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for (X_batch,y_batch) in loader:
            y_pred = model(X_batch)
            y_pred = y_pred.transpose(1, 2) 
            
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))


Number of batches: 2590


0it [00:00, ?it/s]

1it [00:01,  1.70s/it]

Epoch 0, batch 0: Loss = 2.8833


101it [01:43,  1.01it/s]

Epoch 0, batch 100: Loss = 0.1719


195it [03:22,  1.08s/it]

In [None]:
val_dataset = LettersDataset('clean_out/X_val.csv', 'clean_out/y_val.csv', device=device)   

val_loader = data.DataLoader(val_dataset, shuffle=True, batch_size=batch_size)

# evaluaate accuracy on validation set


model.eval()
correct = 0
total = 0

with torch.no_grad():
    for (X_batch,y_batch) in val_loader:
        is_padding = X_batch == val_dataset.char_encoder.get_pad_id()
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        _, predicted = torch.max(y_pred.data, 1)
        # Count only non-padding characters
        total += torch.sum(~is_padding).item()
        
        # Count correct predictions
        correct += torch.sum((predicted == y_batch) & (~is_padding)).item()
print("Accuracy: %.2f%%" % (100 * correct / total))



w = 500
Accuracy: 83.73%


In [None]:
print('DER of the network on the validation set: %d %%' % (100 * (1 - correct / total)))


DER of the network on the validation set: 16 %
