In [1]:
import monai

In [2]:
monai.__version__

'0.8.1'

## Change directory & Hyperparameter setting

In [3]:
import os
print(os.getcwd())
os.chdir('/scratch/connectome/jubin/ABCD-3DCNN/STEP_4_Multimodal-Learning/MultiChannel-Learning/contrastive_learning/')

/scratch/connectome/jubin/ABCD-3DCNN/STEP_4_Multimodal-Learning/MultiChannel-Learning/contrastive_learning/codes


In [4]:
print(os.listdir())

['codes', 'README.md', 'envs', '__pycache__', 'run_contrastive_learning.py', 'models', 'result', 'test.py', 'dataloaders', 'utils', '.ipynb_checkpoints']


In [5]:
target="BMI"
data_type="freesurfer FA_warpped_nii"
model="densenet3D121"
epoch_FC="0"
epoch="5"
optim="AdamW"
scheduler="--scheduler on" # step_80"
batch="8"
val_size="0.1"
test_size="0.1"
lr="1e-3"
lr_adjust="--lr_adjust 1"
unfrozen="all"
exp_name='multichannel_test'


In [6]:
import os
import json
import argparse 

import pandas as pd
import torch

import models.simple3d as simple3d 
import models.vgg3d as vgg3d 
import models.resnet3d as resnet3d 
import models.densenet3d as densenet3d 
import models.sfcn as sfcn

parser = argparse.ArgumentParser()

# Options for model setting
parser.add_argument("--model", type=str, required=True, help='Select model. e.g. densenet3D121, sfcn.',
                    choices=['simple3D', 'sfcn', 'vgg3D11', 'vgg3D13', 'vgg3D16', 'vgg3D19',
                             'resnet3D50', 'resnet3D101', 'resnet3D152',
                             'densenet3D121', 'densenet3D169', 'densenet201', 'densenet264'])
parser.add_argument("--in_channels", default=1, type=int, help='')

# Options for dataset and data type, split ratio, CV, resize, augmentation
parser.add_argument("--dataset", type=str, choices=['UKB','ABCD'], required=True, help='Selelct dataset')
parser.add_argument("--data_type", nargs='+', type=str, help='Select data type(sMRI, dMRI)',
                    choices=['fmriprep', 'freesurfer', 'freesurfer_256', 'FA_unwarpped_nii', 'FA_warpped_nii',
                             'MD_unwarpped_nii', 'MD_warpped_nii', 'RD_unwarpped_nii', 'RD_warpped_nii'])
parser.add_argument("--tissue", default=None, type=str, help='Select tissue mask(Cortical grey matter, \
                    Sub-cortical grey matter, White matter, CSF, Pathological tissue)',
                    choices=['cgm', 'scgm', 'wm', 'csf', 'pt'])
parser.add_argument("--metric", default='cos', type=str, help='')
parser.add_argument("--val_size", default=0.1, type=float, help='')
parser.add_argument("--test_size", default=0.1, type=float, help='')
parser.add_argument("--cv", default=None, type=int, choices=[1,2,3,4,5], help="option for 5-fold CV. 1~5.")
parser.add_argument("--resize", nargs="*", default=(96, 96, 96), type=int, help='')
parser.add_argument("--transform", nargs="*", default=[], type=str, choices=['crop'],
                    help="option for additional transform - [crop] are available")
parser.add_argument("--augmentation", nargs="*", default=[], type=str, choices=['shift','flip'],
                    help="Data augmentation - [shift, flip] are available")

# Hyperparameters for model training
parser.add_argument("--lr", default=0.01, type=float, help='')
parser.add_argument("--lr_adjust", default=0.01, type=float, help='')
parser.add_argument("--epoch", type=int, required=True, help='')
parser.add_argument("--epoch_FC", type=int, default=0, help='Option for training only FC layer')
parser.add_argument("--optim", default='Adam', type=str, choices=['Adam','SGD','RAdam','AdamW'], help='')
parser.add_argument("--weight_decay", default=0.001, type=float, help='')
parser.add_argument("--scheduler", default='', type=str, help='') 
parser.add_argument("--early_stopping", default=None, type=int, help='')
parser.add_argument('--accumulation_steps', default=None, type=int, required=False)
parser.add_argument("--train_batch_size", default=16, type=int, help='')
parser.add_argument("--val_batch_size", default=16, type=int, help='')
parser.add_argument("--test_batch_size", default=1, type=int, help='')

# Options for experiment setting
parser.add_argument("--exp_name", type=str, required=True, help='')
parser.add_argument("--gpus", nargs='+', type=int, help='')
parser.add_argument("--sbatch", type=str, choices=['True', 'False'])
parser.add_argument("--cat_target", nargs='+', default=[], type=str, help='')
parser.add_argument("--num_target", nargs='+', default=[], type=str, help='')
parser.add_argument("--confusion_matrix",  nargs='*', default=[], type=str, help='')
parser.add_argument("--filter", nargs="*", default=[], type=str,
                    help='options for filter data by phenotype. usage: --filter abcd_site:10 sex:1')
parser.add_argument("--load", default='', type=str, help='Load model weight that mathces {your_exp_dir}/result/*{load}*')
parser.add_argument("--scratch", default='', type=str, help='Option for learning from scratch')
parser.add_argument("--transfer", default='', type=str, choices=['sex','age','simclr','MAE'],
                    help='Choose pretrained model according to your option')
parser.add_argument("--unfrozen_layer", default='0', type=str, help='Select the number of layers that would be unfrozen')
parser.add_argument("--init_unfrozen", default='', type=str, help='Initializes unfrozen layers')
parser.add_argument("--debug", default='', type=str, help='')


_StoreAction(option_strings=['--debug'], dest='debug', nargs=None, const=None, default='', type=<class 'str'>, choices=None, help='', metavar=None)

In [20]:
com = f'--num_target {target} --dataset ABCD --data_type {data_type} --val_size {val_size} --test_size {test_size} --lr {lr} --optim {optim} --resize 80 80 80  --train_batch_size {batch} --val_batch_size {batch} --exp_name {exp_name} --in_channels 2 --model {model} --epoch {epoch} --epoch_FC {epoch_FC} --unfrozen_layer {unfrozen} --gpus 0 --debug 1'

In [21]:
args = parser.parse_args(com.split())
if args.cat_target == args.num_target:
    raise ValueError('--num-target or --cat-target should be specified')

print(f"*** Categorical target labels are {args.cat_target} and Numerical target labels are {args.num_target} *** \n")

*** Categorical target labels are [] and Numerical target labels are ['BMI'] *** 



## run_constrastive_learning.py

## check MultiChannel DataLoader

In [36]:
from typing import Any, Callable, Optional, Sequence, Union
from itertools import repeat

import numpy as np
import pandas as pd
from torch.utils.data import Dataset

from monai.config import DtypeLike
from monai.data.image_reader import ImageReader
from monai.transforms import LoadImage, Randomizable, apply_transform
from monai.utils import MAX_SEED, get_seed

