In [1]:
import torch
import torch.nn as nn
import pdb

import os
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from networks.deep_lstm import DeepLSTM
from networks.FCN import FCN
from networks.model import Combined_teng_imu
from preprocessing.make_dataset import GestureDataset

In [3]:
# load data
train_loader = torch.load('/mnt/fyp/data/train_loader.pt')
val_loader = torch.load('/mnt/fyp/data/val_loader.pt')
# device
device = torch.device("cuda:0")

In [4]:
# model instantiation
input_size_teng = 10  
input_size_imu = 18 
hidden_size = 128  
num_layers = 3 
num_classes = 39  

lstm_teng = DeepLSTM(input_size_teng, hidden_size, num_layers, 20)
lstm_imu = DeepLSTM(input_size_imu, hidden_size, num_layers, 20)

final_model = FCN(20*2, num_classes)

combined_model = Combined_teng_imu(lstm_teng, lstm_imu, final_model).to(device)

# model structure
# print(combined_model)

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(combined_model.parameters(), lr=0.001)

In [6]:
# start to train

num_epochs = 100

train_accuracies = []
val_accuracies = []

for epoch in range(num_epochs):
    
    correct_predictions_train = 0
    total_samples_train = 0

    combined_model.train()

    for batch in train_loader:
        teng_data = batch['data'][:,224*224*3:224*224*3+500].reshape(-1,50,10).to(device) 
        imu_data = batch['data'][:,224*224*3+500:].reshape(-1,50,18).to(device) 
        label_data = batch['label'].to(torch.long).to(device)
        label_data = label_data.squeeze()
        
        optimizer.zero_grad()
        outputs = combined_model(teng_data, imu_data)    
        
        loss = criterion(outputs, label_data)
        
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs, 1) 
        total_samples_train += label_data.size(0)
        
        correct_predictions_train += (predicted == label_data).sum().item()
    
    accuracy = correct_predictions_train / total_samples_train

    train_accuracies.append(accuracy)
        
    # eval
    
    correct_predictions_val = 0
    total_samples_val = 0
    
    with torch.no_grad():
        for batch in val_loader:
            teng_data_val = batch['data'][:,224*224*3:224*224*3+500].reshape(-1,50,10).to(device) 
            imu_data_val = batch['data'][:,224*224*3+500:].reshape(-1,50,18).to(device) 
            label_data_val = batch['label'].to(torch.long).to(device) 
            label_data_val = label_data_val.squeeze()

            outputs_val = combined_model(teng_data_val, imu_data_val)

            _, predicted_val = torch.max(outputs_val, 1) 
            
            total_samples_val += label_data_val.size(0)

            correct_predictions_val += (predicted_val == label_data_val).sum().item()

        accuracy = correct_predictions_val / total_samples_val

        val_accuracies.append(accuracy)

        print(f'Accuracy after epoch {epoch + 1}: {accuracy * 100:.2f}%')
        combined_model.eval()

'''
plt.plot(range(1, num_epochs + 1), val_accuracies, marker='o', linestyle='-', color='b')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy vs. Epoch')
plt.grid(True)
plt.savefig("final_teng_imu.png")
plt.show()
'''

Accuracy after epoch 1: 21.92%
Accuracy after epoch 2: 36.92%
Accuracy after epoch 3: 55.00%
Accuracy after epoch 4: 66.79%
Accuracy after epoch 5: 67.56%
Accuracy after epoch 6: 76.03%
Accuracy after epoch 7: 81.79%
Accuracy after epoch 8: 85.64%
Accuracy after epoch 9: 83.21%
Accuracy after epoch 10: 85.90%
Accuracy after epoch 11: 87.05%
Accuracy after epoch 12: 86.79%
Accuracy after epoch 13: 87.69%
Accuracy after epoch 14: 87.69%
Accuracy after epoch 15: 88.08%
Accuracy after epoch 16: 89.49%
Accuracy after epoch 17: 92.56%
Accuracy after epoch 18: 91.67%
Accuracy after epoch 19: 91.28%
Accuracy after epoch 20: 88.97%
Accuracy after epoch 21: 89.23%
Accuracy after epoch 22: 91.03%
Accuracy after epoch 23: 91.28%
Accuracy after epoch 24: 92.05%
Accuracy after epoch 25: 90.26%
Accuracy after epoch 26: 90.26%
Accuracy after epoch 27: 90.90%
Accuracy after epoch 28: 92.44%
Accuracy after epoch 29: 89.87%
Accuracy after epoch 30: 92.82%
Accuracy after epoch 31: 93.97%
Accuracy after ep

'\nplt.plot(range(1, num_epochs + 1), val_accuracies, marker=\'o\', linestyle=\'-\', color=\'b\')\nplt.xlabel(\'Epoch\')\nplt.ylabel(\'Accuracy\')\nplt.title(\'Accuracy vs. Epoch\')\nplt.grid(True)\nplt.savefig("final_teng_imu.png")\nplt.show()\n'

In [7]:
# save trained model
torch.save(combined_model,"/mnt/fyp/models/teng_imu_model.pth")
# save train&val accuracies
train_accuracies = torch.tensor(train_accuracies)
torch.save(train_accuracies,"/mnt/fyp/acc/teng_imu_train.pth")
val_accuracies = torch.tensor(val_accuracies)
torch.save(val_accuracies,"/mnt/fyp/acc/teng_imu_val.pth")