In [1]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn

import argparse
import time
import pandas as pd

from easydict import EasyDict as edict
from tqdm import trange
from torchsummary import summary
from torchstat import stat
from thop import profile

YOUR_PATH = "/home/jyt/workspace/fNIRS_models/code_data_tufts"
sys.path.insert(0, YOUR_PATH + '/fNIRS-mental-workload-classifiers/helpers')
import models
import brain_data
from utils import generic_GetTrainValTestSubjects, seed_everything, makedir_if_not_exist, plot_confusion_matrix, save_pickle, train_one_epoch, eval_model, save_training_curves_FixedTrainValSplit, write_performance_info_FixedTrainValSplit, write_program_time, write_inference_time
from utils import LabelSmoothing, train_one_epoch_fNIRS_T, eval_model_fNIRST, train_one_epoch_Ours_T, eval_model_OursT
from utils import EarlyStopping

OursT parameters

In [2]:
cuda = torch.cuda.is_available()
if cuda:
    print('Detected GPUs', flush = True)
    #device = torch.device('cuda')
    device = torch.device('cuda:{}'.format(0))
else:
    print('DID NOT detect GPUs', flush = True)
    device = torch.device('cpu')

Detected GPUs


In [45]:
model_to_use = models.Ours_T
model = model_to_use(n_class=2, sampling_points=150, patch_length=30, dim=64, depth=6, heads=8, mlp_dim=256).to(device)

In [46]:
summary(model, input_size=(150, 8), batch_size=1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1                [1, 8, 150]               0
         Rearrange-2              [1, 8, 5, 30]               0
         Rearrange-3              [1, 5, 8, 30]               0
         Rearrange-4                [1, 5, 240]               0
            Linear-5                 [1, 5, 64]          15,424
           Dropout-6                 [1, 6, 64]               0
         LayerNorm-7                 [1, 6, 64]             128
            Linear-8               [1, 6, 1536]          98,304
            Linear-9                 [1, 6, 64]          32,832
          Dropout-10                 [1, 6, 64]               0
        Attention-11                 [1, 6, 64]               0
          PreNorm-12                 [1, 6, 64]               0
         Residual-13                 [1, 6, 64]               0
        LayerNorm-14                 [1

In [70]:
dummy_input = torch.randn(2, 150, 8).to(device)
flops, params = profile(model,(dummy_input,))
print('flops: ', flops, 'params: ', params)
print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))

[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
flops:  23975424.0 params:  1002562.0
flops: 23.98 M, params: 1.00 M


In [64]:
model_to_use_1 = models.fNIRS_PreT
model_1 = model_to_use_1(n_class=2, sampling_point=300, dim=64, depth=6, heads=8, mlp_dim=64).to(device)

In [65]:
summary(model_1, input_size=(2,2,300), batch_size=2)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         AvgPool1d-1                [2, 2, 300]               0
         AvgPool1d-2                [2, 2, 300]               0
         AvgPool1d-3                [2, 2, 300]               0
         LayerNorm-4                [2, 2, 300]             600
         AvgPool1d-5                [2, 2, 300]               0
         AvgPool1d-6                [2, 2, 300]               0
         AvgPool1d-7                [2, 2, 300]               0
         LayerNorm-8                [2, 2, 300]             600
          PreBlock-9             [2, 2, 2, 300]               0
           Conv2d-10              [2, 8, 3, 91]             968
        Rearrange-11                [2, 3, 728]               0
           Linear-12                 [2, 3, 64]          46,656
        LayerNorm-13                 [2, 3, 64]             128
           Conv2d-14              [2, 8

In [57]:
dummy_input = torch.randn(2, 2, 2, 300).to(device)
flops, params = profile(model_1,(dummy_input,))
print('flops: ', flops, 'params: ', params)
print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))

[INFO] Register count_avgpool() for <class 'torch.nn.modules.pooling.AvgPool1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
flops:  51707520.0 params:  1773282.0
flops: 51.71 M, params: 1.77 M


In [4]:
model_to_use_2 = models.DCNN
model_2 = model_to_use_2(n_class=2,dropout=0.5).to(device)

In [5]:
summary(model_2, input_size=(150, 8), batch_size=1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1                [1, 8, 150]               0
            Conv2d-2            [1, 25, 8, 146]             150
            Conv2d-3            [1, 25, 1, 146]           5,025
       BatchNorm2d-4            [1, 25, 1, 146]              50
               ELU-5            [1, 25, 1, 146]               0
         MaxPool2d-6             [1, 25, 1, 73]               0
           Dropout-7             [1, 25, 1, 73]               0
            Conv2d-8             [1, 50, 1, 69]           6,300
       BatchNorm2d-9             [1, 50, 1, 69]             100
              ELU-10             [1, 50, 1, 69]               0
        MaxPool2d-11             [1, 50, 1, 35]               0
          Dropout-12             [1, 50, 1, 35]               0
           Conv2d-13            [1, 100, 1, 31]          25,100
      BatchNorm2d-14            [1, 100

  return F.log_softmax(out)


In [7]:
dummy_input = torch.randn(1, 150, 8).to(device)
flops, params = profile(model_2,(dummy_input,))
print('flops: ', flops, 'params: ', params)
print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))

[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
flops:  3661274.0 params:  313543.0
flops: 3.66 M, params: 0.31 M


In [15]:
model_to_use_3 = models.DeepConvNet150
model_3 = model_to_use_3().to(device)

In [16]:
summary(model_3, input_size=(150, 8), batch_size=1)

torch.Size([2, 2])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [1, 25, 8, 146]             150
            Conv2d-2            [1, 25, 1, 146]           5,000
       BatchNorm2d-3            [1, 25, 1, 146]              50
               ELU-4            [1, 25, 1, 146]               0
         MaxPool2d-5             [1, 25, 1, 73]               0
           Dropout-6             [1, 25, 1, 73]               0
            Conv2d-7             [1, 50, 1, 69]           6,250
       BatchNorm2d-8             [1, 50, 1, 69]             100
               ELU-9             [1, 50, 1, 69]               0
        MaxPool2d-10             [1, 50, 1, 34]               0
          Dropout-11             [1, 50, 1, 34]               0
           Conv2d-12            [1, 100, 1, 30]          25,000
      BatchNorm2d-13            [1, 100, 1, 30]             200
              ELU-14

In [17]:
dummy_input = torch.randn(1, 150, 8).to(device)
flops, params = profile(model_3,(dummy_input,))
print('flops: ', flops, 'params: ', params)
print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
torch.Size([1, 2])
flops:  3208450.0 params:  139152.0
flops: 3.21 M, params: 0.14 M
