## Setting for inference

In [1]:
import os
os.getcwd()

'/scratch/connectome/jubin/ABCD-3DCNN/STEP_5_Transfer_learning/codes'

In [2]:
os.chdir('/scratch/connectome/jubin/ABCD-3DCNN/STEP_5_Transfer_learning')

In [34]:
## ======= load module ======= ##
from utils.utils import argument_setting, select_model, save_exp_result, checkpoint_load #  
from dataloaders.dataloaders import make_dataset
from dataloaders.preprocessing import preprocessing_cat, preprocessing_num


import os
import glob
import argparse 
from tqdm.auto import tqdm ##progress
import random
from copy import deepcopy

from sklearn.metrics import confusion_matrix, roc_auc_score

import torch
import torch.nn as nn

import numpy as np

import warnings
warnings.filterwarnings("ignore")




if __name__ == "__main__":
    cwd = os.getcwd()

    ## ========= Setting ========= ##
    # seed number
    seed = 1234
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [4]:
parser = argparse.ArgumentParser()

parser.add_argument("--model",required=True,type=str,help='',choices=['sfcn','simple3D','vgg3D11','vgg3D13','vgg3D16','vgg3D19','resnet3D50','resnet3D101','resnet3D152', 'densenet3D121', 'densenet3D169','densenet201','densenet264','sfcn'])
parser.add_argument("--dataset",required=True, type=str, choices=['UKB','ABCD'],help='') # revising
parser.add_argument("--data", type=str, help='select data type') # revising
parser.add_argument("--val_size",default=0.1,type=float,required=False,help='')
parser.add_argument("--test_size",default=0.1,type=float,required=False,help='')
parser.add_argument("--resize",default=(96, 96, 96),type=int,nargs="*",required=False,help='')
parser.add_argument("--train_batch_size",default=16,type=int,required=False,help='')
parser.add_argument("--val_batch_size",default=16,type=int,required=False,help='')
parser.add_argument("--test_batch_size",default=1,type=int,required=False,help='')
parser.add_argument("--in_channels",default=1,type=int,required=False,help='')
parser.add_argument("--optim",type=str,required=True,help='', choices=['Adam','SGD','RAdam','AdamW'])
parser.add_argument("--scheduler",type=str,default='',help='') # revising
parser.add_argument("--early_stopping",type=int,default=None,help='') # revising
parser.add_argument("--lr", default=0.01,type=float,required=False,help='')
parser.add_argument("--lr_adjust", default=0.01, type=float, required=False,help='')   
parser.add_argument("--weight_decay",default=0.001,type=float,required=False,help='')
parser.add_argument("--epoch",type=int,required=True,help='')
parser.add_argument("--epoch_FC",type=int,required=False,default=0,help='')
parser.add_argument("--exp_name",type=str,required=True,help='')
parser.add_argument("--cat_target", type=str, nargs='*', required=False, help='')
parser.add_argument("--num_target", type=str,nargs='*', required=False, help='')
parser.add_argument("--confusion_matrix", type=str, nargs='*',required=False, help='')
parser.add_argument("--gpus", type=int,nargs='*', required=False, help='')
parser.add_argument("--sbatch", type=str, required=False, choices=['True', 'False'])
parser.add_argument("--transfer", type=str, required=False, default="", choices=['sex','age','simclr','MAE'])
parser.add_argument("--unfrozen_layer", type=str, required=False, default='0') 
parser.add_argument("--load", type=str, required=False, default="")
parser.add_argument("--init_unfrozen", type=str, required=False, default="",help='init unfrozen layers')
parser.add_argument("--scratch", type=str, required=False, default='',help='option for learning from scratch')
parser.add_argument("--filter",required=False, nargs="+", default=[],
                    help='options for filter data by phenotype. usage: --filter abcd_site:10 sex:1')
parser.add_argument("--augmentation",required=False, nargs="+", default=[],
                        help="Data augmentation - [shift, flip] are available")
parser.add_argument("--cv",required=False, type=int, default=None,choices=[1,2,3,4,5],help="option for 5-fold CV. 1~5.")


_StoreAction(option_strings=['--cv'], dest='cv', nargs=None, const=None, default=None, type=<class 'int'>, choices=[1, 2, 3, 4, 5], help='option for 5-fold CV. 1~5.', metavar=None)

## Load model

In [51]:
model_name = "ABCD_sex_TL_ALL_10"

In [52]:
com = f'--load {model_name} --cat_target sex --dataset ABCD --data freesurfer --model densenet3D121 --resize 80 80 80 --gpus 0 --test_batch_size 1 --val_size 0.1 --test_size 0.1 --exp_name CNN_sex_TL_test --optim AdamW --epoch 0 --confusion_matrix sex'

In [53]:
args = parser.parse_args(com.split())
print("*** Categorical target labels are {} and Numerical target labels are {} *** \n".format(
    args.cat_target, args.num_target)
     )

if not args.cat_target:
    args.cat_target = []
elif not args.num_target:
    args.num_target = []
elif not args.cat_target and args.num_target:
    raise ValueError('YOU SHOULD SELECT THE TARGET!')


*** Categorical target labels are ['sex'] and Numerical target labels are None *** 



In [54]:
if args.transfer in ['age','MAE']:
    assert 96 in args.resize, "age(MSE/MAE) transfer model's resize should be 96"
elif args.transfer == 'sex':
    assert 80 in args.resize, "sex transfer model's resize should be 80"

save_dir = os.getcwd() + '/result'
partition, subject_data = make_dataset(args)  

## ========= Run Experiment and saving result ========= ## 

# Run Experiment
print(f"*** Test for {args.exp_name} Start ***")
net = select_model(subject_data, args) #  

