In [25]:
import torch
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from letters_dataset import LettersDataset
from text_encoder import TextEncoder
import torch.nn as nn
from train_collections import DS_ARABIC_LETTERS, DS_HARAKAT
import numpy as np
from tqdm import tqdm

In [20]:

dim_vocab = len(DS_ARABIC_LETTERS)
dim_out = len(DS_HARAKAT) + 2
embedding_dim = 64
n_epochs = 10
batch_size = 128


In [4]:

dataset = LettersDataset()
loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size)


w = 495


In [26]:




class CharModel(nn.Module):
    def __init__(self):
        super().__init__()

        # embedding and LSTM layers
        self.embedding = nn.Embedding(dim_vocab, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=256,
                            num_layers=1, batch_first=True)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, dim_out)

    def forward(self, x):
        # pass thru embedding layer
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(self.dropout(x))
        return x



model = CharModel()

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    for i, batch in tqdm(enumerate(loader)):
        X_batch = batch["input"]
        y_batch = batch["output"]
        y_pred = model(X_batch)
        y_pred = y_pred.transpose(1, 2) 
        # print(y_pred.shape)
        # print(y_batch.shape)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print("Epoch %d, batch %d: Loss = %.4f" % (epoch, i, loss))
        
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            y_pred = model(X_batch)
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))


0it [00:00, ?it/s]

1it [00:00,  1.38it/s]

Epoch 0, batch 0: Loss = 2.8491


101it [01:11,  1.43it/s]

Epoch 0, batch 100: Loss = 0.1338


138it [01:38,  1.40it/s]


KeyboardInterrupt: 