## Training a classification model using the TMC data

### Dependencies

In [3]:
import glob
import numpy as np
import os
import collections
import sys

import pandas as pd
import pickle
import matplotlib
import matplotlib.pyplot as plt

import cv2
import imageio
import csv
from tifffile import TiffFile, imsave, imread, imwrite

from scipy import signal
from scipy.ndimage import uniform_filter1d


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

from torchvision import transforms
from torchinfo import summary

import argparse
import yaml
import xgboost as xgb 

from sklearn.model_selection import train_test_split, KFold, cross_validate, cross_val_score, cross_val_predict
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.utils.class_weight import compute_class_weight


import monai

%load_ext autoreload
original_sys_path = sys.path.copy()

In [4]:
font = {'weight' : 'normal','size': 21}
matplotlib.rc('font', **font)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

matplotlib.rcParams['axes.spines.top'] = True
matplotlib.rcParams['axes.spines.left'] = True
matplotlib.rcParams['axes.spines.right'] = True
matplotlib.rcParams['axes.spines.bottom'] = True

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load configs and ground truth

In [6]:
with open('configs.yaml', 'r') as file:
    configs = yaml.safe_load(file)
with open(configs['gt_path'], 'rb') as file:
        grades = pickle.load(file)

### Kinematic model

In [5]:
with open(os.path.join(configs['processed_data_path'], 'kinematic_data.pkl'), 'rb') as file:
    kinematic_data = pickle.load(file)

In [6]:
%autoreload 1
sys.path[:] = original_sys_path
sys.path.append('kinematic_model')
%aimport data_loader
model_configs = configs['kinematic_model_configs'] 
all_sc = list(kinematic_data.keys())
dataset = data_loader.TimeSeriesDataset(data=kinematic_data, 
                                        is_train=True,
                                        gt={key: grades['new'][key] for key in all_sc}, 
                                        gesture_list=configs['gesture_list'], 
                                        target_cycle_len=configs['target_cycle_len'], 
                                        mean=None, std=None,
                                        downsample_rate=configs['downsample_rate'])

AttributeError: module 'data_loader' has no attribute 'TimeSeriesDataset'

In [7]:
for fold, (train_idx, test_idx) in enumerate(kf.split(all_sc)):
    if fold == 0:
        break

NameError: name 'kf' is not defined

In [None]:
cv_configs = configs['cv_configs']
gt_labels = grades['new']
all_sc = np.array(list(gt_labels.keys()))
valid_sc = np.array(list(kinematic_data.keys()))

num_classes = len(np.unique(list(gt_labels.values())))
kf = KFold(n_splits=cv_configs['num_splits'], 
       shuffle=True, 
       random_state=cv_configs['random_seed'])

# Define datasets for the current fold
train_gt = {key: gt_labels[key] for key in all_sc[train_idx] if key in valid_sc}
train_data = {key: kinematic_data[key] for key in all_sc[train_idx] if key in valid_sc}
train_set = data_loader.TimeSeriesDataset(
    data=train_data,
    gt=train_gt,
    gesture_list=configs['gesture_list'], 
    target_cycle_len=configs['target_cycle_len'], 
    mean=None, std=None,
    downsample_rate=configs['downsample_rate']
    )
        

In [None]:
%autoreload 1
%aimport sys
sys.path[:] = original_sys_path
sys.path.append('kinematic_model')
%aimport model

In [None]:
kinematic_model = model.KinematicModel(n_gestures=5, input_channels=8, feat_channel=8, num_classes=3)
kinematic_model = kinematic_model.apply(model.initialize_weights)
kinematic_model.load_state_dict(torch.load('/project/ahoover/mhealth/zeyut/tmc/results/kinematic_model/model_2_fold_1.pth'))
kinematic_model = kinematic_model.to(device)

In [201]:
sample.size()

torch.Size([5, 8, 1200])

In [202]:
kinematic_model.eval()
with torch.no_grad():
    for i in range(10):
        sample,label = train_set[i]
        output = kinematic_model(sample.unsqueeze(0).to(device))
        pred = torch.argmax(output).cpu().detach().numpy()
        print(pred, label, output)

