<a href="https://colab.research.google.com/github/DiogoLepri/ASD2_Project/blob/main/ASD-DiagNet2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%cd /content/drive/MyDrive/ASD2_Project/

/content/drive/MyDrive/ASD2_Project


In [5]:
!pip install pandas numpy matplotlib scikit-learn torch pyprind scipy




In [1]:
#options: cc200, dosenbach160, aal
p_ROI = "cc200"
p_fold = 10
p_center = "Caltech"
p_mode = "whole"
p_augmentation = True
p_Method = "ASD-DiagNet"

In [2]:
parameter_list = [p_ROI,p_fold,p_center,p_mode,p_augmentation,p_Method]
print("*****List of patameters****")
print("ROI atlas: ",p_ROI)
print("per Center or whole: ",p_mode)
if p_mode == 'percenter':
    print("Center's name: ",p_center)
print("Method's name: ",p_Method)
if p_Method == "ASD-DiagNet":
    print("Augmentation: ",p_augmentation)


*****List of patameters****
ROI atlas:  cc200
per Center or whole:  whole
Method's name:  ASD-DiagNet
Augmentation:  True


In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from functools import reduce
from sklearn.impute import SimpleImputer
import time
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
import pyprind
import sys
import pickle
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import KFold, StratifiedKFold
import torch.optim as optim
from sklearn.metrics import confusion_matrix
from scipy import stats
from sklearn import tree
import functools
import numpy.ma as ma # for masked arrays
import pyprind
import random
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier

## Importing the data

In [4]:
def get_key(filename):
    f_split = filename.split('_')
    if f_split[3] == 'rois':
        key = '_'.join(f_split[0:3])
    else:
        key = '_'.join(f_split[0:2])
    return key

In [5]:
# Path to your time series data
data_main_path = f'/content/drive/MyDrive/ASD2_Project/Outputs/cpac/filt_global/rois_{p_ROI}'
flist = os.listdir(data_main_path)
print("Number of files in data_main_path:", len(flist))

# Convert filenames to keys
for f in range(len(flist)):
    flist[f] = get_key(flist[f])

# Load phenotypic CSV
df_labels = pd.read_csv('/content/drive/MyDrive/ASD2_Project/Phenotypic_V1_0b_preprocessed1.csv')
print("Columns in df_labels:", df_labels.columns)

# Convert ASD/Control to 1/0
df_labels.DX_GROUP = df_labels.DX_GROUP.map({1: 1, 2: 0})
print("Number of rows in df_labels:", len(df_labels))

# Build the 'labels' dictionary for your dataset
labels = {}
for i, row in df_labels.iterrows():
    file_id = row['FILE_ID']
    y_label = row['DX_GROUP']
    if file_id == 'no_filename':
        continue
    assert(file_id not in labels)
    labels[file_id] = y_label

# Now, load and normalize the age column (assumes column is named 'AGE_AT_SCAN')
# Adjust if your CSV uses a different name for the age column
df_labels['AGE_AT_SCAN'] = df_labels['AGE_AT_SCAN'].astype(np.float32)
ages = df_labels['AGE_AT_SCAN']
ages_norm = (ages - ages.mean()) / ages.std()

# Build a phen_dict mapping FILE_ID -> normalized age
phen_dict = {}
for i, row in df_labels.iterrows():
    file_id = row['FILE_ID']
    if file_id == 'no_filename':
        continue
    # Use the same index 'i' to grab the normalized age
    phen_dict[file_id] = [ages_norm.iloc[i]]

# Quick check
print("Example entries from phen_dict (FILE_ID: [norm_age]):", list(phen_dict.items())[:5])

