In [1]:
import monai

In [2]:
monai.__version__

'1.0.1'

## Change directory & Hyperparameter setting

In [2]:
import os
print(os.getcwd())
os.chdir('/scratch/connectome/jubin/ABCD-3DCNN/STEP_4_Multimodal-Learning/MultiChannel-Learning/contrastive_learning/')

/scratch/connectome/jubin/ABCD-3DCNN/STEP_4_Multimodal-Learning/MultiChannel-Learning/contrastive_learning/codes


In [4]:
print(os.listdir())

['codes', 'README.md', 'envs', '__pycache__', 'run_contrastive_learning.py', 'models', 'result', 'test.py', 'dataloaders', 'utils', '.ipynb_checkpoints']


In [3]:
target="BMI"
data_type="freesurfer_256 FA_warpped_nii"
model="densenet3D121"
epoch_FC="0"
epoch="5"
optimn="AdamW"
scheduler="--scheduler on" # step_80"
batch="8"
val_size="0.1"
test_size="0.1"
lr="1e-3"
lr_adjust="--lr_adjust 1"
cfm="--confusion_matrix sex"
exp_name='test_dmri'


In [23]:
import os
import json
import argparse 

import pandas as pd
import torch

import models.simple3d as simple3d 
import models.vgg3d as vgg3d 
import models.resnet3d as resnet3d 
import models.densenet3d as densenet3d 
import models.sfcn as sfcn

parser = argparse.ArgumentParser()

parser.add_argument("--model", type=str, required=True, help='Select model. e.g. densenet3D121, sfcn.',
                    choices=['simple3D', 'sfcn', 'vgg3D11', 'vgg3D13', 'vgg3D16', 'vgg3D19',
                             'resnet3D50', 'resnet3D101', 'resnet3D152',
                             'densenet3D121', 'densenet3D169', 'densenet201', 'densenet264'])
parser.add_argument("--in_channels", default=1, type=int, help='')

# Options for dataset and data type, split ratio, CV, resize, augmentation
parser.add_argument("--dataset", type=str, choices=['UKB','ABCD'], required=True, help='Selelct dataset')
parser.add_argument("--data_type", nargs='+', type=str, help='Select data type(sMRI, dMRI)',
                    choices=['fmriprep', 'freesurfer', 'freesurfer_256', 'FA_unwarpped_nii', 'FA_warpped_nii',
                             'MD_unwarpped_nii', 'MD_warpped_nii', 'RD_unwarpped_nii', 'RD_warpped_nii'])
parser.add_argument("--tissue", default=None, type=str, help='Select tissue mask(Cortical grey matter, \
                    Sub-cortical grey matter, White matter, CSF, Pathological tissue)',
                    choices=['cgm', 'scgm', 'wm', 'csf', 'pt'])
parser.add_argument("--metric", default='cos', type=str, help='')
parser.add_argument("--val_size", default=0.1, type=float, help='')
parser.add_argument("--test_size", default=0.1, type=float, help='')
parser.add_argument("--cv", default=None, type=int, choices=[1,2,3,4,5], help="option for 5-fold CV. 1~5.")
parser.add_argument("--resize", nargs="*", default=(96, 96, 96), type=int, help='')
parser.add_argument("--transform", nargs="*", default=[], type=str, choices=['crop'],
                    help="option for additional transform - [crop] are available")
parser.add_argument("--augmentation", nargs="*", default=[], type=str, choices=['shift','flip'],
                    help="Data augmentation - [shift, flip] are available")

# Hyperparameters for model training
parser.add_argument("--lr", default=0.01, type=float, help='')
parser.add_argument("--lr_adjust", default=0.01, type=float, help='')
parser.add_argument("--epoch", type=int, required=True, help='')
parser.add_argument("--epoch_FC", type=int, default=0, help='Option for training only FC layer')
parser.add_argument("--optim", default='Adam', type=str, choices=['Adam','SGD','RAdam','AdamW'], help='')
parser.add_argument("--weight_decay", default=0.001, type=float, help='')
parser.add_argument("--scheduler", default='', type=str, help='') 
parser.add_argument("--early_stopping", default=None, type=int, help='')
parser.add_argument('--accumulation_steps', default=None, type=int, required=False)
parser.add_argument("--train_batch_size", default=16, type=int, help='')
parser.add_argument("--val_batch_size", default=16, type=int, help='')
parser.add_argument("--test_batch_size", default=1, type=int, help='')