0 tensor(0) tensor([[9.9902e-01, 3.5975e-04, 6.1779e-04]], device='cuda:0')
2 tensor(2) tensor([[0.0017, 0.0042, 0.9941]], device='cuda:0')
1 tensor(1) tensor([[0.0079, 0.9809, 0.0111]], device='cuda:0')
1 tensor(1) tensor([[0.0012, 0.9960, 0.0028]], device='cuda:0')
1 tensor(1) tensor([[7.7609e-04, 9.9842e-01, 8.0242e-04]], device='cuda:0')
1 tensor(1) tensor([[5.9522e-04, 9.9865e-01, 7.5386e-04]], device='cuda:0')
2 tensor(2) tensor([[0.0033, 0.0152, 0.9815]], device='cuda:0')
1 tensor(1) tensor([[0.0099, 0.9756, 0.0145]], device='cuda:0')
1 tensor(1) tensor([[0.0025, 0.9936, 0.0039]], device='cuda:0')
1 tensor(1) tensor([[0.0061, 0.9886, 0.0053]], device='cuda:0')


In [167]:
pred

tensor([2], device='cuda:0')

In [118]:
downsample_rate = configs['downsample_rate']
summary(kinematic_model, input_size=(1, 5, 8, 12000//downsample_rate))

Layer (type:depth-idx)                   Output Shape              Param #
KinematicModel                           [1, 3]                    --
├─CNN_LSTM: 1-1                          [1, 64]                   --
│    └─Conv1d: 2-1                       [1, 16, 600]              5,264
│    └─Conv1d: 2-2                       [1, 32, 300]              10,784
│    └─Conv1d: 2-3                       [1, 32, 150]              5,152
│    └─BatchNorm1d: 2-4                  [1, 32, 150]              64
│    └─ReLU: 2-5                         [1, 32, 150]              --
│    └─LSTM: 2-6                         [1, 150, 64]              16,896
│    └─LSTM: 2-7                         [1, 150, 64]              33,280
│    └─BatchNorm1d: 2-8                  [1, 64, 150]              128
│    └─ReLU: 2-9                         [1, 64, 150]              --
│    └─AdaptiveMaxPool1d: 2-10           [1, 64, 1]                --
│    └─Dropout: 2-11                     [1, 64, 1]               

In [73]:
kinematic_model.eval()
output = kinematic_model(sample.unsqueeze(0).to(device))

### Image model

In [304]:
with open(os.path.join(configs['processed_data_path'], 'image_data.pkl'), 'rb') as file:
    image_data = pickle.load(file)

KeyboardInterrupt: 

In [None]:
%autoreload 1
sys.path[:] = original_sys_path
sys.path.append('image_model')
%aimport data_loader


In [124]:

model_configs = configs['image_model_configs'] 
gt_labels = grades['new']
all_sc = list(gt_labels.keys())


# Define datasets for the current fold
dataset = data_loader.ImageDataset(
    gt={key: gt_labels[key] for key in all_sc},
    data=image_data
)


In [129]:
dataset.data.keys()

dict_keys(['J10', 'H8', 'H1', 'H2'])

In [125]:
sample, label = dataset[0]

IndexError: too many indices for array: array is 0-dimensional, but 3 were indexed

In [101]:
sample.max()

tensor(6141.)

In [94]:
%autoreload 1
%aimport model

In [95]:
unet = model.SwinUNETRClassifier(model_configs['seed_path'], 3)
unet = unet.to(device)

In [96]:
4**4+4**4+8*2**3+16*2**3+64

768

In [97]:
summary(unet, input_size=(1, 1, 128, 128, 128))

torch.Size([1, 384])


Layer (type:depth-idx)                                  Output Shape              Param #
SwinUNETRClassifier                                     [1, 3]                    --
├─SwinTransformer: 1-1                                  [1, 48, 64, 64, 64]       --
│    └─PatchEmbed: 2-1                                  [1, 48, 64, 64, 64]       --
│    │    └─Conv3d: 3-1                                 [1, 48, 64, 64, 64]       (432)
│    └─Dropout: 2-2                                     [1, 48, 64, 64, 64]       --
│    └─ModuleList: 2-3                                  --                        --
│    │    └─BasicLayer: 3-2                             [1, 96, 32, 32, 32]       (107,358)
│    └─ModuleList: 2-4                                  --                        --
│    │    └─BasicLayer: 3-3                             [1, 192, 16, 16, 16]      (399,036)
│    └─ModuleList: 2-5                                  --                        --
│    │    └─BasicLayer: 3-4                