class MultiChannelImageDataset(Dataset, Randomizable):
    """
    Loads image/segmentation pairs of files from the given filename lists. Transformations can be specified
    for the image and segmentation arrays separately.
    The difference between this dataset and `ArrayDataset` is that this dataset can apply transform chain to images
    and segs and return both the images and metadata, and no need to specify transform to load images from files.
    For more information, please see the image_dataset demo in the MONAI tutorial repo,
    https://github.com/Project-MONAI/tutorials/blob/master/modules/image_dataset.ipynb
    """
    def __init__(
        self,
        image_files: pd.DataFrame, # modified
        seg_files: Optional[Sequence[str]] = None,
        labels: pd.DataFrame = None,
        transform: Sequence[Optional[Callable]] = None, # modified
        seg_transform: Optional[Callable] = None,
        label_transform: Optional[Callable] = None,
        image_only: bool = True,
        transform_with_metadata: bool = False,
        dtype: DtypeLike = np.float32,
        reader: Optional[Union[ImageReader, str]] = None,
        *args,
        **kwargs,
    ) -> None:
        if seg_files is not None and len(image_files) != len(seg_files):
            raise ValueError(
                "Must have same the number of segmentation as image files: "
                f"images={len(image_files)}, segmentations={len(seg_files)}."
            )

        self.image_files = image_files
        self.seg_files = seg_files
        self.labels = labels
        self.transform = transform
        self.seg_transform = seg_transform
        self.label_transform = label_transform
        if image_only and transform_with_metadata:
            raise ValueError("transform_with_metadata=True requires image_only=False.")
        self.image_only = image_only
        self.transform_with_metadata = transform_with_metadata
        self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs)
        self.set_random_state(seed=get_seed())
        self._seed = 0  # transform synchronization seed


    def __len__(self) -> int:
        return len(self.image_files) # modified

    def randomize(self, data: Optional[Any] = None) -> None:
        self._seed = self.R.randint(MAX_SEED, dtype="uint32")

    def __getitem__(self, index: int):
        self.randomize()
        meta_data, seg_meta_data, seg, label = None, None, None, None

        # load data and optionally meta
        if self.image_only:
            img = list(map(self.loader, self.image_files.iloc[index])) # modified
            if self.seg_files is not None:
                seg = self.loader(self.seg_files[index])
        else:
            img, meta_data = list(map(self.loader, self.image_files.iloc[index])) # modified
            if self.seg_files is not None:
                seg, seg_meta_data = self.loader(self.seg_files[index])

        # apply the transforms
        if self.transform is not None:
            if isinstance(self.transform, Randomizable):
                self.transform.set_random_state(seed=self._seed)

            if self.transform_with_metadata:
                img, meta_data = list(map(apply_transform, self.transform, img,
                                     repeat(False), repeat(True))) # modified
            else:
                img = list(map(apply_transform, self.transform, img, repeat(False))) # modified

        if self.seg_files is not None and self.seg_transform is not None:
            if isinstance(self.seg_transform, Randomizable):
                self.seg_transform.set_random_state(seed=self._seed)

            if self.transform_with_metadata:
                seg, seg_meta_data = apply_transform(
                    self.seg_transform, (seg, seg_meta_data), map_items=False, unpack_items=True
                )
            else:
                seg = apply_transform(self.seg_transform, seg, map_items=False)

        if self.labels is not None:
            label = self.labels.iloc[index].to_dict()
            print(label, type(label))
            if self.label_transform is not None:
                label = apply_transform([self.label_transform]*len(label), label, map_items=False)  # type: ignore
          
        # construct outputs
        img = [torch.cat(img, dim=0)] # modified for making multi-channel input
        data = [img]
        if seg is not None:
            data.append(seg)
        if label is not None:
            data.append(label)
        if not self.image_only and meta_data is not None:
            data.append(meta_data)
        if not self.image_only and seg_meta_data is not None:
            data.append(seg_meta_data)
#         if len(data) == 1: # commented because multimodal model takes list of img as an input
#             return data[0]
        # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
        return tuple(data)

In [37]:
import os
import re
import glob
import random

import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from monai.transforms import (AddChannel, Compose, CenterSpatialCrop, Flip, RandAffine,
                              RandFlip, RandRotate90, Resize, ScaleIntensity, ToTensor)
from monai.data import ImageDataset, NibabelReader

from dataloaders.preprocessing import preprocessing_cat, preprocessing_num
from dataloaders.custom_dataset import MultiModalImageDataset

def case_control_count(labels, dataset_type, args):
    if args.cat_target:
        for cat_target in args.cat_target:
            curr_cnt = labels[cat_target].value_counts()
            print(f'In {dataset_type},\t"{cat_target}" contains {curr_cnt[1]} CASE and {curr_cnt[0]} CONTROL')
            
#             target_labels = []

#             for label in labels:
#                 target_labels.append(label[cat_target])
            
#             n_control = target_labels.count(0)
#             n_case = target_labels.count(1) + target_labels.count(2) # revising - count also 2 for UKB data
#             print(f'In {dataset_type} dataset, {cat_target} contains {n_case} CASE and {n_control} CONTROL')
            

## ========= Define directories of data ========= ##
# revising
ABCD_data_dir = {
    'fmriprep':'/scratch/connectome/3DCNN/data/1.ABCD/1.sMRI_fmriprep/preprocessed_masked/',
    'freesurfer':'/scratch/connectome/3DCNN/data/1.ABCD/2.sMRI_freesurfer/',
    'FA_unwarpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.1.FA_unwarpped_nii/',
    'FA_warpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.2.FA_warpped_nii/',
    'MD_unwarpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.3.MD_unwarpped_nii/',
    'MD_warpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.4.MD_warpped_nii/',
    'RD_unwarpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.5.RD_unwarpped_nii/',
    'RD_warpped_nii':'/scratch/connectome/3DCNN/data/1.ABCD/3.6.RD_warpped_nii/'
}

ABCD_phenotype_dir = {
    'total':'/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_phenotype_total.csv',
    'ADHD_case':'/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_ADHD.csv',
    'suicide_case':'/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_suicide_case.csv',
    'suicide_control':'/scratch/connectome/3DCNN/data/1.ABCD/4.demo_qc/ABCD_suicide_control.csv'
}

UKB_data_dir = '/scratch/connectome/3DCNN/data/2.UKB/1.sMRI_fs_cropped/'
UKB_phenotype_dir = '/scratch/connectome/3DCNN/data/2.UKB/2.demo_qc/UKB_phenotype.csv'


## ========= Define helper functions ========= ##