Number of files in data_main_path: 884
Columns in df_labels: Index(['Unnamed: 0.1', 'Unnamed: 0', 'SUB_ID', 'X', 'subject', 'SITE_ID',
       'FILE_ID', 'DX_GROUP', 'DSM_IV_TR', 'AGE_AT_SCAN',
       ...
       'qc_notes_rater_1', 'qc_anat_rater_2', 'qc_anat_notes_rater_2',
       'qc_func_rater_2', 'qc_func_notes_rater_2', 'qc_anat_rater_3',
       'qc_anat_notes_rater_3', 'qc_func_rater_3', 'qc_func_notes_rater_3',
       'SUB_IN_SMP'],
      dtype='object', length=106)
Number of rows in df_labels: 1112
Example entries from phen_dict (FILE_ID: [norm_age]): [('Pitt_0050003', [0.92095]), ('Pitt_0050004', [0.25398603]), ('Pitt_0050005', [-0.41297793]), ('Pitt_0050006', [-0.45777398]), ('Pitt_0050007', [0.09097814])]


### Helper functions for computing correlations

In [6]:
def get_label(filename):
    assert (filename in labels)
    return labels[filename]


def get_corr_data(filename):
    #print(filename)
    for file in os.listdir(data_main_path):
        if file.startswith(filename):
            df = pd.read_csv(os.path.join(data_main_path, file), sep='\t')

    with np.errstate(invalid="ignore"):
        corr = np.nan_to_num(np.corrcoef(df.T))
        mask = np.invert(np.tri(corr.shape[0], k=-1, dtype=bool))
        m = ma.masked_where(mask == 1, mask)
        return ma.masked_where(m, corr).compressed()

def get_corr_matrix(filename):
    for file in os.listdir(data_main_path):
        if file.startswith(filename):
            df = pd.read_csv(os.path.join(data_main_path, file), sep='\t')
    with np.errstate(invalid="ignore"):
        corr = np.nan_to_num(np.corrcoef(df.T))
        return corr

def confusion(g_turth,predictions):
    tn, fp, fn, tp = confusion_matrix(g_turth,predictions).ravel()
    accuracy = (tp+tn)/(tp+fp+tn+fn)
    sensitivity = (tp)/(tp+fn)
    specificty = (tn)/(tn+fp)
    return accuracy,sensitivity,specificty

def get_regs(samplesnames,regnum):
    datas = []
    for sn in samplesnames:
        datas.append(all_corr[sn][0])
    datas = np.array(datas)
    avg=[]
    for ie in range(datas.shape[1]):
        avg.append(np.mean(datas[:,ie]))
    avg=np.array(avg)
    highs=avg.argsort()[-regnum:][::-1]
    lows=avg.argsort()[:regnum][::-1]
    regions=np.concatenate((highs,lows),axis=0)
    return regions


## Helper fnuctions for computing correlations

In [7]:
if not os.path.exists('./correlations_file'+p_ROI+'.pkl'):
    pbar=pyprind.ProgBar(len(flist))
    all_corr = {}
    for f in flist:

        lab = get_label(f)
        all_corr[f] = (get_corr_data(f), lab)
        pbar.update()

    print('Corr-computations finished')

    pickle.dump(all_corr, open('./correlations_file'+p_ROI+'.pkl', 'wb'))
    print('Saving to file finished')

else:
    all_corr = pickle.load(open('./correlations_file'+p_ROI+'.pkl', 'rb'))

