In [1]:
import os
os.getcwd()

'/scratch/connectome/jubin/ABCD-3DCNN/STEP_4_Multimodal-Learning/MultiChannel-Learning/contrastive_learning/codes'

In [2]:
os.chdir('..')
os.listdir()

['codes',
 'README.md',
 'envs',
 '__pycache__',
 'run_contrastive_learning.py',
 'models',
 'result',
 'test.py',
 'dataloaders',
 'utils',
 '.ipynb_checkpoints']

In [3]:
## ======= load module ======= ##
from utils.utils import argument_setting, select_model, save_exp_result, checkpoint_load #  
from dataloaders.dataloaders import make_dataset
from dataloaders.preprocessing import preprocessing_cat, preprocessing_num


import os
import glob
import argparse 
from tqdm.auto import tqdm ##progress
import random
from copy import deepcopy

from sklearn.metrics import confusion_matrix

import torch
import torch.nn as nn

import numpy as np

import warnings
warnings.filterwarnings("ignore")


if __name__ == "__main__":
    cwd = os.getcwd()

    ## ========= Setting ========= ##
    # seed number
    seed = 1234
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


    


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
parser = argparse.ArgumentParser()

# Options for model setting
parser.add_argument("--model", type=str, required=True, help='Select model. e.g. densenet3D121, sfcn.',
                    choices=['simple3D', 'sfcn', 'vgg3D11', 'vgg3D13', 'vgg3D16', 'vgg3D19',
                             'resnet3D50', 'resnet3D101', 'resnet3D152',
                             'densenet3D121', 'densenet3D169', 'densenet201', 'densenet264'])
parser.add_argument("--in_channels", default=1, type=int, help='')

# Options for dataset and data type, split ratio, CV, resize, augmentation
parser.add_argument("--dataset", type=str, choices=['UKB','ABCD'], required=True, help='Selelct dataset')
parser.add_argument("--data_type", nargs='+', type=str, help='Select data type(sMRI, dMRI)',
                    choices=['fmriprep', 'freesurfer', 'FA_unwarpped_nii', 'FA_warpped_nii',
                             'MD_unwarpped_nii', 'MD_warpped_nii', 'RD_unwarpped_nii', 'RD_warpped_nii'])
parser.add_argument("--val_size", default=0.1, type=float, help='')
parser.add_argument("--test_size", default=0.1, type=float, help='')
parser.add_argument("--cv", default=None, type=int, choices=[1,2,3,4,5], help="option for 5-fold CV. 1~5.")
parser.add_argument("--resize", nargs="*", default=(96, 96, 96), type=int, help='')
parser.add_argument("--augmentation", nargs="*", default=[], type=str, choices=['shift','flip'],
                    help="Data augmentation - [shift, flip] are available")

# Hyperparameters for model training
parser.add_argument("--lr", default=0.01, type=float, help='')
parser.add_argument("--lr_adjust", default=0.01, type=float, help='')
parser.add_argument("--epoch", type=int, required=True, help='')
parser.add_argument("--epoch_FC", type=int, default=0, help='Option for training only FC layer')
parser.add_argument("--optim", default='Adam', type=str, choices=['Adam','SGD','RAdam','AdamW'], help='')
parser.add_argument("--weight_decay", default=0.001, type=float, help='')
parser.add_argument("--scheduler", default='', type=str, help='') 
parser.add_argument("--early_stopping", default=None, type=int, help='')
parser.add_argument("--train_batch_size", default=16, type=int, help='')
parser.add_argument("--val_batch_size", default=16, type=int, help='')
parser.add_argument("--test_batch_size", default=1, type=int, help='')

# Options for experiment setting
parser.add_argument("--exp_name", type=str, required=True, help='')
parser.add_argument("--gpus", nargs='+', type=int, help='')
parser.add_argument("--sbatch", type=str, choices=['True', 'False'])
parser.add_argument("--cat_target", nargs='+', default=[], type=str, help='')
parser.add_argument("--num_target", nargs='+', default=[], type=str, help='')
parser.add_argument("--confusion_matrix",  nargs='*', type=str, help='')
parser.add_argument("--filter", nargs="*", default=[], type=str,
                    help='options for filter data by phenotype. usage: --filter abcd_site:10 sex:1')
parser.add_argument("--load", default='', type=str, help='Load model weight that mathces {your_exp_dir}/result/*{load}*')
parser.add_argument("--scratch", default='', type=str, help='Option for learning from scratch')
parser.add_argument("--transfer", default='', type=str, choices=['sex','age','simclr','MAE'],
                    help='Choose pretrained model according to your option')