def loading_images(image_dir, args):
    image_files = pd.DataFrame()
    for brain_modality in args.data_type:
        curr_dir = image_dir[brain_modality]
        curr_files = pd.DataFrame({brain_modality:glob.glob(curr_dir+'*[yz]')}) # to get .npy(sMRI) & .nii.gz(dMRI) files
        curr_files[subjectkey] = curr_files[brain_modality].map(lambda x: x.split("/")[-1].split('.')[0])
        if args.dataset == 'UKB':
            curr_files[subjectkey] = curr_files[subjectkey].map(int)
        curr_files.sort_values(by=subjectkey, inplace=True)
        
        if len(image_files) == 0:
            image_files = curr_files
        else:
            image_files = pd.merge(image_files, curr_files, how='inner', on=subjectkey)
            
    if args.debug:
        image_files = image_files[:100]
        
    return image_files


def get_available_subjects(subject_data, args):
    case  = pd.read_csv(ABCD_phenotype_dir['ADHD_case'])[subjectkey]
    control = pd.read_csv(ABCD_phenotype_dir['suicide_control'])[subjectkey]
    filtered_subjectkey = pd.concat([case,control]).reset_index(drop=True)
    subject_data = subject_data[subject_data[subjectkey].isin(filtered_subjectkey)]
    
    return subject_data


def filter_phenotype(subject_data, filters):
    for fil in filters:
        fil_name, fil_option = fil.split(':')
        fil_option = np.float64(fil_option)
        subject_data = subject_data[subject_data[fil_name] == fil_option]
        
    return subject_data


def loading_phenotype(phenotype_dir, target_list, args):
    col_list = target_list + [subjectkey]

    ## get subject ID and target variables
    subject_data = pd.read_csv(phenotype_dir)
    subject_data = subject_data.loc[:,col_list]
    if 'Attention.Deficit.Hyperactivity.Disorder.x' in target_list:
        subject_data = get_available_subjects(subject_data, args)
    subject_data = filter_phenotype(subject_data, args.filter)
    subject_data = subject_data.sort_values(by=subjectkey)
    subject_data = subject_data.dropna(axis = 0)
    subject_data = subject_data.reset_index(drop=True)
    
    if (args.transfer == 'MAE' and args.dataset == 'ABCD') or args.scratch == 'MAE':
        return subject_data

    ### preprocessing categorical variables and numerical variables
    subject_data = preprocessing_cat(subject_data, args)
    subject_data = preprocessing_num(subject_data, args)
    
    return subject_data


def is_multichannel(args):
    if len(args.data_type) > 1 and args.in_channels > 1:
        assert len(args.data_type) == args.in_channels, \
        "the number of data type and in_channel should be same"
        return True
    else:
        return False


# defining train,val, test set splitting function
def partition_dataset(imageFiles_labels, target_list, args):
    ## Make lists of images & lables
    images = imageFiles_labels[args.data_type]
    labels = imageFiles_labels[target_list]
    
    ## Define transform function
    resize = tuple(args.resize)
    
    default_transforms = [ScaleIntensity(), AddChannel(), Resize(resize), ToTensor()] 
    dMRI_transform = [CenterSpatialCrop(192)] + default_transforms
    aug_transforms = []
    
    if 'shift' in args.augmentation:
        aug_transforms.append(RandAffine(prob=0.1,translate_range=(0,2),padding_mode='zeros'))
    elif 'flip' in args.augmentation:
        aug_transforms.append(RandFlip(prob=0.1, spatial_axis=0))
    
    train_transforms, val_transforms, test_transforms = [], [], []
    for brain_modality in args.data_type:
        if re.search('FA|MD|RD',brain_modality) != None:
            train_transforms.append(Compose(dMRI_transform+aug_transforms))
            val_transforms.append(Compose(dMRI_transform))
            test_transforms.append(Compose(dMRI_transform))
        else:
            train_transforms.append(Compose(default_transforms+aug_transforms))
            val_transforms.append(Compose(default_transforms))
            test_transforms.append(Compose(default_transforms))

    ## Dataset split
    num_total = len(images)
    num_test = int(num_total*args.test_size)
    num_val = int(num_total*args.val_size) if args.cv == None else int((num_total-num_test)/5)
    num_train = num_total - (num_val+num_test)
    
    ## split dataset by 5-fold cv or given split size
    if args.cv == None:
        images_train, images_val, images_test = np.split(images, [num_train, num_train+num_val]) # revising
        labels_train, labels_val, labels_test = np.split(labels, [num_train, num_train+num_val])
    else:
        split_points = [num_val, 2*num_val, 3*num_val, 4*num_val, num_total-num_test]
        images_total, labels_total = np.split(images, split_points), np.split(labels, split_points)
        images_test, labels_test = images_total.pop(), labels_total.pop()
        images_val, labels_val = images_total.pop(args.cv-1), labels_total.pop(args.cv-1)
        images_train, labels_train = np.concatenate(images_total), np.concatenate(labels_total)
        num_train, num_val = images_train.shape[0], images_val.shape[0]
        
    print(f"Total subjects={num_total}, train={num_train}, val={num_val}, test={num_test}")

    ## make splitted dataset
    ImgDataset = MultiChannelImageDataset if is_multichannel(args) else MultiModalImageDataset
    train_set = ImgDataset(image_files=images_train, labels=labels_train, transform=train_transforms)
    val_set = ImgDataset(image_files=images_val, labels=labels_val, transform=val_transforms)
    test_set = ImgDataset(image_files=images_test, labels=labels_test, transform=test_transforms)

    partition = {}
    partition['train'] = train_set
    partition['val'] = val_set
    partition['test'] = test_set

    case_control_count(labels_train, 'train', args)
    case_control_count(labels_val, 'validation', args)
    case_control_count(labels_test, 'test', args)

    return partition
## ====================================== ##


## ========= Main function that makes partition of dataset  ========= ##
def make_dataset(args): # revising
    global subjectkey
    subjectkey = 'subjectkey' if args.dataset == 'ABCD' else 'eid'
    image_dir = ABCD_data_dir if args.dataset == 'ABCD' else UKB_data_dir
    phenotype_dir = ABCD_phenotype_dir['total'] if args.dataset == 'ABCD' else UKB_phenotype_dir
    target_list = args.cat_target + args.num_target
    
    image_files = loading_images(image_dir, args)
    subject_data = loading_phenotype(phenotype_dir, target_list, args)

    # combining image files & labels
    imageFiles_labels = pd.merge(subject_data, image_files, how='inner', on=subjectkey)

    # partitioning dataset and preprocessing (change the range of categorical variables and standardize numerical variables)
    partition = partition_dataset(imageFiles_labels, target_list, args)
    print("*** Making a dataset is completed *** \n")
    
    return partition, subject_data

