# Model Ensembling  
### Models: 

- Vanilla CNN
- CNN-RNN
- CNN with Multi-head Attention
- CNN-GRU

- CNN-LSTM


### Final Ensembling:
- Vanilla CNN
- CNN-GRU

- CNN-LSTM

### Accuracies:
- Trained on all, tested on all (0-800): 73.5%
- Trained on all, tested on subject 1 (0-800): 74% 
- Trained on subject 1, tested on subject 1 (0-800): 44%
- Trained on subject 1, tested on all (0-800): 38.1% 
 
- Trained on all, tested on all (0-400): 69.5%%

In [1]:
import torch
from utils.preprocessing import *
from utils.loops import *
from models.vanilla_cnn import *
from models.cnn_attention import *
from models.cnn_gru import *
from models.cnn_lstm import *
from models.cnn_rnn import *

%load_ext autoreload
%autoreload 2

In [2]:
## Set device

# if torch.backends.mps.is_available():
#     device = torch.device("mps")
# else:
#     device = torch.device('cpu')
# print(device)

## When testing, use cpu
device = torch.device('cpu')

In [3]:
## Instantiate test dataloader for all
_, _, test_dataloader_all = load_data(64)

# ## Instantiate test dataloader for person 1
_, _, test_dataloader_1 = load_data(64, one_person=True)

Shape of training set: (14535, 22, 200)
Shape of validation set: (500, 22, 200)
Shape of training labels: (14535,)
Shape of validation labels: (500,)
Shape of training labels after categorical conversion: (14535, 4)
Shape of validation labels after categorical conversion: (500, 4)
Shape of test labels after categorical conversion: (443, 4)
Shape of training set: (1602, 22, 200)
Shape of validation set: (59, 22, 200)
Shape of training labels: (1602,)
Shape of validation labels: (59,)
Shape of training labels after categorical conversion: (1602, 4)
Shape of validation labels after categorical conversion: (59, 4)
Shape of test labels after categorical conversion: (50, 4)


## Trained on all subjects (0-800)

In [4]:
## Instantiate models and load their weights
models_800 = {}

cnn = CNN(kernel_size=11, pad=5)
checkpoint = torch.load('weights/all_subjects_800/CNN_epoch100.pt', map_location=torch.device('cpu'))
cnn.load_state_dict(checkpoint['model_state_dict'])
cnn = cnn.to(device)
models_800['cnn'] = cnn

# rnn = RNN(input_dim=22, conv_dims=[32, 64, 128, 256], hidden_dim=128, num_layers=1)
# checkpoint = torch.load('weights/all_subjects_800/RNN_epoch46_400_800.pt', map_location=torch.device('cpu'))
# rnn.load_state_dict(checkpoint['model_state_dict'])
# rnn = rnn.to(device)
# models_800['rnn'] = rnn

# cnn_attention = CNN_Attention_Model()
# checkpoint = torch.load('weights/all_subjects_800/Attention_epoch100.pt', map_location=torch.device('cpu'))
# cnn_attention.load_state_dict(checkpoint['model_state_dict'])
# cnn_attention = cnn_attention.to(device)
# models_800['cnn_attention'] = cnn_attention

gru = GRU(input_dim=22, conv_dims=[32, 64, 128], hidden_dim=256, num_layers=1)
checkpoint = torch.load('weights/all_subjects_800/GRU_epoch18.pt', map_location=torch.device('cpu'))
gru.load_state_dict(checkpoint['model_state_dict'])
gru = gru.to(device)
models_800['gru'] = gru

lstm = LSTM(input_dim=22, conv_dims=[32, 64, 128], hidden_dim=64, num_layers=1)
checkpoint = torch.load('weights/all_subjects_800/LSTM_epoch100_70.pt', map_location=torch.device('cpu'))
lstm.load_state_dict(checkpoint['model_state_dict'])
lstm = lstm.to(device)
models_800['lstm'] = lstm

### Test on all:

In [33]:
## Evaluate ensembled models on test set

# Average the probabilities
# accuracy = test_average(models_800, test_dataloader, device)
# print('Test Accuracy (average):', accuracy)

# Majority vote
accuracy = test_majority(models_800, test_dataloader_all, device)
print('Test Accuracy (majority vote):', accuracy)

Test Accuracy (majority vote): 0.7358916478555305


### Test on person 1:

In [34]:
# Majority vote
accuracy = test_majority(models_800, test_dataloader_1, device)
print('Test Accuracy (majority vote):', accuracy)

Test Accuracy (majority vote): 0.74


## Trained on person 1 (0-800)

In [37]:
## Instantiate models and load their weights
models_1 = {}

cnn = CNN(kernel_size=11, pad=5)
checkpoint = torch.load('weights/1_subject_800/CNN_epoch60_one_subject.pt', map_location=torch.device('cpu'))
cnn.load_state_dict(checkpoint['model_state_dict'])
cnn = cnn.to(device)
models_1['cnn'] = cnn