parser.add_argument("--unfrozen_layer", default='0', type=str, help='Select the number of layers that would be unfrozen')
parser.add_argument("--init_unfrozen", default='', type=str, help='Initializes unfrozen layers')


_StoreAction(option_strings=['--init_unfrozen'], dest='init_unfrozen', nargs=None, const=None, default='', type=<class 'str'>, choices=None, help='Initializes unfrozen layers', metavar=None)

In [5]:
# test for adhd
com = '--cat_target Attention.Deficit.Hyperactivity.Disorder.x --dataset ABCD --data_type freesurfer MD_warpped_nii --model sfcn --resize 80 80 80 --gpus 0 1 --test_batch_size 1 --val_size 0.1 --test_size 0.1 --exp_name adhd_test --optim AdamW --epoch 1 --confusion_matrix sex'

In [13]:
com = '--cv 5 --load SFCNSC_11 --cat_target sex --dataset ABCD --data freesurfer --model sfcn --resize 80 80 80 --gpus 0 --test_batch_size 1 --val_size 0.1 --test_size 0.1 --exp_name SFCNSC_11test --optim AdamW --epoch 0 --confusion_matrix sex'

In [10]:
com = '--cat_target sex --dataset ABCD --data freesurfer --model sfcn --resize 80 80 80 --gpus 0 --test_batch_size 128 --val_size 0.1 --test_size 0.1 --exp_name sfcn_test --optim AdamW --epoch 0 --confusion_matrix sex'

In [6]:
args = parser.parse_args(com.split())
print("*** Categorical target labels are {} and Numerical target labels are {} *** \n".format(
    args.cat_target, args.num_target)
     )

if not args.cat_target:
    args.cat_target = []
elif not args.num_target:
    args.num_target = []
elif not args.cat_target and args.num_target:
    raise ValueError('YOU SHOULD SELECT THE TARGET!')


*** Categorical target labels are ['Attention.Deficit.Hyperactivity.Disorder.x'] and Numerical target labels are [] *** 



In [7]:
if args.transfer in ['age','MAE']:
    assert 96 in args.resize, "age(MSE/MAE) transfer model's resize should be 96"
elif args.transfer == 'sex':
    assert 80 in args.resize, "sex transfer model's resize should be 80"

save_dir = os.getcwd() + '/result'
partition, subject_data = make_dataset(args,args.data_type[1])  

## ========= Run Experiment and saving result ========= ## 

# Run Experiment
print(f"*** Test for {args.exp_name} Start ***")
net = select_model(subject_data, args) #  

# loading pretrained model if transfer option is given
if args.load:
    print("*** Model setting for test *** \n")
    model_dir = glob.glob(f'/scratch/connectome/jubin/result/model/*{args.load}*')[0]
    print(f"Loaded {args.load}")
    net = checkpoint_load(net, model_dir)

# setting a DataParallel and model on GPU
if args.sbatch == "True":
    devices = []
    for d in range(torch.cuda.device_count()):
        devices.append(d)
    net = nn.DataParallel(net, device_ids = devices)
else:
    if not args.gpus:
        raise ValueError("GPU DEVICE IDS SHOULD BE ASSIGNED")
    else:
        net = nn.DataParallel(net, device_ids=args.gpus)

if args.sbatch == 'True':
    net.cuda()
else:
    net.to(f'cuda:{args.gpus[0]}')

testloader = torch.utils.data.DataLoader(partition['test'],
                                            batch_size=args.test_batch_size,
                                            shuffle=False,
                                            num_workers=4)

net.eval()
if hasattr(net, 'module'):
    device = net.device_ids[0]
else: 
    if args.sbatch =='True':
        device = 'cuda:0'
    else:
        device = f'cuda:{args.gpus[0]}'
#correct = {}
#y_true = {}

outputs = {}
y_true = {}
test_acc = {}
confusion_matrices = {}


if args.cat_target:
    for cat_target in args.cat_target:
        outputs[cat_target] = torch.tensor([])
        y_true[cat_target] = torch.tensor([])
        test_acc[cat_target] = []

if args.num_target:
    for num_target in args.num_target:
        outputs[num_target] = torch.tensor([])
        y_true[num_target] = torch.tensor([])
        test_acc[num_target] = []

