In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import nibabel as nib
import pandas as pd
from torch.utils import data
from mytools.utils import *
import pathlib, sys, os
from datetime import datetime
from sklearn.metrics import balanced_accuracy_score, precision_score,f1_score,roc_auc_score
torch.cuda.empty_cache()

In [None]:
data_name = 'T1_ADVDLBD_Mlabel_10fold_Match0307v2' 
which_model = 'resnet18_64_linear'   
b_size = int(8) 
gpu_device = int(0)
loss_name = 'ASL_4_1_0.1'
data_folder = '/home/dw/Desktop/DemCLF/data/data_Mlabel_10fold_0307/'
learning_rate = float(0.0001)
norm = 'minmax'
sigma = float(0.1)
n_epochs = 300
downsample =  '3DS'
data_path = '/media/dw/Data/BINC_T1/BINC/dropbox/nacc_harmonized_may023//beforeHarm_' + downsample +norm+'/'
wd=0.001
labels = ["AD","VD","LBD"]
clf_class = 3
device = torch.device(gpu_device)

In [None]:
def trainkfold(train = True, test= True, get_prob=True):
    for fold in range(0,10):
        print(fold)

        ####################################
        ########## get resnet ##############
        ####################################
        m = which_model.split("_")[0]
        channel = int(which_model.split("_")[1])
        res_head = which_model.split("_")[2]
        model = get_resnet(m, channel, clf_class, res_head)
        model = model.to(device)

        if 'he' in which_model:
            print('he init')
            model.apply(init_he)


        ############################################
        ########## set up save folder ##############
        ############################################

        if loss_name.startswith('ASL'):
            gn,gp,c = int(loss_name.split("_")[1]), int(loss_name.split("_")[2]),float(loss_name.split("_")[3])
            criterion = AsymmetricLossOptimized(gamma_neg=gn, gamma_pos=gp,clip= c ).to(device)
            model_pth = "/media/dw/Data/BINC_T1/BINC/dropbox/DW/results_GenieCode/task2_" + \
                data_name + "_" + downsample + norm+"_OTF/" + which_model+ "_mlr" + str(learning_rate) + \
                "_ep" + str(n_epochs)  + '_batch' + str(b_size)+ "_ASL_gn" + str(gn) + "_gp" + str(gp) + \
                "_c" + str(c)+ "_sigma" + str(sigma) + "_wd"+str(wd) +"/"
        if loss_name == 'BCE':
            criterion = nn.BCELoss().to(device)
            model_pth = "/media/dw/Data/BINC_T1/BINC/dropbox/DW/results_GenieCode/task2_" + data_name + "_" + downsample + norm+"_OTF/" + which_model+ "_mlr" + str(learning_rate) + \
                    "_ep" + str(n_epochs)  + '_batch' + str(b_size)+ "_"+ loss_name+ "_sigma" + str(sigma) + "_wd" + str(wd)+ "/"



        pathlib.Path(model_pth).mkdir(parents=True, exist_ok=True)
        print(model_pth)


        ########################################
        ########## get dataloader ##############
        ########################################

        Train_set  = DatasetFromNiiOTF(data_path = data_path, csv_path= data_folder + '/'+ data_name + '_train_f'+ str(fold) + '.csv',
                                    label=labels, sigma = sigma, ds = downsample)
        Test_set  = DatasetFromNiiOTF(data_path = data_path, csv_path=  data_folder + '/'+ data_name + '_test_f'+ str(fold) + '.csv',
                                    label=labels, sigma = 0, ds = downsample) ## sigma = 0 no augmentaion

        Train_loader = torch.utils.data.DataLoader(Train_set, batch_size=b_size, num_workers=0, shuffle=True)
        Test_loader = torch.utils.data.DataLoader(Test_set, batch_size=b_size, num_workers=0, shuffle=True)


        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay = wd)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[n_epochs//4,n_epochs//4*2,n_epochs//4*3], gamma=0.1)

        save_name = 'fold'+str(fold)

        ##################################
        ########## Training ##############
        ##################################
        if train:
            if os.path.isfile(model_pth + "/" + save_name + "_target.npy"):
                pass
            else:
                fold_train_loss = []
                for epoch in range(1, n_epochs+1):
                    model_trained, epoch_train_loss = trainClass(Train_loader, model, criterion, optimizer,device, loss_name)
                    print(epoch,epoch_train_loss/len(Train_set))
                    scheduler.step()
                    fold_train_loss.append(epoch_train_loss/len(Train_set))

                torch.save(model_trained.state_dict(), model_pth + '/'+ save_name +'.pt')
                df = pd.DataFrame(fold_train_loss)
                df.to_csv(model_pth + "/" + save_name + "_loss.csv", index=False, header =None )

        ###################################
        ########### Testing ###############
        ###################################
        if test:
            model_trained = model
            model_trained.load_state_dict(torch.load(model_pth + '/'+ save_name +'.pt'))
            model_trained.to(device)
            model_trained.eval()
            all_target = []
            all_predict = []
            all_probs = []
            with torch.no_grad():
                for j, (data, target, ids) in enumerate(Test_loader): 
                    torch.cuda.empty_cache() 
                    d,t = data.to(device), target.to(device, dtype=torch.float)
                    all_target.extend(t.tolist())  # tensor to list
                    ouput = model_trained(d)
                    all_probs.extend(ouput.cpu().detach().tolist())
                if "Mlabel" in data_name: 
                    for sample in all_probs:
                        all_predict.append([1 if i>=0.5 else 0 for i in sample]) 

            np.save(model_pth + "/" + save_name + "_predict.npy", all_predict)
            np.save(model_pth + "/" + save_name + "_target.npy", all_target)  

        ###################################
        ########### get Prob ##############
        ###################################
        if get_prob:
            model_trained = model
            model_trained.load_state_dict(torch.load(model_pth + '/'+ save_name +'.pt'))
            model_trained.to(device)
            model_trained.eval()

            if "ADFTD" in data_name:
                out_cols = ['outAD', 'outVD'] 
                prob_cols= ['probAD', 'probVD']
            elif "ADVDFTD" in data_name:

                prob_cols = ['probAD', 'probVD', 'probFTD']
            elif "ADVDLBD" in data_name:
                out_cols = ['outAD', 'outVD', 'outLBD'] 
                prob_cols = ['probAD', 'probVD', 'probLBD']
            elif "T1_Mlabel" in data_name:
                out_cols = ['outAD', 'outVD'] 
                prob_cols= ['probAD', 'probVD']
            all_out = []
            all_prob = []
            all_label = []
            all_tp = []
            all_ids = []

            with torch.no_grad():
                for batch, (data, target, ids) in enumerate(Test_loader):
                    data = data.to(device, dtype = torch.float)
                    out = model_trained(data)
                    prob = nn.Sigmoid()(out)
                    out = out.cpu().detach().tolist() 
                    prob = prob.cpu().detach().tolist() 
                    all_out.extend(out)
                    all_prob.extend(prob)
                    all_label.extend(target.cpu().detach().numpy())
                    all_tp.extend(['test']*len(data))
                    all_ids.extend(ids)

            df = pd.DataFrame(all_prob, columns = ['probAD', 'probVD', 'probLBD'])
            df2 = pd.DataFrame(all_label, columns = ['AD', 'VD', 'LBD'])
            df.insert(loc = 0, column = 'mri_ID' , value = all_ids)
            df.insert(loc = 1, column = 'tp' , value = all_tp)
            df = pd.concat([df, df2], axis = 1)
            df.to_csv(model_pth + save_name + "_prob.csv", index = False)