In [1]:
import torch
import torch.nn as nn
import pdb

import os
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from networks.deep_lstm import DeepLSTM
from networks.resnet import ResNet18
from networks.FCN import FCN
from networks.model import Combined
from preprocessing.make_dataset import GestureDataset

In [3]:
# load data
train_loader = torch.load('/mnt/fyp/data/train_loader.pt')
val_loader = torch.load('/mnt/fyp/data/val_loader.pt')
# device
device = torch.device("cuda:0")

In [4]:
# model instantiation
input_size_teng = 10  
input_size_imu = 18 
hidden_size = 128  
num_layers = 3 
num_classes = 39  

resnet18 = ResNet18()
lstm_teng = DeepLSTM(input_size_teng, hidden_size, num_layers, 20)
lstm_imu = DeepLSTM(input_size_imu, hidden_size, num_layers, 20)

fcn_model = FCN(20*3, num_classes)

combined_model = Combined(resnet18, lstm_teng, lstm_imu, fcn_model).to(device)

# model structure
# print(combined_model)

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(combined_model.parameters(), lr=0.001)

In [6]:
# start to train

num_epochs = 100

train_accuracies = []
val_accuracies = []

for epoch in range(num_epochs):
    
    correct_predictions_train = 0
    total_samples_train = 0

    combined_model.train()

    for batch in train_loader:
        img_data = batch['data'][:,:224*224*3].reshape(-1,3,224,224).to(device)
        teng_data = batch['data'][:,224*224*3:224*224*3+500].reshape(-1,50,10).to(device) 
        imu_data = batch['data'][:,224*224*3+500:].reshape(-1,50,18).to(device) 
        label_data = batch['label'].to(torch.long).to(device)
        
        optimizer.zero_grad()
        outputs = combined_model(img_data, teng_data, imu_data)    
        
        label_data = label_data.squeeze()
        
        loss = criterion(outputs, label_data)
        
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs, 1)  
        total_samples_train += label_data.size(0)
        
        correct_predictions_train += (predicted == label_data).sum().item()
    
    accuracy = correct_predictions_train / total_samples_train

    train_accuracies.append(accuracy)
    
    # eval
    
    correct_predictions_val = 0
    total_samples_val = 0
    
    with torch.no_grad():
        for batch in val_loader:
            img_data_val = batch['data'][:,:224*224*3].reshape(-1,3,224,224).to(device)

            teng_data_val = batch['data'][:,224*224*3:224*224*3+500].reshape(-1,50,10).to(device) 

            imu_data_val = batch['data'][:,224*224*3+500:].reshape(-1,50,18).to(device) 

            label_data_val = batch['label'].to(torch.long).to(device) 
            label_data_val = label_data_val.squeeze()

            outputs_val = combined_model(img_data_val, teng_data_val, imu_data_val)

            _, predicted_val = torch.max(outputs_val, 1)  
            
            total_samples_val += label_data_val.size(0)

            correct_predictions_val += (predicted_val == label_data_val).sum().item()

        accuracy = correct_predictions_val / total_samples_val

        val_accuracies.append(accuracy)

        print(f'Accuracy after epoch {epoch + 1}: {accuracy * 100:.2f}%')
        combined_model.eval()

'''
plt.plot(range(1, num_epochs + 1), val_accuracies, marker='o', linestyle='-', color='b')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy vs. Epoch')
plt.grid(True)
plt.savefig('final_img_teng_imu.png')
plt.show()
'''


Accuracy after epoch 1: 72.31%
Accuracy after epoch 2: 86.92%
Accuracy after epoch 3: 92.31%
Accuracy after epoch 4: 93.46%
Accuracy after epoch 5: 96.41%
Accuracy after epoch 6: 94.87%
Accuracy after epoch 7: 97.18%
Accuracy after epoch 8: 93.21%
Accuracy after epoch 9: 96.67%
Accuracy after epoch 10: 97.31%
Accuracy after epoch 11: 98.08%
Accuracy after epoch 12: 98.46%
Accuracy after epoch 13: 98.08%
Accuracy after epoch 14: 95.38%
Accuracy after epoch 15: 95.51%
Accuracy after epoch 16: 96.67%
Accuracy after epoch 17: 98.33%
Accuracy after epoch 18: 98.33%
Accuracy after epoch 19: 97.56%
Accuracy after epoch 20: 97.82%
Accuracy after epoch 21: 98.72%
Accuracy after epoch 22: 96.15%
Accuracy after epoch 23: 97.82%
Accuracy after epoch 24: 96.92%
Accuracy after epoch 25: 98.21%
Accuracy after epoch 26: 97.56%
Accuracy after epoch 27: 98.59%
Accuracy after epoch 28: 97.44%
Accuracy after epoch 29: 98.72%
Accuracy after epoch 30: 97.95%
Accuracy after epoch 31: 98.33%
Accuracy after ep

"\nplt.plot(range(1, num_epochs + 1), val_accuracies, marker='o', linestyle='-', color='b')\nplt.xlabel('Epoch')\nplt.ylabel('Accuracy')\nplt.title('Accuracy vs. Epoch')\nplt.grid(True)\nplt.savefig('final_img_teng_imu.png')\nplt.show()\n"

In [7]:
# save trained model
torch.save(combined_model,"/mnt/fyp/models/img_teng_imu_model.pth")
# save train&val accuracies
train_accuracies = torch.tensor(train_accuracies)
torch.save(train_accuracies,"/mnt/fyp/acc/img_teng_imu_train.pth")
val_accuracies = torch.tensor(val_accuracies)
torch.save(val_accuracies,"/mnt/fyp/acc/img_teng_imu_val.pth")