# Options for experiment setting
parser.add_argument("--exp_name", type=str, required=True, help='')
parser.add_argument("--gpus", nargs='+', type=int, help='')
parser.add_argument("--sbatch", type=str, choices=['True', 'False'])
parser.add_argument("--cat_target", nargs='+', default=[], type=str, help='')
parser.add_argument("--num_target", nargs='+', default=[], type=str, help='')
parser.add_argument("--num_normalize", type=str, default=True, help='')
parser.add_argument("--confusion_matrix",  nargs='*', default=[], type=str, help='')
parser.add_argument("--filter", nargs="*", default=[], type=str,
                    help='options for filter data by phenotype. usage: --filter abcd_site:10 sex:1')
parser.add_argument("--load", default='', type=str, help='Load model weight that mathces {your_exp_dir}/result/*{load}*')
parser.add_argument("--scratch", default='', type=str, help='Option for learning from scratch')
parser.add_argument("--transfer", default='', type=str, choices=['sex','age','simclr','MAE'],
                    help='Choose pretrained model according to your option')
parser.add_argument("--unfrozen_layer", default='0', type=str, help='Select the number of layers that would be unfrozen')
parser.add_argument("--init_unfrozen", default='', type=str, help='Initializes unfrozen layers')
parser.add_argument("--debug", default='', type=str, help='')


_StoreAction(option_strings=['--debug'], dest='debug', nargs=None, const=None, default='', type=<class 'str'>, choices=None, help='', metavar=None)

In [24]:
com = f'--optim {optimn} --num_target BMI --num_normalize no --dataset ABCD --data_type {data_type} --val_size {val_size} --test_size {test_size} --lr {lr} --resize 128 128 128 --train_batch_size {batch} --val_batch_size {batch} --exp_name {exp_name} --model {model} --epoch {epoch} --epoch_FC {epoch_FC} --gpus 0'

In [25]:
com

'--optim AdamW --num_target BMI --num_normalize no --dataset ABCD --data_type FA_warpped_nii --val_size 0.1 --test_size 0.1 --lr 1e-3 --resize 128 128 128 --train_batch_size 8 --val_batch_size 8 --exp_name test_dmri --model densenet3D121 --epoch 5 --epoch_FC 0 --gpus 0'

In [26]:
args = parser.parse_args(com.split())
if args.cat_target == args.num_target:
    raise ValueError('--num-target or --cat-target should be specified')

print(f"*** Categorical target labels are {args.cat_target} and Numerical target labels are {args.num_target} *** \n")

*** Categorical target labels are [] and Numerical target labels are ['BMI'] *** 



## run_constrastive_learning.py

In [18]:
def CLIreporter(train_loss, train_acc, val_loss, val_acc):
    '''command line interface reporter per every epoch during experiments'''
    visual_report = defaultdict(list)
    for label_name in train_loss:
        loss_value = f'{train_loss[label_name]:2.4f} / {val_loss[label_name]:2.4f}'
        if 'contrastive_loss' not in label_name:
            acc_value = f'{train_acc[label_name]:2.4f} / {val_acc[label_name]:2.4f}' 
        else:
            acc_value = None
            
        visual_report['Loss (train/val)'].append(loss_value)
        visual_report['R2 or ACC (train/val)'].append(acc_value)

    print(pd.DataFrame(visual_report, index=train_loss.keys()))

In [12]:
## ======= load module ======= ##
import os
import glob
import time
import datetime
import random
import hashlib
from copy import deepcopy
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

from utils.utils import argument_setting, select_model, CLIreporter, save_exp_result, checkpoint_save, checkpoint_load
from dataloaders.dataloaders import make_dataset
from dataloaders.preprocessing import preprocessing_cat, preprocessing_num
from envs.experiments import train, validate, test
from envs.transfer import setting_transfer

import warnings
warnings.filterwarnings("ignore")

