<a href="https://colab.research.google.com/github/greyhound101/Orbuculum/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%writefile download_abide_preproc.py
def collect_and_download(derivative, pipeline, strategy, out_dir, less_than, greater_than, site, sex, diagnosis):
    """
    Function to collect and download images from the ABIDE preprocessed
    directory on FCP-INDI's S3 bucket
    Parameters
    ----------
    derivative : string
        derivative or measure of interest
    pipeline : string
        pipeline used to process data of interest
    strategy : string
        noise removal strategy used to process data of interest
    out_dir : string
        filepath to a local directory to save files to
    less_than : float
        upper age (years) threshold for participants of interest
    greater_than : float
        lower age (years) threshold for participants of interest
    site : string
        acquisition site of interest
    sex : string
        'M' or 'F' to indicate whether to download male or female data
    diagnosis : string
        'asd', 'tdc', or 'both' corresponding to the diagnosis of the
        participants for whom data should be downloaded
    Returns
    -------
    None
        this function does not return a value; it downloads data from
        S3 to a local directory
    :param derivative: 
    :param pipeline: 
    :param strategy: 
    :param out_dir: 
    :param less_than: 
    :param greater_than: 
    :param site: 
    :param sex:
    :param diagnosis:
    :return: 
    """

    # Import packages
    import os
    import urllib.request as request

    # Init variables
    mean_fd_thresh = 0.2
    s3_prefix = 'https://s3.amazonaws.com/fcp-indi/data/Projects/'\
                'ABIDE_Initiative'
    s3_pheno_path = '/'.join([s3_prefix, 'Phenotypic_V1_0b_preprocessed1.csv'])

    # Format input arguments to be lower case, if not already
    derivative = derivative.lower()
    pipeline = pipeline.lower()
    strategy = strategy.lower()

    # Check derivative for extension
    if 'roi' in derivative:
        extension = '.1D'
    else:
        extension = '.nii.gz'

    # If output path doesn't exist, create it
    if not os.path.exists(out_dir):
        print('Could not find {0}, creating now...'.format(out_dir))
        os.makedirs(out_dir)

    # Load the phenotype file from S3
    s3_pheno_file = request.urlopen(s3_pheno_path)
    pheno_list = s3_pheno_file.readlines()
    print(pheno_list[0])

    # Get header indices
    header = pheno_list[0].decode().split(',')
    try:
        site_idx = header.index('SITE_ID')
        file_idx = header.index('FILE_ID')
        age_idx = header.index('AGE_AT_SCAN')
        sex_idx = header.index('SEX')
        dx_idx = header.index('DX_GROUP')
        mean_fd_idx = header.index('func_mean_fd')
    except Exception as exc:
        err_msg = 'Unable to extract header information from the pheno file: {0}\nHeader should have pheno info:' \
                  ' {1}\nError: {2}'.format(s3_pheno_path, str(header), exc)
        raise Exception(err_msg)

    # Go through pheno file and build download paths
    print('Collecting images of interest...')
    s3_paths = []
    for pheno_row in pheno_list[1:]:

        # Comma separate the row
        cs_row = pheno_row.decode().split(',')

        try:
            # See if it was preprocessed
            row_file_id = cs_row[file_idx]
            # Read in participant info
            row_site = cs_row[site_idx]
            row_age = float(cs_row[age_idx])
            row_sex = cs_row[sex_idx]
            row_dx = cs_row[dx_idx]
            row_mean_fd = float(cs_row[mean_fd_idx])
        except Exception as e:
            err_msg = 'Error extracting info from phenotypic file, skipping...'
            print(err_msg)
            continue

        # If the filename isn't specified, skip
        if row_file_id == 'no_filename':
            continue
        # If mean fd is too large, skip
        if row_mean_fd >= mean_fd_thresh:
            continue

        # Test phenotypic criteria (three if's looks cleaner than one long if)
        # Test sex
        if (sex == 'M' and row_sex != '1') or (sex == 'F' and row_sex != '2'):
            continue

        if (diagnosis == 'asd' and row_dx != '1') or (diagnosis == 'tdc' and row_dx != '2'):
            continue

        # Test site
        if site is not None and site.lower() != row_site.lower():
            continue
        # Test age range
        if greater_than < row_age < less_than:
            filename = row_file_id + '_' + derivative + extension
            s3_path = '/'.join([s3_prefix, 'Outputs', pipeline, strategy, derivative, filename])
            print('Adding {0} to download queue...'.format(s3_path))
            s3_paths.append(s3_path)
        else:
            continue

    # And download the items
    total_num_files = len(s3_paths)
    for path_idx, s3_path in enumerate(s3_paths):
        rel_path = s3_path.lstrip(s3_prefix)
        download_file = os.path.join(out_dir, rel_path)
        download_dir = os.path.dirname(download_file)
        if not os.path.exists(download_dir):
            os.makedirs(download_dir)
        try:
            if not os.path.exists(download_file):
                print('Retrieving: {0}'.format(download_file))
                request.urlretrieve(s3_path, download_file)
                print('{0:3f}% percent complete'.format(100*(float(path_idx+1)/total_num_files)))
            else:
                print('File {0} already exists, skipping...'.format(download_file))
        except Exception as exc:
            print('There was a problem downloading {0}.\n Check input arguments and try again.'.format(s3_path))

    # Print all done
    print('Done!')


# Make module executable
if __name__ == '__main__':

    # Import packages
    import argparse
    import os
    import sys

    # Init argument parser
    parser = argparse.ArgumentParser(description=__doc__)

    # Required arguments
    parser.add_argument('-a', '--asd', required=False, default=False, action='store_true',
                        help='Only download data for participants with ASD.'
                             ' Specifying neither or both -a and -c will download data from all participants.')
    parser.add_argument('-c', '--tdc', required=False, default=False, action='store_true',
                        help='Only download data for participants who are typically developing controls.'
                             ' Specifying neither or both -a and -c will download data from all participants.')
    parser.add_argument('-d', '--derivative', nargs=1, required=True, type=str,
                        help='Derivative of interest (e.g. \'reho\')')
    parser.add_argument('-p', '--pipeline', nargs=1, required=True, type=str,
                        help='Pipeline used to preprocess the data (e.g. \'cpac\')')
    parser.add_argument('-s', '--strategy', nargs=1, required=True, type=str,
                        help='Noise-removal strategy used during preprocessing (e.g. \'nofilt_noglobal\'')
    parser.add_argument('-o', '--out_dir', nargs=1, required=True, type=str,
                        help='Path to local folder to download files to')

    # Optional arguments
    parser.add_argument('-lt', '--less_than', nargs=1, required=False,
                        type=float, help='Upper age threshold (in years) of participants to download (e.g. for '
                                         'subjects 30 or younger, \'-lt 31\')')
    parser.add_argument('-gt', '--greater_than', nargs=1, required=False,
                        type=int, help='Lower age threshold (in years) of participants to download (e.g. for '
                                       'subjects 31 or older, \'-gt 30\')')
    parser.add_argument('-t', '--site', nargs=1, required=False, type=str,
                        help='Site of interest to download from (e.g. \'Caltech\'')
    parser.add_argument('-x', '--sex', nargs=1, required=False, type=str,
                        help='Participant sex of interest to download only (e.g. \'M\' or \'F\')')

    # Parse and gather arguments
    args = parser.parse_args()

    # Init variables
    desired_derivative = args.derivative[0].lower()
    desired_pipeline = args.pipeline[0].lower()
    desired_strategy = args.strategy[0].lower()
    download_data_dir = os.path.abspath(args.out_dir[0])

    # Try and init optional arguments

    # for diagnosis if both ASD and TDC flags are set to true or false, we download both
    desired_diagnosis = ''
    if args.tdc == args.asd:
        desired_diagnosis = 'both'
        print('Downloading data for ASD and TDC participants')
    elif args.tdc:
        desired_diagnosis = 'tdc'
        print('Downloading data for TDC participants')
    elif args.asd:
        desired_diagnosis = 'asd'
        print('Downloading data for ASD participants')

    try:
        desired_age_max = args.less_than[0]
        print('Using upper age threshold of {0:d}...'.format(desired_age_max))
    except TypeError:
        desired_age_max = 200.0
        print('No upper age threshold specified')

    try:
        desired_age_min = args.greater_than[0]
        print('Using lower age threshold of {0:d}...'.format(desired_age_min))
    except TypeError:
        desired_age_min = -1.0
        print('No lower age threshold specified')

    try:
        desired_site = args.site[0]
    except TypeError:
        desired_site = None
        print('No site specified, using all sites...')

    try:
        desired_sex = args.sex[0].upper()
        if desired_sex == 'M':
            print('Downloading only male subjects...')
        elif desired_sex == 'F':
            print('Downloading only female subjects...')
        else:
            print('Please specify \'M\' or \'F\' for sex and try again')
            sys.exit()
    except TypeError:
        desired_sex = None
        print('No sex specified, using all sexes...')

    # Call the collect and download routine
    collect_and_download(desired_derivative, desired_pipeline, desired_strategy, download_data_dir, desired_age_max,
                         desired_age_min, desired_site, desired_sex, desired_diagnosis)