# loading pretrained model if transfer option is given
if args.load:
    print("*** Model setting for test *** \n")
    model_dir = glob.glob(f'/scratch/connectome/jubin/result/model/*{args.load}*')[0]
    print(f"Loaded {args.load}")
    net = checkpoint_load(net, model_dir)

# setting a DataParallel and model on GPU
if args.sbatch == "True":
    devices = []
    for d in range(torch.cuda.device_count()):
        devices.append(d)
    net = nn.DataParallel(net, device_ids = devices)
else:
    if not args.gpus:
        raise ValueError("GPU DEVICE IDS SHOULD BE ASSIGNED")
    else:
        net = nn.DataParallel(net, device_ids=args.gpus)

if args.sbatch == 'True':
    net.cuda()
else:
    net.to(f'cuda:{args.gpus[0]}')

testloader = torch.utils.data.DataLoader(partition['test'],
                                            batch_size=args.test_batch_size,
                                            shuffle=False,
                                            num_workers=4)

net.eval()
if hasattr(net, 'module'):
    device = net.device_ids[0]
else: 
    if args.sbatch =='True':
        device = 'cuda:0'
    else:
        device = f'cuda:{args.gpus[0]}'
#correct = {}
#y_true = {}

outputs = {}
y_true = {}
test_acc = {}
confusion_matrices = {}


if args.cat_target:
    for cat_target in args.cat_target:
        outputs[cat_target] = torch.tensor([])
        y_true[cat_target] = torch.tensor([])
        test_acc[cat_target] = []

if args.num_target:
    for num_target in args.num_target:
        outputs[num_target] = torch.tensor([])
        y_true[num_target] = torch.tensor([])
        test_acc[num_target] = []

  0%|          | 0/11352 [00:00<?, ?it/s]

Total subjects=11352, train=9082, val=1135, test=1135
In train dataset, sex contains 4333 CASE and 4749 CONTROL
In validation dataset, sex contains 533 CASE and 602 CONTROL
In test dataset, sex contains 544 CASE and 591 CONTROL
*** Making a dataset is completed *** 

*** Test for CNN_sex_TL_test Start ***
*** Model setting for test *** 

Loaded ABCD_sex_TL_ALL_10
The best checkpoint is loaded


In [55]:
with torch.no_grad():
    for i, data in enumerate(tqdm(testloader),0):
        image, targets = data
        image = image.to(device)

        output = net(image)
        if args.cat_target:
            for cat_target in args.cat_target:
                outputs[cat_target] = torch.cat((outputs[cat_target], output[cat_target].cpu()))
                y_true[cat_target] = torch.cat((y_true[cat_target], targets[cat_target].cpu()))

        if args.num_target:
            for num_target in args.num_target:
                outputs[num_target] = torch.cat((outputs[num_target], output[num_target].cpu()))
                y_true[num_target] = torch.cat((y_true[num_target], targets[num_target].cpu()))


# caculating ACC and R2 at once  
if args.cat_target:
    for cat_target in args.cat_target:
        _, predicted = torch.max(outputs[cat_target].data,1)
        correct = (predicted == y_true[cat_target]).sum().item()
        total = y_true[cat_target].size(0)
        test_acc[cat_target].append(100 * (correct / total))
        auroc = roc_auc_score(y_true[cat_target].detach().cpu(), outputs[cat_target].data[:, 1].detach().cpu())

        if args.confusion_matrix:
            for label_cm in args.confusion_matrix: 
                if len(np.unique(y_true[cat_target].numpy())) == 2:
                    confusion_matrices[label_cm] = {}
                    confusion_matrices[label_cm]['True Positive'] = 0
                    confusion_matrices[label_cm]['True Negative'] = 0
                    confusion_matrices[label_cm]['False Positive'] = 0
                    confusion_matrices[label_cm]['False Negative'] = 0
                    if label_cm == cat_target:
                        tn, fp, fn, tp = confusion_matrix(y_true[cat_target].numpy(), predicted.numpy()).ravel()
                        confusion_matrices[label_cm]['True Positive'] = int(tp)
                        confusion_matrices[label_cm]['True Negative'] = int(tn)
                        confusion_matrices[label_cm]['False Positive'] = int(fp)
                        confusion_matrices[label_cm]['False Negative'] = int(fn)                       

MAE = None
if args.num_target:
    for num_target in args.num_target:
        predicted =  outputs[num_target].float()
        criterion = nn.MSELoss()
        loss = criterion(predicted, y_true[num_target].float().unsqueeze(1))
        l1loss = nn.L1Loss()
        MAE = l1loss(predicted, y_true[num_target].float().unsqueeze(1))
        y_var = torch.var(y_true[num_target])
        r_square = 1 - (loss / y_var)
        test_acc[num_target].append(r_square.item())
        confusion_matrices = None

result = {'test_acc':test_acc,'MAE':MAE}

print(f"Test result: {test_acc} & {MAE} for {args.load}") 

if confusion_matrices != None:
    result['confusion_matrices'] = confusion_matrices


  0%|          | 0/1135 [00:00<?, ?it/s]

Test result: {'sex': [83.78854625550662]} & None for ABCD_sex_TL_ALL_10


In [43]:
auroc

0.9640471036130188

In [50]:
confusion_matrices

{'sex': {'True Positive': 438,
  'True Negative': 564,
  'False Positive': 27,
  'False Negative': 106}}

In [None]:
--num_target BMI --val_size 0.1 --test_size 0.1 --lr 1 --optim AdamW \
--resize 128 --train_batch_size 64 --val_batch_size 64 --test_batch_size 160 --dataset ABCD --data_type FA_warppped_nii \
--exp_name BMI_FA_CNN_test --model densenet3D121 --epoch 1 --gpus 0 --load BMI_FA_01_243710