## ========= Helper Functions =============== ##

def seed_all(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

def setup_results(args):
    train_losses = defaultdict(list)
    train_accs = defaultdict(list)
    val_losses = defaultdict(list)
    val_accs = defaultdict(list)

    result = {}
    result['train_losses'] = train_losses
    result['train_accs'] = train_accs
    result['val_losses'] = val_losses
    result['val_accs'] = val_accs
    
    return result

    
def set_optimizer(args, net):
    if args.optim == 'SGD':
        optimizer = optim.SGD(params = filter(lambda p: p.requires_grad, net.parameters()),
                              lr=args.lr, momentum=0.9)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(params = filter(lambda p: p.requires_grad, net.parameters()),
                               lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim =='RAdam':
        optimizer = optim.RAdam(params = filter(lambda p: p.requires_grad, net.parameters()),
                                lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay)
    elif args.optim == 'AdamW':
        optimizer = optim.AdamW(params = filter(lambda p: p.requires_grad, net.parameters()),
                                lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay)
    else:
        raise ValueError('Invalid optimizer choice')
        
    return optimizer
    
    
def set_lr_scheduler(args, optimizer, len_dataloader):
    if args.scheduler == '':
        scheduler = None
    elif args.scheduler == 'on':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,'max', patience=10, factor=0.1, min_lr=1e-6)
    elif args.scheduler.lower() == 'cos':
#             scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=5, T_mult=2, eta_max=0.1, T_up=2, gamma=1)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, T_mult=2, eta_min=0)
    elif 'step' in args.scheduler:
        step_size = 80 if len(args.scheduler.split('_')) != 2 else int(args.scheduler.split('_')[1])        
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1)
    elif args.scheduler.lower() == 'onecycle':
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, total_steps=args.epoch)
    else:
        raise Exception("Invalid scheduler option")
        
    return scheduler
    
    
def run_experiment(args, net, partition, result, mode):
    epoch_exp = args.epoch if mode == 'ALL' else args.epoch_FC
    num_unfrozen = args.unfrozen_layer if mode == 'ALL' else '0'
    
    if (args.transfer != '') and (args.unfrozen_layer.lower() != 'all'):
        setting_transfer(args, net.module, num_unfrozen = num_unfrozen)
    optimizer = set_optimizer(args, net)
    scheduler = set_lr_scheduler(args, optimizer, len(partition['train']))

    best_val_loss = float('inf')
    best_train_loss = float('inf')
    best_val_acc = 0
    patience = 0

    for epoch in tqdm(range(epoch_exp)):
        ts = time.time()
        net, train_loss, train_acc = train(net, partition, optimizer, args)
        val_loss, val_acc = validate(net, partition, scheduler, args)
        te = time.time()

        ## sorting the results
        train_loss_sum = 0
        val_loss_sum = 0
        val_acc_sum = 0
        
        for target_name in train_loss:
            result['train_losses'][target_name].append(train_loss[target_name])
            result['val_losses'][target_name].append(val_loss[target_name])
            train_loss_sum += train_loss[target_name]
            val_loss_sum += val_loss[target_name]
            if 'contrastive_loss' not in target_name:
                result['train_accs'][target_name].append(train_acc[target_name])
                result['val_accs'][target_name].append(val_acc[target_name])
                val_acc_sum += val_acc[target_name]
                
        ## saving the checkpoint and results   
        if val_acc_sum > best_val_acc:
            best_val_acc = val_acc_sum
            best_val_loss = val_loss_sum
            best_train_loss = train_loss_sum
            patience = 0
            checkpoint_dir = checkpoint_save(net, epoch, args)  
        else:
            patience += 1
            
        ## visualize the result                   
        save_exp_result(vars(args).copy(), result) 
        CLIreporter(train_loss, train_acc, val_loss, val_acc)
        curr_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}. Current learning rate {curr_lr}. Took {te-ts:2.2f} sec")

        ## Early-Stopping
        if args.early_stopping != None:
            if patience >= args.early_stopping and epoch >= 50:
                print(f"*** Validation Loss patience reached {args.early_stopping} epochs. Early Stopping Experiment ***")
                break
    
    opt = '' if mode == 'ALL' else '_FC'
    result[f'best_val_loss{opt}'] = best_val_loss
    result[f'best_train_loss{opt}'] = best_train_loss
        
    return result, checkpoint_dir


