# FNIRS + VFT analysis

* fNIRS brain blood flow dataset with VFT (Verbal Fluency Task) on normal / depression / suicidality subjects

In [None]:
import os, time, random
import numpy as np
import pandas as pd
import torch, torchvision
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models

In [None]:
from PIL import Image
from torch.optim import lr_scheduler
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
from efficientnet_pytorch import EfficientNet

# #2. VFT Analysis

## Dataset Preparation

### Basic settings

In [None]:
## checking current directory
directory = os.getcwd()
print(directory)

In [None]:
# Data Transformation
data_transforms = transforms.Compose([
#     transforms.Resize(300),
#     transforms.RandomResizedCrop(300),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.ColorJitter(contrast=(0.3, 1), saturation=(0.3, 1)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456,0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Uploading the food image data
vft_train = datasets.ImageFolder(root = 'E:/RESEARCH/BRAIN/research_data/VFT__', transform = data_transforms)

In [None]:
vft_train

### Device setting

In [None]:
## enviroinment setting
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
    
print('Using PyTorch version:', torch.__version__, ' Device:', DEVICE)

## Model  preparation

### Basic settings

#### arguments

In [None]:
## arguments setting for hyperparameter tuning
class Args:
    # arugments
    epochs=50
    bs=64
    lr=0.0001
    momentum=0.9
    
    num_channels=3
    num_classes=3
    verbose='store_true'
    seed=712002

args = Args()    

np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)

In [None]:
## divide the overall dataset into train and test dataset
train_size = int(0.8 * len(vft_train))
test_size = len(vft_train)-train_size
print('Training dataset size is:', train_size, '/ Test dataset size is:', test_size)

In [None]:
## train test split for model training
train_dataset, test_dataset = torch.utils.data.random_split(vft_train, [train_size, test_size])

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.bs, shuffle=False, num_workers=4)

In [None]:
dataiter = iter(train_loader)
images, labels = dataiter.next()
print(labels)

### Training model

In [None]:
class cnn_vft(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(cnn_vft, self).__init__()

        def conv_batch(input_size, output_size, stride):
            return nn.Sequential(
                nn.Conv2d(input_size, output_size, 3, stride, 1, bias=False),
                nn.BatchNorm2d(output_size),
                nn.ReLU(inplace=True)
                )

        def conv_depth(input_size, output_size, stride):
            return nn.Sequential(
                nn.Conv2d(input_size, input_size, 3, stride, 1, groups=input_size, bias=False),
                nn.BatchNorm2d(input_size),
                nn.ReLU(inplace=True),
                
                nn.Conv2d(input_size, output_size, 1, 1, 0, bias=False),
                nn.BatchNorm2d(output_size),
                nn.ReLU(inplace=True),
                )

        self.model = nn.Sequential(
            conv_batch(3, 32, 2),
            conv_depth(32, 64, 1),
            conv_depth(64, 128, 2),
            conv_depth(128, 128, 1),
            conv_depth(128, 256, 2),
            conv_depth(256, 256, 1),
            conv_depth(256, 512, 2),
            conv_depth(512, 512, 1),
            conv_depth(512, 512, 1),
            conv_depth(512, 1024, 2),
            conv_depth(1024, 1024, 1),
            nn.AdaptiveAvgPool2d(1)
        )
#         self.fc1 = nn.Linear(1024, 100)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 1024)
#         x = self.fc1(x)
        x = self.fc2(x)
        return x


In [None]:
# Setting Optimizer and Objective Function

model = cnn_vft(in_channels=args.num_channels, num_classes=args.num_classes).to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.005, total_steps=600, anneal_strategy='cos')

# print(model)

### Training on dataset

In [None]:
# Function for checking model performance during CNN model

def train(model, train_loader, optimizer, log_interval):
    model.train()
    print(optimizer.param_groups[0]['lr'])
    for batch_idx, (image, label) in enumerate(train_loader):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}".format(
                epoch, batch_idx * len(image), 
                len(train_loader.dataset), 100. * batch_idx / len(train_loader), 
                loss.item()))
    
    scheduler.step() #for learning rate scheduler

### Model evaluation

In [None]:
# Function for checking model performance during the learning process

def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for image, label in test_loader:
            image = image.to(DEVICE)
            label = label.to(DEVICE)
            output = model(image)
            test_loss += criterion(output, label).item()
            prediction = output.max(1, keepdim = True)[1]
            correct += prediction.eq(label.view_as(prediction)).sum().item()
    
    test_loss /= (len(test_loader)) 
    test_accuracy = 100. * correct / len(test_loader.dataset)
    
    return test_loss, test_accuracy

In [None]:
# Checking train, val loss and accuracy
total = []