Writing download_abide_preproc.py


In [None]:
!python3 download_abide_preproc.py -d 'rois_ho' -p 'cpac' -s 'filt_noglobal' -t "NYU" -o'/content'

Downloading data for ASD and TDC participants
No upper age threshold specified
No lower age threshold specified
No sex specified, using all sexes...
b',Unnamed: 0,SUB_ID,X,subject,SITE_ID,FILE_ID,DX_GROUP,DSM_IV_TR,AGE_AT_SCAN,SEX,HANDEDNESS_CATEGORY,HANDEDNESS_SCORES,FIQ,VIQ,PIQ,FIQ_TEST_TYPE,VIQ_TEST_TYPE,PIQ_TEST_TYPE,ADI_R_SOCIAL_TOTAL_A,ADI_R_VERBAL_TOTAL_BV,ADI_RRB_TOTAL_C,ADI_R_ONSET_TOTAL_D,ADI_R_RSRCH_RELIABLE,ADOS_MODULE,ADOS_TOTAL,ADOS_COMM,ADOS_SOCIAL,ADOS_STEREO_BEHAV,ADOS_RSRCH_RELIABLE,ADOS_GOTHAM_SOCAFFECT,ADOS_GOTHAM_RRB,ADOS_GOTHAM_TOTAL,ADOS_GOTHAM_SEVERITY,SRS_VERSION,SRS_RAW_TOTAL,SRS_AWARENESS,SRS_COGNITION,SRS_COMMUNICATION,SRS_MOTIVATION,SRS_MANNERISMS,SCQ_TOTAL,AQ_TOTAL,COMORBIDITY,CURRENT_MED_STATUS,MEDICATION_NAME,OFF_STIMULANTS_AT_SCAN,VINELAND_RECEPTIVE_V_SCALED,VINELAND_EXPRESSIVE_V_SCALED,VINELAND_WRITTEN_V_SCALED,VINELAND_COMMUNICATION_STANDARD,VINELAND_PERSONAL_V_SCALED,VINELAND_DOMESTIC_V_SCALED,VINELAND_COMMUNITY_V_SCALED,VINELAND_DAILYLVNG_STANDARD,V

In [None]:
!python3 download_abide_preproc.py -d 'rois_ho' -p 'cpac' -s 'filt_noglobal' -t "USM" -o'/content'

Downloading data for ASD and TDC participants
No upper age threshold specified
No lower age threshold specified
No sex specified, using all sexes...
b',Unnamed: 0,SUB_ID,X,subject,SITE_ID,FILE_ID,DX_GROUP,DSM_IV_TR,AGE_AT_SCAN,SEX,HANDEDNESS_CATEGORY,HANDEDNESS_SCORES,FIQ,VIQ,PIQ,FIQ_TEST_TYPE,VIQ_TEST_TYPE,PIQ_TEST_TYPE,ADI_R_SOCIAL_TOTAL_A,ADI_R_VERBAL_TOTAL_BV,ADI_RRB_TOTAL_C,ADI_R_ONSET_TOTAL_D,ADI_R_RSRCH_RELIABLE,ADOS_MODULE,ADOS_TOTAL,ADOS_COMM,ADOS_SOCIAL,ADOS_STEREO_BEHAV,ADOS_RSRCH_RELIABLE,ADOS_GOTHAM_SOCAFFECT,ADOS_GOTHAM_RRB,ADOS_GOTHAM_TOTAL,ADOS_GOTHAM_SEVERITY,SRS_VERSION,SRS_RAW_TOTAL,SRS_AWARENESS,SRS_COGNITION,SRS_COMMUNICATION,SRS_MOTIVATION,SRS_MANNERISMS,SCQ_TOTAL,AQ_TOTAL,COMORBIDITY,CURRENT_MED_STATUS,MEDICATION_NAME,OFF_STIMULANTS_AT_SCAN,VINELAND_RECEPTIVE_V_SCALED,VINELAND_EXPRESSIVE_V_SCALED,VINELAND_WRITTEN_V_SCALED,VINELAND_COMMUNICATION_STANDARD,VINELAND_PERSONAL_V_SCALED,VINELAND_DOMESTIC_V_SCALED,VINELAND_COMMUNITY_V_SCALED,VINELAND_DAILYLVNG_STANDARD,V

In [None]:
!python3 download_abide_preproc.py -d 'rois_ho' -p 'cpac' -s 'filt_noglobal' -t "UCLA_1" -o'/content'

Downloading data for ASD and TDC participants
No upper age threshold specified
No lower age threshold specified
No sex specified, using all sexes...
b',Unnamed: 0,SUB_ID,X,subject,SITE_ID,FILE_ID,DX_GROUP,DSM_IV_TR,AGE_AT_SCAN,SEX,HANDEDNESS_CATEGORY,HANDEDNESS_SCORES,FIQ,VIQ,PIQ,FIQ_TEST_TYPE,VIQ_TEST_TYPE,PIQ_TEST_TYPE,ADI_R_SOCIAL_TOTAL_A,ADI_R_VERBAL_TOTAL_BV,ADI_RRB_TOTAL_C,ADI_R_ONSET_TOTAL_D,ADI_R_RSRCH_RELIABLE,ADOS_MODULE,ADOS_TOTAL,ADOS_COMM,ADOS_SOCIAL,ADOS_STEREO_BEHAV,ADOS_RSRCH_RELIABLE,ADOS_GOTHAM_SOCAFFECT,ADOS_GOTHAM_RRB,ADOS_GOTHAM_TOTAL,ADOS_GOTHAM_SEVERITY,SRS_VERSION,SRS_RAW_TOTAL,SRS_AWARENESS,SRS_COGNITION,SRS_COMMUNICATION,SRS_MOTIVATION,SRS_MANNERISMS,SCQ_TOTAL,AQ_TOTAL,COMORBIDITY,CURRENT_MED_STATUS,MEDICATION_NAME,OFF_STIMULANTS_AT_SCAN,VINELAND_RECEPTIVE_V_SCALED,VINELAND_EXPRESSIVE_V_SCALED,VINELAND_WRITTEN_V_SCALED,VINELAND_COMMUNICATION_STANDARD,VINELAND_PERSONAL_V_SCALED,VINELAND_DOMESTIC_V_SCALED,VINELAND_COMMUNITY_V_SCALED,VINELAND_DAILYLVNG_STANDARD,V

In [None]:
!python3 download_abide_preproc.py -d 'rois_ho' -p 'cpac' -s 'filt_noglobal' -t "UM_1" -o'/content'

