# Multi-site Domain Adaptation for rs-fMRI Classifications

Recent applications of deep learning techniques have shown promising results in conducting various predictive and classification based analysis of medical imagery, especially using **resting-state fMRI** (rs-fMRI). In common practices, it is seen that supervised DL algorithms perform better when large datasets are utilized for training them.

However, large rs-fMRI datasets are difficult to acquire from a single site and hence, one database may contain data collected from multiple sites/sources. While this increases the amount of data that can be used by a DL model, the inherent differences in acquisition method and subjects’ demographics among the sites introduces a bias termed as **batch effects**, which could impact the training capacity of a DL model.


**Multi-site DA models** aim to extract site-invariant features by mapping the neuroimaging data from different sites onto a shared space which may lead to better information retention while reducing the problems faced by batch effects. To show the viability of multi-site DA models, public datasets like `ABIDE` and `ADHD200` were used for classification tasks such as control vs illness  and male vs female. DA Models such as `DANN, MDAN, DARN, MDMN and MSDA` were compared against baseline DL models. 


In this notebook the code for the experiments carried out in (cite) are provided. 




## Importing necessary libraries

In [None]:
import numpy as np 
import pandas as pd 
from torch.utils.data.sampler import SubsetRandomSampler,RandomSampler
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader      
import scipy.io as sio
import math
from easydict import EasyDict as ED
import copy
import os
from nilearn.regions import RegionExtractor
from nilearn.decomposition import DictLearning
from nilearn.connectome import ConnectivityMeasure
import nibabel
import pickle
import glob
from sklearn.model_selection import StratifiedKFold
from torch.autograd import Function
import imblearn 
from imblearn.over_sampling import RandomOverSampler
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
import nilearn

## Importing custom modules

These modules provide utilty functions along with classes of each of the DA models

In [None]:
import sys
sys.path.append('../')
from models import DarnMLP, DarnConv, MdmnMLP, MdmnConv, MsdaMLP, MsdaConv
from load_data import load_fmri_data, multi_data_loader, data_loader
import utils
from utils import set_configs, OverSampler
import time

## Hyperparameter and arguments setting

The following code contains the hyperparameter settings which are independent of the data related information, namely:

1. `classifier_config` : Hyperparams for the classifier model which classifies the labels of the data.

2. `domain_config` : Hyperparams for the domain classifier model which is used for certain DA models to introduce domain invariancy

3. `args` : global hyperparams that would be used throughout the expermiment to configure various aspects of data streaming, training and testing.

Note: args for `ADHD` and `ABIDE` are provided separately for convinience. 

In [None]:

classifier_config = {
    'input_dim' : 1000,
    # 'hidden_layers' : [512,256,100],
    'hidden_layers' : [100],
    'output_dim' : 2,
    'drop_rate' : 0.5,
    'process_final' : False,
    'lr': 1e-4,
    'decay': 0,  # L-2 regularization
    'batch_size' : 35,
    'epochs' : 100
    }

domain_config = {
    'input_dim' : 1000,
    'hidden_layers' : [100],
    # 'hidden_layers' : [512,256,100],
    'output_dim' : 1,
    'drop_rate' : 0.5,
    'process_final' : False,
    'lr': 1e-4,
    'decay': 0,
    'batch_size' : 35,
    'epochs' : 100
    }

# args = ED({
#   "name" : "ABIDE",
#   "method" : "mdmn",
#   "data_path" : "data/ABIDE_aal_allData.pkl",
#   "result_path": "Torch_cache/model_ckpts/DOM_ADAP_CV",
#   "mode" : "L2",
#   "lr" : 1e-4,
#   "mu" : 1,
#   "gamma" : 0.5,
#   "epoch" : 100,
#   "batch_size" : 100,
#   "cuda" : 0,
#   "seed" : 0,
#   "folds" : 10,
#   "mod" : 1,
#   "sampling" : None
# })