In [38]:
args.seed = 1234
def seed_all(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
seed_all(args.seed)

In [39]:
partition, subject_data = make_dataset(args)

Total subjects=100, train=80, val=10, test=10
*** Making a dataset is completed *** 



In [40]:
args.seed = 1234
def seed_worker(worker_id):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

g = torch.Generator()
g.manual_seed(args.seed)

'''GradScaler is for calculating gradient with float 16 type'''
scaler = torch.cuda.amp.GradScaler()

trainloader = torch.utils.data.DataLoader(partition['train'],
                                          batch_size=args.train_batch_size,
                                          shuffle=False,
                                          pin_memory=True,
                                          num_workers=4,
                                          worker_init_fn=seed_worker,
                                          generator=g)

for i, data in enumerate(trainloader,0):
    image, targets = data
    break
    

{'BMI': 0.39505973568445957}{'BMI': -0.8890379428475166}  <class 'dict'><class 'dict'>

{'BMI': -0.27170970277589945} <class 'dict'>
{'BMI': 1.976151886916753} <class 'dict'>
{'BMI': -0.5590939889144871} <class 'dict'>
{'BMI': -0.13660421773270348} <class 'dict'>
{'BMI': -0.8005013298236575}{'BMI': 3.6355618748948992}  <class 'dict'><class 'dict'>

{'BMI': -0.23484457349586813} <class 'dict'>
{'BMI': -0.5873757939371386} <class 'dict'>
{'BMI': -1.1321592148468074} {'BMI': 1.4013343743940818} <class 'dict'>
<class 'dict'>
{'BMI': -1.2855276955854025} <class 'dict'>{'BMI': 0.31687386032116416}
 <class 'dict'>
{'BMI': 0.07012670240998889}{'BMI': 0.7713097110742154}  <class 'dict'><class 'dict'>

{'BMI': -0.6322268768308481}{'BMI': -0.8610403708983724}  <class 'dict'><class 'dict'>

{'BMI': -1.2316411385074673}{'BMI': -0.6151137082622601}  <class 'dict'><class 'dict'>

{'BMI': 0.059282100701255024} <class 'dict'>
{'BMI': -0.8289438397768333} <class 'dict'>
{'BMI': -0.7699440862291403} <cla

In [42]:
for i in image:
    print(i.shape)

torch.Size([8, 2, 80, 80, 80])


In [19]:
for i in image:
    print(i.shape)

torch.Size([8, 1, 80, 80, 80])
torch.Size([8, 1, 80, 80, 80])


## check Model

In [38]:
## =================================== ##
## ======= DenseNet ======= ##
## =================================== ##

# model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

# utils
import collections

## ========= DenseNet Model ========= #
#(ref) explanation - https://wingnim.tistory.com/39
#(ref) densenet3d - https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py
#(ref) pytorch - https://pytorch.org/vision/0.8/_modules/torchvision/models/densenet.html

class _DenseLayer(nn.Module):

    def __init__(self, num_input_features, growth_rate, bn_size):
        super().__init__()

        ## DenseNet Composite function: BN -> relu -> 3x3 conv
        # 1
        self.add_module('norm1', nn.BatchNorm3d(num_input_features))
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False))

        # 2
        self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate))
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False))

        #self.memory_efficient = memory_efficient
    
    def bn_function(self, inputs) -> Tensor:
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    def forward(self, x):
        if isinstance(x, Tensor):
            prev_features = [x]
        else:
            prev_features = x

        bottleneck_output = self.bn_function(prev_features)
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        return new_features  ## **

class _DenseBlock(nn.ModuleDict):
    # receive and concatenate the outputs of all previous blocks as inputs 
    # growth rate? the number of channel of feature map in each layer
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate):
        super().__init__()
        
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate,
                                growth_rate, bn_size)
            self.add_module("denselayer%d" % (i + 1), layer)
    
    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)

class _Transition(nn.Sequential):
    ## convolution + pooling between block
    # in paper: bach normalization -> 1x1 conv layer -> 2x2 average pooling layer

    def __init__(self, num_input_features, num_output_features):
        super().__init__()
        self.add_module('norm', nn.BatchNorm3d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))