with torch.no_grad():
    for i, data in enumerate(tqdm(testloader),0):
        image, targets = data
        image = image.to(device)

        output = net(image)
        if args.cat_target:
            for cat_target in args.cat_target:
                outputs[cat_target] = torch.cat((outputs[cat_target], output[cat_target].cpu()))
                y_true[cat_target] = torch.cat((y_true[cat_target], targets[cat_target].cpu()))

        if args.num_target:
            for num_target in args.num_target:
                outputs[num_target] = torch.cat((outputs[num_target], output[num_target].cpu()))
                y_true[num_target] = torch.cat((y_true[num_target], targets[num_target].cpu()))


# caculating ACC and R2 at once  
if args.cat_target:
    for cat_target in args.cat_target:
        _, predicted = torch.max(outputs[cat_target].data,1)
        correct = (predicted == y_true[cat_target]).sum().item()
        total = y_true[cat_target].size(0)
        test_acc[cat_target].append(100 * (correct / total))

        if args.confusion_matrix:
            for label_cm in args.confusion_matrix: 
                if len(np.unique(y_true[cat_target].numpy())) == 2:
                    confusion_matrices[label_cm] = {}
                    confusion_matrices[label_cm]['True Positive'] = 0
                    confusion_matrices[label_cm]['True Negative'] = 0
                    confusion_matrices[label_cm]['False Positive'] = 0
                    confusion_matrices[label_cm]['False Negative'] = 0
                    if label_cm == cat_target:
                        tn, fp, fn, tp = confusion_matrix(y_true[cat_target].numpy(), predicted.numpy()).ravel()
                        confusion_matrices[label_cm]['True Positive'] = int(tp)
                        confusion_matrices[label_cm]['True Negative'] = int(tn)
                        confusion_matrices[label_cm]['False Positive'] = int(fp)
                        confusion_matrices[label_cm]['False Negative'] = int(fn)                       

MAE = None
if args.num_target:
    for num_target in args.num_target:
        predicted =  outputs[num_target].float()
        criterion = nn.MSELoss()
        loss = criterion(predicted, y_true[num_target].float().unsqueeze(1))
        l1loss = nn.L1Loss()
        MAE = l1loss(predicted, y_true[num_target].float().unsqueeze(1))
        y_var = torch.var(y_true[num_target])
        r_square = 1 - (loss / y_var)
        test_acc[num_target].append(r_square.item())
        confusion_matrices = None

result = {'test_acc':test_acc,'MAE':MAE}

print(f"Test result: {test_acc} & {MAE} for {args.load}") 

if confusion_matrices != None:
    result['confusion_matrices'] = confusion_matrices


0       NDARINV00BD7VDC
1       NDARINV00CY2MDM
2       NDARINV00LJVZK2
3       NDARINV00U4FTRU
4       NDARINV014RTM1V
             ...       
3986    NDARINVV5MHL75K
3987    NDARINVV5XX9GEF
3988    NDARINVV6KFJX12
3989    NDARINVV6MZ4VB1
3990    NDARINVV6NAXTR2
Name: subjectkey, Length: 3991, dtype: object


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3991/3991 [00:00<00:00, 84888.75it/s]

Total subjects=3991, train=3193, val=399, test=399
In train dataset, Attention.Deficit.Hyperactivity.Disorder.x contains 1393 CASE and 1800 CONTROL
In validation dataset, Attention.Deficit.Hyperactivity.Disorder.x contains 183 CASE and 216 CONTROL
In test dataset, Attention.Deficit.Hyperactivity.Disorder.x contains 160 CASE and 239 CONTROL
*** Making a dataset is completed *** 

*** Test for adhd_test Start ***



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 399/399 [01:45<00:00,  3.78it/s]


Test result: {'Attention.Deficit.Hyperactivity.Disorder.x': [55.388471177944865]} & None for 


In [20]:
image.shape

torch.Size([128, 1, 80, 80, 80])

In [17]:
partition['train'].labels

{'age': 1.6139065983869547}

In [16]:
predicted

tensor([1, 1, 1,  ..., 1, 1, 1])

In [9]:
y_true

{'sex': tensor([1., 1., 1.,  ..., 0., 0., 1.], dtype=torch.float64)}

In [15]:
confusion_matrices['sex']

{'True Positive': 545,
 'True Negative': 0,
 'False Positive': 591,
 'False Negative': 0}

In [19]:
# sum(torch.max(outputs['sex'].data,1)[1] == y_true['sex'])
len(y_true['sex'])

1136

In [22]:
testloader = torch.utils.data.DataLoader(partition['val'], batch_size=args.test_batch_size, shuffle=False, num_workers=4)

outputs = {}
y_true = {}
test_acc = {}
confusion_matrices = {}


if args.cat_target:
    for cat_target in args.cat_target:
        outputs[cat_target] = torch.tensor([])
        y_true[cat_target] = torch.tensor([])
        test_acc[cat_target] = []

