In [None]:
#import all modules required
import os
import sys

from sklearn.model_selection import train_test_split
import pandas as pd

import time
import pickle

nitorch_dir = "nitorch/"
sys.path.insert(0, os.path.join(nitorch_dir))
from nitorch.metrics import binary_balanced_accuracy, sensitivity, specificity, prepare_values
from nitorch.data import *
from nitorch.callbacks import EarlyStopping, ModelCheckpoint
from nitorch.trainer import Trainer
from nitorch.initialization import weights_init

import torch.optim as optim
from torch.autograd import Variable
from torch.utils import data

from Code.data import *
from Code.models import *

%matplotlib inline

In [None]:
#Read UKB datalabels
UKB_label_info = pd.read_excel("/Data_labels/UKB_label_info.xlsx",index_col =None)
data_dir = "/data"

In [None]:
#Train test split
df_trn, df_val  = train_test_split(UKB_label_info,test_size=0.15,stratify=UKB_label_info['Gender'],shuffle=True)
df_trn, df_test = train_test_split(df_trn,test_size=0.175,stratify=df_trn['Gender'],shuffle=True)

In [None]:
df_trn_sam = df_trn['Gender'].value_counts().reset_index()
df_val_sam = df_val['Gender'].value_counts().reset_index()
df_test_sam = df_test['Gender'].value_counts().reset_index()
print(df_trn_sam,df_val_sam,df_test_sam)

In [None]:
df_trn.reset_index(drop=True, inplace=True)
df_val.reset_index(drop=True, inplace=True)
df_test.reset_index(drop=True, inplace=True)
df_val.to_pickle("UKB_Val_Data.pkl")

training_set = UKB_FMRI_Dataset_BCNN(data_dir, df_trn)
validation_set = UKB_FMRI_Dataset_BCNN(data_dir, df_val)
test_set = UKB_FMRI_Dataset_BCNN(data_dir, df_test)


batch_size = 8
train_params = {'batch_size': batch_size,
                'shuffle': True,
                'num_workers': 1}
params       = {'batch_size': 1,
                'shuffle': True,
                'num_workers': 1}
testparams   = {'batch_size': 1,
               'shuffle': True,
               'num_workers': 0}

train_loader = data.DataLoader(training_set, **train_params)
val_loader = data.DataLoader(validation_set,**params)
test_loader = data.DataLoader(test_set,**testparams)

In [None]:
#Model params
res_dir = "/projects/connectome_gender_classification/trial1/UKB/BCNN/"

debug = 1
device = 5
multi_gpus = None
gpu = True
num_runs = 3

net_name = "UKB_BCNN"
learn_rate = 0.001
weight_decay= 8.4e-07

num_epochs = 100
images, labels = next(iter(train_loader))
input_size = images.shape[1]
f_size = images.shape[3]
output_size = 1

e2e= 16
e2n = 128
n2g=26

dropout = 0.6
min_iters = 0
mc = ModelCheckpoint(res_dir, 
        retain_metric=binary_balanced_accuracy,
        prepend="auto_save_",
        num_iters=-1,
        ignore_before=15,
        store_best=True,
        mode="max")


callbacks = [EarlyStopping(patience=10, retain_metric=binary_balanced_accuracy, mode="max", ignore_before=30)]
training_callback = None
metrics = [binary_balanced_accuracy]

on_folder_error_continue = False

In [None]:
net = BCNN(e2e,e2n,n2g,f_size,dropout).cuda(device)

In [None]:
net.modules

In [None]:
val_acc = []
for i in range(num_runs):
    print("Starting training for run no: "+str(i))
    del net
    net = BCNN(e2e,e2n,n2g,f_size,dropout).cuda(device)
    net.eval()
    # set hyperparameters
    #criterion = nn.BCEWithLogitsLoss()
    criterion = nn.CrossEntropyLoss().cuda(device)
    optimizer = optim.SGD(net.parameters(), lr=learn_rate, momentum = 0.9, weight_decay=weight_decay)
    net_name = str("UKB_BCNN") + str("_model_")+ str(i)

    # define the Trainer
    trainer = Trainer(
        net,
        criterion,
        optimizer,
        metrics=metrics,
        callbacks=callbacks,
        training_time_callback=training_callback,
        prediction_type="classification",
        device=device
    )
    
    # run training
    net, report = trainer.train_model(
        train_loader,
        val_loader,
        num_epochs=num_epochs,
        show_train_steps=100,
        show_validation_epochs=1, # show validation loss every epoch
    )
    
    # save model
    torch.save(net.state_dict(), (res_dir + net_name))

    # visualize training result
    trainer.visualize_training(report, save_fig_path = res_dir)
    val_acc.append(report['best_metric'])


In [None]:
a = plt.imread(res_dir + str("_binary_balanced_accuracy.png"))
plt.imshow(a)

In [None]:
b = plt.imread(res_dir + str("_loss.png"))
plt.imshow(b)

