In [35]:
import importlib

import torch
import torch.nn as nn
import torch.optim as optim
 
from constants import *
import utils
importlib.reload(utils)

CLASSIFIED_STATES = ['boredom', 'flow', 'frustration', 'neutral']

HIDDEN_SIZE = len(DATA_COLUMNS)
OUTPUT_SIZE = len(CLASSIFIED_STATES)

In [36]:
class MentalStateClassifier(nn.Module):
  def __init__(self, svm):
    super().__init__()
    self.SVM = svm

  def forward(self, x): # (1, 25, 40) -> (1, 4)
    x = x.mean(dim=1) # Average 1 second of data 
    x = self.SVM.predict(x)
    return x
    
model = MentalStateClassifier()
 
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [37]:
n_epochs = 14
batch_size = 8
n_batches = utils.getTrainingRowNum(CLASSIFIED_STATES) // SAMPLING_RATE // batch_size - 5

In [32]:
import copy
import tqdm
import numpy as np
importlib.reload(utils)

best_acc = - np.inf
best_weights = None
train_loss_hist = []
train_acc_hist = []
test_loss_hist = []
test_acc_hist = []
X_batch = None # for exporting to ONNX

for epoch in range(n_epochs):
  epoch_loss = []
  epoch_acc = []
  model.train() # set model to training mode
  data_loader = utils.lazy_load_training_data(states=CLASSIFIED_STATES)

  X_train, y_train = next(data_loader)

  with tqdm.trange(n_batches, unit="batch", mininterval=0) as bar:
    bar.set_description(f"Epoch {epoch}")
    j = 0
    for i in bar:
      start = j * batch_size
      if start + batch_size >= len(X_train):
        j = 0
        start = 0
        X_train, y_train = next(data_loader)
      
      j += 1
      X_batch = X_train[start:start+batch_size]
      y_batch = y_train[start:start+batch_size] 

      # forward pass
      y_pred = model(X_batch)
      loss = loss_fn(y_pred, y_batch)
      
      # backward pass
      optimizer.zero_grad()
      loss.backward()
      # update weights
      optimizer.step()
      # compute and store metrics
      acc = (torch.argmax(y_pred) == torch.argmax(y_batch)).float().mean()
      epoch_loss.append(float(loss))
      epoch_acc.append(float(acc))
      bar.set_postfix(
          loss=float(loss),
          acc=float(acc)
      )
  model.eval()
  X_test, y_test = utils.load_testing_data(states=CLASSIFIED_STATES)
  y_pred = model(X_test)
  ce = loss_fn(y_pred, y_test)
  acc = (torch.argmax(y_pred, 1) == torch.argmax(y_test, 1)).float().mean()
  ce = float(ce)
  acc = float(acc)
  train_loss_hist.append(np.mean(epoch_loss))
  train_acc_hist.append(np.mean(epoch_acc))
  test_loss_hist.append(ce)
  test_acc_hist.append(acc)
  if acc > best_acc:
      best_acc = acc
      best_weights = copy.deepcopy(model.state_dict())
  print(f"Epoch {epoch} validation: Cross-entropy={round(ce, 4)}, Accuracy={round(acc, 4)}")

model.load_state_dict(best_weights)

Epoch 0: 100%|██████████| 546/546 [00:04<00:00, 125.98batch/s, acc=0, loss=0.839]


Epoch 0 validation: Cross-entropy=1.1784, Accuracy=0.4048


Epoch 1: 100%|██████████| 546/546 [00:03<00:00, 145.92batch/s, acc=0, loss=0.629]


Epoch 1 validation: Cross-entropy=1.4152, Accuracy=0.3783


Epoch 2: 100%|██████████| 546/546 [00:06<00:00, 88.21batch/s, acc=0, loss=0.71]  


Epoch 2 validation: Cross-entropy=1.2857, Accuracy=0.4206


Epoch 3: 100%|██████████| 546/546 [00:04<00:00, 113.56batch/s, acc=0, loss=0.624]


Epoch 3 validation: Cross-entropy=1.2167, Accuracy=0.381


Epoch 4: 100%|██████████| 546/546 [00:04<00:00, 132.55batch/s, acc=0, loss=0.693]


Epoch 4 validation: Cross-entropy=1.7379, Accuracy=0.364


Epoch 5: 100%|██████████| 546/546 [00:04<00:00, 135.69batch/s, acc=0, loss=0.418] 


Epoch 5 validation: Cross-entropy=1.4893, Accuracy=0.4476


Epoch 6: 100%|██████████| 546/546 [00:04<00:00, 132.27batch/s, acc=0, loss=0.538] 


Epoch 6 validation: Cross-entropy=1.4962, Accuracy=0.4265


Epoch 7: 100%|██████████| 546/546 [00:03<00:00, 138.46batch/s, acc=0, loss=0.459]


Epoch 7 validation: Cross-entropy=1.1684, Accuracy=0.4254


Epoch 8: 100%|██████████| 546/546 [00:03<00:00, 136.55batch/s, acc=0, loss=0.452]


Epoch 8 validation: Cross-entropy=1.8821, Accuracy=0.3847


Epoch 9: 100%|██████████| 546/546 [00:03<00:00, 139.74batch/s, acc=0, loss=0.831] 


Epoch 9 validation: Cross-entropy=1.27, Accuracy=0.4323


Epoch 10: 100%|██████████| 546/546 [00:04<00:00, 133.22batch/s, acc=0, loss=0.795] 


Epoch 10 validation: Cross-entropy=1.1208, Accuracy=0.4397


Epoch 11: 100%|██████████| 546/546 [00:03<00:00, 139.94batch/s, acc=0, loss=0.353] 


Epoch 11 validation: Cross-entropy=1.3485, Accuracy=0.409


Epoch 12: 100%|██████████| 546/546 [00:04<00:00, 134.15batch/s, acc=0, loss=0.879] 


Epoch 12 validation: Cross-entropy=1.5645, Accuracy=0.4132


Epoch 13: 100%|██████████| 546/546 [00:04<00:00, 127.26batch/s, acc=1, loss=0.0954]


Epoch 13 validation: Cross-entropy=1.5402, Accuracy=0.418


<All keys matched successfully>

In [33]:
X_batch.shape

torch.Size([8, 25, 40])

In [34]:
# %pip install onnx onnxscript
# %pip install onnxruntime

torch.onnx.export(model,                       # model being run
                  X_batch,                         # model input (or a tuple for multiple inputs)
                  './exportedModels/BNM005-without-neutral.onnx',            # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=9,           # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['X'],       # the model's input names
                  output_names = ['Y']       # the model's output names
                  )