## ========= Experiment =============== ##
def experiment(partition, subject_data, args):
    if args.transfer in ['age','MAE']:
        assert 96 in args.resize, "age(MSE/MAE) transfer model's resize should be 96"
    elif args.transfer == 'sex':
        assert 80 in args.resize, "sex transfer model's resize should be 80"
    
    # selecting a model
    net = select_model(subject_data, args)
    
    # loading pretrained model if transfer option is given
    if (args.transfer != "") and (args.load == ""):
        print("*** Model setting for transfer learning *** \n")
        net = checkpoint_load(net, args.transfer)
    elif args.load:
        print("*** Model setting for transfer learning & fine tuning *** \n")
        model_dir = glob.glob(f'/scratch/connectome/jubin/result/model/*{args.load}*')[0]
        print(f"Loaded {model_dir[:-4]}")
        net = checkpoint_load(net, model_dir)
    else:
        print("*** Model setting for learning from scratch ***")
    
    # setting a DataParallel and model on GPU
    if args.sbatch == "True":
        devices = []
        for d in range(torch.cuda.device_count()):
            devices.append(d)
        net = nn.DataParallel(net, device_ids = devices)
    else:
        if not args.gpus:
            raise ValueError("GPU DEVICE IDS SHOULD BE ASSIGNED")
        else:
            net = nn.DataParallel(net, device_ids=args.gpus)
            
    net.to(f'cuda:{net.device_ids[0]}')
    
    # setting for results' DataFrame
    result = setup_results(args)
    
    # training a model
    print("*** Start training a model *** \n")
    if args.epoch_FC != 0:
        print("*** Transfer Learning - Training FC layers *** \n")
        result, _ = run_experiment(args, net, partition, result, 'FC')
                
        print(f"Adjust learning rate for Training unfrozen layers from {args.lr} to {args.lr*args.lr_adjust}")
        args.lr *= args.lr_adjust
        result['lr_adjusted'] = args.lr
            
    print("*** Training unfrozen layers *** \n")
    result, checkpoint_dir = run_experiment(args, net, partition, result, 'ALL')
                    
    # testing a model
    if args.debug == '':
        print("\n*** Start testing a model *** \n")
        net.to('cpu')
        torch.cuda.empty_cache()

        net = checkpoint_load(net, checkpoint_dir)
        if args.sbatch == 'True':
            net.cuda()
        else:
            net.to(f'cuda:{args.gpus[0]}')
        test_acc, confusion_matrices = test(net, partition, args)
        result['test_acc'] = test_acc
        print(f"===== Test result for {args.exp_name} =====") 
        print(test_acc)

        if confusion_matrices != None:
            print("===== Confusion Matrices =====")
            print(confusion_matrices,'\n')
            result['confusion_matrices'] = confusion_matrices
        
    return vars(args), result
## ==================================== ##

  from .autonotebook import tqdm as notebook_tqdm


In [33]:
import os
import re
import glob
import random

import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from monai.transforms import (AddChannel, Compose, CenterSpatialCrop, Flip, RandAffine,
                              RandFlip, RandRotate90, Resize, ScaleIntensity, ToTensor)
from monai.data import ImageDataset, NibabelReader

from dataloaders.custom_transform import MaskTissue
from dataloaders.custom_dataset import MultiModalImageDataset
from dataloaders.preprocessing import preprocessing_cat, preprocessing_num

ABCD_data_dir = {
    'fmriprep':'/scratch/connectome/3DCNN/data/1.ABCD/1.sMRI_fmriprep/preprocessed_masked/',
    'freesurfer':'/scratch/connectome/3DCNN/data/1.ABCD/2.sMRI_freesurfer/',
    'freesurfer_256':'/scratch/connectome/3DCNN/data/1.ABCD/2.2.sMRI_freesurfer_256/',
    'FA_unwarpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.1.FA_unwarpped_nii/',
    'FA_warpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_warpped_nii/',
    'MD_unwarpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.3.MD_unwarpped_nii/',
    'MD_warpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.4.MD_warpped_nii/',
    'RD_unwarpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.5.RD_unwarpped_nii/',
    'RD_warpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.6.RD_warpped_nii/',
    '5tt_warped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.7.5tt_warped_nii/'
}