if args.num_target:
    for num_target in args.num_target:
        outputs[num_target] = torch.tensor([])
        y_true[num_target] = torch.tensor([])
        test_acc[num_target] = []

with torch.no_grad():
    for i, data in enumerate(tqdm(testloader),0):
        image, targets = data
        image = image.to(device)

        output = net(image)
        if args.cat_target:
            for cat_target in args.cat_target:
                outputs[cat_target] = torch.cat((outputs[cat_target], output[cat_target].cpu()))
                y_true[cat_target] = torch.cat((y_true[cat_target], targets[cat_target].cpu()))

        if args.num_target:
            for num_target in args.num_target:
                outputs[num_target] = torch.cat((outputs[num_target], output[num_target].cpu()))
                y_true[num_target] = torch.cat((y_true[num_target], targets[num_target].cpu()))


# caculating ACC and R2 at once  
if args.cat_target:
    for cat_target in args.cat_target:
        _, predicted = torch.max(outputs[cat_target].data,1)
        correct = (predicted == y_true[cat_target]).sum().item()
        total = y_true[cat_target].size(0)
        test_acc[cat_target].append(100 * (correct / total))

        if args.confusion_matrix:
            for label_cm in args.confusion_matrix: 
                if len(np.unique(y_true[cat_target].numpy())) == 2:
                    confusion_matrices[label_cm] = {}
                    confusion_matrices[label_cm]['True Positive'] = 0
                    confusion_matrices[label_cm]['True Negative'] = 0
                    confusion_matrices[label_cm]['False Positive'] = 0
                    confusion_matrices[label_cm]['False Negative'] = 0
                    if label_cm == cat_target:
                        tn, fp, fn, tp = confusion_matrix(y_true[cat_target].numpy(), predicted.numpy()).ravel()
                        confusion_matrices[label_cm]['True Positive'] = int(tp)
                        confusion_matrices[label_cm]['True Negative'] = int(tn)
                        confusion_matrices[label_cm]['False Positive'] = int(fp)
                        confusion_matrices[label_cm]['False Negative'] = int(fn)                       

MAE = None
if args.num_target:
    for num_target in args.num_target:
        predicted =  outputs[num_target].float()
        criterion = nn.MSELoss()
        loss = criterion(predicted, y_true[num_target].float().unsqueeze(1))
        l1loss = nn.L1Loss()
        MAE = l1loss(predicted, y_true[num_target].float().unsqueeze(1))
        y_var = torch.var(y_true[num_target])
        r_square = 1 - (loss / y_var)
        test_acc[num_target].append(r_square.item())
        confusion_matrices = None

result = {'test_acc':test_acc,'MAE':MAE}

print(f"Test result: {test_acc} & {MAE} for {args.load}") 

if confusion_matrices != None:
    result['confusion_matrices'] = confusion_matrices


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:25<00:00,  9.46s/it]


Test result: {'sex': [98.94273127753304]} & None for ABCD_sex_TL_ALL_11_3bb9ad


In [24]:
testloader = torch.utils.data.DataLoader(partition['train'], batch_size=256, shuffle=False, num_workers=4)

outputs = {}
y_true = {}
test_acc = {}
confusion_matrices = {}


if args.cat_target:
    for cat_target in args.cat_target:
        outputs[cat_target] = torch.tensor([])
        y_true[cat_target] = torch.tensor([])
        test_acc[cat_target] = []

if args.num_target:
    for num_target in args.num_target:
        outputs[num_target] = torch.tensor([])
        y_true[num_target] = torch.tensor([])
        test_acc[num_target] = []

with torch.no_grad():
    for i, data in enumerate(tqdm(testloader),0):
        image, targets = data
        image = image.to(device)

        output = net(image)
        if args.cat_target:
            for cat_target in args.cat_target:
                outputs[cat_target] = torch.cat((outputs[cat_target], output[cat_target].cpu()))
                y_true[cat_target] = torch.cat((y_true[cat_target], targets[cat_target].cpu()))

        if args.num_target:
            for num_target in args.num_target:
                outputs[num_target] = torch.cat((outputs[num_target], output[num_target].cpu()))
                y_true[num_target] = torch.cat((y_true[num_target], targets[num_target].cpu()))