# rnn = RNN(input_dim=22, conv_dims=[32, 64, 128, 256], hidden_dim=128, num_layers=1)
# checkpoint = torch.load('weights/1_subject_800/RNN_epoch61_one_subject.pt', map_location=torch.device('cpu'))
# rnn.load_state_dict(checkpoint['model_state_dict'])
# rnn = rnn.to(device)
# models_1['rnn'] = rnn

# cnn_attention = CNN_Attention_Model()
# checkpoint = torch.load('weights/1_subject_800/Attention_epoch84_one_subject.pt', map_location=torch.device('cpu'))
# cnn_attention.load_state_dict(checkpoint['model_state_dict'])
# cnn_attention = cnn_attention.to(device)
# models_1['cnn_attention'] = cnn_attention

gru = GRU(input_dim=22, conv_dims=[32, 64, 128], hidden_dim=256, num_layers=1)
checkpoint = torch.load('weights/1_subject_800/gru_epoch91_one_subject.pt', map_location=torch.device('cpu'))
gru.load_state_dict(checkpoint['model_state_dict'])
gru = gru.to(device)
models_1['gru'] = gru

lstm = LSTM(input_dim=22, conv_dims=[32, 64, 128], hidden_dim=64, num_layers=1)
checkpoint = torch.load('weights/1_subject_800/LSTM_epoch99_one_subject.pt', map_location=torch.device('cpu'))
lstm.load_state_dict(checkpoint['model_state_dict'])
lstm = lstm.to(device)
models_1['lstm'] = lstm

### Test on person 1:

In [33]:
# Majority vote
accuracy = test_majority(models_1, test_dataloader_1, device)
print('Test Accuracy (majority vote):', accuracy)

Test Accuracy (majority vote): 0.44


### Test on all:

In [38]:
# Majority vote
accuracy = test_majority(models_1, test_dataloader_all, device)
print('Test Accuracy (majority vote):', accuracy)

Test Accuracy (majority vote): 0.3837471783295711


## Trained on all subjects (0-400)

In [5]:
# Before running, go to the preprocessing file, change 0:800 to 0:400 on lines 24, 121, 144.
# Then change 0:400 to 0:200 on lines 53, 67, 74, 81, 90.
# Then rerun the dataload

## Instantiate test dataloader for all
_, _, test_dataloader_all = load_data(64)

Shape of training set: (14535, 22, 200)
Shape of validation set: (500, 22, 200)
Shape of training labels: (14535,)
Shape of validation labels: (500,)
Shape of training labels after categorical conversion: (14535, 4)
Shape of validation labels after categorical conversion: (500, 4)
Shape of test labels after categorical conversion: (443, 4)


In [10]:
## Instantiate models and load their weights
models_400 = {}

cnn = CNN(kernel_size=11, pad=5, in_length=200)
checkpoint = torch.load('weights/all_subjects_400/CNN_epoch59_0_400.pt', map_location=torch.device('cpu'))
cnn.load_state_dict(checkpoint['model_state_dict'])
cnn = cnn.to(device)
models_400['cnn'] = cnn

# rnn = RNN(input_dim=22, conv_dims=[32, 64, 128, 256], hidden_dim=128, num_layers=1)
# checkpoint = torch.load('weights/all_subjects_400/RNN_epoch65_0_400.pt', map_location=torch.device('cpu'))
# rnn.load_state_dict(checkpoint['model_state_dict'])
# rnn = rnn.to(device)
# models_400['rnn'] = rnn

# cnn_attention = CNN_Attention_Model()
# checkpoint = torch.load('weights/all_subjects_400/Attention_epoch57_0_400.pt', map_location=torch.device('cpu'))
# cnn_attention.load_state_dict(checkpoint['model_state_dict'])
# cnn_attention = cnn_attention.to(device)
# models_400['cnn_attention'] = cnn_attention

gru = GRU(input_dim=22, conv_dims=[32, 64, 128], hidden_dim=256, num_layers=1)
checkpoint = torch.load('weights/all_subjects_400/GRU_epoch60_0_400.pt', map_location=torch.device('cpu'))
gru.load_state_dict(checkpoint['model_state_dict'])
gru = gru.to(device)
models_400['gru'] = gru

lstm = LSTM(input_dim=22, conv_dims=[32, 64, 128], hidden_dim=64, num_layers=1, in_length=200)
checkpoint = torch.load('weights/all_subjects_400/LSTM_epoch83_0_400.pt', map_location=torch.device('cpu'))
lstm.load_state_dict(checkpoint['model_state_dict'])
lstm = lstm.to(device)
models_400['lstm'] = lstm

### Test on all:

In [11]:
# Majority vote
accuracy = test_majority(models_400, test_dataloader_all, device)
print('Test Accuracy (majority vote):', accuracy)

Test Accuracy (majority vote): 0.672686230248307