Downloading data for ASD and TDC participants
No upper age threshold specified
No lower age threshold specified
No sex specified, using all sexes...
b',Unnamed: 0,SUB_ID,X,subject,SITE_ID,FILE_ID,DX_GROUP,DSM_IV_TR,AGE_AT_SCAN,SEX,HANDEDNESS_CATEGORY,HANDEDNESS_SCORES,FIQ,VIQ,PIQ,FIQ_TEST_TYPE,VIQ_TEST_TYPE,PIQ_TEST_TYPE,ADI_R_SOCIAL_TOTAL_A,ADI_R_VERBAL_TOTAL_BV,ADI_RRB_TOTAL_C,ADI_R_ONSET_TOTAL_D,ADI_R_RSRCH_RELIABLE,ADOS_MODULE,ADOS_TOTAL,ADOS_COMM,ADOS_SOCIAL,ADOS_STEREO_BEHAV,ADOS_RSRCH_RELIABLE,ADOS_GOTHAM_SOCAFFECT,ADOS_GOTHAM_RRB,ADOS_GOTHAM_TOTAL,ADOS_GOTHAM_SEVERITY,SRS_VERSION,SRS_RAW_TOTAL,SRS_AWARENESS,SRS_COGNITION,SRS_COMMUNICATION,SRS_MOTIVATION,SRS_MANNERISMS,SCQ_TOTAL,AQ_TOTAL,COMORBIDITY,CURRENT_MED_STATUS,MEDICATION_NAME,OFF_STIMULANTS_AT_SCAN,VINELAND_RECEPTIVE_V_SCALED,VINELAND_EXPRESSIVE_V_SCALED,VINELAND_WRITTEN_V_SCALED,VINELAND_COMMUNICATION_STANDARD,VINELAND_PERSONAL_V_SCALED,VINELAND_DOMESTIC_V_SCALED,VINELAND_COMMUNITY_V_SCALED,VINELAND_DAILYLVNG_STANDARD,V

In [None]:
import os
import glob
import shutil
for dir in ['NYU','UCLA_1','USM','UM_1']:
  try:
    os.mkdir('/content/'+dir)
  except:
    os.mkdir('/content/'+dir+'_correlation_matrix')
  for path in glob.glob('/content/Outputs/cpac/filt_noglobal/rois_ho/'+dir+'*'):
    shutil.move(path,'/content/'+dir+'/')

In [None]:
rmvs=['UM_1/UM_1_0050302_rois_ho.1D'
,'UM_1/UM_1_0050320_rois_ho.1D'
,'UM_1/UM_1_0050273_rois_ho.1D'
,'UM_1/UM_1_0050280_rois_ho.1D'
,'UM_1/UM_1_0050313_rois_ho.1D'
,'UM_1/UM_1_0050367_rois_ho.1D'
,'UM_1/UM_1_0050354_rois_ho.1D'
,'UM_1/UM_1_0050352_rois_ho.1D'
,'UM_1/UM_1_0050353_rois_ho.1D'
,'UM_1/UM_1_0050362_rois_ho.1D'
,'UM_1/UM_1_0050301_rois_ho.1D'
,'UM_1/UM_1_0050316_rois_ho.1D'
,'UM_1/UM_1_0050324_rois_ho.1D'
,'UM_1/UM_1_0050282_rois_ho.1D'
,'UM_1/UM_1_0050336_rois_ho.1D'
,'UM_1/UM_1_0050343_rois_ho.1D'
,'UM_1/UM_1_0050330_rois_ho.1D'
,'UM_1/UM_1_0050298_rois_ho.1D'
,'USM/USM_0050532_rois_ho.1D'
,'USM/USM_0050497_rois_ho.1D'
,'USM/USM_0050520_rois_ho.1D'
,'USM/USM_0050505_rois_ho.1D'
,'USM/USM_0050510_rois_ho.1D'
,'USM/USM_0050526_rois_ho.1D'
,'USM/USM_0050466_rois_ho.1D'
,'USM/USM_0050491_rois_ho.1D'
,'USM/USM_0050492_rois_ho.1D'
,'USM/USM_0050507_rois_ho.1D'
,'USM/USM_0050435_rois_ho.1D'
,'USM/USM_0050528_rois_ho.1D'
,'USM/USM_0050449_rois_ho.1D'
,'USM/USM_0050467_rois_ho.1D'
,'USM/USM_0050490_rois_ho.1D'
,'USM/USM_0050498_rois_ho.1D'
,'USM/USM_0050463_rois_ho.1D'
,'USM/USM_0050501_rois_ho.1D'
,'USM/USM_0050437_rois_ho.1D'
,'NYU/NYU_0051110_rois_ho.1D'
,'NYU/NYU_0050961_rois_ho.1D'
,'NYU/NYU_0051123_rois_ho.1D'
,'NYU/NYU_0051155_rois_ho.1D'
,'NYU/NYU_0050958_rois_ho.1D'
,'NYU/NYU_0051114_rois_ho.1D'
,'NYU/NYU_0051058_rois_ho.1D'
,'NYU/NYU_0051118_rois_ho.1D'
,'UCLA_1/UCLA_1_0051281_rois_ho.1D'
,'UCLA_1/UCLA_1_0051236_rois_ho.1D'
,'UCLA_1/UCLA_1_0051220_rois_ho.1D'
,'UCLA_1/UCLA_1_0051266_rois_ho.1D'
,'UCLA_1/UCLA_1_0051260_rois_ho.1D'
,'UCLA_1/UCLA_1_0051212_rois_ho.1D'
,'UCLA_1/UCLA_1_0051227_rois_ho.1D'
,'UCLA_1/UCLA_1_0051271_rois_ho.1D'
,'UCLA_1/UCLA_1_0051272_rois_ho.1D']
for path in rmvs:
  try:
    os.remove(path)
  except:
    pass

In [None]:
pip install deepdish

Collecting deepdish
  Downloading https://files.pythonhosted.org/packages/6e/39/2a47c852651982bc5eb39212ac110284dd20126bdc7b49bde401a0139f5d/deepdish-0.3.6-py2.py3-none-any.whl
Installing collected packages: deepdish
Successfully installed deepdish-0.3.6


In [None]:
pip install nilearn