for epoch in range(1, args.epochs):
    train(model, train_loader, optimizer, log_interval = 200)
    test_loss, test_accuracy = evaluate(model, test_loader)
    print("\n[EPOCH: {}], \tTest Loss: {:.4f}, \tTest Accuracy: {:.2f} % \n".format(
        epoch, test_loss, test_accuracy))
    
    total.append((test_loss, test_accuracy))

In [None]:
# total

### Save the model state

In [None]:
# Saving pytorch model

torch.save(model.state_dict(), directory + '/vft_model1.pt')

## Using pre-trained model

In [None]:
model_eff3 = EfficientNet.from_pretrained('efficientnet-b3', num_classes= args.num_classes)
model = model_eff3.to(DEVICE)

In [None]:
# Saving pytorch model
torch.save(model.state_dict(), directory + '/vft_pretrained_model.pt')

In [None]:
# Setting Optimizer and Objective Function

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, total_steps=args.epochs, anneal_strategy='cos')

# print(model)

In [None]:
# Function for checking model performance during CNN model

def train(model, train_loader, optimizer, log_interval):
    model.train()
    print(optimizer.param_groups[0]['lr'])
    
    for batch_idx, (image, label) in enumerate(train_loader):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}".format(
                epoch, batch_idx * len(image), 
                len(train_loader.dataset), 100. * batch_idx / len(train_loader), 
                loss.item()))

    scheduler.step() #for learning rate scheduler

In [None]:
# Function for checking model performance during the learning process

def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for image, label in test_loader:
            image = image.to(DEVICE)
            label = label.to(DEVICE)
            output = model(image)
            test_loss += criterion(output, label).item()
            prediction = output.max(1, keepdim = True)[1]
            correct += prediction.eq(label.view_as(prediction)).sum().item()
    
    test_loss /= (len(test_loader)) 
    test_accuracy = 100. * correct / len(test_loader.dataset)
    
    return test_loss, test_accuracy

In [None]:
# Checking train, val loss and accuracy

total = []

for epoch in range(1, args.epochs):
    train(model, train_loader, optimizer, log_interval = 200)
    test_loss, test_accuracy = evaluate(model, test_loader)
    print("\n[EPOCH: {}], \tTest Loss: {:.4f}, \tTest Accuracy: {:.2f} % \n".format(
        epoch, test_loss, test_accuracy))
    
    total.append((test_loss, test_accuracy))

In [None]:
total

In [None]:
# Saving pytorch model

torch.save(model.state_dict(), directory + '/fnirs_pretrained_model.pt')

## Model performance check

### Heatmap for classification

In [None]:
nb_classes = args.num_classes
confusion_matrix = np.zeros((nb_classes, nb_classes))
classes = {
    "0": "Depression",
    "1": "Normal",
    "2": "Suicidality"
}

with torch.no_grad():
    for i, (image, label) in enumerate(test_loader):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        outputs = model(image)
        _, preds = torch.max(outputs, 1)
        for t, p in zip(label.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

plt.figure(figsize=(8,4))
print(confusion_matrix)

class_names = list(classes.values())
df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names).astype(int)
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")

heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right',fontsize=10)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right',fontsize=10)
plt.ylabel('True label', fontsize=12)
plt.xlabel('Predicted label', fontsize=12)
# plt.savefig('dep_train_entire_output.png')

### Accuracy, sensitivity, specificity check

In [None]:
cm = confusion_matrix
total = sum(sum(cm))

## Accuracy, Sensitivity, and Specificity
acc = (cm[0,0]+cm[1,1]+cm[2,2]) / total
sen_dep = cm[0,0] / (cm[0,0] + cm[0,1] + cm[0,2])
sen_nor = cm[1,1] / (cm[1,0] + cm[1,1] + cm[1,2])
sen_sui = cm[2,2] / (cm[2,0] + cm[2,1] + cm[2,2])

spe_dep = (cm[1,1] + cm[2,2]) / (cm[1,0] + cm[2,0] + cm[1,1] + cm[2,2])
spe_nor = (cm[0,0] + cm[2,2]) / (cm[0,1] + cm[2,1] + cm[0,0] + cm[2,2])
spe_sui = (cm[0,0] + cm[1,1]) / (cm[0,2] + cm[1,2] + cm[0,0] + cm[1,1])

print("Overall classification accuracy is :", round(acc, 4))
print("sensitivity of Depression class is :", round(sen_dep, 4))
print("sensitivity of Normal class is :", round(sen_nor,4))
print("sensitivity of Suicidality class is :", round(sen_sui,4))

print("specificity of Depression class is :", round(spe_dep,4))
print("specificity of Normal class is :", round(spe_nor,4))
print("specificity of Suicidality class is :", round(spe_sui,4))

In [None]:
print("Average sensitivity is ",      ((sen_dep + sen_nor + sen_sui) /3) )
print("Average specificity is ", ((spe_dep + spe_nor + spe_sui) /3) )