In [None]:
from torch.autograd import Variable
net = BCNN(e2e,e2n,n2g,f_size,dropout).cuda(device)
acrcy = []
sntvty = []
sptvty = []
for num in range(num_runs):
    net_name = str("UKB_BCNN") + str("_model_")+ str(num)
    net.load_state_dict(torch.load(res_dir+net_name)) 
    all_preds = []
    all_labels = []
    class_threshold = 0.5
    net.eval() 
    for i,data in enumerate(test_loader): 
        X,y = data 
        X, y = Variable(X.cuda(device)), Variable(y.cuda(device))
        output = net.forward(X)        
        #pred = torch.tensor(1 if (F.sigmoid(output) >= class_threshold) else 0)
        pred = torch.argmax(F.softmax(output, dim=1))
        all_preds.append(pred.cpu().numpy().item())
        all_labels.append(y.cpu().numpy().item())

    acrcy.append(binary_balanced_accuracy(all_labels, all_preds))
    sntvty.append(sensitivity(all_labels, all_preds))
    sptvty.append(specificity(all_labels, all_preds))

### Visualization

In [None]:
#from https://github.com/jrieke/cnn-interpretability
def sensitivity_analysis(model, image_tensor, target_class=None, postprocess='abs', apply_softmax=True, cuda=True,
                         verbose=False):
    """
    Perform sensitivity analysis (via backpropagation; Simonyan et al. 2014) to
    determine the relevance of each image pixel
    for the classification decision. Return a relevance heatmap over the input image.

    Args:
        model (torch.nn.Module): The pytorch model. Should be set to eval mode.
        image_tensor (torch.Tensor or numpy.ndarray): The image to run through the `model` (channels first!).
        target_class (int): The target output class for which to produce the heatmap.
                      If `None` (default), use the most likely class from the `model`s output.
        postprocess (None or 'abs' or 'square'): The method to postprocess the heatmap with. `'abs'` is used
                                                 in Simonyan et al. 2014, `'square'` is used in Montavon et al. 2018.
        apply_softmax (boolean): Whether to apply the softmax function to the output. Useful for models that are trained
                                 with `torch.nn.CrossEntropyLoss` and do not apply softmax themselves.
        appl (None or 'binary' or 'categorical'): Whether the output format of the `model` is binary
                                                         (i.e. one output neuron with sigmoid activation) or categorical
                                                         (i.e. multiple output neurons with softmax activation).
                                                         If `None` (default), infer from the shape of the output.
        cuda (boolean): Whether to run the computation on a cuda device.
        verbose (boolean): Whether to display additional output during the computation.

    Returns:
        A numpy array of the same shape as image_tensor, indicating the relevance of each image pixel.
    """
    if postprocess not in [None, 'abs', 'square']:
        raise ValueError("postprocess must be None, 'abs' or 'square'")

    # Forward pass.
    
    X = Variable(image_tensor, requires_grad=True)  # add dimension to simulate batch
    output = model(X)
    if apply_softmax:
        output = F.softmax(output)

    # Backward pass.
    model.zero_grad()
    output_class = output.max(1)[1].data[0]
    if verbose: print('Image was classified as', output_class, 'with probability', output.max(1)[0].data[0])
    one_hot_output = torch.zeros(output.size())
    if target_class is None:
        one_hot_output[0, output_class] = 1
    else:
        one_hot_output[0, target_class] = 1
    if cuda:
        one_hot_output = one_hot_output.cuda(image_tensor.get_device())
    output.backward(gradient=one_hot_output)

    relevance_map = X.grad.data[0].cpu().numpy()
    

    # Postprocess the relevance map.
    if postprocess == 'abs':  # as in Simonyan et al. (2014)
        return np.abs(relevance_map)
    elif postprocess == 'square':  # as in Montavon et al. (2018)
        return relevance_map ** 2
    elif postprocess is None:
        return relevance_map

In [None]:
net = BCNN(e2e,e2n,n2g,f_size,dropout).cuda(device)
net_name = str("UKB_BCNN") + str("_model_")+ str(0)
net.load_state_dict(torch.load(res_dir+net_name))

In [None]:
net.eval()
relevance_map_backprop = []
clas = []
for i,(image,label) in enumerate(test_loader):
        
        clas.append(label)
        
        image = Variable(image).cuda(device)
        label = Variable(label).cuda(device)   
        relevance_map_backprop.append(sensitivity_analysis(net,image,target_class=label,postprocess=None,cuda=True))

In [None]:
def check_symmetric(a, rtol=1e-05, atol=1e-08):
    return np.allclose(a, a.T, rtol=rtol, atol=atol)

In [None]:
from nilearn.plotting import plot_connectome 
import nibabel as nib
from nilearn import plotting

img = nib.load('/Visualization/test_file.nii.gz')

In [None]:
coords_connectome =  plotting.find_probabilistic_atlas_cut_coords(maps_img=img)

In [None]:
## Recognize between men and women
mat = np.asarray(relevance_map_backprop)
labels = torch.stack(clas).cpu().numpy()

male_avg = []
female_avg = []

for i,val in enumerate(labels):
    if val[0] ==0:
        female_avg.append(mat[i][0])
        
    else:
        male_avg.append(mat[i][0])
                

m_avg = np.average(np.asarray(male_avg),axis=0)
f_avg = np.average(np.asarray(female_avg),axis=0)

In [None]:
plotting.plot_connectome(f_avg+f_avg.T,coords_connectome,title='Female Activation Connectome - UK Biobank',edge_threshold='99.5%',colorbar= True,node_size=6)
plotting.show()

In [None]:
display = plotting.plot_connectome(m_avg+m_avg.T,coords_connectome,title='Male Activation Connectome - UK Biobank',edge_threshold='99.5%',colorbar= True,node_size=6)
plotting.show()