Collecting nilearn
[?25l  Downloading https://files.pythonhosted.org/packages/4a/bd/2ad86e2c00ecfe33b86f9f1f6d81de8e11724e822cdf1f5b2d0c21b787f1/nilearn-0.7.1-py3-none-any.whl (3.0MB)
[K     |████████████████████████████████| 3.1MB 11.6MB/s 
Installing collected packages: nilearn
Successfully installed nilearn-0.7.1


In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
WINDOW_SIZE = 32
HO_NUM_REGION = 111
SITE = ["UM_1", "NYU", "UCLA_1", "USM"]
SITE_FOLDER = [["USM_correlation_matrix", 11],
               ["UM_1_correlation_matrix", 12],
               ["NYU_correlation_matrix", 11],
               ["UCLA_1_correlation_matrix", 14]]
from tqdm import tqdm
import os
import multiprocessing
import numpy as np
import deepdish as dd
from nilearn.connectome import ConnectivityMeasure
from nilearn import plotting


def make_correlation_matrix(line, site, filename):
    seq_len = WINDOW_SIZE
    time_series = []
    # good = bad = 0
    n = len(line) - seq_len + 1

    for j in range(n):
        lst = []
        for i in line[j: j + seq_len]:
            lst.append(np.array(list(map(float, i.split()))))
        time_series.append(np.array(lst))
        correlation_measure = ConnectivityMeasure(kind='correlation')
        correlation_matrix = correlation_measure.fit_transform([time_series[j]])[0]
        fisher = np.arctanh(correlation_matrix)
        np.fill_diagonal(fisher, 0)
        try:
          os.mkdir('/content/{}_correlation_matrix'.format(site))
        except:
          pass
        dd.io.save(folder + '/{}_correlation_matrix/{}_{}.h5'.format(site, filename, j), fisher)

        # # check whether there are lines all 0 in this subject
        # for i in range(num_region):
        #     # if np.all(correlation_matrix.sum(axis=0)[82] == 1):
        #     if np.all(fisher[i] == 0):
        #         bad += 1
        #         bad_lst.append("{}".format(filename))
        #         break
        # if i == (num_region - 1):
        #     good += 1
        #     good_lst.append("{}".format(filename))
        # # plot matrix heat map
        # plotting.plot_matrix(fisher, figure=(10, 8), labels=labels, vmax=0.8, vmin=-0.8, reorder=True)


def truncation(site):
    file_dir = folder + '/' + site
    file = list(os.walk(file_dir))[0][-1][:]
    # bad_lst = good_lst = []
    for filename in tqdm(file):
        f = open(folder + "/{}/{}".format(site, filename))
        lines = f.readlines()[1:]
        make_correlation_matrix(lines, site, filename)
        f.close()


num_region = HO_NUM_REGION
labels = [str(i) for i in range(num_region)]
folder = os.getcwd()
cores = multiprocessing.cpu_count()
pool = multiprocessing.Pool(cores)
pool.map(truncation, SITE)

100%|██████████| 67/67 [13:56<00:00, 12.49s/it]
100%|██████████| 47/47 [03:11<00:00,  4.08s/it]
100%|██████████| 163/163 [19:17<00:00,  7.10s/it]
100%|██████████| 44/44 [04:38<00:00,  6.34s/it]


[None, None, None, None]

In [None]:
import csv
reader = csv.reader('/content/Phenotypic_V1_0b_preprocessed1.csv', delimiter=' ', quotechar='|')
for i in reader:
  if list(i[0].split(','))[5] in ['UM_1', 'NYU', 'USM', 'UCLA_1']:
      name, lab = list(i[0].split(','))[6:8]
      lab = int(lab) % 2
      rows[name] = lab   

IndexError: ignored

In [None]:
import os
os.listdir('/content/UCLA_1_correlation_matrix/')

['UCLA_1_0051226_rois_ho.1D_48.h5',
 'UCLA_1_0051250_rois_ho.1D_73.h5',
 'UCLA_1_0051267_rois_ho.1D_35.h5',
 'UCLA_1_0051229_rois_ho.1D_56.h5',
 'UCLA_1_0051282_rois_ho.1D_48.h5',
 'UCLA_1_0051248_rois_ho.1D_55.h5',
 'UCLA_1_0051257_rois_ho.1D_49.h5',
 'UCLA_1_0051218_rois_ho.1D_60.h5',
 'UCLA_1_0051256_rois_ho.1D_69.h5',
 'UCLA_1_0051257_rois_ho.1D_77.h5',
 'UCLA_1_0051250_rois_ho.1D_30.h5',
 'UCLA_1_0051226_rois_ho.1D_68.h5',
 'UCLA_1_0051268_rois_ho.1D_20.h5',
 'UCLA_1_0051257_rois_ho.1D_28.h5',
 'UCLA_1_0051237_rois_ho.1D_2.h5',
 'UCLA_1_0051257_rois_ho.1D_32.h5',
 'UCLA_1_0051224_rois_ho.1D_39.h5',
 'UCLA_1_0051257_rois_ho.1D_41.h5',
 'UCLA_1_0051256_rois_ho.1D_72.h5',
 'UCLA_1_0051229_rois_ho.1D_26.h5',
 'UCLA_1_0051268_rois_ho.1D_71.h5',
 'UCLA_1_0051251_rois_ho.1D_34.h5',
 'UCLA_1_0051249_rois_ho.1D_84.h5',
 'UCLA_1_0051262_rois_ho.1D_49.h5',
 'UCLA_1_0051253_rois_ho.1D_76.h5',
 'UCLA_1_0051221_rois_ho.1D_52.h5',
 'UCLA_1_0051276_rois_ho.1D_75.h5',
 'UCLA_1_0051264_rois_ho.1D_4

In [None]:
reader.loc[reader['SITE_ID']=='UCLA_1']

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,SUB_ID,X,subject,SITE_ID,FILE_ID,DX_GROUP,DSM_IV_TR,AGE_AT_SCAN,SEX,HANDEDNESS_CATEGORY,HANDEDNESS_SCORES,FIQ,VIQ,PIQ,FIQ_TEST_TYPE,VIQ_TEST_TYPE,PIQ_TEST_TYPE,ADI_R_SOCIAL_TOTAL_A,ADI_R_VERBAL_TOTAL_BV,ADI_RRB_TOTAL_C,ADI_R_ONSET_TOTAL_D,ADI_R_RSRCH_RELIABLE,ADOS_MODULE,ADOS_TOTAL,ADOS_COMM,ADOS_SOCIAL,ADOS_STEREO_BEHAV,ADOS_RSRCH_RELIABLE,ADOS_GOTHAM_SOCAFFECT,ADOS_GOTHAM_RRB,ADOS_GOTHAM_TOTAL,ADOS_GOTHAM_SEVERITY,SRS_VERSION,SRS_RAW_TOTAL,SRS_AWARENESS,SRS_COGNITION,SRS_COMMUNICATION,SRS_MOTIVATION,...,WISC_IV_SIM_SCALED,WISC_IV_VOCAB_SCALED,WISC_IV_INFO_SCALED,WISC_IV_BLK_DSN_SCALED,WISC_IV_PIC_CON_SCALED,WISC_IV_MATRIX_SCALED,WISC_IV_DIGIT_SPAN_SCALED,WISC_IV_LET_NUM_SCALED,WISC_IV_CODING_SCALED,WISC_IV_SYM_SCALED,EYE_STATUS_AT_SCAN,AGE_AT_MPRAGE,BMI,anat_cnr,anat_efc,anat_fber,anat_fwhm,anat_qi1,anat_snr,func_efc,func_fber,func_fwhm,func_dvars,func_outlier,func_quality,func_mean_fd,func_num_fd,func_perc_fd,func_gsr,qc_rater_1,qc_notes_rater_1,qc_anat_rater_2,qc_anat_notes_rater_2,qc_func_rater_2,qc_func_notes_rater_2,qc_anat_rater_3,qc_anat_notes_rater_3,qc_func_rater_3,qc_func_notes_rater_3,SUB_IN_SMP
878,878,879,51201,879,51201,UCLA_1,UCLA_1_0051201,1,1,13.52,1,R,,104.0,98.0,109.0,WASI,WASI,WASI,22.0,16.0,7.0,4.0,1.0,3.0,9.0,3.0,6.0,1.0,1.0,7.0,1.0,8.0,5.0,,,,,,,...,,,,,,,,,,,1,13.52,,8.793185,1.733320,4.809707,2.960791,0.032907,12.528147,0.581065,60.042466,2.129714,1.194919,0.011555,0.013289,0.312288,63.0,52.066116,0.039166,OK,,maybe,Motion,OK,,OK,,OK,,1
879,879,880,51202,880,51202,UCLA_1,UCLA_1_0051202,1,1,11.56,1,R,,98.0,110.0,86.0,WASI,WASI,WASI,27.0,19.0,10.0,3.0,1.0,3.0,9.0,2.0,7.0,0.0,1.0,-9999.0,-9999.0,-9999.0,-9999.0,,,,,,,...,,,,,,,,,,,1,11.56,,9.096065,1.354653,10.668705,2.664233,0.018289,12.757453,0.562170,87.547926,2.101467,1.260376,0.028926,0.016332,0.310231,36.0,29.752066,0.049217,OK,,OK,,OK,,OK,,OK,,1
880,880,881,51203,881,51203,UCLA_1,UCLA_1_0051203,1,1,13.37,1,R,,103.0,91.0,116.0,WASI,WASI,WASI,18.0,13.0,8.0,4.0,1.0,3.0,5.0,3.0,2.0,0.0,1.0,-9999.0,-9999.0,-9999.0,-9999.0,,,,,,,...,,,,,,,,,,,1,13.37,,8.772454,1.401283,11.382680,3.080632,0.021908,14.390008,0.555803,74.772327,2.169794,1.131333,0.021848,0.027800,0.684801,68.0,56.198347,0.045634,OK,,maybe,Motion,OK,,OK,,OK,,0
881,881,882,51204,882,51204,UCLA_1,UCLA_1_0051204,1,1,14.57,1,R,,98.0,110.0,90.0,WISC_IV_FULL,WISC_IV_FULL,WISC_IV_FULL,17.0,17.0,9.0,5.0,1.0,3.0,9.0,2.0,7.0,0.0,1.0,7.0,1.0,8.0,5.0,,,,,,,...,,,,,,,,,,,1,14.57,,10.637707,1.995920,5.972081,3.371854,0.016671,18.342188,0.531164,74.789150,2.055451,1.145183,0.009950,0.018130,0.437586,87.0,71.900826,0.032909,OK,,OK,,OK,,OK,,OK,,1
882,882,883,51205,883,51205,UCLA_1,UCLA_1_0051205,1,1,17.94,1,R,,102.0,105.0,99.0,WAIS,WAIS,WAIS,16.0,19.0,5.0,4.0,1.0,3.0,2.0,0.0,2.0,0.0,1.0,2.0,0.0,2.0,1.0,,,,,,,...,,,,,,,,,,,1,17.94,,9.139554,2.195452,3.890805,3.068135,0.017858,15.320306,0.598935,49.263387,2.091364,1.237724,0.002115,0.008336,0.064932,5.0,4.132231,0.053096,OK,,OK,,OK,,OK,,OK,,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
955,955,956,51278,956,51278,UCLA_1,UCLA_1_0051278,2,0,9.21,1,L,,109.0,121.0,97.0,WASI,WASI,WASI,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,1,9.21,,10.545404,1.575898,6.872414,2.839821,0.034987,13.675356,0.555613,99.530714,2.056917,1.002942,0.010600,0.017133,0.351817,56.0,46.280992,0.036071,OK,,OK,,OK,,OK,,OK,,1
956,956,957,51279,957,51279,UCLA_1,UCLA_1_0051279,2,0,13.82,2,R,,106.0,99.0,109.0,WASI,WASI,WASI,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,1,13.82,,9.064599,2.134575,2.833002,3.068417,0.023949,13.370993,0.556405,75.815389,2.032701,1.179329,0.000880,0.003014,0.031960,1.0,0.826446,0.044641,OK,,OK,,OK,,OK,,OK,,0
957,957,958,51280,958,51280,UCLA_1,UCLA_1_0051280,2,0,9.50,1,R,,109.0,123.0,96.0,WASI,WASI,WASI,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,1,9.69,,10.004408,1.898254,7.603178,3.393677,0.026945,16.545860,0.568095,79.610148,1.930497,1.293018,0.000505,0.004228,0.031893,0.0,0.000000,0.039031,OK,,OK,,OK,,OK,,OK,,1
958,958,959,51281,959,51281,UCLA_1,UCLA_1_0051281,2,0,11.83,1,R,,108.0,115.0,100.0,WASI,WASI,WASI,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,1,11.91,,9.096981,1.362094,8.897065,2.733390,0.024802,12.609585,0.543495,87.950891,2.042465,1.107621,0.002691,0.004035,0.064947,4.0,3.305785,0.037095,OK,,OK,,OK,,OK,,OK,,1


In [None]:
import pandas as pd
reader = pd.read_csv('/content/Phenotypic_V1_0b_preprocessed1.csv')
reader

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,SUB_ID,X,subject,SITE_ID,FILE_ID,DX_GROUP,DSM_IV_TR,AGE_AT_SCAN,SEX,HANDEDNESS_CATEGORY,HANDEDNESS_SCORES,FIQ,VIQ,PIQ,FIQ_TEST_TYPE,VIQ_TEST_TYPE,PIQ_TEST_TYPE,ADI_R_SOCIAL_TOTAL_A,ADI_R_VERBAL_TOTAL_BV,ADI_RRB_TOTAL_C,ADI_R_ONSET_TOTAL_D,ADI_R_RSRCH_RELIABLE,ADOS_MODULE,ADOS_TOTAL,ADOS_COMM,ADOS_SOCIAL,ADOS_STEREO_BEHAV,ADOS_RSRCH_RELIABLE,ADOS_GOTHAM_SOCAFFECT,ADOS_GOTHAM_RRB,ADOS_GOTHAM_TOTAL,ADOS_GOTHAM_SEVERITY,SRS_VERSION,SRS_RAW_TOTAL,SRS_AWARENESS,SRS_COGNITION,SRS_COMMUNICATION,SRS_MOTIVATION,...,WISC_IV_SIM_SCALED,WISC_IV_VOCAB_SCALED,WISC_IV_INFO_SCALED,WISC_IV_BLK_DSN_SCALED,WISC_IV_PIC_CON_SCALED,WISC_IV_MATRIX_SCALED,WISC_IV_DIGIT_SPAN_SCALED,WISC_IV_LET_NUM_SCALED,WISC_IV_CODING_SCALED,WISC_IV_SYM_SCALED,EYE_STATUS_AT_SCAN,AGE_AT_MPRAGE,BMI,anat_cnr,anat_efc,anat_fber,anat_fwhm,anat_qi1,anat_snr,func_efc,func_fber,func_fwhm,func_dvars,func_outlier,func_quality,func_mean_fd,func_num_fd,func_perc_fd,func_gsr,qc_rater_1,qc_notes_rater_1,qc_anat_rater_2,qc_anat_notes_rater_2,qc_func_rater_2,qc_func_notes_rater_2,qc_anat_rater_3,qc_anat_notes_rater_3,qc_func_rater_3,qc_func_notes_rater_3,SUB_IN_SMP
0,0,1,50002,1,50002,PITT,no_filename,1,1,16.77,1,Ambi,,103.0,116.0,89.0,WASI,WASI,WASI,16.0,9.0,5.0,4.0,1.0,4.0,12.0,4.0,8.0,3.0,1.0,,,,,,,,,,,...,,,,,,,,,,,2,,,10.201539,1.194664,16.223458,3.878000,0.152711,12.072452,0.613128,45.446551,1.873339,1.054931,0.000641,0.011443,0.116828,8.0,3.980100,0.054346,fail,,OK,,fail,ic-parietal-cerebellum,OK,,fail,ERROR #24,1
1,1,2,50003,2,50003,PITT,Pitt_0050003,1,1,24.45,1,R,,124.0,128.0,115.0,WASI,WASI,WASI,27.0,22.0,5.0,3.0,1.0,4.0,13.0,5.0,8.0,1.0,1.0,,,,,,,,,,,...,,,,,,,,,,,2,,,7.165701,1.126752,10.460008,4.282238,0.161716,9.241155,0.578301,56.286350,2.012112,0.949857,0.000474,0.031781,0.322092,135.0,67.164179,0.041862,OK,,OK,,OK,,OK,,OK,,1
2,2,3,50004,3,50004,PITT,Pitt_0050004,1,1,19.09,1,R,,113.0,108.0,117.0,WASI,WASI,WASI,19.0,12.0,5.0,3.0,1.0,4.0,18.0,6.0,12.0,2.0,1.0,,,,,,,,,,,...,,,,,,,,,,,2,,,7.698144,1.226218,9.725750,3.881684,0.174186,9.323463,0.578960,63.317943,1.866104,1.180605,0.008262,0.014260,0.127745,29.0,14.427861,0.046745,OK,,OK,,OK,,OK,,OK,,1
3,3,4,50005,4,50005,PITT,Pitt_0050005,1,1,13.73,2,R,,119.0,117.0,118.0,WASI,WASI,WASI,23.0,19.0,3.0,4.0,1.0,4.0,12.0,4.0,8.0,1.0,1.0,,,,,,,,,,,...,,,,,,,,,,,2,,,9.071807,1.256278,11.198226,3.628667,0.119269,10.814200,0.556064,70.800354,1.918278,1.092030,0.001711,0.019205,0.128136,22.0,10.945274,0.027963,OK,,OK,,maybe,ic-parietal-cerebellum,OK,,OK,,0
4,4,5,50006,5,50006,PITT,Pitt_0050006,1,1,13.37,1,L,,109.0,99.0,119.0,WASI,WASI,WASI,13.0,10.0,4.0,3.0,1.0,4.0,12.0,4.0,8.0,4.0,1.0,,,,,,,,,,,...,,,,,,,,,,,2,,,8.026798,1.407166,6.282055,3.674539,0.130647,10.123574,0.562942,75.364679,2.213873,1.086830,0.001500,0.006919,0.070143,3.0,1.492537,0.054006,OK,,OK,,maybe,ic-parietal slight,OK,,OK,,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1107,1107,1108,51583,1108,51583,SBL,SBL_0051583,1,2,35.00,1,,100.0,95.0,105.0,84.0,WAIS_III,WAIS_III,WAIS_III,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,,,,,,,,,,,,...,,,,,,,,,,,2,,-9999.00,3.899774,1.697271,3.465151,3.318830,0.096813,5.434540,0.507184,91.232616,2.022145,1.278364,0.001204,0.006403,0.116186,24.0,11.940299,0.037362,OK,,OK,,OK,ic-cerebellum-temporal_lobe,OK,,OK,,0
1108,1108,1109,51584,1109,51584,SBL,SBL_0051584,1,2,49.00,1,,100.0,-9999.0,133.0,135.0,,GIT,GIT,9.0,9.0,3.0,0.0,0.0,4.0,12.0,4.0,8.0,-9999.0,0.0,,,,,,,,,,,...,,,,,,,,,,,2,,23.24,2.757735,1.073076,7.633618,3.309370,0.104817,4.516250,0.486408,108.510115,2.064103,1.122410,0.001508,0.012669,0.140171,37.0,18.407960,0.033177,OK,,OK,,maybe,vmpfc dropout,OK,,OK,,0
1109,1109,1110,51585,1110,51585,SBL,SBL_0051585,1,1,27.00,1,,90.0,96.0,99.0,106.0,WAIS_III,WAIS_III,WAIS_III,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,-9999.0,,,,,,,,,,,,...,,,,,,,,,,,2,,-9999.00,3.413469,1.358238,4.335700,3.324550,0.109490,4.933960,0.465152,90.100651,2.157835,1.110226,0.002469,0.008923,0.154887,52.0,25.870647,0.026048,OK,,OK,,maybe,ic-cerebellum-temporal_lobe,OK,,OK,,0
1110,1110,1111,51606,1111,51606,MAX_MUN,MaxMun_a_0051606,1,2,29.00,2,R,,118.0,,,WST,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,2,,,7.839007,1.754363,12.270055,3.232170,0.083964,16.403174,0.573711,77.402099,1.780653,1.152828,0.000784,0.005814,0.048246,0.0,0.000000,0.053727,OK,,OK,,maybe,ic-cerebellum,OK,,OK,,0


In [None]:
import numpy as np
import deepdish as dd
import csv
import os
import multiprocessing


def create_vector(site_folder):
    site, num = site_folder
    file_dir = fold + '/' + site
    file = list(os.walk(file_dir))[0][-1][:]
    label = []
    data = []
    id = []
    rows = {}  # label: rows = ['UM_1_0050272':1 ...]
    with open(fold + '/abide_preprocessed.csv', newline='') as csvfile:
        reader = csv.reader(csvfile, delimiter=' ', quotechar='|')
        for i in reader:
            if list(i[0].split(','))[5] in ['UM_1', 'NYU', 'USM', 'UCLA_1']:
                name, lab = list(i[0].split(','))[6:8]
                lab = int(lab) % 2
                rows[name] = lab
    for filename in file:
        tmp = dd.io.load(fold + "/content/{}/{}".format(site, filename))
        tri = np.triu(tmp, 1).reshape(-1)
        tri = tri[tri != 0]
        tri[tri < 0] = 0
        data.append(tri)
        label.append(int(rows[filename[:num]]) % 2)
        id.append(filename)
    data = np.array(data)
    label = np.array(label, dtype=np.int32)
    id = np.array(id)
    dataset = {'data': data, 'label': label, 'id': id}
    dd.io.save(fold + '{}.h5'.format(site), dataset)


np.random.seed(5)
fold = os.getcwd()
cores = 4 if multiprocessing.cpu_count() >= 4 else multiprocessing.cpu_count()
pool = multiprocessing.Pool(cores)
pool.map(create_vector, SITE_FOLDER)

FileNotFoundError: ignored

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as tdist



class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(dim_in, dim_hidden)
        self.bn1 = nn.BatchNorm1d(dim_hidden)
        self.relu = nn.ReLU(dim_hidden)
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)

class Encoder(nn.Module):
    def __init__(self, dim_in, dim_hidden):
        super( Encoder, self).__init__()
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(dim_hidden)

    def forward(self, x):
        x = self.dropout(x)
        x = self.fc(x)
        x = self.relu(x)
        x = self.bn(x)
        return x

class Classifier(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(Classifier, self).__init__()
        self.encoder = Encoder(dim_in, dim_hidden)
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = self.encoder(x)
        x = self.dropout(x)
        x = self.fc(x)
        return F.log_softmax(x, dim=-1)

class Discriminator(nn.Module):
    def __init__(self, dim_in):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(dim_in, 4)
        self.relu = nn.ReLU()
        self.fc2= nn.Linear(4, 1)

    def forward(self, x):
        #noise = noise.to(device)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return F.sigmoid(x)

class MoE(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MoE, self).__init__()
        self.classifier = Classifier(dim_in, dim_hidden, dim_out)
        self.gate = nn.Linear(dim_in, 1)

    def forward(self, x, yg):
        yl = self.classifier(x)
        a = self.gate(x)
        a = F.sigmoid(a)
        res = yl*a+yg*(1-a)
        return res, a

In [None]:
%%writefile train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
import time
import deepdish as dd
import torch.distributions as tdist
import os
import argparse
import numpy as np
import copy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

EPS = 1e-15

def main(args):
    torch.manual_seed(args.seed)
    if not os.path.exists(args.res_dir):
        os.mkdir(args.res_dir)
    if not os.path.exists(os.path.join(args.res_dir,args.type+str(args.noise))):
        os.mkdir(os.path.join(args.res_dir,args.type+str(args.noise)))
    if not os.path.exists(os.path.join(args.res_dir,args.type+str(args.noise),str(args.pace))):
        os.mkdir(os.path.join(args.res_dir,args.type+str(args.noise),str(args.pace)))

    if not os.path.exists(args.model_dir):
        os.mkdir(args.model_dir)

    res_dir = os.path.join(args.res_dir,args.type+str(args.noise),str(args.pace))

    data1 = dd.io.load(os.path.join(args.vec_dir,'NYU_correlation_matrix.h5'))
    data2 = dd.io.load(os.path.join(args.vec_dir,'UM_correlation_matrix.h5'))
    data3 = dd.io.load(os.path.join(args.vec_dir,'USM_correlation_matrix.h5'))
    data4 = dd.io.load(os.path.join(args.vec_dir,'UCLA_correlation_matrix.h5'))

    x1 = torch.from_numpy(data1['data']).float()
    y1 = torch.from_numpy(data1['label']).long()
    x2 = torch.from_numpy(data2['data']).float()
    y2 = torch.from_numpy(data2['label']).long()
    x3 = torch.from_numpy(data3['data']).float()
    y3 = torch.from_numpy(data3['label']).long()
    x4 = torch.from_numpy(data4['data']).float()
    y4 = torch.from_numpy(data4['label']).long()

    if args.overlap:
        idNYU = dd.io.load('./idx/NYU_sub_overlap.h5')
        idUM = dd.io.load('./idx/UM_sub_overlap.h5')
        idUSM = dd.io.load('./idx/USM_sub_overlap.h5')
        idUCLA = dd.io.load('./idx/UCLA_sub_overlap.h5')
    else:
        idNYU = dd.io.load('./idx/NYU_sub.h5')
        idUM = dd.io.load('./idx/UM_sub.h5')
        idUSM = dd.io.load('./idx/USM_sub.h5')
        idUCLA = dd.io.load('./idx/UCLA_sub.h5')

    if args.split==0:
        tr1 = idNYU['1']+idNYU['2']+idNYU['3']+idNYU['4']
        tr2 = idUM['1']+idUM['2']+idUM['3']+idUM['4']
        tr3 = idUSM['1']+idUSM['2']+idUSM['3']+idUSM['4']
        tr4 = idUCLA['1']+idUCLA['2']+idUCLA['3']+idUCLA['4']
        te1=  idNYU['0']
        te2 = idUM['0']
        te3=  idUSM['0']
        te4 = idUCLA['0']
    elif args.split==1:
        tr1 = idNYU['0']+idNYU['2']+idNYU['3']+idNYU['4']
        tr2 = idUM['0']+idUM['2']+idUM['3']+idUM['4']
        tr3 = idUSM['0']+idUSM['2']+idUSM['3']+idUSM['4']
        tr4 = idUCLA['0']+idUCLA['2']+idUCLA['3']+idUCLA['4']
        te1=  idNYU['1']
        te2 = idUM['1']
        te3=  idUSM['1']
        te4 = idUCLA['1']
    elif args.split==2:
        tr1 = idNYU['0']+idNYU['1']+idNYU['3']+idNYU['4']
        tr2 = idUM['0']+idUM['1']+idUM['3']+idUM['4']
        tr3 = idUSM['0']+idUSM['1']+idUSM['3']+idUSM['4']
        tr4 = idUCLA['0']+idUCLA['1']+idUCLA['3']+idUCLA['4']
        te1=  idNYU['2']
        te2 = idUM['2']
        te3=  idUSM['2']
        te4 = idUCLA['2']
    elif args.split==3:
        tr1 = idNYU['0']+idNYU['1']+idNYU['2']+idNYU['4']
        tr2 = idUM['0']+idUM['1']+idUM['2']+idUM['4']
        tr3 = idUSM['0']+idUSM['1']+idUSM['2']+idUSM['4']
        tr4 = idUCLA['0']+idUCLA['1']+idUCLA['2']+idUCLA['4']
        te1=  idNYU['3']
        te2 = idUM['3']
        te3=  idUSM['3']
        te4 = idUCLA['3']
    elif args.split==4:
        tr1 = idNYU['0']+idNYU['1']+idNYU['2']+idNYU['3']
        tr2 = idUM['0']+idUM['1']+idUM['2']+idUM['3']
        tr3 = idUSM['0']+idUSM['1']+idUSM['2']+idUSM['3']
        tr4 = idUCLA['0']+idUCLA['1']+idUCLA['2']+idUCLA['3']
        te1=  idNYU['4']
        te2 = idUM['4']
        te3=  idUSM['4']
        te4 = idUCLA['4']

    x1_train = x1[tr1]
    y1_train = y1[tr1]
    x2_train = x2[tr2]
    y2_train = y2[tr2]
    x3_train = x3[tr3]
    y3_train = y3[tr3]
    x4_train = x4[tr4]
    y4_train = y4[tr4]

    x1_test = x1[te1]
    y1_test = y1[te1]
    x2_test = x2[te2]
    y2_test = y2[te2]
    x3_test = x3[te3]
    y3_test = y3[te3]
    x4_test = x4[te4]
    y4_test = y4[te4]


    if args.sepnorm:
        mean = x1_train.mean(0, keepdim=True)
        dev = x1_train.std(0, keepdim=True)
        x1_train = (x1_train - mean) / dev
        x1_test = (x1_test - mean) / dev

        mean = x2_train.mean(0, keepdim=True)
        dev = x2_train.std(0, keepdim=True)
        x2_train = (x2_train - mean) / dev
        x2_test = (x2_test - mean) / dev

        mean = x3_train.mean(0, keepdim=True)
        dev = x3_train.std(0, keepdim=True)
        x3_train = (x3_train - mean) / dev
        x3_test = (x3_test - mean) / dev

        mean = x4_train.mean(0, keepdim=True)
        dev = x4_train.std(0, keepdim=True)
        x4_train = (x4_train - mean) / dev
        x4_test = (x4_test - mean) / dev
    else:
        mean = torch.cat((x1_train,x2_train,x3_train,x4_train),0).mean(0, keepdim=True)
        dev = torch.cat((x1_train,x2_train,x3_train,x4_train),0).std(0, keepdim=True)
        x1_train = (x1_train - mean) / dev
        x1_test = (x1_test - mean) / dev
        x2_train = (x2_train - mean) / dev
        x2_test = (x2_test - mean) / dev
        x3_train = (x3_train - mean) / dev
        x3_test = (x3_test - mean) / dev
        x4_train = (x4_train - mean) / dev
        x4_test = (x4_test - mean) / dev

    train1 = TensorDataset(x1_train, y1_train)
    train_loader1 = DataLoader(train1, batch_size=len(train1)//args.nsteps, shuffle=True)
    train2 = TensorDataset(x2_train, y2_train)
    train_loader2 = DataLoader(train2, batch_size=len(train2)//args.nsteps, shuffle=True)
    train3 = TensorDataset(x3_train, y3_train)
    train_loader3 = DataLoader(train3, batch_size=len(train3)//args.nsteps, shuffle=True)
    train4 = TensorDataset(x4_train, y4_train)
    train_loader4 = DataLoader(train4, batch_size=len(train4)//args.nsteps, shuffle=True)
    train_all=ConcatDataset([train1,train2,train3,train4])
    train_loader = DataLoader(train_all, batch_size=500, shuffle= False)

    test1 = TensorDataset(x1_test, y1_test)
    test2 = TensorDataset(x2_test, y2_test)
    test3 = TensorDataset(x3_test, y3_test)
    test4 = TensorDataset(x4_test, y4_test)
    test_loader1 = DataLoader(test1, batch_size=args.test_batch_size1, shuffle=False)
    test_loader2 = DataLoader(test2, batch_size=args.test_batch_size2, shuffle=False)
    test_loader3 = DataLoader(test3, batch_size=args.test_batch_size3, shuffle=False)
    test_loader4 = DataLoader(test4, batch_size=args.test_batch_size4, shuffle=False)
    tbs= [args.test_batch_size1, args.test_batch_size2, args.test_batch_size3, args.test_batch_size4]



    model1 = MLP(6105,args.dim,2).to(device)
    model2 = MLP(6105,args.dim,2).to(device)
    model3 = MLP(6105,args.dim,2).to(device)
    model4 = MLP(6105,args.dim,2).to(device)
    optimizer1 = optim.Adam(model1.parameters(), lr=args.lr1, weight_decay=5e-2)
    optimizer2 = optim.Adam(model2.parameters(), lr=args.lr2, weight_decay=5e-2)
    optimizer3 = optim.Adam(model3.parameters(), lr=args.lr3, weight_decay=5e-2)
    optimizer4 = optim.Adam(model4.parameters(), lr=args.lr4, weight_decay=5e-2)



    models = [model1, model2, model3, model4]
    train_loaders = [train_loader1, train_loader2, train_loader3, train_loader4]
    optimizers = [optimizer1, optimizer2, optimizer3, optimizer4]
    data_inters = [iter(train_loader1),iter(train_loader2),iter(train_loader3),iter(train_loader4)]


    model = MLP(6105,args.dim,2).to(device)
    print(model)
    nnloss = nn.NLLLoss()


    def train(epoch):
        pace = args.pace
        for i in range(4):
            models[i].train()
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']

        #define weights
        w = dict()
        denominator = np.sum(np.array(tbs))
        for i in range(4):
            w[i] = 0.25 #tbs[i]/denominator

        loss_all = dict()
        num_data = dict()
        for i in range(4):
            loss_all[i] = 0
            num_data[i] = 0
        count = 0
        for t in range(args.nsteps):
            for i in range(4):
                optimizers[i].zero_grad()
                a, b= next(data_inters[i])
                num_data[i] += b.size(0)
                a = a.to(device)
                b = b.to(device)
                output = models[i](a)
                loss = nnloss(output, b)
                loss.backward()
                loss_all[i] += loss.item() * b.size(0)
                optimizers[i].step()
            count += 1
            if count%pace ==0 or t == args.nsteps-1 :
                with torch.no_grad():
                    for key in model.state_dict().keys():
                        if models[0].state_dict()[key].dtype == torch.int64:
                            model.state_dict()[key].data.copy_(models[0].state_dict()[key])
                        else:
                            temp = torch.zeros_like(model.state_dict()[key])
                            # add noise
                            for s in range(4):
                                if args.type == 'G':
                                    nn = tdist.Normal(torch.tensor([0.0]), args.noise*torch.std(models[s].state_dict()[key].detach().cpu()))
                                else:
                                    nn = tdist.Laplace(torch.tensor([0.0]), args.noise*torch.std(models[s].state_dict()[key].detach().cpu()))
                                noise = nn.sample(models[s].state_dict()[key].size()).squeeze()
                                noise = noise.to(device)
                                temp += w[s]*(models[s].state_dict()[key]+noise)
                            # update global model
                            model.state_dict()[key].data.copy_(temp)
                            # updata local model
                            for s in range(4):
                                models[s].state_dict()[key].data.copy_(model.state_dict()[key])

        return loss_all[0] / num_data[0], loss_all[1] / num_data[1], \
               loss_all[2] / num_data[2], loss_all[3] / num_data[3]


    def test(federated_model,dataloader,train=False):
        federated_model.eval()
        test_loss = 0
        correct = 0
        outputs = []
        preds = []
        targets = []
        for data, target in dataloader:
            targets.append(target[0].detach().numpy())
            data = data.to(device)
            target = target.to(device)
            output = federated_model(data)
            outputs.append(output.detach().cpu().numpy())
            test_loss += nnloss(output, target).item()*target.size(0)
            pred = output.data.max(1)[1]
            preds.append(pred.detach().cpu().numpy())
            correct += pred.eq(target.view(-1)).sum().item()

        test_loss /= len(dataloader.dataset)
        correct /= len(dataloader.dataset)
        if train:
            print('Train set local: Average loss: {:.4f}, Average acc: {:.4f}'.format(test_loss, correct))
        else:
            print('Test set local: Average loss: {:.4f}, Average acc: {:.4f}'.format(test_loss, correct))
        return test_loss, correct, targets, outputs, preds

    best_acc = 0
    best_epoch = 0
    train_loss = dict()
    for i in range(4):
        train_loss[i] = list()
    for epoch in range(args.epochs):
        start_time = time.time()
        print(f"Epoch Number {epoch + 1}")
        l1,l2,l3,l4= train(epoch)
        print(' L1 loss: {:.4f}, L2 loss: {:.4f}, L3 loss: {:.4f}, L4 loss: {:.4f}'.format(l1,l2,l3,l4))
        train_loss[0].append(l1)
        train_loss[1].append(l2)
        train_loss[2].append(l3)
        train_loss[3].append(l4)
        test(model,train_loader,train=True)
        test(model,train_loader,train=True)

        print('===NYU===')
        _, acc1,targets1, outputs1, preds1 = test(model, test_loader1, train=False)
        print('===UM===')
        _, acc2,targets2, outputs2, preds2 = test(model, test_loader2, train=False)
        print('===USM===')
        _, acc3,targets3, outputs3, preds3 = test(model, test_loader3, train=False)
        print('===UCLA===')
        _, acc4,targets4, outputs4, preds4 = test(model, test_loader4, train=False)
        if (acc1+acc2+acc3+acc4)/4 > best_acc:
            best_acc = (acc1+acc2+acc3+acc4)/4
            best_epoch = epoch
        total_time = time.time() - start_time
        print('Communication time over the network', round(total_time, 2), 's\n')
    model_wts = copy.deepcopy(model.state_dict())
    torch.save(model_wts, os.path.join(args.model_dir, str(args.split) +'.pth'))
    dd.io.save(os.path.join(res_dir, 'NYU_' + str(args.split) + '.h5'),
                {'outputs': outputs1, 'preds': preds1, 'targets': targets1})
    dd.io.save(os.path.join(res_dir, 'UM_' + str(args.split) + '.h5'),
                {'outputs': outputs2, 'preds': preds2, 'targets': targets2})
    dd.io.save(os.path.join(res_dir, 'USM_' + str(args.split) + '.h5'),
                {'outputs': outputs3, 'preds': preds3, 'targets': targets3})
    dd.io.save(os.path.join(res_dir, 'UCLA_' + str(args.split) + '.h5'),
                {'outputs': outputs4, 'preds': preds4, 'targets': targets4})
    dd.io.save(os.path.join(res_dir,'train_loss.h5'),{'loss':train_loss})
    print('Best Acc:',best_acc)
    print('split:', args.split,'   noise:', args.noise, '   pace:', args.pace)


#==========================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # specify for dataset site
    parser.add_argument('--split', type=int, default=0, help='select 0-4 fold')
    parser.add_argument('--pace', type=int, default=20, help='communication pace')
    parser.add_argument('--noise', type=float, default=0, help='noise level for gaussian or err level for Lap')
    parser.add_argument('--type', type=str, default='G', help='Gaussian or Lap')
    # do not need to change
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr1', type=float, default=1e-5)
    parser.add_argument('--lr2', type=float, default=1e-5)
    parser.add_argument('--lr3', type=float, default=1e-5)
    parser.add_argument('--lr4', type=float, default=1e-5)
    parser.add_argument('--clip', type=float, default=5.0, help='gradient clip')
    parser.add_argument('--dim', type=int, default=16,help='hidden dim of MLP')
    parser.add_argument('--nsteps', type=int, default=60, help='training steps/epoach')
    parser.add_argument('-tbs1', '--test_batch_size1', type=int, default=145, help='NYU test batch size')
    parser.add_argument('-tbs2', '--test_batch_size2', type=int, default=265, help='UM test batch size')
    parser.add_argument('-tbs3', '--test_batch_size3', type=int, default=205, help='USM test batch size')
    parser.add_argument('-tbs4', '--test_batch_size4', type=int, default=85, help='UCLA test batch size')
    parser.add_argument('--overlap', type=bool, default=True, help='augmentation method')
    parser.add_argument('--sepnorm', type=bool, default=True, help='normalization method')
    parser.add_argument('--id_dir', type=str, default='./idx')
    parser.add_argument('--res_dir', type=str, default='/content/result/fed_overlap')
    parser.add_argument('--vec_dir', type=str, default='/content/data/HO_vector_overlap')
    parser.add_argument('--model_dir', type=str, default='/content/model/fed_overlap')
    os.mkdir('/content/result')
    os.mkdir('/content/data')
    os.mkdir('/content/model')
    args = parser.parse_args()
    assert args.split in [0,1,2,3,4]
    main(args)

Overwriting train.py


In [None]:
!python train.py

Traceback (most recent call last):
  File "train.py", line 364, in <module>
    main(args)
  File "train.py", line 30, in main
    data1 = dd.io.load(os.path.join(args.vec_dir,'NYU_correlation_matrix.h5'))
  File "/usr/local/lib/python3.7/dist-packages/deepdish/io/hdf5io.py", line 635, in load
    with tables.open_file(path, mode='r') as h5file:
  File "/usr/local/lib/python3.7/dist-packages/tables/file.py", line 320, in open_file
    return File(filename, mode, title, root_uep, filters, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/tables/file.py", line 784, in __init__
    self._g_new(filename, mode, **params)
  File "tables/hdf5extension.pyx", line 371, in tables.hdf5extension.File._g_new
  File "/usr/local/lib/python3.7/dist-packages/tables/utils.py", line 157, in check_file_access
    raise IOError("``%s`` does not exist" % (filename,))
OSError: ``/content/data/HO_vector_overlap/NYU_correlation_matrix.h5`` does not exist
