In [56]:
import os
import sys
import glob
import math
import json
import shutil
import numpy as np
# pytorch stuff
import torch
import torch.utils.data as torch_data
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
# sklearn
from sklearn.metrics import confusion_matrix

##### Model

In [40]:
class CNN3D(nn.Module):
    def __init__(self, num_classes=9, in_planes=2):
        super(CNN3D, self).__init__()
        # N, C, D, H, W = 1, 1, 160, 5, 22
        self.conv1 = nn.Conv3d(in_channels=in_planes, out_channels=5, kernel_size=(65, 1, 2))
        self.bn1 = nn.BatchNorm3d(5)
        self.conv2 = nn.Conv3d(in_channels=5, out_channels=10, kernel_size=(1, 1, 1))
        self.bn2 = nn.BatchNorm3d(10)
        self.conv3 = nn.Conv3d(in_channels=10, out_channels=5, kernel_size=(1, 1, 1))
        self.bn3 = nn.BatchNorm3d(5)
        self.conv4 = nn.Conv3d(in_channels=5, out_channels=1, kernel_size=(1, 1, 1))
        self.bn4 = nn.BatchNorm3d(1)
        self.dropout1 = nn.Dropout(p=0.6)
        self.dropout2 = nn.Dropout(p=0.4)
        self.fc1   = nn.Linear(11 , 3)
        self.softmax = nn.Sigmoid()


    def forward(self, x):
        out = self.conv1(x)
        out = F.max_pool3d(out, 2)
        out = self.bn1(out)

        if self.training:
            out = self.dropout1(out)


        out = self.conv2(out)
        out = F.max_pool3d(out, 2)
        out = self.bn2(out)

        if self.training:
            out = self.dropout1(out)

        out = self.conv3(out)
        out = F.max_pool3d(out, 2)
        out = self.bn3(out)

        if self.training:
            out = self.dropout1(out)

        out = self.conv4(out)
        out = F.max_pool3d(out, 2)
        out = self.bn4(out)

        if self.training:
            out = self.dropout1(out)


        out = out.view(out.size(0), -1)

        if self.training:
            out = self.dropout1(out)


        out = self.fc1(out)

        """
        if self.training:
            out = self.dropout2(out)
        """
        #out = F.log_softmax(out)

        return out


##### Load data

In [97]:
class MultilabelWMLoader(torch_data.Dataset):
    def __init__(self, data_dir, split, num_classes = 1, time_steps = 160 ):
        self.data_dir = data_dir
        self.split = split
        self.time_steps = time_steps

        self.num_classes = num_classes
        self.image_list, self.label_list = [], []

        self.read_lists()

    def read_lists(self):
        data_bins = os.path.join(self.data_dir, self.split)
        assert os.path.exists(data_bins)
        for each_file in glob.glob(data_bins + '\\' + '*.npy'):
            data = np.load(each_file)
            # .reshape((160, 5, 22, 1))
            self.image_list.append(
                data[0].reshape((2, self.time_steps, 5, data[0].shape[3]))
            )

            self.label_list.append(data[1][2])

    def __getitem__(self, index):

        return (self.image_list[index], self.label_list[index])

    def __len__(self):
        return len(self.image_list)
    
def get_dataset_distribution(dataset_loader):

    label_bins = {0:0, 1:0, 2:0}

    for data, target in dataset_loader:
        target = target.numpy()
        for t in target:
            label_bins[t]+=1

    return json.dumps(label_bins, indent=1, sort_keys=True)


In [98]:
val_dataset = MultilabelWMLoader(
        data_dir='C:/Users/dhruv/Downloads/thesis_dl-fnirs-b9882cd9c405022d5e447de45a406334d56d671b/thesis_dl-fnirs-b9882cd9c405022d5e447de45a406334d56d671b/data/multilabel/',
        split='val', time_steps = 250
        )

val_loader = torch_data.DataLoader(
    val_dataset,
    batch_size=1, shuffle=True, num_workers=0
    )

In [99]:
print(get_dataset_distribution(dataset_loader=val_loader))

{
 "0": 35,
 "1": 7,
 "2": 0
}


##### Testing Working Memory

In [114]:
model_path = '../deep-learning/experiments/cnn3d/011/model-cnn3d-epoch-59.pth'
model = CNN3D()
model.load_state_dict(torch.load(model_path)["model"])
model.eval()

CNN3D(
  (conv1): Conv3d(2, 5, kernel_size=(65, 1, 2), stride=(1, 1, 1))
  (bn1): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv3d(5, 10, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (bn2): BatchNorm3d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(10, 5, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (bn3): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv3d(5, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (bn4): BatchNorm3d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout1): Dropout(p=0.6)
  (dropout2): Dropout(p=0.4)
  (fc1): Linear(in_features=11, out_features=3, bias=True)
  (softmax): Sigmoid()
)

In [115]:
pred = []
true = []
with torch.no_grad():
    for data, target in val_loader:
        data, target = data.float(), target.long()
        output = model(data)
        _, predicted = torch.max(output.data, 1)
        pred.append(predicted.item())
        true.append(target.item())
confusion_matrix(y_pred=pred, y_true=true)

array([[35,  0],
       [ 7,  0]], dtype=int64)