ABCD_phenotype_dir = {
    'total':'/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_phenotype_total.csv',
    'ADHD_case':'/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_ADHD.csv',
    'suicide_case':'/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_suicide_case.csv',
    'suicide_control':'/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_suicide_control.csv'
}

UKB_data_dir = '/scratch/connectome/3DCNN/data/2.UKB/1.sMRI_fs_cropped/'
UKB_phenotype_dir = '/scratch/connectome/3DCNN/data/2.UKB/2.demo_qc/UKB_phenotype.csv'


def case_control_count(labels, dataset_type, args):
    if args.cat_target:
        for cat_target in args.cat_target:
            print(type(labels),cat_target)
            curr_cnt = labels[cat_target].value_counts()
            print(f'In {dataset_type},\t"{cat_target}" contains {curr_cnt[1]} CASE and {curr_cnt[0]} CONTROL')

            
def loading_images(image_dir, args):
    image_files = pd.DataFrame()
    data_types = args.data_type if (args.tissue == None) else args.data_type + ['5tt_warped_nii']
    for brain_modality in data_types:
        curr_dir = image_dir[brain_modality]
        curr_files = pd.DataFrame({brain_modality:glob.glob(curr_dir+'*[yz]')}) # to get .npy(sMRI) & .nii.gz(dMRI) files
        curr_files[subjectkey] = curr_files[brain_modality].map(lambda x: x.split("/")[-1].split('.')[0].split('sub-')[-1])
        if args.dataset == 'UKB':
            curr_files[subjectkey] = curr_files[subjectkey].map(int)
        curr_files.sort_values(by=subjectkey, inplace=True)
        
        if len(image_files) == 0:
            image_files = curr_files
        else:
            image_files = pd.merge(image_files, curr_files, how='inner', on=subjectkey)
            
    if args.debug:
        image_files = image_files[:100]
        
    return image_files


def get_available_subjects(subject_data, args):
    case  = pd.read_csv(ABCD_phenotype_dir['ADHD_case'])[subjectkey]
    control = pd.read_csv(ABCD_phenotype_dir['suicide_control'])[subjectkey]
    filtered_subjectkey = pd.concat([case,control]).reset_index(drop=True)
    subject_data = subject_data[subject_data[subjectkey].isin(filtered_subjectkey)]
    
    return subject_data


def filter_phenotype(subject_data, filters):
    for fil in filters:
        fil_name, fil_option = fil.split(':')
        fil_option = np.float64(fil_option)
        subject_data = subject_data[subject_data[fil_name] == fil_option]
        
    return subject_data


def loading_phenotype(phenotype_dir, target_list, args):
    col_list = target_list + [subjectkey]

    ## get subject ID and target variables
    subject_data = pd.read_csv(phenotype_dir)
    subject_data = subject_data.loc[:,col_list]
    if 'Attention.Deficit.Hyperactivity.Disorder.x' in target_list:
        subject_data = get_available_subjects(subject_data, args)
    subject_data = filter_phenotype(subject_data, args.filter)
    subject_data = subject_data.sort_values(by=subjectkey)
    subject_data = subject_data.dropna(axis = 0)
    subject_data = subject_data.reset_index(drop=True)

    ### preprocessing categorical variables and numerical variables
    subject_data = preprocessing_cat(subject_data, args)
    if args.num_normalize == True:
        subject_data = preprocessing_num(subject_data, args)
    
    return subject_data


def make_balanced_testset(il, num_test, args):
    n_case = num_test//2
    n_control = num_test - n_case
    t_case, rest_case = np.split(il[il[args.cat_target[0]]==0], (n_case,))
    t_control, rest_control = np.split(il[il[args.cat_target[0]]==1],(n_control,))
    
    test = pd.concat((t_case, t_control))
    rest = pd.concat((rest_case, rest_control))
    
    test = test.sort_values(by=subjectkey)
    rest = rest.sort_values(by=subjectkey)
    
    imageFiles_labels = pd.concat((rest,test)).reset_index(drop=True)
    
    return imageFiles_labels