args = ED({
  "name" : "ADHD",
  "method" : "mdmn",  # Options: "dann", "darn", "mdmn", "msda"
  "data_path" : "data/ADHD_aal_allData.pkl",
  "result_path": "Torch_cache/model_ckpts/DOM_ADAP_CV",
  "mode" : "L2", # Options: "dynamic" and "L2" for DARN only.
  "lr" : 3e-4,
  "mu" : 3,
  "gamma" : 0.6,
  "epoch" : 50,
  "batch_size" : 200,
  "cuda" : 0,
  "seed" : 0,
  "folds" : 10, # for stratified K-Fold splitting
  "mod" : 0, # flip to 1 for oversampling
  "sampling" : None # None : 50/50 class distribution for all domains; [0.0,0.9] : ratio of (major_class/total_data) wherever possible 
  
})


## Main Train and Test loop

Next, the args and params dependent on the data used for training and testing are set.
After setting up the enviornment, the iterations take place for each site in a 10-fold stratified cross validation style.

Note : the params related to feature extractor are also set below.

In [None]:

################################################################################

######################## logger, Seed and Folds ################################

device = torch.device("cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size

result_path = os.path.join(args.result_path,
                          args.name,
                          args.method,
                          args.mode,
                          'lr_'+str(args.lr),
                            )
if not os.path.exists(result_path):
    os.makedirs(result_path)

logger = utils.get_logger(os.path.join(result_path,
                                      "gamma_%g_seed_%d.log" % (args.gamma,
                                                                args.seed)))
logger.info("Hyperparameter setting = %s" % args)

# Set random number seed.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
skf = StratifiedKFold(args.folds, random_state = args.seed, shuffle=True)


################################################################################

#################### Loading the datasets ######################################

time_start = time.time()

# the insts and labels are now in dict format
data_names, train_insts, train_labels, test_insts, test_labels = load_fmri_data(args.name,
                                                                                          args.data_path,
                                                                                          logger,mod = args.mod,
                                                                                          alpha = args.sampling)

if args.method == "src+":
  train_insts, train_labels, test_insts, test_labels = add_site_info(train_insts, train_labels, test_insts, test_labels)

configs = set_configs(train_insts[list(train_insts.keys())[0]].shape[-1], len(data_names) - 1)

configs["mode"] = args.mode
configs["mu"] = args.mu
configs["gamma"] = args.gamma
configs["num_src_domains"] = len(data_names) - 1
num_datasets = len(data_names)
configs['moments'] = 5
logger.info("Time used to process the %s = %g seconds." % (args.name, time.time() - time_start))
logger.info("-" * 100)

test_results = {}
np_test_results = np.zeros(num_datasets)
target_c_dict = {}
################################################################################


##################### misc settings ############################################

if args.method in ["dann", "src", "tar","src+"]:
    # Combine all sources for these methods
    num_src_domains = configs["num_src_domains"] = 1
else:
  num_src_domains = configs["num_src_domains"]

logger.info("Model setting = %s." % configs)
if args.method in ['src', 'tar',"src+"]:
  args.mu = configs["mu"] =  0
if args.method == "dann":
  args.mode = configs["mode"] = "dynamic"

################################################################################

###################### Train/Test Loops ########################################

alpha_list = np.zeros([num_datasets, num_src_domains, args.folds,args.epoch])



for i in range(num_datasets):

  logger.info('########### Site: %s #############' % data_names[i])

  #------- src tgt data split-----------------------------------#
  
  source_insts = copy.deepcopy(train_insts)
  source_labels = copy.deepcopy(train_labels)
  _,target_insts = source_insts.pop(i,None), test_insts.pop(i,None)
  _,target_labels = source_labels.pop(i,None), test_labels.pop(i,None)

  if args.method == "dann" or args.method == "src":
    tmp_insts = []
    tmp_labels = []
    for key in source_insts.keys():
      tmp_insts.extend(source_insts[key])
      tmp_labels.extend(source_labels[key])
    new_inst,new_labels = {},{}
    new_inst[0] = np.squeeze(np.array(tmp_insts))
    new_labels[0] = np.array(tmp_labels)
    source_insts = new_inst
    source_labels = new_labels
  if args.method == "tar":
    source_insts, source_labels = {},{}
    source_insts[0] = target_insts
    source_labels[0] = target_labels

  #---------------------------------------------------------------#
  
  target_c = np.zeros((2,2)) # confusion matrix for each domain
  target_acc = 0.0
  target_feats = []
  # print(np.unique(target_labels,return_counts=True))
  #-------------------10foldCV-------------------------------------#
  for fold,(train_idx, test_idx) in enumerate(skf.split(target_insts,target_labels)):
    # print(np.shape(target_insts),np.shape(train_idx),np.shape(test_idx))

    #----------- Model selection -----------------------------------#

    if args.method == 'mdmn':
      # print('using mdmn')
      model = MdmnMLP(configs, classifier_config, domain_config).to(device)
    elif args.method == 'msda':
      model = MsdaMLP(configs, classifier_config).to(device)
    else:
      # print('using darn')
      model = DarnMLP(configs, classifier_config, domain_config).to(device)
    
    #---------------------------------------------------------------#

    #------------------- Optimizer ---------------------------------#
    if args.method != 'msda':
      optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
    else:
      opt_G = optim.Adagrad(model.G_params, lr=args.lr)
      opt_C1 = optim.Adagrad(model.C1_params, lr=args.lr)
      opt_C2 = optim.Adagrad(model.C2_params, lr=args.lr)
    # scheduler = optim.lr_scheduler.OneCycleLR(optimizer,max_lr = 1e-3, epochs=args.epoch, steps_per_epoch=len(train_idx))
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode = 'max',patience=5)
    #---------------------------------------------------------------#
  

    
    logger.info("-"*10)
    logger.info("Starting Fold %d" %(fold+1) )
    logger.info("-"*10)

    time_start = time.time()
    max_acc = 0.0
    for t in range(args.epoch):
      #--------------------Training---------------------------------#
      if args.method =="tar":
        train_loader =  data_loader(target_insts[train_idx], 
                                target_labels[train_idx],
                                batch_size=batch_size, shuffle=True)
      else:
        train_loader = multi_data_loader(source_insts, source_labels, batch_size)
        
      running_loss = 0.0
      feats = []
      c = np.zeros((2,2))        
      model.train()

      for xs, ys in train_loader:

        if args.method != "tar":
          for j in range(num_src_domains):
  
            xs[j] = torch.squeeze(torch.tensor(xs[j], 
                                              requires_grad=False)).to(device).float()
            ys[j] = torch.tensor(ys[j], requires_grad=False).to(device).long()
        
        else:
          xs = [torch.reshape(torch.tensor(xs, requires_grad=False, 
                            dtype=torch.float32).to(device).float(),(xs.shape[0],-1))]
          ys = [torch.tensor(ys, requires_grad=False,
                            dtype=torch.long).to(device)]    
            
        ridx = np.random.choice(train_idx, batch_size)
        tinputs = target_insts[ridx, :]
        tinputs = torch.tensor(tinputs, 
                              requires_grad=False).to(device).float().squeeze()
        
        # print(np.shape(xs),np.shape(ys),np.shape(tinputs) ) #debugging
        if args.method != 'msda':
          optimizer.zero_grad()
          loss, alpha = model(xs, ys, tinputs)
          loss.backward()
          optimizer.step()
          # scheduler.step()
        else:
          loss = utils.msda_train_step(model, xs, ys,
                                      tinputs, opt_G, opt_C1, opt_C2)
        running_loss += loss.item()
      #--------------------------------------------------------------#

      #--------------------Testing-----------------------------------#

      model.eval()
      test_loader = data_loader(target_insts[test_idx], 
                                target_labels[test_idx],
                                batch_size=len(test_idx), shuffle=False)
      test_acc = 0.
      
      for xt, yt in test_loader:
          xt = torch.tensor(xt, requires_grad=False, 
                            dtype=torch.float32).to(device).squeeze()
          yt = torch.tensor(yt, requires_grad=False,
                            dtype=torch.int64).to(device)
          # print(np.shape(xt),np.shape(yt))
          preds_labels = torch.max(model.inference(xt), 1)[1]
          # _,preds_labels = torch.max(preds_labels, dim =1)
          # print(np.shape(preds_labels),np.shape(yt))
          feats.extend(model.get_feats(xt).detach().cpu().numpy())
          for ii in range(len(preds_labels)):
            c[preds_labels[ii]][yt[ii]] += 1
          
          test_acc += torch.sum(preds_labels == yt).item()
      test_acc /= target_insts[test_idx].shape[0]
      # scheduler.step(test_acc)

      if max_acc<test_acc:
        max_acc = test_acc
        max_c = c
        best_feats = feats
        # feat_path = os.path.join(result_path,"%s_gamma_%g_seed_%d_%s_%s_feat.npy" 
        #                          % (data_names[i],args.gamma,
        #                           args.seed,
        #                           args.method,
        #                           args.mode))
        # np.save(feat_path, best_feats)
        logger.info("Epoch[%d/%d] | Train_loss : %.7f | Acc: %.3f" %(t, args.epoch,
                                                                    running_loss,
                                                                    max_acc))
        if args.method == 'mdmn' or (args.method == 'darn' and args.mode == 'L2'):
          logger.info("Epoch %d, Fold %d, Alpha on %s: %s" % (t,fold ,data_names[i], alpha))
          alpha_list[i, :, fold, t] = alpha

      if max_acc == 1.0: #stop if acc 100%
        break
      #------------------------------------------------------------#
    target_acc += max_acc
    target_c += max_c
    target_feats.extend(best_feats)
  target_acc /= args.folds
  

  feat_path = os.path.join(result_path,
                          "%s_gamma_%g_seed_%g_mu_%g_feats.pkl" % (data_names[i], args.gamma,
                                                                    args.seed, args.mu))
  conf_path = os.path.join(result_path,
                          "gamma_%g_seed_%g_mu_%g_conf.pkl" % ( args.gamma,
                                                                    args.seed, args.mu))
  target_c_dict[data_names[i]] = target_c 
  print(target_c)

  with open(feat_path,'wb') as fp:
    pickle.dump(target_feats,fp)
  with open(conf_path,'wb') as fp:
    pickle.dump(target_c_dict,fp)
    
  test_results[data_names[i]] = target_acc
  np_test_results[i] = target_acc
  # matrices[data_names[i]] = target_c
  # logger.info("%s Confusion Matrix: %s" % (data_names[i],str(target_c)))
  logger.info("%s Accuracy: %.4f" %(data_names[i], target_acc))