# caculating ACC and R2 at once  
if args.cat_target:
    for cat_target in args.cat_target:
        _, predicted = torch.max(outputs[cat_target].data,1)
        correct = (predicted == y_true[cat_target]).sum().item()
        total = y_true[cat_target].size(0)
        test_acc[cat_target].append(100 * (correct / total))

        if args.confusion_matrix:
            for label_cm in args.confusion_matrix: 
                if len(np.unique(y_true[cat_target].numpy())) == 2:
                    confusion_matrices[label_cm] = {}
                    confusion_matrices[label_cm]['True Positive'] = 0
                    confusion_matrices[label_cm]['True Negative'] = 0
                    confusion_matrices[label_cm]['False Positive'] = 0
                    confusion_matrices[label_cm]['False Negative'] = 0
                    if label_cm == cat_target:
                        tn, fp, fn, tp = confusion_matrix(y_true[cat_target].numpy(), predicted.numpy()).ravel()
                        confusion_matrices[label_cm]['True Positive'] = int(tp)
                        confusion_matrices[label_cm]['True Negative'] = int(tn)
                        confusion_matrices[label_cm]['False Positive'] = int(fp)
                        confusion_matrices[label_cm]['False Negative'] = int(fn)                       

MAE = None
if args.num_target:
    for num_target in args.num_target:
        predicted =  outputs[num_target].float()
        criterion = nn.MSELoss()
        loss = criterion(predicted, y_true[num_target].float().unsqueeze(1))
        l1loss = nn.L1Loss()
        MAE = l1loss(predicted, y_true[num_target].float().unsqueeze(1))
        y_var = torch.var(y_true[num_target])
        r_square = 1 - (loss / y_var)
        test_acc[num_target].append(r_square.item())
        confusion_matrices = None

result = {'test_acc':test_acc,'MAE':MAE}

print(f"Test result: {test_acc} & {MAE} for {args.load}") 

if confusion_matrices != None:
    result['confusion_matrices'] = confusion_matrices


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [07:27<00:00, 12.42s/it]


Test result: {'sex': [98.66754762691333]} & None for ABCD_sex_TL_ALL_11_3bb9ad


In [30]:
def hi():
    print(a)
    def wow():
        print('wow')

In [50]:
torch.

<torch._C.Generator at 0x7fe553d94eb0>

In [79]:
m = sum(y_true['age'].flatten())/len(y_true['age'])
v=sum((y_true['age'].flatten()-m)**2)/len(y_true['age'])

In [117]:
sum(abs(y_true['age']-outputs['age'].flatten()))/1137

tensor(6.2458, dtype=torch.float64)

In [105]:
1-loss/v

tensor(0.0473, dtype=torch.float64)

In [115]:
import pandas as pd
to=pd.read_csv("/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_phenotype_total.csv")
ad=pd.read_csv("/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_ADHD.csv")
con=pd.read_csv("/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_suicide_control.csv")

In [119]:
print(ad.shape,con.shape)

(2506, 118) (3108, 45)


In [123]:
import glob
images = glob.glob('/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_warpped_nii/*')
images_subjectkeys=pd.Series(map(lambda x: x.split("/")[-1].split(".")[0],images))
images_subjectkeys[:3]

0    NDARINV89B7M962
1    NDARINV6ZU9NKBV
2    NDARINVLPMG7ZFU
dtype: object

In [72]:
for c in con.columns:
    if 'Attention' in c:print(c)

Unspecified.Attention.Deficit.Hyperactivity.Disorder.x
Attention.Deficit.Hyperactivity.Disorder.x


In [117]:
sjk=pd.concat([ad.subjectkey,con.subjectkey]).reset_index(drop=True)

In [124]:
newdf=to[to.subjectkey.isin(sjk)==True]

In [125]:
newdf.subjectkey.isin(images_subjectkeys).value_counts()

True     3993
False    1621
Name: subjectkey, dtype: int64

In [126]:
images2 = glob.glob('/scratch/connectome/3DCNN/data/1.ABCD/2.sMRI_freesurfer/*')
images_subjectkeys2=pd.Series(map(lambda x: x.split("/")[-1].split(".")[0],images2))
images_subjectkeys2[:3]

0    NDARINVBZJGG4AN
1    NDARINVXPZGM0LG
2    NDARINVU9C36KFY
dtype: object

In [134]:
smris=newdf[newdf.subjectkey.isin(images_subjectkeys2)==True]
swithd=smris[smris.subjectkey.isin(images_subjectkeys)==True]
swithd['Attention.Deficit.Hyperactivity.Disorder.x'].value_counts()

0.0    2255
1.0    1736
Name: Attention.Deficit.Hyperactivity.Disorder.x, dtype: int64

In [135]:
swithd.sex.value_counts()

1.0    2181
2.0    1803
Name: sex, dtype: int64