# defining train,val, test set splitting function
def partition_dataset(imageFiles_labels, target_list, args):
    ## Random shuffle according to args.seed
#     imageFiles_labels = imageFiles_labels.sample(frac=1).reset_index(drop=True)
    
    ## Dataset split    
    num_total = len(imageFiles_labels)
    num_test = int(num_total*args.test_size)
    num_val = int(num_total*args.val_size) if args.cv == None else int((num_total-num_test)/5)
    num_train = num_total - (num_val+num_test)
    
    if len(args.cat_target) > 0:
        imageFiles_labels = make_balanced_testset(imageFiles_labels, num_test, args)
    images = imageFiles_labels[args.data_type]
    labels = imageFiles_labels[target_list].to_dict('records')
    
    ## split dataset by 5-fold cv or given split size
    if args.cv == None:
        images_train, images_val, images_test = np.split(images, [num_train, num_train+num_val]) # revising
        labels_train, labels_val, labels_test = np.split(labels, [num_train, num_train+num_val])
    else:
        split_points = [num_val, 2*num_val, 3*num_val, 4*num_val, num_total-num_test]
        images_total, labels_total = np.split(images, split_points), np.split(labels, split_points)
        images_test, labels_test = images_total.pop(), labels_total.pop()
        images_val, labels_val = images_total.pop(args.cv-1), labels_total.pop(args.cv-1)
        images_train, labels_train = np.concatenate(images_total), np.concatenate(labels_total)
        num_train, num_val = images_train.shape[0], images_val.shape[0]
        
    print(f"Total subjects={num_total}, train={num_train}, val={num_val}, test={num_test}")

    ## Define transform function
    resize = tuple(args.resize)
    
    default_transforms = [ScaleIntensity(), AddChannel(), Resize(resize), ToTensor()] 
    if 'crop' in args.transform:
        default_transforms.insert(0, CenterSpatialCrop(192))
        
    if args.tissue:
        dMRI_transform = [MaskTissue(imageFiles_labels['5tt_warped_nii'], args.tissue)]
        dMRI_transform += default_transforms
        
    aug_transforms = []
    if 'shift' in args.augmentation:
        aug_transforms.append(RandAffine(prob=0.1,translate_range=(0,2),padding_mode='zeros'))
    elif 'flip' in args.augmentation:
        aug_transforms.append(RandFlip(prob=0.1, spatial_axis=0))
    
    train_transforms, val_transforms, test_transforms = [], [], []
    for brain_modality in args.data_type:
        curr_transform = dMRI_transform if args.tissue else default_transforms
        train_transforms.append(Compose(curr_transform + aug_transforms))
        val_transforms.append(Compose(curr_transform))
        test_transforms.append(Compose(curr_transform))
    
    ## make splitted dataset
    train_set = MultiModalImageDataset(image_files=images_train, labels=labels_train, transform=train_transforms)
    val_set = MultiModalImageDataset(image_files=images_val, labels=labels_val, transform=val_transforms)
    test_set = MultiModalImageDataset(image_files=images_test, labels=labels_test, transform=test_transforms)

    partition = {}
    partition['train'] = train_set
    partition['val'] = val_set
    partition['test'] = test_set

#     case_control_count(labels_train, 'train', args)
#     case_control_count(labels_val, 'validation', args)
#     case_control_count(labels_test, 'test', args)

    return partition


global subjectkey
subjectkey = 'subjectkey' if args.dataset == 'ABCD' else 'eid'
image_dir = ABCD_data_dir if args.dataset == 'ABCD' else UKB_data_dir
phenotype_dir = ABCD_phenotype_dir['total'] if args.dataset == 'ABCD' else UKB_phenotype_dir
target_list = args.cat_target + args.num_target

image_files = loading_images(image_dir, args)
subject_data = loading_phenotype(phenotype_dir, target_list, args)