0% [##############################] 100% | ETA: 00:00:00
Total time elapsed: 00:00:37


Corr-computations finished
Saving to file finished


## Computing eigenvalues and eigenvector

In [8]:
if p_Method=="ASD-DiagNet":
    eig_data = {}
    pbar = pyprind.ProgBar(len(flist))
    for f in flist:
        d = get_corr_matrix(f)
        eig_vals, eig_vecs = np.linalg.eig(d)

        for ev in eig_vecs.T:
            np.testing.assert_array_almost_equal(1.0, np.linalg.norm(ev))

        sum_eigvals = np.sum(np.abs(eig_vals))
        # Make a list of (eigenvalue, eigenvector, norm_eigval) tuples
        eig_pairs = [(np.abs(eig_vals[i]), eig_vecs[:,i], np.abs(eig_vals[i])/sum_eigvals)
                     for i in range(len(eig_vals))]

        # Sort the (eigenvalue, eigenvector) tuples from high to low
        eig_pairs.sort(key=lambda x: x[0], reverse=True)

        eig_data[f] = {'eigvals':np.array([ep[0] for ep in eig_pairs]),
                       'norm-eigvals':np.array([ep[2] for ep in eig_pairs]),
                       'eigvecs':[ep[1] for ep in eig_pairs]}
        pbar.update()

0% [##############################] 100% | ETA: 00:00:00
Total time elapsed: 00:01:45


## Calculating Eros similarity

In [9]:
def norm_weights(sub_flist):
    num_dim = len(eig_data[flist[0]]['eigvals'])
    norm_weights = np.zeros(shape=num_dim)
    for f in sub_flist:
        norm_weights += eig_data[f]['norm-eigvals']
    return norm_weights

def cal_similarity(d1, d2, weights, lim=None):
    res = 0.0
    if lim is None:
        weights_arr = weights.copy()
    else:
        weights_arr = weights[:lim].copy()
        weights_arr /= np.sum(weights_arr)
    for i,w in enumerate(weights_arr):
        res += w*np.inner(d1[i], d2[i])
    return res

## Defining dataset class

In [10]:
class CC200Dataset(Dataset):
    def __init__(self, pkl_filename=None, data=None, samples_list=None,
                 phenotype_data=None,  # New parameter for phenotypic info (age)
                 augmentation=False, aug_factor=1, num_neighbs=5,
                 eig_data=None, similarity_fn=None, verbose=False, regs=None):
        self.regs = regs
        self.phenotype_data = phenotype_data  # store phenotypic info (e.g., age)
        if pkl_filename is not None:
            if verbose:
                print('Loading ..!', end=' ')
            self.data = pickle.load(open(pkl_filename, 'rb'))
        elif data is not None:
            self.data = data.copy()
        else:
            sys.stderr.write('Either PKL file or data is needed!')
            return

        # Prepare the sample list
        if samples_list is None:
            self.flist = [f for f in self.data]
        else:
            self.flist = [f for f in samples_list]
        self.labels = np.array([self.data[f][1] for f in self.flist])

        current_flist = np.array(self.flist.copy())
        current_lab0_flist = current_flist[self.labels == 0]
        current_lab1_flist = current_flist[self.labels == 1]

        if augmentation:
            self.num_data = aug_factor * len(self.flist)
            self.neighbors = {}
            pbar = pyprind.ProgBar(len(self.flist))
            weights = norm_weights(samples_list)  # assuming norm_weights is defined
            for f in self.flist:
                label = self.data[f][1]
                candidates = (set(current_lab0_flist) if label == 0 else set(current_lab1_flist))
                candidates.remove(f)
                eig_f = eig_data[f]['eigvecs']
                sim_list = []
                for cand in candidates:
                    eig_cand = eig_data[cand]['eigvecs']
                    sim = similarity_fn(eig_f, eig_cand, weights)
                    sim_list.append((sim, cand))
                sim_list.sort(key=lambda x: x[0], reverse=True)
                self.neighbors[f] = [item[1] for item in sim_list[:num_neighbs]]
        else:
            self.num_data = len(self.flist)

    def __getitem__(self, index):
        # Non-augmented samples
        if index < len(self.flist):
            fname = self.flist[index]
            data = self.data[fname][0].copy()  # get_corr_data(fname, mode=cal_mode)
            data = data[self.regs].copy()
            label = (self.labels[index],)
            if self.phenotype_data is not None:
                # Retrieve phenotype (age) using the file identifier as key
                age_val = self.phenotype_data[fname]
                return torch.FloatTensor(data), torch.FloatTensor(age_val), torch.FloatTensor(label)
            else:
                return torch.FloatTensor(data), torch.FloatTensor(label)
        else:
            # Augmentation branch: mix two samples
            f1 = self.flist[index % len(self.flist)]
            d1, y1 = self.data[f1][0], self.data[f1][1]
            d1 = d1[self.regs]
            if len(self.neighbors[f1]) > 0:
                f2 = np.random.choice(self.neighbors[f1])
            else:
                f2 = f1  # fallback to self if no neighbors exist
            d2, y2 = self.data[f2][0], self.data[f2][1]
            d2 = d2[self.regs]
            assert y1 == y2
            r = np.random.uniform(low=0, high=1)
            data = r * d1 + (1 - r) * d2
            label = (y1,)
            if self.phenotype_data is not None:
                # Use phenotype from the first sample for the augmented data
                age_val = self.phenotype_data[f1]
                return torch.FloatTensor(data), torch.FloatTensor(age_val), torch.FloatTensor(label)
            else:
                return torch.FloatTensor(data), torch.FloatTensor(label)

    def __len__(self):
        return self.num_data


## Definig data loader function

In [11]:
def get_loader(pkl_filename=None, data=None, samples_list=None,
               batch_size=64,
               num_workers=1, mode='train',
               *, augmentation=False, aug_factor=1, num_neighbs=5,
               eig_data=None, similarity_fn=None, verbose=False, regions=None,
               phenotype_data=None):  # New parameter for phenotype_data
    """Build and return data loader."""
    if mode == 'train':
        shuffle = True
    else:
        shuffle = False
        augmentation = False

    dataset = CC200Dataset(pkl_filename=pkl_filename, data=data, samples_list=samples_list,
                           augmentation=augmentation, aug_factor=aug_factor,
                           eig_data=eig_data, similarity_fn=similarity_fn, verbose=verbose, regs=regions,
                           phenotype_data=phenotype_data)  # Pass phenotype_data here

    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    return data_loader


## Defining Autoencoder class

In [12]:
class MTAutoEncoder(nn.Module):
    def __init__(self, num_inputs=990,
                 num_latent=200, tied=True,
                 num_classes=2, use_dropout=False):
        super(MTAutoEncoder, self).__init__()
        self.tied = tied
        self.num_latent = num_latent

        self.fc_encoder = nn.Linear(num_inputs, num_latent)

        if not tied:
            self.fc_decoder = nn.Linear(num_latent, num_inputs)

        # If using dropout, incorporate it before the classifier.
        # Note: The classifier now takes num_latent+1 inputs.
        if use_dropout:
            self.classifier = nn.Sequential(
                nn.Dropout(p=0.5),
                nn.Linear(self.num_latent + 1, 1)
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(self.num_latent + 1, 1)
            )

    def forward(self, x, age_data=None, eval_classifier=False):
        # Encode the input to get the latent representation
        latent = self.fc_encoder(x)
        latent = torch.tanh(latent)

        if eval_classifier:
            if age_data is None:
                raise ValueError("Classifier mode requires age_data input")
            # Concatenate the latent representation with age_data
            # (Ensure age_data is of shape [batch_size, 1])
            combined = torch.cat((latent, age_data), dim=1)
            x_logit = self.classifier(combined)
        else:
            x_logit = None

        # Reconstruct the input
        if self.tied:
            rec = F.linear(latent, self.fc_encoder.weight.t())
        else:
            rec = self.fc_decoder(latent)

        return rec, x_logit

mtae = MTAutoEncoder()
mtae


MTAutoEncoder(
  (fc_encoder): Linear(in_features=990, out_features=200, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=201, out_features=1, bias=True)
  )
)

## Defining training and testing functions

In [13]:
def train(model, epoch, train_loader, p_bernoulli=None, mode='both', lam_factor=1.0):
    model.train()
    train_losses = []
    for i, (images, age_data, batch_y) in enumerate(train_loader):
        # Skip incomplete batches
        if len(images) != batch_size:
            continue

        if p_bernoulli is not None:
            if i == 0:
                p_tensor = torch.ones_like(images).to(device) * p_bernoulli
            rand_bernoulli = torch.bernoulli(p_tensor).to(device)

        images = images.to(device)
        age_data = age_data.to(device)
        batch_y = batch_y.to(device)
        optimizer.zero_grad()

        # Autoencoder branch
        if mode in ['both', 'ae']:
            if p_bernoulli is not None:
                rec_noisy, _ = model(images * rand_bernoulli, age_data, False)
                loss_ae = criterion_ae(rec_noisy, images) / len(images)
            else:
                rec, _ = model(images, age_data, False)
                loss_ae = criterion_ae(rec, images) / len(images)

        # Classifier branch
        if mode in ['both', 'clf']:
            rec_clean, logits = model(images, age_data, True)
            loss_clf = criterion_clf(logits, batch_y)

        # Combine losses based on mode
        if mode == 'both':
            loss_total = loss_ae + lam_factor * loss_clf
            train_losses.append([loss_ae.detach().cpu().numpy(), loss_clf.detach().cpu().numpy()])
        elif mode == 'ae':
            loss_total = loss_ae
            train_losses.append([loss_ae.detach().cpu().numpy(), 0.0])
        elif mode == 'clf':
            loss_total = loss_clf
            train_losses.append([0.0, loss_clf.detach().cpu().numpy()])

        loss_total.backward()
        optimizer.step()

    return train_losses



def test(model, criterion, test_loader, eval_classifier=False, num_batch=None):
    test_loss, n_test, correct = 0.0, 0, 0
    all_predss = []
    if eval_classifier:
        y_true, y_pred = [], []
    with torch.no_grad():
        model.eval()
        for i, (images, age_data, batch_y) in enumerate(test_loader, 1):
            if num_batch is not None and i >= num_batch:
                continue
            images = images.to(device)
            age_data = age_data.to(device)
            batch_y = batch_y.to(device)

            rec, logits = model(images, age_data, eval_classifier)
            test_loss += criterion(rec, images).detach().cpu().numpy()
            n_test += len(images)
            if eval_classifier:
                proba = torch.sigmoid(logits).detach().cpu().numpy()
                preds = np.ones_like(proba, dtype=np.int32)
                preds[proba < 0.5] = 0
                all_predss.extend(preds)
                y_arr = batch_y.cpu().numpy().astype(np.int32)
                correct += np.sum(preds == y_arr)
                y_true.extend(y_arr.tolist())
                y_pred.extend(proba.tolist())
        if eval_classifier:
            mlp_acc, mlp_sens, mlp_spef = confusion(y_true, all_predss)
    return mlp_acc, mlp_sens, mlp_spef



In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [15]:

if p_Method == "ASD-DiagNet" and p_mode == "whole":

    num_corr = len(all_corr[flist[0]][0])
    print("num_corr:  ",num_corr)

    start =time.time()
    batch_size = 8
    learning_rate_ae, learning_rate_clf = 0.0001, 0.0001
    num_epochs = 25

    p_bernoulli = None
    augmentation = p_augmentation
    use_dropout = False

    aug_factor = 2
    num_neighbs = 5
    lim4sim = 2
    n_lat = int(num_corr/4)
    print(n_lat)
    start= time.time()

    print('p_bernoulli: ', p_bernoulli)
    print('augmentaiton: ', augmentation, 'aug_factor: ', aug_factor,
          'num_neighbs: ', num_neighbs, 'lim4sim: ', lim4sim)
    print('use_dropout: ', use_dropout, '\n')


    sim_function = functools.partial(cal_similarity, lim=lim4sim)
    crossval_res_kol=[]
    y_arr = np.array([get_label(f) for f in flist])
    flist = np.array(flist)
    kk=0
    for rp in range(10):
        kf = StratifiedKFold(n_splits=p_fold, random_state=1, shuffle=True)
        np.random.shuffle(flist)
        y_arr = np.array([get_label(f) for f in flist])
        for kk,(train_index, test_index) in enumerate(kf.split(flist, y_arr)):
            train_samples, test_samples = flist[train_index], flist[test_index]


            verbose = (True if (kk == 0) else False)

            regions_inds = get_regs(train_samples,int(num_corr/4))

            num_inpp = len(regions_inds)
            n_lat = int(num_inpp/2)
            train_loader=get_loader(data=all_corr, samples_list=train_samples,
                                    batch_size=batch_size, mode='train',
                                    augmentation=augmentation, aug_factor=aug_factor,
                                    num_neighbs=num_neighbs, eig_data=eig_data, similarity_fn=sim_function,
                                    verbose=verbose,regions=regions_inds)

            test_loader=get_loader(data=all_corr, samples_list=test_samples,
                                   batch_size=batch_size, mode='test', augmentation=False,
                                   verbose=verbose,regions=regions_inds)

            model = MTAutoEncoder(tied=True, num_inputs=num_inpp, num_latent=n_lat, use_dropout=use_dropout)
            model.to(device)
            criterion_ae = nn.MSELoss(reduction='sum')
            criterion_clf = nn.BCEWithLogitsLoss()
            optimizer = optim.SGD([{'params': model.fc_encoder.parameters(), 'lr': learning_rate_ae},
                                   {'params': model.classifier.parameters(), 'lr': learning_rate_clf}],
                                  momentum=0.9)

            for epoch in range(1, num_epochs+1):
                if epoch <= 20:
                    train_losses = train(model, epoch, train_loader, p_bernoulli, mode='both')
                else:
                    train_losses = train(model, epoch, train_loader, p_bernoulli, mode='clf')


            res_mlp = test(model, criterion_ae, test_loader, eval_classifier=True)
            print(test(model, criterion_ae, test_loader, eval_classifier=True))
            crossval_res_kol.append(res_mlp)
        print("averages:")
        print(np.mean(np.array(crossval_res_kol),axis = 0))
        finish= time.time()

        print(finish-start)



num_corr:   19900
4975
p_bernoulli:  None
augmentaiton:  True aug_factor:  2 num_neighbs:  5 lim4sim:  2
use_dropout:  False 



KeyboardInterrupt: 

In [20]:
p_ROI = "cc200"
p_fold = 5  # Use 5-fold for intra-site evaluation
p_center = "KKI"
p_mode = "percenter"  # Change to "percenter" for intra-site evaluation
p_augmentation = True
p_Method = "ASD-DiagNet"

In [21]:
if p_Method == "ASD-DiagNet" and p_mode == "percenter":
    num_corr = len(all_corr[flist[0]][0])

    flist = os.listdir(data_main_path)
    flist = [get_key(f) for f in flist]

    centers_dict = {}
    for f in flist:
        key = f.split('_')[0]
        centers_dict.setdefault(key, []).append(f)

    flist = np.array(centers_dict[p_center])
    y_arr = np.array([get_label(f) for f in flist])

    # Determine the number of splits dynamically
    unique_labels, counts = np.unique(y_arr, return_counts=True)
    new_n_splits = min(p_fold, counts.min())

    if new_n_splits < 2:
        print(f"Skipping center {p_center} due to insufficient samples in one class.")
    else:
        print(f"Using {new_n_splits}-fold cross-validation for center {p_center}.")

        start = time.time()
        batch_size = 8
        learning_rate_ae, learning_rate_clf = 0.0001, 0.0001
        num_epochs = 25
        p_bernoulli = None
        augmentation = p_augmentation
        use_dropout = False
        aug_factor = 2
        num_neighbs = 5
        lim4sim = 2
        n_lat = int(num_corr / 4)

        sim_function = functools.partial(cal_similarity, lim=lim4sim)
        all_rp_res = []

        for rp in range(10):
            print(f"Running repeat {rp + 1} for center {p_center}...")
            crossval_res_kol = []
            kf = StratifiedKFold(n_splits=new_n_splits)
            for kk, (train_index, test_index) in enumerate(kf.split(flist, y_arr)):
                train_samples, test_samples = flist[train_index], flist[test_index]

                verbose = (kk == 0)
                regions_inds = get_regs(train_samples, int(num_corr / 4))
                num_inpp = len(regions_inds)
                n_lat = int(num_inpp / 2)

                train_loader = get_loader(data=all_corr, samples_list=train_samples,
                                          batch_size=batch_size, mode='train',
                                          augmentation=augmentation, aug_factor=aug_factor,
                                          num_neighbs=num_neighbs, eig_data=eig_data,
                                          similarity_fn=sim_function, verbose=verbose, regions=regions_inds,
                                          phenotype_data=phen_dict)

                test_loader = get_loader(data=all_corr, samples_list=test_samples,
                                         batch_size=batch_size, mode='test', augmentation=False,
                                         verbose=verbose, regions=regions_inds,
                                         phenotype_data=phen_dict)

                model = MTAutoEncoder(tied=True, num_inputs=num_inpp, num_latent=n_lat, use_dropout=use_dropout)
                model.to(device)
                criterion_ae = nn.MSELoss(reduction='sum')
                criterion_clf = nn.BCEWithLogitsLoss()
                optimizer = optim.SGD([{'params': model.fc_encoder.parameters(), 'lr': learning_rate_ae},
                                       {'params': model.classifier.parameters(), 'lr': learning_rate_clf}],
                                      momentum=0.9)

                for epoch in range(1, num_epochs + 1):
                    if epoch <= 20:
                        train_losses = train(model, epoch, train_loader, p_bernoulli, mode='both')
                    else:
                        train_losses = train(model, epoch, train_loader, p_bernoulli, mode='clf')

                res_mlp = test(model, criterion_ae, test_loader, eval_classifier=True)
                crossval_res_kol.append(res_mlp)

            print(f"Result of repeat {rp + 1} for center {p_center}: {np.mean(np.array(crossval_res_kol), axis=0)}")
            all_rp_res.append(np.mean(np.array(crossval_res_kol), axis=0))

        print(f"Average result for 10 repeats for center {p_center}: {np.mean(np.array(all_rp_res), axis=0)}")
        finish = time.time()
        print(f"Total running time for center {p_center}: {finish - start:.2f} seconds")


Using 5-fold cross-validation for center KKI.
Running repeat 1 for center KKI...
Result of repeat 1 for center KKI: [0.69285714 0.         1.        ]
Running repeat 2 for center KKI...
Result of repeat 2 for center KKI: [0.69285714 0.         1.        ]
Running repeat 3 for center KKI...
Result of repeat 3 for center KKI: [0.69285714 0.         1.        ]
Running repeat 4 for center KKI...
Result of repeat 4 for center KKI: [0.69285714 0.         1.        ]
Running repeat 5 for center KKI...
Result of repeat 5 for center KKI: [0.69285714 0.         1.        ]
Running repeat 6 for center KKI...
Result of repeat 6 for center KKI: [0.69285714 0.         1.        ]
Running repeat 7 for center KKI...
Result of repeat 7 for center KKI: [0.69285714 0.         1.        ]
Running repeat 8 for center KKI...
Result of repeat 8 for center KKI: [0.69285714 0.         1.        ]
Running repeat 9 for center KKI...
Result of repeat 9 for center KKI: [0.69285714 0.         1.        ]
Running r

Result Paper - Result Mine:

Accuracy%

  O%   -   1%     -     2%

Caltech: 52.8% - 60.5%

CMU: 68.5% - 51.6%

KKI: 69.5% - 69.3%

Leuven: 61.3% - 54.1%

Maxmun: 48.6% - 63.9%

NYU: 68.0% - 67.9%

OHSU: 82% - 73.3%

Olin: 65.1% - 68.8%

Pitt: 67.8% - 63.3%

SBL: 51.6% - 43.6%

SDSU: 63.0% - 64.0%

Stanford: 64.2% - 64.9%

Trinity: 54.1% - 53.1%

UCLA: 73.2% - 67.1%

USM: 68.2% - 68.2%

UM: 63.8% - 65.7%

Yale: 63.6% - 64.1%