In [1]:
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader, random_split
from gloss_dataset import GlossDataset
from gloss_model import GlossModel
from torch.utils.tensorboard import SummaryWriter


In [2]:
# get data for training
gd = GlossDataset()
input_size = gd[0][0].shape[1]
class_no = len(gd.classes)
input_size, class_no

(1596, 4)

In [3]:
# provide input and class size
model = GlossModel(input_size, class_no)
optim = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
model


GlossModel(
  (softmax): Softmax(dim=1)
  (relu): ReLU()
  (lstm1): LSTM(1596, 128, batch_first=True)
  (dropout1): Dropout(p=0.15, inplace=False)
  (lstm2): LSTM(128, 64, batch_first=True)
  (dropout2): Dropout(p=0.15, inplace=False)
  (fc1): Linear(in_features=64, out_features=32, bias=True)
  (dropout3): Dropout(p=0.15, inplace=False)
  (fc2): Linear(in_features=32, out_features=4, bias=True)
)

In [4]:
# initialize sumamry writer
writer=SummaryWriter()

In [5]:
# Create testing and training dataLoader from single dataset using random_split
# and also set training epoch
split_ratio = 0.8
batch_size = 4
train_size = int(split_ratio*len(gd))
test_size = len(gd)-train_size
train_data, test_data = random_split(gd, [train_size, test_size])
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=True)
epoch=4000

In [6]:
# Start model training
for i in range(epoch+1):
    for x_train, y_train in train_dl:
        optim.zero_grad()
        out = model(x_train)
        loss = loss_fn(out, y_train)
        loss.backward()
        optim.step()
    writer.add_scalar("Loss/epoch", loss.item(), i)
    if (i % 100 == 0):
        print(f"Loss/epoch : {loss.item(),i}")
    writer.close()


Loss/epoch : (3.871016025543213, 0)
Loss/epoch : (3.7347261905670166, 100)
Loss/epoch : (3.172196626663208, 200)
Loss/epoch : (3.175344467163086, 300)
Loss/epoch : (3.1616132259368896, 400)
Loss/epoch : (3.1587371826171875, 500)
Loss/epoch : (3.1569461822509766, 600)
Loss/epoch : (3.156630754470825, 700)
Loss/epoch : (3.15665340423584, 800)
Loss/epoch : (3.1568446159362793, 900)
Loss/epoch : (3.1565518379211426, 1000)
Loss/epoch : (3.1566264629364014, 1100)
Loss/epoch : (3.156419277191162, 1200)
Loss/epoch : (3.156524658203125, 1300)
Loss/epoch : (3.1565184593200684, 1400)
Loss/epoch : (3.156419277191162, 1500)
Loss/epoch : (3.1564266681671143, 1600)
Loss/epoch : (3.156437635421753, 1700)
Loss/epoch : (3.1566691398620605, 1800)
Loss/epoch : (3.1563968658447266, 1900)
Loss/epoch : (3.156400203704834, 2000)
Loss/epoch : (3.156373977661133, 2100)
Loss/epoch : (3.156388282775879, 2200)
Loss/epoch : (3.1564462184906006, 2300)
Loss/epoch : (3.1563820838928223, 2400)
Loss/epoch : (3.156384944

In [7]:
# save the model
torch.save(model, "swaram_lstm.pt")