logger.info("Test Results: %s" % str(test_results))
# logger.info("Conf Matrices: %s" % str(matrices))
################################################################################  


################# Save results to files ########################################
test_file = os.path.join(result_path,
                        "gamma_%g_seed_%d_test.txt" % (args.gamma,
                                                        args.seed))
np.savetxt(test_file, np_test_results, fmt='%.6g')
# test_file2 = os.path.join(result_path,
#                         "gamma_%g_seed_%d_%s_%s_matrices.pkl" % (args.gamma,
#                                                         args.seed,
#                                                         args.method,
#                                                         args.mode))



# with open(test_file2,'wb') as fp:
#   pickle.dump(matrices,fp)
# fp.close()

# if args.method == 'mdmn' or (args.method == 'darn' and args.mode == 'L2'):
#     for i in range(num_datasets):
#         alpha_file = os.path.join(result_path,
#                                   "gamma_%g_seed_%d_alpha%d.txt" % (args.gamma,
#                                                                     args.seed,
#                                                                     i))
#         np.savetxt(alpha_file, alpha_list[i], fmt='%.6g')

################################################################################
accs = test_results
mean_acc = []
for k,v in accs.items():
  mean_acc.append(v)
logger.info('mean Acc:%.6f' % np.mean(mean_acc))
logger.info("Done")
logger.info("*" * 100)