class DenseNet(nn.Module):
    """Densenet-BC model class
    Args:
        growth_rate (int) - how many filters to add each layer (k in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self, subject_data, args,
                 n_input_channels=1,conv1_t_size=7,conv1_t_stride=2,no_max_pool=False,
                 growth_rate=32,block_config=(6, 12, 24, 16),num_init_features=64,
                 bn_size=4,drop_rate=0,num_classes=1000):

        super(DenseNet, self).__init__()
        self.subject_data = subject_data
        self.brain_dtypes = args.data_type
        self.cat_target = args.cat_target
        self.num_target = args.num_target 
        self.target = args.cat_target + args.num_target
        
        self.n_input_channels = n_input_channels
        self.conv1_t_size = conv1_t_size
        self.conv1_t_stride = conv1_t_stride
        self.growth_rate = growth_rate
        self.block_config = block_config
        self.num_init_features = num_init_features
        self.bn_size = bn_size
        self.drop_rate = drop_rate

              
        self.feature_extractors = self._make_feature_extractors()

        # Linear layer
        self.FClayers = self._make_fclayers()

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
    
    def _make_feature_extractors(self):
        feature_extractors = []
        for brain_dtype in self.brain_dtypes:
            # First convolution  
            feature_extractor = nn.Sequential(collections.OrderedDict([
                ('conv0',nn.Conv3d(self.n_input_channels,
                                   self.num_init_features,
                                   kernel_size=(self.conv1_t_size, 7, 7),
                                   stride=(self.conv1_t_stride, 2, 2),
                                   padding=(self.conv1_t_size // 2, 3, 3),
                                   bias=False)),
                 ('norm0', nn.BatchNorm3d(self.num_init_features)),
                 ('relu0', nn.ReLU(inplace=True)),
                 ('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1))
            ]))
            # Each denseblock
            num_features = self.num_init_features
            for i, num_layers in enumerate(self.block_config):
                block = _DenseBlock(num_layers=num_layers,
                                    num_input_features=num_features,
                                    bn_size=self.bn_size,
                                    growth_rate=self.growth_rate)
                feature_extractor.add_module(f'denseblock{i+1}', block)
                num_features = num_features + num_layers*self.growth_rate

                if i != len(self.block_config) - 1:
                    trans = _Transition(num_input_features = num_features,
                                        num_output_features = num_features // 2)
                    feature_extractor.add_module(f'transition{i+1}', trans)
                    num_features = num_features // 2

            # Final batch norm
            feature_extractor.add_module('norm5', nn.BatchNorm3d(num_features))
            
            feature_extractors.append(feature_extractor)
            
        self.num_features = num_features
        
        return nn.ModuleList(feature_extractors)
    
    def _make_fclayers(self):
        FClayer = []
        
        for cat_label in self.cat_target:
            self.out_dim = len(self.subject_data[cat_label].value_counts())                        
            FClayer.append(nn.Sequential(nn.Linear(self.num_features, self.out_dim)))

        for num_label in self.num_target:
            FClayer.append(nn.Sequential(nn.Linear(self.num_features, 1)))

        return nn.ModuleList(FClayer)


    def forward(self, images):            
        outs = []
        results = {'embeddings':[]}
        for i, x in enumerate(images): # feed each brain modality into its own CNN
            features = self.feature_extractors[i](x)
            print(features.shape)
            results['embeddings'].append(torch.flatten(features, 1))
            print(torch.flatten(features, 1).shape)
            out = F.adaptive_avg_pool3d(features, output_size=(1, 1, 1))
            print(out.shape)
            out = F.relu(out, inplace=True)
            out = torch.flatten(out, 1)
            print(out.shape)
            outs.append(out)
            
        out = torch.cat(outs,1)
        print(out.shape)
        
        for i in range(len(self.FClayers)):
            results[self.target[i]] = self.FClayers[i](out)
            
        return results

def generate_model(model_depth, subject_data, args, **kwargs):
    assert model_depth in [121, 169, 201, 264]

    if model_depth == 121:
        model = DenseNet(subject_data, args,
                         num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 24, 16),
                         **kwargs)
    elif model_depth == 169:
        model = DenseNet(subject_data, args,
                         num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 32, 32),
                         **kwargs)
    elif model_depth == 201:
        model = DenseNet(subject_data, args,
                         num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 48, 32),
                         **kwargs)
    elif model_depth == 264:
        model = DenseNet(subject_data, args,
                         num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 64, 48),
                         **kwargs)
    return model

def densenet3D121(subject_data, args):
    model = generate_model(121, subject_data, args)
    return model

def densenet3D169(subject_data, args):
    model = generate_model(169, subject_data, args)
    return model

def densenet3D201(subject_data, args):
    model = generate_model(201, subject_data, args)
    return model

def densenet3D264(subject_data, args):
    model = generate_model(264, subject_data, args)
    return model


In [64]:
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F

class SFCN(nn.Module):
    def __init__(self, subject_data, args, channel_number=[32, 64, 128, 256, 256, 64], output_dim=40, dropout=True):
        super(SFCN, self).__init__()
        # Setting experiment related variables
        self.subject_data = subject_data
        self.cat_target = args.cat_target
        self.num_target = args.num_target 
        self.target = args.cat_target + args.num_target
        
        # Setting model related variables
        self.n_layer = len(channel_number)
        self.channel_number = channel_number
        self.last_feature = channel_number[-1]
        self.output_dim = output_dim
        self.dropout = dropout
        
        self.brain_dtypes = args.data_type
        
        # make feature extractor
        self.feature_extractors = self._make_feature_extractors()
        
        # make classifier part
        self.FClayers = self._make_fclayers()
        
#         avg_shape = max(set(args.resize))//(2**len(channel_number))
#         self.classifier = self._make_classifier(avg_shape)
        
        # initialize trainable weights
        self._init_weights()

        
    @staticmethod
    def conv_layer(in_channel, out_channel, maxpool=True, kernel_size=3, padding=0, maxpool_stride=2):
        if maxpool is True:
            layer = nn.Sequential(
                nn.Conv3d(in_channel, out_channel, padding=padding, kernel_size=kernel_size),
                nn.BatchNorm3d(out_channel),
                nn.MaxPool3d(2, stride=maxpool_stride),
                nn.ReLU(),
            )
        else:
            layer = nn.Sequential(
                nn.Conv3d(in_channel, out_channel, padding=padding, kernel_size=kernel_size),
                nn.BatchNorm3d(out_channel),
                nn.ReLU()
            )
        return layer
    
    
    def _make_feature_extractors(self):
        feature_extractors = []
        
        for brain_dtype in self.brain_dtypes:
            feature_extractor = nn.Sequential()

            for i in range(self.n_layer):
                in_channel = 1 if i == 0 else self.channel_number[i-1]
                out_channel = self.channel_number[i]

                curr_kernel_size = 3 if i < self.n_layer-1 else 1
                curr_padding = 1 if i < self.n_layer-1 else 0

                feature_extractor.add_module('conv_%d' % i,
                                             self.conv_layer(
                                                 in_channel, out_channel,
                                                 maxpool=True, kernel_size=curr_kernel_size,
                                                 padding=curr_padding))
                
            feature_extractors.append(feature_extractor)
                
        return nn.ModuleList(feature_extractors)
    
    
    def _make_fclayers(self):
        FClayer = []
        
        for cat_label in self.cat_target:
            self.out_dim = 2                       
            FClayer.append(nn.Sequential(nn.Linear(self.last_feature, self.out_dim)))

        for num_label in self.num_target:
            FClayer.append(nn.Sequential(nn.Linear(self.last_feature, 1)))

        return nn.ModuleList(FClayer)
    
       
    def make_classifier(self, avg_shape):
        self.classifier = nn.Sequential()
        if avg_shape >1:
            self.classifier.add_module('average_pool', nn.AvgPool3d(avg_shape))
        if self.dropout is True:
            self.classifier.add_module('dropout', nn.Dropout(0.5))
        i = self.n_layer
        in_channel = self.channel_number[-1]
        out_channel = output_dim
        self.classifier.add_module('conv_%d' % i,
                                   nn.Conv3d(in_channel, out_channel, padding=0, kernel_size=1))

        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
        
    def forward(self, images):
#         out = list()
#         x_f = self.features(x)
#         x = self.classifier(x_f)
#         x = F.log_softmax(x, dim=1)
#         out.append(x)
#         return out

        outs = []
        results = {'embeddings':[]}
        for i, x in enumerate(images): # feed each brain modality into its own CNN
            features = self.feature_extractors[i](x)
            results['embeddings'].append(torch.flatten(features, 1))
            out = F.adaptive_avg_pool3d(features, output_size=(1, 1, 1))
            out = F.relu(out, inplace=True)
            out = torch.flatten(out, 1)
            outs.append(out)
            
        out = torch.cat(outs,1)
        
        for i in range(len(self.FClayers)):
            results[self.target[i]] = self.FClayers[i](out)
            
        return results
    


In [66]:
net2 = SFCN(subject_data, args).cuda()

In [67]:
output2 = net(image)

In [72]:
targets['Attention.Deficit.Hyperactivity.Disorder.x'].shape

torch.Size([8])

In [69]:
output2['Attention.Deficit.Hyperactivity.Disorder.x'].shape

torch.Size([16, 2])

In [39]:
net = densenet3D121(subject_data, args)

In [26]:
net.target

['Attention.Deficit.Hyperactivity.Disorder.x']

In [51]:
output = net(image)

In [52]:
output

{'embeddings': [tensor([[-0.4822,  0.5738,  0.0774,  ...,  0.1468,  1.3184,  1.4301],
          [-0.4388,  1.1981, -0.1515,  ..., -0.2443,  1.4693,  0.9829],
          [-0.4721,  0.3771,  0.7308,  ..., -0.8011,  1.9918,  0.7602],
          ...,
          [ 1.3875,  1.2068,  1.1519,  ..., -1.1429,  1.9906, -0.8282],
          [-0.2427,  0.4352,  0.4627,  ..., -0.2086,  1.5322,  0.6452],
          [-0.0283, -0.0331, -0.0127,  ..., -0.6111,  1.3100,  1.2850]],
         device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
  tensor([[ 1.0457,  0.8199,  1.2897,  ..., -0.5242,  0.3318, -0.5330],
          [ 1.1884,  0.8106,  0.4973,  ..., -0.8140, -0.3504, -1.9877],
          [ 0.6546,  0.5781,  0.9055,  ..., -0.7825,  0.6448, -0.4094],
          ...,
          [ 1.0259,  0.5164,  0.6972,  ..., -0.0338, -0.7488, -2.3160],
          [ 1.3154,  0.9218,  1.1458,  ..., -1.3427, -0.1512, -0.5285],
          [ 1.0396,  0.4987,  1.4770,  ..., -0.4364,  0.1519, -0.5632]],
         device='cuda:0', grad

In [35]:
net

DenseNet(
  (feature_extractors): ModuleList(
    (0): Sequential(
      (conv0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
      (norm0): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
          (norm2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv3d(128, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNor

In [28]:
crit_ssim = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')

In [29]:
embed_1, embed_2 = output['embeddings']
label_positive = torch.ones(embed_1.shape[0])
label_negative = -torch.ones(embed_1.shape[0])

In [471]:
embed_1.clone()

tensor([[0.4899, 1.2098, 0.0000, 0.0887, 0.0000, 0.7583, 0.8378, 0.2195, 0.9658,
         0.8562, 0.0000, 0.0000, 0.0000, 0.8016, 0.0000, 0.0000, 1.0345, 1.3318,
         0.7532, 1.7366, 0.2829, 1.0754, 0.0000, 0.0000, 0.0000, 0.1661, 0.0447,
         0.0154, 0.0000, 0.4641, 0.3588, 0.0000, 1.1272, 0.5417, 0.7224, 0.3109,
         1.2832, 1.3940, 0.0000, 1.6496, 0.0000, 1.0792, 0.4733, 0.0000, 0.0000,
         0.0000, 0.1128, 0.1073, 1.3811, 0.0000, 1.1237, 1.4666, 0.0000, 0.4366,
         0.3673, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1392, 0.0000, 0.6729,
         0.0000],
        [0.4929, 1.2085, 0.0000, 0.0856, 0.0000, 0.7530, 0.8353, 0.2161, 0.9649,
         0.8555, 0.0000, 0.0000, 0.0000, 0.7971, 0.0000, 0.0000, 1.0334, 1.3255,
         0.7558, 1.7363, 0.2802, 1.0777, 0.0000, 0.0000, 0.0000, 0.1664, 0.0456,
         0.0177, 0.0000, 0.4630, 0.3590, 0.0000, 1.1293, 0.5392, 0.7255, 0.3060,
         1.2805, 1.3925, 0.0000, 1.6627, 0.0000, 1.0818, 0.4718, 0.0000, 0.0000,
         0

In [233]:
%timeit torch.concat((embed_2[-1:],embed_2[:-1]),0)

8.01 µs ± 78.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [229]:
torch.concat((embed_2[-1:],embed_2[:-1]),0)

tensor([[1.4107, 0.9416, 1.1886, 1.5822, 1.4380, 3.0026, 2.6295, 1.3752, 1.0127,
         1.9757, 1.2542, 1.7054, 0.9270, 1.2276, 0.9402, 1.3779, 1.6086, 1.7499,
         1.0743, 1.1287, 0.9094, 1.1572, 1.4577, 1.0137, 2.0756, 1.2001, 1.2710,
         1.9049, 1.8862, 1.3772, 1.4206, 0.9205, 1.2601, 1.4614, 1.5008, 1.3417,
         1.7736, 1.3703, 1.7096, 2.2104, 0.7348, 1.3079, 1.6511, 1.2537, 0.6134,
         1.6634, 1.1548, 1.0102, 2.3488, 0.9393, 1.0405, 1.8895, 1.8403, 1.1865,
         1.1741, 2.4666, 1.6783, 1.5272, 1.0589, 1.3336, 1.4851, 1.1742, 1.5161,
         0.4171],
        [1.1632, 2.8595, 1.8268, 1.5520, 1.4468, 1.9557, 2.4680, 2.3295, 1.0214,
         1.0486, 1.5386, 1.5192, 1.6651, 1.7470, 1.3922, 0.9683, 1.5688, 1.4232,
         1.0641, 0.9972, 1.4803, 1.4450, 1.4623, 2.1452, 1.5907, 1.6945, 2.3313,
         2.2439, 1.4132, 0.7928, 1.7071, 2.1073, 2.0168, 1.0734, 1.8236, 1.3022,
         0.8966, 0.7699, 0.7538, 1.3051, 1.7875, 1.3180, 1.1171, 1.1247, 1.4811,
         1

In [30]:
%timeit torch.roll(embed_2,1)

KeyboardInterrupt: 

In [31]:
torch.roll(embed_2,1,0)

tensor([[1.6237, 0.6142, 0.7278, 0.7405, 1.8391, 1.5216, 2.6620, 1.4709, 1.4728,
         1.1043, 0.8118, 1.6947, 1.9731, 1.1350, 0.8003, 1.3716, 2.1817, 0.8897,
         2.3871, 1.5077, 1.5912, 1.5758, 1.1941, 1.4748, 1.3659, 1.3840, 0.8293,
         0.6092, 1.3223, 1.3602, 1.7204, 1.4076, 1.1182, 1.2279, 1.6797, 1.4163,
         1.3967, 1.4233, 0.9765, 1.5810, 1.5547, 1.7480, 1.1365, 0.8409, 1.9343,
         1.1086, 0.8552, 2.0802, 1.4789, 1.1731, 0.9449, 0.8527, 1.4200, 1.1805,
         1.6369, 1.9909, 1.2904, 0.4850, 0.4799, 1.0004, 0.4229, 0.6690, 1.2777,
         1.4117],
        [1.2231, 1.0159, 1.9400, 1.2623, 1.5852, 1.3358, 1.1257, 2.6471, 2.1629,
         1.7245, 1.2286, 1.7435, 0.9345, 1.4318, 0.6566, 0.6916, 1.3193, 0.9516,
         1.0414, 2.2415, 1.0737, 2.0053, 1.3512, 0.8876, 1.0069, 1.2072, 1.8058,
         1.8162, 1.3111, 1.4944, 1.2883, 1.8648, 1.4432, 0.7757, 1.4014, 1.5737,
         1.9508, 1.2649, 2.0394, 1.3127, 1.0670, 1.7606, 1.0997, 1.1277, 1.3197,
         1

In [32]:
embed_2

tensor([[1.2231, 1.0159, 1.9400, 1.2623, 1.5852, 1.3358, 1.1257, 2.6471, 2.1629,
         1.7245, 1.2286, 1.7435, 0.9345, 1.4318, 0.6566, 0.6916, 1.3193, 0.9516,
         1.0414, 2.2415, 1.0737, 2.0053, 1.3512, 0.8876, 1.0069, 1.2072, 1.8058,
         1.8162, 1.3111, 1.4944, 1.2883, 1.8648, 1.4432, 0.7757, 1.4014, 1.5737,
         1.9508, 1.2649, 2.0394, 1.3127, 1.0670, 1.7606, 1.0997, 1.1277, 1.3197,
         1.4341, 1.4526, 1.9129, 1.4406, 1.2972, 0.8266, 0.9674, 1.8484, 1.9979,
         1.2216, 0.7840, 1.3228, 2.7306, 1.5392, 0.9966, 1.5077, 1.7736, 1.3773,
         1.4719],
        [1.2315, 1.0099, 2.4215, 1.2168, 1.2165, 1.0611, 0.9396, 0.9516, 1.7010,
         0.8731, 1.9273, 1.9629, 1.3425, 1.9894, 1.6837, 0.4715, 1.6907, 1.4503,
         1.5968, 1.7265, 1.4584, 0.9110, 1.8944, 0.6242, 0.8255, 1.6196, 0.9413,
         1.4027, 1.4903, 1.1916, 0.8441, 0.4398, 0.3489, 2.0114, 2.5329, 0.9612,
         1.1694, 1.1906, 1.0584, 1.5518, 1.2698, 1.5895, 0.7993, 0.7200, 2.0581,
         1

In [445]:
%timeit torch.roll(embed_2,1,0)

7.44 µs ± 26 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [446]:
%timeit embed_2.roll(1,0)

7.39 µs ± 73 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [430]:
rolled = embed_2.roll(1,0)
for i, embed in enumerate(embed_1):
    print(crit_ssim(embed.unsqueeze(0), rolled[i].unsqueeze(0), -torch.ones(1).cuda()))

print('hi',crit_ssim(embed_1, rolled, label_negative.cuda()))

tensor(0.4919, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4926, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4922, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4923, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4922, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4918, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4922, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4926, device='cuda:0', grad_fn=<MeanBackward0>)
hi tensor(0.4922, device='cuda:0', grad_fn=<MeanBackward0>)


## experiements.py (train, valid, test part) & loss_functions.py

In [431]:
torch.roll?

In [12]:
import torch
import torch.nn as nn

def constrastive_loss(output):
    # << should be implemented later >> case where len(args.data_type) >= 3 
    embedding_1, embedding_2 = output['embeddings']
    embedding_2_rolled = embedding_2.roll(1,0)
    
    label_positive = torch.ones(embedding_1.shape[0]).to(f'cuda:{net.device_ids[0]}')
    label_negative = -label_positive

    criterion_ssim = nn.CosineEmbeddingLoss(margin=0.0, reduction='mean').to(f'cuda:{net.device_ids[0]}')
    loss_positive = criterion_ssim(embedding_1, embedding_2, label_positive)
    loss_negative = criterion_ssim(embedding_1, embedding_2_rolled, label_negative)
    
    output.pop('embeddings')
    
    return loss_positive, loss_negative


def calc_acc(tmp_output, label, args, tmp_loss=None):
    _, predicted = torch.max(tmp_output.data, 1)
    correct = (predicted == label).sum().item()
    total = label.size(0)
    
    return (100 * correct / total)


def calc_R2(tmp_output, y_true, args, tmp_loss=None):
    if ('MAE' in [args.transfer, args.scratch]) or tmp_loss == None:
        criterion = nn.MSELoss()
        tmp_loss = criterion(tmp_output.float(), y_true.float().unsqueeze(1))

    y_var = torch.var(y_true, unbiased=False)
    r_square = 1 - (tmp_loss / y_var)
                    
    return r_square.item()


def calculating_loss_acc(targets, output, loss_dict, acc_dict, net, args):
    '''define calculating loss and accuracy function used during training and validation step'''
    # << should be implemented later >> how to set ssim_weight?
    cat_weight = (len(args.cat_target)/(len(args.cat_target)+len(args.num_target)))
    num_weight = 1 - cat_weight
    loss = 0.0
    
    # calculate constrastive_loss
    if len(args.data_type) > 1:
        loss_positive, loss_negative = constrastive_loss(output)
        loss += (loss_positive + loss_negative)/2
        loss_dict['contrastive_loss_positive'].append(loss_positive.item())
        loss_dict['contrastive_loss_negative'].append(loss_negative.item())
        
    # calculate target_losses & accuracies
    for curr_target in output:
        tmp_output = output[curr_target]
        label = targets[curr_target].to(f'cuda:{net.device_ids[0]}')
        label = label.repeat(2)
        tmp_label = label.long() if curr_target in args.cat_target else label.float().unsqueeze(1)
        weight = cat_weight if curr_target in args.cat_target else num_weight
        
        if curr_target in args.cat_target:
            criterion = nn.CrossEntropyLoss().to(f'cuda:{net.device_ids[0]}')
        elif curr_target == 'age' and 'MAE' in [args.transfer, args.scratch]:
            criterion = nn.L1Loss().to(f'cuda:{net.device_ids[0]}')
        else:
            criterion = nn.MSELoss().to(f'cuda:{net.device_ids[0]}')
        
        # Loss
        tmp_loss = criterion(tmp_output.float(), tmp_label)
        loss += tmp_loss * weight
        loss_dict[curr_target].append(tmp_loss.item())
        
        # Acc
        acc_func = calc_acc if curr_target in args.cat_target else calc_R2
        acc = acc_func(tmp_output, label, args, tmp_loss)
        acc_dict[curr_target].append(acc) 
            
    return loss

In [239]:
from collections import defaultdict
print(defaultdict.fromkeys([1,2,3],[]))
print(defaultdict.fromkeys(['a','b'],0))

defaultdict(None, {1: [], 2: [], 3: []})
defaultdict(None, {'a': 0, 'b': 0})


In [270]:
a = defaultdict(int)
a['hi']=0
a['2']=1
a

defaultdict(int, {'hi': 0, '2': 1})

In [271]:
def change_dict(mydict):
    for key in mydict:
        mydict[key] += 1
change_dict(a)
a

defaultdict(int, {'hi': 1, '2': 2})

In [294]:
a = defaultdict(list)

In [297]:
for i in np.arange(10):
    a['age'].append(i+10)
    a['hi'].append(i+14)

In [324]:
map(lambda x: np.mean(a[x]),a)

<map at 0x7f7eb8d78ed0>

In [46]:
import random
from collections import defaultdict

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix

# from envs.loss_functions import calculating_loss_acc, calc_acc, calc_R2

### ========= Train,Validate, and Test ========= ###
'''The process of calcuating loss and accuracy metrics is as follows.
   1) sequentially calculate loss and accuracy metrics of target labels with for loop.
   2) store the result information with dictionary type.
   3) return the dictionary, which form as {'cat_target':value, 'num_target:value}
   This process is intended to easily deal with loss values from each target labels.'''


'''All of the loss from predictions are summated and this loss value is used for backpropagation.'''

# define training step
def train(net,partition,optimizer,args):
    def seed_worker(worker_id):
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    g = torch.Generator()
    g.manual_seed(args.seed)
    
    '''GradScaler is for calculating gradient with float 16 type'''
    scaler = torch.cuda.amp.GradScaler()

    trainloader = torch.utils.data.DataLoader(partition['train'],
                                              batch_size=args.train_batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=4,
                                              worker_init_fn=seed_worker,
                                              generator=g)

    net.train()
    
    train_loss = defaultdict(list)
    train_acc =  defaultdict(list)

    for i, data in enumerate(trainloader,0):
        optimizer.zero_grad()
        image, targets = data
        image = list(map(lambda x: x.to(f'cuda:{net.device_ids[0]}'), image))
        output = net(image)
        loss = calculating_loss_acc(targets, output, train_loss, train_acc, net, args)
        
        # multi-head model sum all the loss from predicting each target variable and back propagation
        scaler.scale(loss).backward() 
        scaler.step(optimizer)
        scaler.update()

    # calculating total loss and acc of separate mini-batch
    for target in train_loss:
        train_loss[target] = np.mean(train_loss[target])
        train_acc[target] = np.mean(train_acc[target])

    return net, train_loss, train_acc


# define validation step
def validate(net,partition,scheduler,args):
    def seed_worker(worker_id):
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    g = torch.Generator()
    g.manual_seed(args.seed)
    
    valloader = torch.utils.data.DataLoader(partition['val'],
                                            batch_size=args.val_batch_size,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=4,
                                            worker_init_fn=seed_worker,
                                            generator=g)

    net.eval()

    val_loss = defaultdict(list)
    val_acc = defaultdict(list)

    with torch.no_grad():
        for i, data in enumerate(valloader,0):
            image, targets = data
            image = list(map(lambda x: x.to(f'cuda:{net.device_ids[0]}'), image))
            output = net(image)
            loss = calculating_loss_acc(targets, output, val_loss, val_acc, net, args)

    for target in val_loss:
        val_loss[target] = np.mean(val_loss[target])
        val_acc[target] = np.mean(val_acc[target])

    # learning rate scheduler
    if scheduler:
        if args.scheduler == 'on':
            scheduler.step(sum(val_acc.values()))
        else:
            scheduler.step()

    return val_loss, val_acc


def calc_confusion_matrix(confusion_matrices, curr_target, output, y_true):
    _, predicted = torch.max(output.data,1)
    tn, fp, fn, tp = confusion_matrix(y_true.numpy(), predicted.numpy()).ravel()
    confusion_matrices[curr_target]['True Positive'] = int(tp)
    confusion_matrices[curr_target]['True Negative'] = int(tn)
    confusion_matrices[curr_target]['False Positive'] = int(fp)
    confusion_matrices[curr_target]['False Negative'] = int(fn)     

    
# define test step
def test(net,partition,args):
    def seed_worker(worker_id):
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    g = torch.Generator()
    g.manual_seed(args.seed)
    
    testloader = torch.utils.data.DataLoader(partition['test'],
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             worker_init_fn=seed_worker,
                                             generator=g)

    net.eval()
    if hasattr(net, 'module'):
        device = net.device_ids[0]
    else: 
        device = 'cuda:0' if args.sbatch =='True' else f'cuda:{args.gpus[0]}'
    
    outputs = defaultdict(list)
    y_true = defaultdict(list)
    test_acc = defaultdict(list)
    confusion_matrices = defaultdict(defaultdict(dict))
    
    with torch.no_grad():
        for i, data in enumerate(tqdm(testloader),0):
            image, targets = data
            image = list(map(lambda x: x.to(f'cuda:{net.device_ids[0]}'), image))
            output = net(image)
            
            for curr_target in output:
                outputs[curr_target].append(output[curr_target].cpu())
                y_true[curr_target].append(targets[curr_target].cpu())
    
    # caculating ACC and R2 at once  
    for curr_target in output:
        if curr_target == 'embeddings':
            continue
            
        outputs[curr_target] = torch.cat(outputs[cat_target])
        acc_func = calc_acc if cat_target in args.cat_target else calc_R2
        curr_acc = acc_func(outputs[curr_target], y_true[curr_target], args, None)
        test_acc[curr_target].append(curr_acc)
        
        if curr_target in args.confusion_matrix:
            calc_confusion_matrix(confusion_matrices, curr_target,
                                  outputs[curr_target], y_true[curr_target])

    return test_acc, confusion_matrices

## ============================================ ##


In [35]:
defaultdict(defaultdict)

defaultdict(collections.defaultdict, {})

## final test

In [42]:
def seed_worker(worker_id):
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

g = torch.Generator()
g.manual_seed(args.seed)
    
trainloader = torch.utils.data.DataLoader(partition['train'],
                                          batch_size=args.train_batch_size,
                                          shuffle=False,
                                          pin_memory=True,
                                          num_workers=4,
                                          worker_init_fn=seed_worker,
                                          generator=g)

net.train()

train_loss = defaultdict(list)
train_acc =  defaultdict(list)

for i, data in enumerate(trainloader,0):
#     optimizer.zero_grad()
    image, targets = data
    image = list(map(lambda x: x.to(f'cuda:{net.device_ids[0]}'), image))
    output = net(image)
#     loss = calculating_loss_acc(targets, output, train_loss, train_acc, net, args)
    break

In [43]:
output

{'embeddings': [tensor([[1.8290, 1.3430, 1.8028, 1.7541, 1.8268, 0.8958, 1.9593, 1.1711, 1.4457,
           1.7346, 1.0857, 1.3100, 1.5850, 1.7867, 0.9789, 1.3299, 2.4745, 1.3464,
           1.7834, 2.1281, 0.4086, 1.0139, 0.9655, 2.2807, 0.9923, 1.6390, 0.9387,
           1.1539, 1.5020, 1.6226, 0.9755, 0.5529, 1.0618, 1.4284, 1.0075, 0.9535,
           1.0771, 1.8091, 1.3573, 1.6884, 1.9233, 2.0752, 1.5433, 1.3085, 1.8264,
           1.8865, 1.4412, 1.3588, 1.9732, 2.3270, 1.6048, 1.5629, 0.5728, 2.6758,
           0.7274, 0.9523, 1.3422, 1.5376, 1.6402, 1.3779, 1.6199, 2.4969, 2.0641,
           1.4194],
          [2.6745, 2.0473, 2.2602, 1.4448, 1.1304, 1.1188, 2.7850, 0.6655, 2.9663,
           1.4921, 1.1243, 2.8253, 1.4241, 2.0195, 0.9776, 2.4197, 1.4592, 1.3429,
           2.0401, 0.9693, 0.8517, 1.5039, 1.1708, 1.6983, 0.9352, 2.6030, 1.6148,
           0.9847, 1.8217, 0.6398, 1.5258, 0.7483, 1.7654, 0.7673, 1.3659, 1.6577,
           0.6337, 2.5614, 1.2102, 1.2793, 1.5980, 2.