# combining image files & labels
imageFiles_labels = pd.merge(subject_data, image_files, how='inner', on=subjectkey)
print(imageFiles_labels.columns)
# partitioning dataset and preprocessing (change the range of categorical variables and standardize numerical variables)
partition = partition_dataset(imageFiles_labels, target_list, args)
print("*** Making a dataset is completed *** \n")

Index(['BMI', 'subjectkey', 'FA_warpped_nii'], dtype='object')
Total subjects=8440, train=6752, val=844, test=844
*** Making a dataset is completed *** 



In [37]:
test_labels = partition['test'].labels
tl = pd.Series(map(lambda x: x['BMI'], test_labels))

In [39]:
tl.describe()

count    844.000000
mean      18.498561
std        3.772194
min       10.815385
25%       15.951632
50%       17.489579
75%       20.117835
max       39.961663
dtype: float64

In [47]:
num_total = len(imageFiles_labels)
num_test = int(num_total*args.test_size)
num_val = int(num_total*args.val_size)
num_train = num_total - (num_val+num_test)
test_data = np.split(imageFiles_labels, [num_train, num_train+num_val])[2]

In [48]:
test_data['BMI'].describe()

count    844.000000
mean      18.498561
std        3.772194
min       10.815385
25%       15.951632
50%       17.489579
75%       20.117835
max       39.961663
Name: BMI, dtype: float64

In [89]:
type(np.split(test_data['BMI'], [1,2])[1])

pandas.core.series.Series

In [50]:
test_data['subjectkey']

7596    NDARINVRR7J1YX2
7597    NDARINVRRAW1J1K
7598    NDARINVRRE6MVYZ
7599    NDARINVRREA9RBW
7600    NDARINVRRFZW203
             ...       
8435    NDARINVV6GUKET3
8436    NDARINVV6KFJX12
8437    NDARINVV6KTVCZH
8438    NDARINVV6MZ4VB1
8439    NDARINVV6NAXTR2
Name: subjectkey, Length: 844, dtype: object

In [66]:
test_data['FA_warpped_nii']

7596    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
7597    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
7598    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
7599    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
7600    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
                              ...                        
8435    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
8436    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
8437    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
8438    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
8439    /scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
Name: FA_warpped_nii, Length: 844, dtype: object

In [85]:
test_data

Unnamed: 0,BMI,subjectkey,FA_warpped_nii
7596,21.335905,NDARINVRR7J1YX2,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
7597,23.746742,NDARINVRRAW1J1K,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
7598,17.203959,NDARINVRRE6MVYZ,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
7599,17.268423,NDARINVRREA9RBW,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
7600,19.354870,NDARINVRRFZW203,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
...,...,...,...
8435,22.053727,NDARINVV6GUKET3,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
8436,15.200292,NDARINVV6KFJX12,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
8437,19.433988,NDARINVV6KTVCZH,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...
8438,19.658733,NDARINVV6MZ4VB1,/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_w...


In [84]:
!pwd

/scratch/connectome/jubin/ABCD-3DCNN/STEP_4_Multimodal-Learning/MultiChannel-Learning/contrastive_learning


In [72]:
!ls /scratch/connectome/jubin/sandbox/

 fa_test			       script1.o	 t4.sh
 list_test.sh			       set_name2.sh	 t5.sh
 MRI_registration		       set_name.sh	 t6.sh
 mrtrix				       srun_catcher.sh	 test.sh
 nametest.sh			       submit_test.sh	 train_part.py
'num_worker test.ipynb'		       symfol		 Untitled1.ipynb
 result				       t2.sh		 Untitled.ipynb
 run_3DCNN_hard_parameter_sharing.py   t3.sh		 VarNet1.json


In [80]:
for path in test_data['FA_warpped_nii']:
    !cp {path} /scratch/connectome/jubin/sandbox/fa_test

In [None]:
# Run Experiment
print(f"*** Experiment {args.exp_name} Start ***")
setting, result = experiment(partition, subject_data, deepcopy(args))
print("===== Experiment Setting Report =====")
print(args)

# Save result
if args.debug == '':
    save_exp_result(setting, result)
print("*** Experiment Done ***\n")
## ====================================== ##

In [107]:
subject_data.BMI.describe()

