In [1]:
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader, random_split
from gloss_dataset import GlossDataset
from gloss_model import GlossModel
from torch.utils.tensorboard import SummaryWriter


In [2]:
# get data for training
gd = GlossDataset()
input_size = gd[0][0].shape[1]
class_no = len(gd.classes)
input_size, class_no

(1596, 4)

In [3]:
# provide input and class size
model = GlossModel(input_size, class_no)
model.to(model.device)
optim = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
model


GlossModel(
  (softmax): Softmax(dim=1)
  (sigmoid): Sigmoid()
  (lstm1): LSTM(1596, 128, batch_first=True)
  (dropout1): Dropout(p=0.15, inplace=False)
  (lstm2): LSTM(128, 64, batch_first=True)
  (dropout2): Dropout(p=0.15, inplace=False)
  (fc1): Linear(in_features=64, out_features=32, bias=True)
  (dropout3): Dropout(p=0.15, inplace=False)
  (fc2): Linear(in_features=32, out_features=4, bias=True)
)

In [4]:
# initialize sumamry writer
writer=SummaryWriter()

In [5]:
# Create testing and training dataLoader from single dataset using random_split
# and also set training epoch
split_ratio = 0.8
batch_size = 1
train_size = int(split_ratio*len(gd))
test_size = len(gd)-train_size
train_data, test_data = random_split(gd, [train_size, test_size])
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=True)
epoch = 4000


In [6]:
# Start model training
model.train()
for i in range(epoch+1):
    for x_train, y_train in train_dl:
        optim.zero_grad()
        out = model(x_train.to(model.device))
        loss = loss_fn(out, y_train.to(model.device))
        loss.backward()
        optim.step()
    writer.add_scalar("Loss/epoch", loss.item(), i)
    if (i % 100 == 0):
        print(f"Loss/epoch : {loss.item(),i}")
    writer.close()


Loss/epoch : (3.8705081939697266, 0)
Loss/epoch : (3.8001532554626465, 100)
Loss/epoch : (3.3649535179138184, 200)
Loss/epoch : (3.172327995300293, 300)
Loss/epoch : (3.1654043197631836, 400)
Loss/epoch : (3.161846160888672, 500)
Loss/epoch : (3.157863140106201, 600)
Loss/epoch : (3.157156229019165, 700)
Loss/epoch : (3.157242774963379, 800)
Loss/epoch : (3.1571006774902344, 900)
Loss/epoch : (3.157226800918579, 1000)
Loss/epoch : (3.156790018081665, 1100)
Loss/epoch : (3.156585693359375, 1200)
Loss/epoch : (3.1563987731933594, 1300)
Loss/epoch : (3.156470775604248, 1400)
Loss/epoch : (3.156411647796631, 1500)


KeyboardInterrupt: 

In [None]:
# Evaluate Test set
model.eval()
tot_loss = 0
tp = 0
tot_test = 0
with torch.no_grad():
    for x_test, y_test in test_dl:
        out = model(x_test)
        loss = loss_fn(out, y_train)
        tot_loss += loss
        y_pred = torch.argmax(out[2])
        y_test = torch.argmax(y_test)
        if y_pred == y_test:
            tp += 1
        tot_test += 1

# Much more data required for training and test set
acc = tp/tot_test
loss = tot_loss/tot_test
print(f"Test Acc {acc}\nTest Loss : {loss}\n")


In [None]:
# save the model
torch.save(model, "swaram_lstm.pt")