count    11519.000000
mean        18.810519
std          4.224589
min          5.266641
25%         15.938893
50%         17.648091
75%         20.625592
max         54.992926
Name: BMI, dtype: float64

In [93]:
6043/(6043+5488)

0.5240655623970167

In [66]:
!ls '/scratch/connectome/3DCNN/data/1.ABCD/2.2.sMRI_freesurfer_256/'

sub-NDARINV003RTV85.brain.nii.gz  sub-NDARINVFYWWBYZH.brain.nii.gz
sub-NDARINV007W6H7B.brain.nii.gz  sub-NDARINVFZ16A6M3.brain.nii.gz
sub-NDARINV00BD7VDC.brain.nii.gz  sub-NDARINVFZ1W3N0D.brain.nii.gz
sub-NDARINV00CY2MDM.brain.nii.gz  sub-NDARINVFZ4F6Y38.brain.nii.gz
sub-NDARINV00HEV6HB.brain.nii.gz  sub-NDARINVFZ7AMNW3.brain.nii.gz
sub-NDARINV00J52GPG.brain.nii.gz  sub-NDARINVFZ7B9KJ8.brain.nii.gz
sub-NDARINV00LH735Y.brain.nii.gz  sub-NDARINVFZ7T2G2T.brain.nii.gz
sub-NDARINV00LJVZK2.brain.nii.gz  sub-NDARINVFZ8BBNJK.brain.nii.gz
sub-NDARINV00NPMHND.brain.nii.gz  sub-NDARINVFZ97ZA0Z.brain.nii.gz
sub-NDARINV00R4TXET.brain.nii.gz  sub-NDARINVFZ9KMK9H.brain.nii.gz
sub-NDARINV00U4FTRU.brain.nii.gz  sub-NDARINVFZC8ZGMR.brain.nii.gz
sub-NDARINV00UMK5VC.brain.nii.gz  sub-NDARINVFZCKB9W4.brain.nii.gz
sub-NDARINV00X2TBWJ.brain.nii.gz  sub-NDARINVFZDYMM0B.brain.nii.gz
sub-NDARINV010ZM3H9.brain.nii.gz  sub-NDARINVFZFYY9KA.brain.nii.gz
sub-NDARINV014RTM1V.brain.nii.gz  sub-NDARINVFZP

sub-NDARINV243VF1J7.brain.nii.gz  sub-NDARINVJ716H2Y0.brain.nii.gz
sub-NDARINV24534L6F.brain.nii.gz  sub-NDARINVJ72AUZZW.brain.nii.gz
sub-NDARINV2484CB0H.brain.nii.gz  sub-NDARINVJ74CGJ37.brain.nii.gz
sub-NDARINV248BF2KE.brain.nii.gz  sub-NDARINVJ79PN2P1.brain.nii.gz
sub-NDARINV249JM0NY.brain.nii.gz  sub-NDARINVJ7AJBW84.brain.nii.gz
sub-NDARINV249R6TFP.brain.nii.gz  sub-NDARINVJ7BJFRDD.brain.nii.gz
sub-NDARINV24BT0Y26.brain.nii.gz  sub-NDARINVJ7GL9B46.brain.nii.gz
sub-NDARINV24D005FC.brain.nii.gz  sub-NDARINVJ7HK4WZJ.brain.nii.gz
sub-NDARINV24K5YJBG.brain.nii.gz  sub-NDARINVJ7JTARJZ.brain.nii.gz
sub-NDARINV24KDRDL2.brain.nii.gz  sub-NDARINVJ7L6EW0J.brain.nii.gz
sub-NDARINV24LWV4C5.brain.nii.gz  sub-NDARINVJ7V2VTL2.brain.nii.gz
sub-NDARINV24MJ2521.brain.nii.gz  sub-NDARINVJ87UYTGU.brain.nii.gz
sub-NDARINV24V8FLZ3.brain.nii.gz  sub-NDARINVJ897YYVE.brain.nii.gz
sub-NDARINV24V96Z76.brain.nii.gz  sub-NDARINVJ8A3AM9C.brain.nii.gz
sub-NDARINV24W171KY.brain.nii.gz  sub-NDARINVJ8A