Georgia Tech CSE7850 Machine Learning in Computational Biology Spring 2024

Group 21 Members: Kyle Peters, Sophia Imhof, Shrramana Ganesh Sudhakar, Dan H Nguyen

In [None]:
!pip uninstall tensorflow
!pip install nilearn
!pip install deepdish
!pip install torch_geometric==2.4.0
!pip install torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
!pip install numpy==1.23.4
!pip install neptune

Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu121.html


In [None]:
!pip install BrainGB

Collecting BrainGB
  Downloading BrainGB-1.0.4-py3-none-any.whl (18 kB)
Collecting node2vec>=0.4.3 (from BrainGB)
  Downloading node2vec-0.4.6-py3-none-any.whl (7.0 kB)
Collecting networkx>=2.5.1 (from BrainGB)
  Downloading networkx-2.8.8-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.10.2->BrainGB)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.10.2->BrainGB)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.10.2->BrainGB)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.10.2->BrainGB)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylin

In [None]:
from google.colab import userdata
import neptune
import os
import warnings
import glob
import csv
import numpy as np
import scipy.io as sio
from nilearn import connectome
import pandas as pd
from scipy.spatial import distance
from scipy import signal
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import Normalizer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
import networkx as nx
import numpy as np
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader
import torch_geometric
from sklearn.model_selection import StratifiedKFold
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from tqdm import tqdm
import enum
from torch import Tensor
from torch import nn
import torch
from sklearn.model_selection import train_test_split
warnings.filterwarnings("ignore")




In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# **Create Dataset**
Creating, processing, and loadding the ABIDE dataset were Adapted from https://github.com/HennyJie/BrainGB

In [None]:
root_path = "/content/drive/MyDrive/comp_bio/"

warnings.filterwarnings("ignore")
# Input data variables
class Reader:

    def __init__(self, root_path, id_file_path=None) -> None:

        root_folder = root_path
        self.data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal')
        self.phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv')
        self.id_file = id_file_path


    def fetch_filenames(self, subject_IDs, file_type, atlas):
        """
            subject_list : list of short subject IDs in string format
            file_type    : must be one of the available file types
            filemapping  : resulting file name format
        returns:
            filenames    : list of filetypes (same length as subject_list)
        """

        filemapping = {'func_preproc': '_func_preproc.nii.gz',
                    'rois_' + atlas: '_rois_' + atlas + '.1D'}
        # The list to be filled
        filenames = []

        # Fill list with requested file paths
        for i in range(len(subject_IDs)):
            os.chdir(self.data_folder)
            find_files = glob.glob('*' + subject_IDs[i] + filemapping[file_type])
            if len(find_files) > 0:
                filenames.append(find_files[0])
            else:
                if os.path.isdir(self.data_folder + '/' + subject_IDs[i]):
                    os.chdir(self.data_folder + '/' + subject_IDs[i])
                    filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0])
                else:
                    filenames.append('N/A')
        return filenames


    # Get timeseries arrays for list of subjects
    def get_timeseries(self, subject_list, atlas_name, silence=False):
        """
            subject_list : list of short subject IDs in string format
            atlas_name   : the atlas based on which the timeseries are generated e.g. aal, cc200
        returns:
            time_series  : list of timeseries arrays, each of shape (timepoints x regions)
        """

        timeseries = []
        for i in range(len(subject_list)):
            subject_folder = os.path.join(self.data_folder, subject_list[i])
            ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')]
            fl = os.path.join(subject_folder, ro_file[0])
            if silence != True:
                print("Reading timeseries file %s" % fl)
            timeseries.append(np.loadtxt(fl, skiprows=0))

        return timeseries


    #  compute connectivity matrices
    def subject_connectivity(self, timeseries, subjects, atlas_name, kind, iter_no='', seed=1234,
                            n_subjects='', save=True, save_path=None):
        """
            timeseries   : timeseries table for subject (timepoints x regions)
            subjects     : subject IDs
            atlas_name   : name of the parcellation atlas used
            kind         : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation
            iter_no      : tangent connectivity iteration number for cross validation evaluation
            save         : save the connectivity matrix to a file
            save_path    : specify path to save the matrix if different from subject folder
        returns:
            connectivity : connectivity matrix (regions x regions)
        """

        if kind in ['TPE', 'TE', 'correlation','partial correlation']:
            if kind not in ['TPE', 'TE']:
                conn_measure = connectome.ConnectivityMeasure(kind=kind)
                connectivity = conn_measure.fit_transform(timeseries)
            else:
                if kind == 'TPE':
                    conn_measure = connectome.ConnectivityMeasure(kind='correlation')
                    conn_mat = conn_measure.fit_transform(timeseries)
                    conn_measure = connectome.ConnectivityMeasure(kind='tangent')
                    connectivity_fit = conn_measure.fit(conn_mat)
                    connectivity = connectivity_fit.transform(conn_mat)
                else:
                    conn_measure = connectome.ConnectivityMeasure(kind='tangent')
                    connectivity_fit = conn_measure.fit(timeseries)
                    connectivity = connectivity_fit.transform(timeseries)

        if save:
            if not save_path:
                save_path = self.data_folder
            if kind not in ['TPE', 'TE']:
                for i, subj_id in enumerate(subjects):
                    subject_file = os.path.join(save_path, subj_id,
                                                subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat')
                    sio.savemat(subject_file, {'connectivity': connectivity[i]})
                return connectivity
            else:
                for i, subj_id in enumerate(subjects):
                    subject_file = os.path.join(save_path, subj_id,
                                                subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '_' + str(
                                                    iter_no) + '_' + str(seed) + '_' + validation_ext + str(
                                                    n_subjects) + '.mat')
                    sio.savemat(subject_file, {'connectivity': connectivity[i]})
                return connectivity_fit


    # Get the list of subject IDs

    def get_ids(self, num_subjects=None):
        """
        return:
            subject_IDs    : list of all subject IDs
        """

        subject_IDs = np.genfromtxt(self.id_file, dtype=str)

        if num_subjects is not None:
            subject_IDs = subject_IDs[:num_subjects]

        return subject_IDs


    # Get phenotype values for a list of subjects
    def get_subject_score(self, subject_list, score):
        scores_dict = {}

        with open(self.phenotype) as csv_file:
            reader = csv.DictReader(csv_file)
            for row in reader:
                if row['SUB_ID'] in subject_list:
                    if score == 'HANDEDNESS_CATEGORY':
                        if (row[score].strip() == '-9999') or (row[score].strip() == ''):
                            scores_dict[row['SUB_ID']] = 'R'
                        elif row[score] == 'Mixed':
                            scores_dict[row['SUB_ID']] = 'Ambi'
                        elif row[score] == 'L->R':
                            scores_dict[row['SUB_ID']] = 'Ambi'
                        else:
                            scores_dict[row['SUB_ID']] = row[score]
                    elif (score == 'FIQ' or score == 'PIQ' or score == 'VIQ'):
                        if (row[score].strip() == '-9999') or (row[score].strip() == ''):
                            scores_dict[row['SUB_ID']] = 100
                        else:
                            scores_dict[row['SUB_ID']] = float(row[score])

                    else:
                        scores_dict[row['SUB_ID']] = row[score]

        return scores_dict


    # preprocess phenotypes. Categorical -> ordinal representation
    @staticmethod
    def preprocess_phenotypes(pheno_ft, params):
        if params['model'] == 'MIDA':
            ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder='passthrough')
        else:
            ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder='passthrough')

        pheno_ft = ct.fit_transform(pheno_ft)
        pheno_ft = pheno_ft.astype('float32')

        return (pheno_ft)


    # create phenotype feature vector to concatenate with fmri feature vectors
    @staticmethod
    def phenotype_ft_vector(pheno_ft, num_subjects, params):
        gender = pheno_ft[:, 0]
        if params['model'] == 'MIDA':
            eye = pheno_ft[:, 0]
            hand = pheno_ft[:, 2]
            age = pheno_ft[:, 3]
            fiq = pheno_ft[:, 4]
        else:
            eye = pheno_ft[:, 2]
            hand = pheno_ft[:, 3]
            age = pheno_ft[:, 4]
            fiq = pheno_ft[:, 5]

        phenotype_ft = np.zeros((num_subjects, 4))
        phenotype_ft_eye = np.zeros((num_subjects, 2))
        phenotype_ft_hand = np.zeros((num_subjects, 3))

        for i in range(num_subjects):
            phenotype_ft[i, int(gender[i])] = 1
            phenotype_ft[i, -2] = age[i]
            phenotype_ft[i, -1] = fiq[i]
            phenotype_ft_eye[i, int(eye[i])] = 1
            phenotype_ft_hand[i, int(hand[i])] = 1

        if params['model'] == 'MIDA':
            phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1)
        else:
            phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1)

        return phenotype_ft


    # Load precomputed fMRI connectivity networks
    def get_networks(self, subject_list, kind, iter_no='', seed=1234, n_subjects='', atlas_name="aal",
                    variable='connectivity'):
        """
            subject_list : list of subject IDs
            kind         : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation
            atlas_name   : name of the parcellation atlas used
            variable     : variable name in the .mat file that has been used to save the precomputed networks
        return:
            matrix      : feature matrix of connectivity networks (num_subjects x network_size)
        """

        all_networks = []
        for subject in subject_list:
            if len(kind.split()) == 2:
                kind = '_'.join(kind.split())
            fl = os.path.join(self.data_folder, subject,
                                subject + "_" + atlas_name + "_" + kind.replace(' ', '_') + ".mat")


            matrix = sio.loadmat(fl)[variable]
            all_networks.append(matrix)

        if kind in ['TE', 'TPE']:
            norm_networks = [mat for mat in all_networks]
        else:
            norm_networks = [np.arctanh(mat) for mat in all_networks]

        networks = np.stack(norm_networks)

        return networks


class KeyDotDict(dict):
    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError as e:
            raise AttributeError(e)

    def __setattr__(self, key, value):
        self[key] = value

# Load Data

In [None]:
from nilearn import datasets
import argparse
import os
import shutil
import sys


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main(args):
    root_folder = args.root_path
    data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/')
    if not os.path.exists(data_folder):
        os.makedirs(data_folder)

    pipeline = args.pipeline
    atlas = args.atlas
    download = args.download

    # Files to fetch

    files = ['rois_' + atlas]

    filemapping = {'func_preproc': 'func_preproc.nii.gz',
                   files[0]: files[0] + '.1D'}


    # Download database files
    if download == True:
        abide = datasets.fetch_abide_pcp(data_dir=root_folder, pipeline=pipeline,
                                         band_pass_filtering=True, global_signal_regression=False, derivatives=files,
                                         quality_checked=False)
    reader = Reader(root_folder, args.id_file_path)
    subject_IDs = reader.get_ids() #changed path to data path
    subject_IDs = subject_IDs.tolist()

    # Create a folder for each subject
    for s, fname in zip(subject_IDs, reader.fetch_filenames(subject_IDs, files[0], atlas)):
        subject_folder = os.path.join(data_folder, s)
        if not os.path.exists(subject_folder):
            os.mkdir(subject_folder)

        # Get the base filename for each subject
        base = fname.split(files[0])[0]

        # Move each subject file to the subject folder
        for fl in files:
            if not os.path.exists(os.path.join(subject_folder, base + filemapping[fl])):
                shutil.move(base + filemapping[fl], subject_folder)

    time_series = reader.get_timeseries(subject_IDs, atlas)

    # Compute and save connectivity matrices
    reader.subject_connectivity(time_series, subject_IDs, atlas, 'correlation')
    reader.subject_connectivity(time_series, subject_IDs, atlas, 'partial correlation')


args = {
 "pipeline": "cpac",
 "atlas": "cc200",
 "download":True,
 "root_path":root_path,
 "id_file_path": root_path+"subject_IDs.txt"
}
args = KeyDotDict(args)
main(args)

Reading timeseries file /content/drive/MyDrive/comp_bio/ABIDE_pcp/ABIDE_pcp/cpac/filt_noglobal/50128/Olin_0050128_rois_cc200.1D
Reading timeseries file /content/drive/MyDrive/comp_bio/ABIDE_pcp/ABIDE_pcp/cpac/filt_noglobal/51203/UCLA_1_0051203_rois_cc200.1D
Reading timeseries file /content/drive/MyDrive/comp_bio/ABIDE_pcp/ABIDE_pcp/cpac/filt_noglobal/50325/UM_1_0050325_rois_cc200.1D
Reading timeseries file /content/drive/MyDrive/comp_bio/ABIDE_pcp/ABIDE_pcp/cpac/filt_noglobal/50117/Olin_0050117_rois_cc200.1D
Reading timeseries file /content/drive/MyDrive/comp_bio/ABIDE_pcp/ABIDE_pcp/cpac/filt_noglobal/50573/Yale_0050573_rois_cc200.1D
Reading timeseries file /content/drive/MyDrive/comp_bio/ABIDE_pcp/ABIDE_pcp/cpac/filt_noglobal/50741/Leuven_2_0050741_rois_cc200.1D
Reading timeseries file /content/drive/MyDrive/comp_bio/ABIDE_pcp/ABIDE_pcp/cpac/filt_noglobal/50779/KKI_0050779_rois_cc200.1D
Reading timeseries file /content/drive/MyDrive/comp_bio/ABIDE_pcp/ABIDE_pcp/cpac/filt_noglobal/5100

# Process Data

In [None]:
import sys
import argparse
import pandas as pd
import numpy as np
import deepdish as dd
import warnings
import os

# Process boolean command line arguments
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main(args):

    print('Arguments: \n', args)

    data_folder = os.path.join(args.root_path, 'ABIDE_pcp/cpac/filt_noglobal/')


    params = dict()

    params['seed'] = args.seed  # seed for random initialisation

    # Algorithm choice
    params['atlas'] = args.atlas  # Atlas for network construction
    atlas = args.atlas  # Atlas for network construction (node definition)

    reader = Reader(args.root_path, args.id_file_path)
    # Get subject IDs and class labels
    subject_IDs = reader.get_ids()
    labels = reader.get_subject_score(subject_IDs, score='DX_GROUP')

    # Number of subjects and classes for binary classification
    num_classes = args.nclass
    num_subjects = len(subject_IDs)
    params['n_subjects'] = num_subjects

    # Initialise variables for class labels and acquisition sites
    # 1 is autism, 2 is control
    y_data = np.zeros([num_subjects, num_classes]) # n x 2
    y = np.zeros([num_subjects, 1]) # n x 1

    # Get class labels for all subjects
    for i in range(num_subjects):
        y_data[i, int(labels[subject_IDs[i]]) - 1] = 1
        y[i] = int(labels[subject_IDs[i]])



    # Compute feature vectors (vectorised connectivity networks)
    fea_corr = reader.get_networks(subject_IDs, iter_no='', kind='correlation', atlas_name=atlas) #(1035, 200, 200)
    fea_pcorr = reader.get_networks(subject_IDs, iter_no='', kind='partial correlation', atlas_name=atlas) #(1035, 200, 200)

    if not os.path.exists(os.path.join(data_folder,'raw')):
        os.makedirs(os.path.join(data_folder,'raw'))
    for i, subject in enumerate(subject_IDs):
        dd.io.save(os.path.join(data_folder,'raw',subject+'.h5'),{'corr':fea_corr[i],'pcorr':fea_pcorr[i],'label':(y[i]-1)})

args = {
 "root_path":root_path,
 "atlas": "cc200",
 "seed": "0",
 "nclass":2,
 "id_file_path": root_path+"subject_IDs.txt"
}
args = KeyDotDict(args)
main(args)

Arguments: 
 {'root_path': '/content/drive/MyDrive/comp_bio/ABIDE_pcp', 'atlas': 'cc200', 'seed': '0', 'nclass': 2, 'id_file_path': '/content/drive/MyDrive/comp_bio/ABIDE_pcpsubject_IDs.txt'}


# Generate Dataset

In [None]:
import deepdish as dd
import os.path as osp
import os
import numpy as np
import argparse
from pathlib import Path
import pandas as pd


def main(args):
    data_dir =  os.path.join(args.root_path, 'ABIDE_pcp/cpac/filt_noglobal/raw')
    timeseires = os.path.join(args.root_path, 'ABIDE_pcp/cpac/filt_noglobal/')

    meta_file = os.path.join(args.root_path, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv')

    meta_file = pd.read_csv(meta_file, header=0)

    id2site = meta_file[["subject", "SITE_ID"]]

    # pandas to map
    id2site = id2site.set_index("subject")
    id2site = id2site.to_dict()['SITE_ID']

    times = []

    labels = []
    pcorrs = []

    corrs = []

    site_list = []

    for f in os.listdir(data_dir):
        if osp.isfile(osp.join(data_dir, f)):
            fname = f.split('.')[0]
            site = id2site[int(fname)]


            files = os.listdir(osp.join(timeseires, fname))

            file = list(filter(lambda x: x.endswith("1D"), files))[0]

            time = np.loadtxt(osp.join(timeseires, fname, file), skiprows=0).T

            if time.shape[1] < 100:
                continue

            temp = dd.io.load(osp.join(data_dir,  f))
            pcorr = temp['pcorr'][()]

            pcorr[pcorr == float('inf')] = 0

            att = temp['corr'][()]

            att[att == float('inf')] = 0

            label = temp['label']

            times.append(time[:,:100])
            labels.append(label[0])
            corrs.append(att)
            pcorrs.append(pcorr)
            site_list.append(site)

    np.save(Path(args.root_path)/'ABIDE_pcp/abide.npy', {'timeseires': np.array(times), "label": np.array(labels),"corr": np.array(corrs),"pcorr": np.array(pcorrs), 'site': np.array(site_list)})

args = {
 "root_path":root_path,
}
args = KeyDotDict(args)
main(args)

# **Modeling**

In [None]:
import torch
import random
from BrainGB.dataset import BrainDataset
from BrainGB.dataset.transforms import Adj
from torch_geometric.loader import DataLoader
import os
import numpy as np
def set_seeds(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)  # set random seed for numpy
    torch.manual_seed(seed)  # set random seed for CPU
    torch.cuda.manual_seed_all(seed)  # set random seed for all GPUs

set_seeds(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

root_path = "/content/drive/MyDrive/comp_bio/ABIDE_pcp"
dataset = "ABIDE"
node_features = Adj()
dataset = BrainDataset(root=root_path,
                           name=dataset,
                           pre_transform=node_features)
y = []
for d in dataset:
    y.append(d.y.item())
num_features = dataset[0].x.shape[1]
nodes_num = dataset.num_nodes

Processing...
Done!


# Supervised Fine Tuning

In [None]:
import torch
from sklearn import metrics
from torch_geometric.nn import global_add_pool, global_mean_pool, MessagePassing
from torch.nn import Parameter
import numpy as np
from torch.nn import functional as F
from torch_geometric.nn.inits import glorot, zeros
from typing import Tuple
from torch import Tensor
from torch_geometric.nn import GCNConv
from torch import nn
import math
import torch_geometric
from torch_geometric.nn import SAGEConv, GATConv
from torch_geometric.nn import MessagePassing
from sklearn.model_selection import StratifiedKFold
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from tqdm import tqdm

class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, activation, n_classes=0):
        super(MLP, self).__init__()
        self.net = []
        self.net.append(torch.nn.Linear(input_dim, hidden_dim))
        self.net.append(activation())
        for _ in range(num_layers - 1):
            self.net.append(torch.nn.Linear(hidden_dim, hidden_dim))
            self.net.append(activation())
        self.net = torch.nn.Sequential(*self.net)
        self.shortcut = torch.nn.Linear(input_dim, hidden_dim)

        if n_classes != 0:
            self.classifier = torch.nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        out = self.net(x) + self.shortcut(x)
        if hasattr(self, 'classifier'):
            return out, self.classifier(out)
        return out

@torch.no_grad()
def evaluate(model, device, loader):
    model.eval()
    preds, trues, preds_prob = [], [], []

    for data in loader:
        data = data.to(device)

        c = model(data)
        pred = c.max(dim=1)[1]
        preds += pred.detach().cpu().tolist()
        preds_prob += torch.softmax(c, dim=-1)[:, 1].detach().cpu().tolist()
        trues += data.y.detach().cpu().tolist()

    train_auc = metrics.roc_auc_score(trues, preds_prob)
    train_accuracy = metrics.accuracy_score(trues, preds)
    return train_accuracy, train_auc

# **Baselines**

Baseline Random Forest

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
import numpy as np

n_splits = 5
n_runs = 20

X = np.stack([data.x.mean(dim=0).numpy() for data in dataset])
y = np.array([data.y.item() for data in dataset])

all_accuracies = []
all_auroc = []

for run in range(n_runs):
    fold_accuracies = []
    fold_auroc = []
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=run)

    for train_index, test_index in skf.split(X, y):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        classifier = RandomForestClassifier(n_estimators=100, random_state=run)
        classifier.fit(X_train, y_train)
        y_pred = classifier.predict(X_test)
        y_pred_prob = classifier.predict_proba(X_test)[:, 1]
        accuracy = accuracy_score(y_test, y_pred)
        auroc = roc_auc_score(y_test, y_pred_prob)
        fold_accuracies.append(accuracy)
        fold_auroc.append(auroc)

    mean_accuracy = np.mean(fold_accuracies)
    all_accuracies.append(mean_accuracy)
    print(f'Run {run+1}, Mean Fold Accuracy: {mean_accuracy:.2f}')

    mean_auroc = np.mean(fold_auroc)
    all_auroc.append(mean_auroc)
    print(f'Run {run+1}, Mean Fold AUROC: {mean_auroc:.2f}')

overall_mean_accuracy = np.mean(all_accuracies)
overall_std_accuracy = np.std(all_accuracies)
print(f'Overall Mean Accuracy: {overall_mean_accuracy:.2f}, Std Dev: {overall_std_accuracy:.2f}')

overall_mean_auroc = np.mean(all_auroc)
overall_std_auroc = np.std(all_auroc)
print(f'Overall Mean AUROC: {overall_mean_auroc:.2f}, Std Dev: {overall_std_auroc:.2f}')

Run 1, Mean Fold Accuracy: 0.54
Run 1, Mean Fold AUROC: 0.56
Run 2, Mean Fold Accuracy: 0.52
Run 2, Mean Fold AUROC: 0.54
Run 3, Mean Fold Accuracy: 0.52
Run 3, Mean Fold AUROC: 0.56
Run 4, Mean Fold Accuracy: 0.56
Run 4, Mean Fold AUROC: 0.58
Run 5, Mean Fold Accuracy: 0.55
Run 5, Mean Fold AUROC: 0.55
Run 6, Mean Fold Accuracy: 0.55
Run 6, Mean Fold AUROC: 0.56
Run 7, Mean Fold Accuracy: 0.51
Run 7, Mean Fold AUROC: 0.53
Run 8, Mean Fold Accuracy: 0.53
Run 8, Mean Fold AUROC: 0.56
Run 9, Mean Fold Accuracy: 0.55
Run 9, Mean Fold AUROC: 0.57
Run 10, Mean Fold Accuracy: 0.53
Run 10, Mean Fold AUROC: 0.55
Run 11, Mean Fold Accuracy: 0.56
Run 11, Mean Fold AUROC: 0.59
Run 12, Mean Fold Accuracy: 0.55
Run 12, Mean Fold AUROC: 0.58
Run 13, Mean Fold Accuracy: 0.53
Run 13, Mean Fold AUROC: 0.56
Run 14, Mean Fold Accuracy: 0.54
Run 14, Mean Fold AUROC: 0.56
Run 15, Mean Fold Accuracy: 0.54
Run 15, Mean Fold AUROC: 0.57
Run 16, Mean Fold Accuracy: 0.55
Run 16, Mean Fold AUROC: 0.58
Run 17, Me

Baseline Graph Convolutional Network

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import DataLoader

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 32)
        self.conv2 = GCNConv(32, 32)
        self.fc = torch.nn.Linear(32, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        x = global_mean_pool(x, batch)
        x = self.fc(x)

        return x


In [None]:
!pip install torchmetrics
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchmetrics.classification import MulticlassAUROC

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)

def train_and_evaluate(model, train_loader, test_loader, device, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()

    model.eval()
    correct = 0
    total = 0
    auroc_score = MulticlassAUROC(num_classes=2, average='weighted').to(device)

    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            auroc_score.update(out, data.y.to(device))
            correct += (pred == data.y).sum().item()
            total += data.y.size(0)

    auroc_score = auroc_score.compute()

    return correct / total, auroc_score.cpu().data.numpy()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
initial_seed = 0

accuracies = []
auroc_gcn = []

for run in range(20):
    set_seed(initial_seed + run)
    train_idx, test_idx = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
    train_dataset = [dataset[i] for i in train_idx]
    test_dataset = [dataset[i] for i in test_idx]

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = GCN(num_node_features=dataset.num_node_features, num_classes=2).to(device)
    optimizer = Adam(model.parameters(), lr=0.01)
    criterion = CrossEntropyLoss()

    accuracy, auroc = train_and_evaluate(model, train_loader, test_loader, device, criterion, optimizer)
    accuracies.append(accuracy)
    print(f'Run {run + 1}, Accuracy: {accuracy:.4f}')
    auroc_gcn.append(auroc)
    print(f'Run {run + 1}, AUROC: {auroc:.4f}')

mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
print(f'Overall Mean Accuracy: {mean_accuracy:.4f}, Std Dev: {std_dev_accuracy:.4f}')

mean_auroc = np.mean(auroc_gcn)
std_dev_auroc = np.std(auroc_gcn)
print(f'Overall Mean AUROC: {mean_auroc:.4f}, Std Dev: {std_dev_auroc:.4f}')


Collecting torchmetrics
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.2 torchmetrics-1.3.2
Run 1, Accuracy: 0.5149
Run 1, AUROC: 0.5106
Run 2, Accuracy: 0.4851
Run 2, AUROC: 0.5059
Run 3, Accuracy: 0.5149
Run 3, AUROC: 0.4952
Run 4, Accuracy: 0.5149
Run 4, AUROC: 0.5051
Run 5, Accuracy: 0.4851
Run 5, AUROC: 0.5001
Run 6, Accuracy: 0.5149
Run 6, AUROC: 0.5002
Run 7, Accuracy: 0.5149
Run 7, AUROC: 0.5000
Run 8, Accuracy: 0.5149
Run 8, AUROC: 0.4952
Run 9, Accuracy: 0.5149
Run 9, AUROC: 0.5048
Run 10, Accuracy: 0.5149
Run 10, AUROC: 0.5004
Run 11, Accuracy: 0.5149
Run 11, AUROC: 0.5099
Run 12, Accuracy: 0.4851
Run 12, AUROC

Baseline Support Vector Machine

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
import random

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)


set_seeds(0)
features = []
labels = []

for data in dataset:
    aggregated_features = data.x.mean(axis=0)
    features.append(aggregated_features)
    labels.append(data.y.item())

features = np.stack(features)
labels = np.array(labels)


scaler = StandardScaler()
features = scaler.fit_transform(features)

initial_seed = 0
accuracies = []
auroc_svm = []
for run in range(20):
    seed = initial_seed + run
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=seed)


    model = SVC(kernel='linear', random_state=seed)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    y_pred_prob = classifier.predict_proba(X_test)[:, 1]

    accuracy = accuracy_score(y_test, y_pred)
    accuracies.append(accuracy)
    print(f'Run {run + 1}, Accuracy: {accuracy:.4f}')

    auroc = roc_auc_score(y_test, y_pred_prob)
    auroc_svm.append(auroc)
    print(f'Run {run + 1}, AUROC: {auroc:.4f}')

mean_accuracy = np.mean(accuracies)
std_accuracy = np.std(accuracies)
print(f'Overall Mean Accuracy: {mean_accuracy:.4f}, Std Dev: {std_accuracy:.4f}')

mean_auroc = np.mean(auroc_svm)
std_auroc = np.std(auroc_svm)
print(f'Overall Mean AUROC: {mean_auroc:.4f}, Std Dev: {std_auroc:.4f}')


Run 1, Accuracy: 0.5743
Run 1, AUROC: 0.4596
Run 2, Accuracy: 0.6040
Run 2, AUROC: 0.5072
Run 3, Accuracy: 0.5743
Run 3, AUROC: 0.5226
Run 4, Accuracy: 0.5545
Run 4, AUROC: 0.5433
Run 5, Accuracy: 0.5743
Run 5, AUROC: 0.5027
Run 6, Accuracy: 0.6337
Run 6, AUROC: 0.5200
Run 7, Accuracy: 0.5446
Run 7, AUROC: 0.4672
Run 8, Accuracy: 0.5990
Run 8, AUROC: 0.4763
Run 9, Accuracy: 0.6040
Run 9, AUROC: 0.5280
Run 10, Accuracy: 0.5743
Run 10, AUROC: 0.5238
Run 11, Accuracy: 0.5891
Run 11, AUROC: 0.5275
Run 12, Accuracy: 0.5495
Run 12, AUROC: 0.5188
Run 13, Accuracy: 0.5792
Run 13, AUROC: 0.5457
Run 14, Accuracy: 0.5693
Run 14, AUROC: 0.4402
Run 15, Accuracy: 0.5891
Run 15, AUROC: 0.5715
Run 16, Accuracy: 0.5891
Run 16, AUROC: 0.5691
Run 17, Accuracy: 0.5792
Run 17, AUROC: 0.5364
Run 18, Accuracy: 0.5644
Run 18, AUROC: 0.5494
Run 19, Accuracy: 0.6089
Run 19, AUROC: 0.5523
Run 20, Accuracy: 0.5495
Run 20, AUROC: 0.4800
Overall Mean Accuracy: 0.5802, Std Dev: 0.0221
Overall Mean AUROC: 0.5171, Std

# Convolutional Neural Network

In [None]:
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader
# Perform train-test split 80/20
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2)

batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
data = train_dataset[0]
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Is undirected: {data.is_undirected()}')
print(data)

Number of nodes: 200
Number of edges: 40000
Average node degree: 200.00
Is undirected: True
Data(edge_index=[2, 40000], edge_attr=[40000], y=[1], site='USM', x=[200, 200], num_nodes=200)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc = nn.Linear(589824, 2)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = x.view(-1, 1, 200, 200)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

In [None]:
!pip install -q torchmetrics

In [None]:
#test method used from GNN Homework 3
from torchmetrics.classification import MulticlassAUROC
from torchmetrics.classification import MulticlassAccuracy

def cnn_test(model, loss_fn, loader, device):
      model.eval()
      total_loss = 0.0
      auroc_score = MulticlassAUROC(num_classes=2, average='weighted').to(device)
      acc_score = MulticlassAccuracy(num_classes=2, average='weighted').to(device)

      with torch.no_grad():
            for data in loader:
                  out = model(data.x.to(device))
                  loss = loss_fn(out, data.y.to(device))
                  total_loss += loss
                  auroc_score.update(out, data.y.to(device))
                  acc_score.update(out, data.y.to(device))

      auroc_score = auroc_score.compute()
      acc_score = acc_score.compute()

      return total_loss / len(loader), auroc_score, acc_score

In [None]:
best_valid_loss = float('inf')
early_stopping_counter = 0
patience = 5

def train():
    model.train()
    total_loss = 0.0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x.to(device))
        loss = l(out, data.y.to(device))
        total_loss += loss
        loss.backward()
        optimizer.step()
    return total_loss / len(train_loader)

In [None]:
#training one time
device = "cuda" if torch.cuda.is_available() else "cpu"
num_edges = train_dataset[0].num_edges
model = CNN(2).to(device)
l = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 20
for epoch in range(num_epochs):
    train_loss = train()
    lt, train_auroc, train_acc = cnn_test(model, l, train_loader, device)
    lv, valid_auroc, valid_acc = cnn_test(model, l, val_loader, device)

    # Early stopping
    if lv < best_valid_loss:
        best_valid_loss = lv
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print(f'Early stopping triggered at epoch {epoch}')
            break
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.5f}, Train Auc: {train_auroc:.4f}, Train ACC: {train_acc:.4f}, Valid Loss: {lv:.5f}, Valid Auc: {valid_auroc:.4f}, Valid ACC: {valid_acc:.4f}')

Epoch: 000, Train Loss: 5.16828, Train Auc: 0.8985, Train ACC: 0.7794, Valid Loss: 1.51290, Valid Auc: 0.7030, Valid ACC: 0.6040
Epoch: 001, Train Loss: 0.91359, Train Auc: 0.9745, Train ACC: 0.8463, Valid Loss: 1.62644, Valid Auc: 0.6310, Valid ACC: 0.5693
Epoch: 002, Train Loss: 0.50273, Train Auc: 0.9954, Train ACC: 0.9616, Valid Loss: 1.24913, Valid Auc: 0.7123, Valid ACC: 0.6337
Epoch: 003, Train Loss: 0.38130, Train Auc: 0.9973, Train ACC: 0.7621, Valid Loss: 2.73698, Valid Auc: 0.7099, Valid ACC: 0.5792
Epoch: 004, Train Loss: 0.33008, Train Auc: 0.9994, Train ACC: 0.9864, Valid Loss: 1.18378, Valid Auc: 0.7299, Valid ACC: 0.6584
Epoch: 005, Train Loss: 0.13687, Train Auc: 0.9999, Train ACC: 0.9975, Valid Loss: 1.39877, Valid Auc: 0.7312, Valid ACC: 0.6436
Epoch: 006, Train Loss: 0.03583, Train Auc: 0.9996, Train ACC: 0.9926, Valid Loss: 1.42638, Valid Auc: 0.7469, Valid ACC: 0.6683
Epoch: 007, Train Loss: 0.07434, Train Auc: 1.0000, Train ACC: 0.9814, Valid Loss: 1.70383, Valid

In [None]:
#perform k=20-fold cross validations
k = 20
num_epochs = 10
# Initialize lists to store evaluation metrics for each fold
validation_auc = []
validation_accuracies = []
rand = 0 #change train/valid split randomly
fold = 0
while fold < k:
    print(f'Fold {fold + 1}/{k}')
    fold_size = 200
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.1, random_state=rand)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
    # Initialize model, optimizer, etc.
    model = CNN(num_classes=2).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    best_valid_loss = float('inf')
    for epoch in range(num_epochs):
        train_loss = train()
        lt, train_auroc, train_acc = cnn_test(model, l, train_loader, device)
        lv, valid_auroc, valid_acc = cnn_test(model, l, val_loader, device)
        if lv < best_valid_loss:
            best_valid_loss = lv
            early_stopping_counter = 0

    # Store validation loss and accuracy for this fold
    print(f'{k}: Validation AUC: {valid_auroc:.4f}, Validation Accuracy: {valid_acc:.4f}')
    validation_auc.append(valid_auroc)
    validation_accuracies.append(valid_acc)
    fold += 1
    rand += 1

Fold 1/20
20: Validation AUC: 0.7475, Validation Accuracy: 0.6337
Fold 2/20
20: Validation AUC: 0.8094, Validation Accuracy: 0.7327
Fold 3/20
20: Validation AUC: 0.6680, Validation Accuracy: 0.6931
Fold 4/20
20: Validation AUC: 0.7005, Validation Accuracy: 0.6436
Fold 5/20
20: Validation AUC: 0.7305, Validation Accuracy: 0.6634
Fold 6/20
20: Validation AUC: 0.7963, Validation Accuracy: 0.7129
Fold 7/20
20: Validation AUC: 0.7447, Validation Accuracy: 0.5644
Fold 8/20
20: Validation AUC: 0.7861, Validation Accuracy: 0.6832
Fold 9/20
20: Validation AUC: 0.8016, Validation Accuracy: 0.7030
Fold 10/20
20: Validation AUC: 0.7358, Validation Accuracy: 0.6436
Fold 11/20
20: Validation AUC: 0.7265, Validation Accuracy: 0.6337
Fold 12/20
20: Validation AUC: 0.7990, Validation Accuracy: 0.7426
Fold 13/20
20: Validation AUC: 0.7535, Validation Accuracy: 0.6931
Fold 14/20
20: Validation AUC: 0.7301, Validation Accuracy: 0.6832
Fold 15/20
20: Validation AUC: 0.7488, Validation Accuracy: 0.6634
Fold

In [None]:
# Calculate average validation loss and accuracy across all folds
validation_accuracies_cp = [tensor.cpu().numpy() for tensor in validation_accuracies]
validation_auc_cp = [tensor.cpu().numpy() for tensor in validation_auc]

mean1 = np.mean(validation_accuracies_cp)
std1 = np.std(validation_accuracies_cp)

mean2 = np.mean(validation_auc_cp)
std2 = np.std(validation_auc_cp)

print("Accuracy:")
print("Mean:", mean1)
print("Standard Deviation:", std1)

print("AUC ROC:")
print("Mean:", mean2)
print("Standard Deviation:", std2)

Accuracy:
Mean: 0.6727723
Standard Deviation: 0.0399091
AUC ROC:
Mean: 0.7405056
Standard Deviation: 0.040880784


# Graph Neural Network

```

```



In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GraphConv
from torchmetrics.classification import MulticlassAUROC

### DO NOT CHANGE ANY CODE ABOVE THIS LINE IN THIS CELL ###

class GNN(torch.nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        # TODO: Finish the design of this GNN model
        # self.conv1 = GraphConv(3, 4)
        # self.conv2 = GraphConv(4, 4)
        # self.conv3 = GraphConv(4, 4)
        # self.lin = Linear(4, 10)
        # self.dropout = 0.5
        self.dropout = 0.5
        self.gconv1 = GraphConv(200, 64)
        self.gconv2 = GraphConv(64,64)
        self.gconv3 = GraphConv(64, 64)
        self.lin = Linear(64, 2)

    def forward(self, x, edge_index, batch):
        # TODO: Finish this function
        # 1. Obtain node embeddings
        x = self.gconv1(x, edge_index)
        x = x.relu()
        x = self.gconv2(x, edge_index)
        x = x.relu()
        x = self.gconv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

In [None]:
model = GNN().to(device)
print(model)

GNN(
  (gconv1): GraphConv(200, 64)
  (gconv2): GraphConv(64, 64)
  (gconv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)

def train_and_evaluate(model, train_loader, test_loader, device, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()

    model.eval()
    correct = 0
    total = 0
    auroc_score = MulticlassAUROC(num_classes=2, average='weighted').to(device)
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            auroc_score.update(out, data.y.to(device))
            correct += (pred == data.y).sum().item()
            total += data.y.size(0)
    auroc_score = auroc_score.compute()
    return correct / total, auroc_score.cpu().data.numpy()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
initial_seed = 0

accuracies = []
auroc_lst = []
for run in range(20):
    set_seed(initial_seed + run)  # Systematically vary the seed
    train_idx, test_idx = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
    train_dataset = [dataset[i] for i in train_idx]
    test_dataset = [dataset[i] for i in test_idx]

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = GNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = torch.nn.CrossEntropyLoss()
    accuracy, auroc = train_and_evaluate(model, train_loader, test_loader, device, criterion, optimizer)
    accuracies.append(accuracy)
    auroc_lst.append(auroc)
    print(f'Run {run + 1}, Accuracy: {accuracy:.4f}')
    print(f'Run {run + 1}, AUROC: {auroc:.4f}')

mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
print(f'Overall Mean Accuracy: {mean_accuracy:.4f}, Std Dev: {std_dev_accuracy:.4f}')

mean_auroc = np.mean(auroc_lst)
std_dev_auroc = np.std(auroc_lst)
print(f'Overall Mean AUROC: {mean_auroc:.4f}, Std Dev: {std_dev_auroc:.4f}')

Run 1, Accuracy: 0.5149
Run 1, AUROC: 0.5110
Run 2, Accuracy: 0.5198
Run 2, AUROC: 0.5178
Run 3, Accuracy: 0.5446
Run 3, AUROC: 0.5415
Run 4, Accuracy: 0.5446
Run 4, AUROC: 0.5507
Run 5, Accuracy: 0.4901
Run 5, AUROC: 0.4886
Run 6, Accuracy: 0.5545
Run 6, AUROC: 0.5483
Run 7, Accuracy: 0.5297
Run 7, AUROC: 0.5266
Run 8, Accuracy: 0.4950
Run 8, AUROC: 0.4991
Run 9, Accuracy: 0.5693
Run 9, AUROC: 0.5732
Run 10, Accuracy: 0.4554
Run 10, AUROC: 0.4538
Run 11, Accuracy: 0.5396
Run 11, AUROC: 0.5401
Run 12, Accuracy: 0.5396
Run 12, AUROC: 0.5399
Run 13, Accuracy: 0.4703
Run 13, AUROC: 0.4688
Run 14, Accuracy: 0.5099
Run 14, AUROC: 0.5096
Run 15, Accuracy: 0.5545
Run 15, AUROC: 0.5530
Run 16, Accuracy: 0.5891
Run 16, AUROC: 0.5920
Run 17, Accuracy: 0.5446
Run 17, AUROC: 0.5441
Run 18, Accuracy: 0.5842
Run 18, AUROC: 0.5827
Run 19, Accuracy: 0.5594
Run 19, AUROC: 0.5560
Run 20, Accuracy: 0.4703
Run 20, AUROC: 0.4710
Overall Mean Accuracy: 0.5290, Std Dev: 0.0368
Overall Mean AUROC: 0.5284, Std

# **Transformer Model**

In [None]:
# cls tokens comes from: https://github.com/yandex-research/rtdl
class _TokenInitialization(enum.Enum):
    UNIFORM = 'uniform'
    NORMAL = 'normal'

    @classmethod
    def from_str(cls, initialization: str) -> '_TokenInitialization':
        try:
            return cls(initialization)
        except ValueError:
            valid_values = [x.value for x in _TokenInitialization]
            raise ValueError(f'initialization must be one of {valid_values}')

    def apply(self, x: Tensor, d: int) -> None:
        d_sqrt_inv = 1 / math.sqrt(d)
        if self == _TokenInitialization.UNIFORM:
            # used in the paper "Revisiting Deep Learning Models for Tabular Data";
            # is equivalent to `nn.init.kaiming_uniform_(x, a=math.sqrt(5))` (which is
            # used by torch to initialize nn.Linear.weight, for example)
            nn.init.uniform_(x, a=-d_sqrt_inv, b=d_sqrt_inv)
        elif self == _TokenInitialization.NORMAL:
            nn.init.normal_(x, std=d_sqrt_inv)


class CLSToken(nn.Module):
    def __init__(self, d_token: int, initialization: str) -> None:
        super().__init__()
        initialization_ = _TokenInitialization.from_str(initialization)
        self.weight = nn.Parameter(Tensor(d_token))
        initialization_.apply(self.weight, d_token)

    def expand(self, *leading_dimensions: int) -> Tensor:
        if not leading_dimensions:
            return self.weight
        new_dims = (1,) * (len(leading_dimensions) - 1)
        return self.weight.view(*new_dims, -1).expand(*leading_dimensions, -1)

    def forward(self, x: Tensor) -> Tensor:
        return torch.cat([x, self.expand(len(x), 1)], dim=1)

class TransformerBlock(nn.Module):
  def __init__(self, in_features, num_heads, att_dropout, mlp_dropout):
    super().__init__()
    self.attention = nn.MultiheadAttention(in_features, num_heads, dropout=att_dropout, batch_first=True) #nn.MultiheadAttention(in_features, num_heads, dropout=att_dropout, batch_first=True)

    self.mlp = nn.Sequential(
        nn.Linear(in_features, 4*in_features),
        nn.Dropout(mlp_dropout),
        nn.ReLU(),
        nn.Linear(4*in_features, in_features),
    )
    self.norm1 = nn.LayerNorm(in_features)
    self.norm2 = nn.LayerNorm(in_features)

  def forward(self, x):
    out, attention = self.attention(x, x, x)

    out = self.norm1(out + x)
    out2 = self.mlp(out)
    out2 = self.norm2(out2 + out)

    return out2


class TransformerModel(torch.nn.Module):
    def __init__(self, num_nodes=200, num_classes=2):
        super(TransformerModel, self).__init__()
        self.num_nodes = num_nodes
        self.node_fcn = nn.Linear(256, 8)

        self.fcn = nn.Sequential(
            nn.Linear(1600, 256),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(256, 32),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(32, num_classes)
        )

        self.input_embed = nn.Sequential(
            nn.Linear(200, 200),
        )


        d_token = 512
        heads = 8
        initialization = _TokenInitialization.from_str("uniform")
        self.transformer = nn.Sequential(
            nn.Linear(200, d_token),
            CLSToken(
                d_token, initialization
            ),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0)
        )

        self.out = nn.Linear(d_token,num_classes)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = data.x.reshape((data.x.shape[0]//200, 200, 200))
        x = self.transformer(x)
        cls_token = x[:, -1] #x[:,0:200].mean(dim=-2)
        return self.out(cls_token)

eval = True
eval_runs = 20
hyperparameter_tuning_split_rand_seed = 0

skf = StratifiedKFold(n_splits=5, shuffle=True)
lr = 5e-5
weight_decay = 1e-4
train_batch_size = 4
test_batch_size = 4
epochs = 10
device = "cuda:0"
run = {}
run["train/accuracy"] = []
run["train/auc"] = []
run["train/loss"] = []
run["test/accuracy"] = []
run["test/auc"] = []


# We don't want to use the test_indexes during evaluation since those were used during hyperparameter tuning
if eval:
    train_index, test_index = train_test_split(np.arange(len(y)), test_size=0.2, random_state=hyperparameter_tuning_split_rand_seed)
    dataset_used = dataset[train_index]
else:
    dataset_used = dataset

accuracy = []
auc = []
for k in range(eval_runs):
    if eval:
        train_index, test_index = train_test_split(np.arange(len(y)), test_size=0.2)
    else:
        train_index, test_index = train_test_split(np.arange(len(y)), test_size=0.2, random_state=hyperparameter_tuning_split_rand_seed)

    model = TransformerModel().to(device)
    #model.transformer.load_state_dict(stored_model.transformer.state_dict()) -> used previously for fine tuning the self supervised pretrained model
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    train_set, test_set = dataset[train_index], dataset[test_index]

    train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=test_batch_size, shuffle=False)

    # train
    model.eval()
    epoch_num = epochs
    for i in range(epoch_num):
        loss_all = 0
        for data in tqdm(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            loss_all += loss.item()
        epoch_loss = loss_all / len(train_loader.dataset)

        train_accuracy, train_auc = evaluate(model, device, train_loader)
        print(f'(Train) | Epoch={i:03d}, loss={epoch_loss:.4f}, '
                    f'train_accuracy={(train_accuracy * 100):.2f},'
                    f'train_auc={(train_auc * 100):.2f}')
        run["train/accuracy"].append(train_accuracy)
        run["train/auc"].append(train_auc)
        run["train/loss"].append(epoch_loss)

        model.eval()
        if (i + 1) % 1 == 0:
            test_accuracy, test_auc = evaluate(model, device, test_loader)
            print(f'test_accuracy={(test_accuracy * 100):.2f}, ' \
                  f'test_auc={(test_auc * 100):.2f}\n')

            run["test/accuracy"].append(test_accuracy)
            run["test/auc"].append(test_auc)


    if not eval:
        break

    # Rerun evaluation just to make 100% sure the results are accurate
    model.eval()
    test_accuracy, test_auc = evaluate(model, device, test_loader)
    accuracy.append(test_accuracy)
    auc.append(test_auc)

    if not eval:
        break

accuracy = np.array(accuracy)
auc = np.array(auc)

print(accuracy.mean(), accuracy.std())
print(auc.mean(), auc.std())

100%|██████████| 202/202 [00:13<00:00, 15.31it/s]


(Train) | Epoch=000, loss=0.1945, train_accuracy=48.57,train_auc=72.15
test_accuracy=50.00, test_auc=61.84



100%|██████████| 202/202 [00:12<00:00, 16.21it/s]


(Train) | Epoch=001, loss=0.1782, train_accuracy=66.42,train_auc=74.50
test_accuracy=57.92, test_auc=65.45



100%|██████████| 202/202 [00:12<00:00, 16.23it/s]


(Train) | Epoch=002, loss=0.1756, train_accuracy=67.41,train_auc=74.31
test_accuracy=64.36, test_auc=68.14



100%|██████████| 202/202 [00:13<00:00, 15.33it/s]


(Train) | Epoch=003, loss=0.1634, train_accuracy=49.07,train_auc=81.34
test_accuracy=50.50, test_auc=73.53



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=004, loss=0.1532, train_accuracy=50.31,train_auc=81.55
test_accuracy=50.50, test_auc=69.02



100%|██████████| 202/202 [00:11<00:00, 16.90it/s]


(Train) | Epoch=005, loss=0.1375, train_accuracy=81.66,train_auc=91.88
test_accuracy=61.88, test_auc=69.69



100%|██████████| 202/202 [00:12<00:00, 16.75it/s]


(Train) | Epoch=006, loss=0.1201, train_accuracy=73.48,train_auc=91.77
test_accuracy=60.40, test_auc=67.37



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=007, loss=0.1002, train_accuracy=92.32,train_auc=98.42
test_accuracy=63.37, test_auc=68.73



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=008, loss=0.0690, train_accuracy=90.95,train_auc=99.04
test_accuracy=59.90, test_auc=63.57



100%|██████████| 202/202 [00:12<00:00, 15.58it/s]


(Train) | Epoch=009, loss=0.0562, train_accuracy=97.15,train_auc=99.67
test_accuracy=58.42, test_auc=63.78



100%|██████████| 202/202 [00:12<00:00, 16.81it/s]


(Train) | Epoch=000, loss=0.2010, train_accuracy=51.30,train_auc=62.35
test_accuracy=50.50, test_auc=50.69



100%|██████████| 202/202 [00:11<00:00, 16.94it/s]


(Train) | Epoch=001, loss=0.1774, train_accuracy=49.19,train_auc=65.79
test_accuracy=49.50, test_auc=53.36



100%|██████████| 202/202 [00:11<00:00, 17.00it/s]


(Train) | Epoch=002, loss=0.1797, train_accuracy=48.70,train_auc=73.89
test_accuracy=49.50, test_auc=63.02



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=003, loss=0.1743, train_accuracy=48.70,train_auc=76.31
test_accuracy=49.50, test_auc=61.42



100%|██████████| 202/202 [00:12<00:00, 16.82it/s]


(Train) | Epoch=004, loss=0.1609, train_accuracy=53.28,train_auc=79.18
test_accuracy=51.49, test_auc=68.60



100%|██████████| 202/202 [00:12<00:00, 16.60it/s]


(Train) | Epoch=005, loss=0.1523, train_accuracy=77.45,train_auc=85.81
test_accuracy=60.89, test_auc=64.93



100%|██████████| 202/202 [00:12<00:00, 16.50it/s]


(Train) | Epoch=006, loss=0.1360, train_accuracy=80.67,train_auc=89.42
test_accuracy=66.34, test_auc=68.40



100%|██████████| 202/202 [00:11<00:00, 16.90it/s]


(Train) | Epoch=007, loss=0.1219, train_accuracy=84.14,train_auc=93.07
test_accuracy=59.90, test_auc=67.02



100%|██████████| 202/202 [00:11<00:00, 16.86it/s]


(Train) | Epoch=008, loss=0.0986, train_accuracy=87.36,train_auc=95.51
test_accuracy=62.87, test_auc=71.98



100%|██████████| 202/202 [00:11<00:00, 16.84it/s]


(Train) | Epoch=009, loss=0.0698, train_accuracy=89.10,train_auc=98.25
test_accuracy=63.86, test_auc=70.60



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=000, loss=0.1892, train_accuracy=49.57,train_auc=62.43
test_accuracy=57.43, test_auc=52.37



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=001, loss=0.1794, train_accuracy=61.46,train_auc=72.33
test_accuracy=51.49, test_auc=63.97



100%|██████████| 202/202 [00:11<00:00, 16.85it/s]


(Train) | Epoch=002, loss=0.1761, train_accuracy=51.92,train_auc=75.03
test_accuracy=44.55, test_auc=67.02



100%|██████████| 202/202 [00:11<00:00, 16.95it/s]


(Train) | Epoch=003, loss=0.1665, train_accuracy=65.55,train_auc=81.99
test_accuracy=64.36, test_auc=73.78



100%|██████████| 202/202 [00:11<00:00, 16.99it/s]


(Train) | Epoch=004, loss=0.1525, train_accuracy=67.53,train_auc=86.80
test_accuracy=58.91, test_auc=74.81



100%|██████████| 202/202 [00:11<00:00, 17.02it/s]


(Train) | Epoch=005, loss=0.1381, train_accuracy=62.33,train_auc=92.25
test_accuracy=63.86, test_auc=73.53



100%|██████████| 202/202 [00:11<00:00, 16.97it/s]


(Train) | Epoch=006, loss=0.1132, train_accuracy=84.26,train_auc=95.88
test_accuracy=68.32, test_auc=75.49



100%|██████████| 202/202 [00:11<00:00, 16.85it/s]


(Train) | Epoch=007, loss=0.0938, train_accuracy=80.55,train_auc=98.00
test_accuracy=65.35, test_auc=72.07



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=008, loss=0.0761, train_accuracy=94.92,train_auc=99.16
test_accuracy=66.83, test_auc=74.85



100%|██████████| 202/202 [00:12<00:00, 16.83it/s]


(Train) | Epoch=009, loss=0.0508, train_accuracy=98.14,train_auc=99.85
test_accuracy=68.32, test_auc=74.75



100%|██████████| 202/202 [00:11<00:00, 17.04it/s]


(Train) | Epoch=000, loss=0.2038, train_accuracy=51.80,train_auc=63.80
test_accuracy=48.51, test_auc=59.91



100%|██████████| 202/202 [00:11<00:00, 16.95it/s]


(Train) | Epoch=001, loss=0.1791, train_accuracy=51.80,train_auc=70.08
test_accuracy=48.51, test_auc=64.87



100%|██████████| 202/202 [00:12<00:00, 16.51it/s]


(Train) | Epoch=002, loss=0.1764, train_accuracy=51.80,train_auc=72.17
test_accuracy=48.51, test_auc=64.96



100%|██████████| 202/202 [00:11<00:00, 16.90it/s]


(Train) | Epoch=003, loss=0.1711, train_accuracy=54.52,train_auc=75.44
test_accuracy=50.50, test_auc=70.09



100%|██████████| 202/202 [00:11<00:00, 16.95it/s]


(Train) | Epoch=004, loss=0.1704, train_accuracy=54.40,train_auc=79.91
test_accuracy=50.00, test_auc=75.99



100%|██████████| 202/202 [00:11<00:00, 17.00it/s]


(Train) | Epoch=005, loss=0.1597, train_accuracy=70.88,train_auc=83.85
test_accuracy=61.88, test_auc=73.59



100%|██████████| 202/202 [00:11<00:00, 17.03it/s]


(Train) | Epoch=006, loss=0.1428, train_accuracy=77.45,train_auc=89.95
test_accuracy=60.40, test_auc=73.20



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=007, loss=0.1334, train_accuracy=79.93,train_auc=94.45
test_accuracy=63.37, test_auc=73.12



100%|██████████| 202/202 [00:11<00:00, 16.95it/s]


(Train) | Epoch=008, loss=0.1151, train_accuracy=89.84,train_auc=96.77
test_accuracy=63.37, test_auc=70.72



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=009, loss=0.0799, train_accuracy=89.10,train_auc=96.45
test_accuracy=63.86, test_auc=70.64



100%|██████████| 202/202 [00:11<00:00, 17.06it/s]


(Train) | Epoch=000, loss=0.1912, train_accuracy=51.43,train_auc=64.60
test_accuracy=50.50, test_auc=56.87



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=001, loss=0.1821, train_accuracy=55.14,train_auc=73.80
test_accuracy=51.98, test_auc=65.43



100%|██████████| 202/202 [00:11<00:00, 16.94it/s]


(Train) | Epoch=002, loss=0.1722, train_accuracy=51.55,train_auc=77.07
test_accuracy=50.50, test_auc=64.89



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=003, loss=0.1619, train_accuracy=75.71,train_auc=84.08
test_accuracy=59.41, test_auc=65.13



100%|██████████| 202/202 [00:11<00:00, 16.94it/s]


(Train) | Epoch=004, loss=0.1476, train_accuracy=81.78,train_auc=89.31
test_accuracy=63.86, test_auc=67.17



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=005, loss=0.1369, train_accuracy=76.33,train_auc=90.91
test_accuracy=57.43, test_auc=68.78



100%|██████████| 202/202 [00:11<00:00, 16.89it/s]


(Train) | Epoch=006, loss=0.1066, train_accuracy=84.14,train_auc=93.34
test_accuracy=60.40, test_auc=64.94



100%|██████████| 202/202 [00:11<00:00, 16.90it/s]


(Train) | Epoch=007, loss=0.0941, train_accuracy=82.78,train_auc=96.97
test_accuracy=55.94, test_auc=62.23



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=008, loss=0.0780, train_accuracy=91.70,train_auc=98.93
test_accuracy=60.40, test_auc=67.00



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=009, loss=0.0606, train_accuracy=91.45,train_auc=98.09
test_accuracy=58.91, test_auc=65.69



100%|██████████| 202/202 [00:11<00:00, 16.89it/s]


(Train) | Epoch=000, loss=0.1968, train_accuracy=51.92,train_auc=70.38
test_accuracy=48.02, test_auc=60.98



100%|██████████| 202/202 [00:11<00:00, 17.11it/s]


(Train) | Epoch=001, loss=0.1802, train_accuracy=51.92,train_auc=69.82
test_accuracy=48.02, test_auc=59.37



100%|██████████| 202/202 [00:11<00:00, 16.99it/s]


(Train) | Epoch=002, loss=0.1769, train_accuracy=51.05,train_auc=74.23
test_accuracy=53.96, test_auc=63.24



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=003, loss=0.1722, train_accuracy=70.88,train_auc=79.17
test_accuracy=63.37, test_auc=66.20



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=004, loss=0.1555, train_accuracy=71.38,train_auc=79.28
test_accuracy=58.91, test_auc=66.34



100%|██████████| 202/202 [00:11<00:00, 17.03it/s]


(Train) | Epoch=005, loss=0.1440, train_accuracy=73.73,train_auc=85.17
test_accuracy=57.43, test_auc=63.89



100%|██████████| 202/202 [00:11<00:00, 16.97it/s]


(Train) | Epoch=006, loss=0.1471, train_accuracy=77.82,train_auc=89.87
test_accuracy=56.93, test_auc=65.11



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=007, loss=0.1313, train_accuracy=84.26,train_auc=93.23
test_accuracy=60.40, test_auc=66.03



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=008, loss=0.1003, train_accuracy=82.90,train_auc=95.24
test_accuracy=58.91, test_auc=62.31



100%|██████████| 202/202 [00:11<00:00, 17.00it/s]


(Train) | Epoch=009, loss=0.0792, train_accuracy=93.68,train_auc=98.84
test_accuracy=60.89, test_auc=66.94



100%|██████████| 202/202 [00:11<00:00, 17.06it/s]


(Train) | Epoch=000, loss=0.1895, train_accuracy=48.70,train_auc=65.69
test_accuracy=49.50, test_auc=58.95



100%|██████████| 202/202 [00:11<00:00, 17.01it/s]


(Train) | Epoch=001, loss=0.1815, train_accuracy=51.30,train_auc=73.05
test_accuracy=50.50, test_auc=66.72



100%|██████████| 202/202 [00:11<00:00, 16.97it/s]


(Train) | Epoch=002, loss=0.1718, train_accuracy=50.19,train_auc=76.12
test_accuracy=49.50, test_auc=68.01



100%|██████████| 202/202 [00:11<00:00, 16.94it/s]


(Train) | Epoch=003, loss=0.1588, train_accuracy=71.62,train_auc=80.63
test_accuracy=65.35, test_auc=68.70



100%|██████████| 202/202 [00:11<00:00, 16.98it/s]


(Train) | Epoch=004, loss=0.1489, train_accuracy=78.31,train_auc=88.91
test_accuracy=64.36, test_auc=71.99



100%|██████████| 202/202 [00:11<00:00, 17.01it/s]


(Train) | Epoch=005, loss=0.1310, train_accuracy=72.37,train_auc=92.30
test_accuracy=60.89, test_auc=74.17



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=006, loss=0.1122, train_accuracy=80.42,train_auc=94.17
test_accuracy=64.85, test_auc=70.60



100%|██████████| 202/202 [00:12<00:00, 16.69it/s]


(Train) | Epoch=007, loss=0.0992, train_accuracy=84.76,train_auc=96.90
test_accuracy=64.85, test_auc=74.19



100%|██████████| 202/202 [00:11<00:00, 17.06it/s]


(Train) | Epoch=008, loss=0.0737, train_accuracy=93.93,train_auc=99.13
test_accuracy=69.31, test_auc=74.45



100%|██████████| 202/202 [00:11<00:00, 17.02it/s]


(Train) | Epoch=009, loss=0.0600, train_accuracy=97.52,train_auc=99.74
test_accuracy=70.30, test_auc=75.43



100%|██████████| 202/202 [00:11<00:00, 17.06it/s]


(Train) | Epoch=000, loss=0.1857, train_accuracy=49.94,train_auc=66.24
test_accuracy=48.02, test_auc=55.46



100%|██████████| 202/202 [00:11<00:00, 16.98it/s]


(Train) | Epoch=001, loss=0.1819, train_accuracy=55.27,train_auc=74.49
test_accuracy=52.97, test_auc=68.40



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=002, loss=0.1759, train_accuracy=66.17,train_auc=75.81
test_accuracy=62.38, test_auc=68.17



100%|██████████| 202/202 [00:11<00:00, 16.86it/s]


(Train) | Epoch=003, loss=0.1633, train_accuracy=74.23,train_auc=82.63
test_accuracy=63.86, test_auc=71.17



100%|██████████| 202/202 [00:11<00:00, 16.94it/s]


(Train) | Epoch=004, loss=0.1537, train_accuracy=77.82,train_auc=87.30
test_accuracy=65.35, test_auc=70.55



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=005, loss=0.1370, train_accuracy=67.41,train_auc=92.01
test_accuracy=59.90, test_auc=70.05



100%|██████████| 202/202 [00:11<00:00, 16.90it/s]


(Train) | Epoch=006, loss=0.1214, train_accuracy=82.53,train_auc=92.21
test_accuracy=60.40, test_auc=67.54



100%|██████████| 202/202 [00:11<00:00, 16.85it/s]


(Train) | Epoch=007, loss=0.1019, train_accuracy=87.73,train_auc=95.27
test_accuracy=63.37, test_auc=67.44



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=008, loss=0.0908, train_accuracy=91.82,train_auc=97.94
test_accuracy=60.40, test_auc=65.99



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=009, loss=0.0722, train_accuracy=92.57,train_auc=98.71
test_accuracy=60.89, test_auc=67.00



100%|██████████| 202/202 [00:11<00:00, 16.88it/s]


(Train) | Epoch=000, loss=0.1922, train_accuracy=50.56,train_auc=63.56
test_accuracy=53.47, test_auc=62.90



100%|██████████| 202/202 [00:11<00:00, 16.89it/s]


(Train) | Epoch=001, loss=0.1828, train_accuracy=49.44,train_auc=70.28
test_accuracy=46.53, test_auc=67.16



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=002, loss=0.1766, train_accuracy=65.92,train_auc=73.73
test_accuracy=63.86, test_auc=70.51



100%|██████████| 202/202 [00:11<00:00, 16.90it/s]


(Train) | Epoch=003, loss=0.1777, train_accuracy=50.56,train_auc=75.28
test_accuracy=53.47, test_auc=73.91



100%|██████████| 202/202 [00:11<00:00, 16.94it/s]


(Train) | Epoch=004, loss=0.1684, train_accuracy=66.29,train_auc=75.47
test_accuracy=64.85, test_auc=72.79



100%|██████████| 202/202 [00:12<00:00, 16.81it/s]


(Train) | Epoch=005, loss=0.1638, train_accuracy=72.86,train_auc=82.52
test_accuracy=66.34, test_auc=71.97



100%|██████████| 202/202 [00:12<00:00, 16.81it/s]


(Train) | Epoch=006, loss=0.1445, train_accuracy=71.50,train_auc=87.95
test_accuracy=67.33, test_auc=73.55



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=007, loss=0.1264, train_accuracy=82.53,train_auc=91.78
test_accuracy=68.81, test_auc=74.15



100%|██████████| 202/202 [00:11<00:00, 17.02it/s]


(Train) | Epoch=008, loss=0.1189, train_accuracy=71.38,train_auc=95.50
test_accuracy=58.91, test_auc=74.01



100%|██████████| 202/202 [00:11<00:00, 16.95it/s]


(Train) | Epoch=009, loss=0.0982, train_accuracy=88.72,train_auc=97.78
test_accuracy=62.87, test_auc=72.06



100%|██████████| 202/202 [00:11<00:00, 16.97it/s]


(Train) | Epoch=000, loss=0.1937, train_accuracy=50.68,train_auc=62.25
test_accuracy=52.97, test_auc=59.58



100%|██████████| 202/202 [00:11<00:00, 16.97it/s]


(Train) | Epoch=001, loss=0.1811, train_accuracy=50.68,train_auc=70.58
test_accuracy=52.97, test_auc=63.43



100%|██████████| 202/202 [00:11<00:00, 16.89it/s]


(Train) | Epoch=002, loss=0.1762, train_accuracy=61.34,train_auc=73.62
test_accuracy=55.45, test_auc=61.42



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=003, loss=0.1699, train_accuracy=49.32,train_auc=76.98
test_accuracy=47.03, test_auc=65.36



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=004, loss=0.1652, train_accuracy=73.85,train_auc=82.12
test_accuracy=65.84, test_auc=72.32



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=005, loss=0.1532, train_accuracy=60.10,train_auc=85.23
test_accuracy=53.96, test_auc=73.42



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=006, loss=0.1374, train_accuracy=73.61,train_auc=89.59
test_accuracy=62.38, test_auc=69.65



100%|██████████| 202/202 [00:11<00:00, 17.01it/s]


(Train) | Epoch=007, loss=0.1218, train_accuracy=85.25,train_auc=93.07
test_accuracy=63.86, test_auc=72.77



100%|██████████| 202/202 [00:11<00:00, 17.05it/s]


(Train) | Epoch=008, loss=0.1001, train_accuracy=88.97,train_auc=97.23
test_accuracy=65.35, test_auc=72.95



100%|██████████| 202/202 [00:11<00:00, 17.04it/s]


(Train) | Epoch=009, loss=0.0772, train_accuracy=87.11,train_auc=96.99
test_accuracy=60.40, test_auc=70.97



100%|██████████| 202/202 [00:11<00:00, 17.07it/s]


(Train) | Epoch=000, loss=0.1895, train_accuracy=50.68,train_auc=69.57
test_accuracy=52.97, test_auc=68.51



100%|██████████| 202/202 [00:11<00:00, 17.07it/s]


(Train) | Epoch=001, loss=0.1794, train_accuracy=50.81,train_auc=73.95
test_accuracy=52.48, test_auc=73.74



100%|██████████| 202/202 [00:12<00:00, 16.64it/s]


(Train) | Epoch=002, loss=0.1744, train_accuracy=50.68,train_auc=65.61
test_accuracy=52.97, test_auc=61.85



100%|██████████| 202/202 [00:11<00:00, 17.11it/s]


(Train) | Epoch=003, loss=0.1768, train_accuracy=50.68,train_auc=72.17
test_accuracy=52.97, test_auc=67.95



100%|██████████| 202/202 [00:11<00:00, 17.03it/s]


(Train) | Epoch=004, loss=0.1689, train_accuracy=68.53,train_auc=76.59
test_accuracy=63.37, test_auc=69.76



100%|██████████| 202/202 [00:11<00:00, 17.05it/s]


(Train) | Epoch=005, loss=0.1530, train_accuracy=75.22,train_auc=85.38
test_accuracy=61.39, test_auc=73.24



100%|██████████| 202/202 [00:11<00:00, 17.05it/s]


(Train) | Epoch=006, loss=0.1392, train_accuracy=75.71,train_auc=89.58
test_accuracy=64.36, test_auc=71.68



100%|██████████| 202/202 [00:11<00:00, 17.01it/s]


(Train) | Epoch=007, loss=0.1200, train_accuracy=56.38,train_auc=88.87
test_accuracy=50.50, test_auc=68.21



100%|██████████| 202/202 [00:11<00:00, 17.06it/s]


(Train) | Epoch=008, loss=0.1083, train_accuracy=84.76,train_auc=94.79
test_accuracy=66.34, test_auc=71.16



100%|██████████| 202/202 [00:11<00:00, 17.02it/s]


(Train) | Epoch=009, loss=0.0786, train_accuracy=90.09,train_auc=96.89
test_accuracy=67.82, test_auc=72.31



100%|██████████| 202/202 [00:11<00:00, 17.11it/s]


(Train) | Epoch=000, loss=0.1978, train_accuracy=63.57,train_auc=68.65
test_accuracy=57.92, test_auc=62.48



100%|██████████| 202/202 [00:11<00:00, 17.06it/s]


(Train) | Epoch=001, loss=0.1806, train_accuracy=50.93,train_auc=72.94
test_accuracy=51.98, test_auc=68.50



100%|██████████| 202/202 [00:11<00:00, 17.10it/s]


(Train) | Epoch=002, loss=0.1751, train_accuracy=50.93,train_auc=73.82
test_accuracy=51.98, test_auc=73.53



100%|██████████| 202/202 [00:11<00:00, 17.11it/s]


(Train) | Epoch=003, loss=0.1717, train_accuracy=66.54,train_auc=77.69
test_accuracy=63.37, test_auc=73.15



100%|██████████| 202/202 [00:11<00:00, 17.05it/s]


(Train) | Epoch=004, loss=0.1536, train_accuracy=54.52,train_auc=85.45
test_accuracy=52.97, test_auc=76.74



100%|██████████| 202/202 [00:11<00:00, 17.03it/s]


(Train) | Epoch=005, loss=0.1441, train_accuracy=79.06,train_auc=88.83
test_accuracy=71.78, test_auc=77.38



100%|██████████| 202/202 [00:11<00:00, 17.07it/s]


(Train) | Epoch=006, loss=0.1330, train_accuracy=81.29,train_auc=92.50
test_accuracy=67.33, test_auc=74.95



100%|██████████| 202/202 [00:11<00:00, 17.04it/s]


(Train) | Epoch=007, loss=0.1061, train_accuracy=83.89,train_auc=95.89
test_accuracy=70.30, test_auc=76.05



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=008, loss=0.0853, train_accuracy=86.62,train_auc=96.09
test_accuracy=64.85, test_auc=67.60



100%|██████████| 202/202 [00:12<00:00, 16.72it/s]


(Train) | Epoch=009, loss=0.0696, train_accuracy=90.58,train_auc=99.26
test_accuracy=63.86, test_auc=74.06



100%|██████████| 202/202 [00:12<00:00, 16.78it/s]


(Train) | Epoch=000, loss=0.1897, train_accuracy=50.43,train_auc=71.54
test_accuracy=53.96, test_auc=64.16



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=001, loss=0.1760, train_accuracy=62.83,train_auc=66.49
test_accuracy=57.43, test_auc=62.86



100%|██████████| 202/202 [00:11<00:00, 17.00it/s]


(Train) | Epoch=002, loss=0.1769, train_accuracy=62.08,train_auc=74.14
test_accuracy=61.39, test_auc=68.08



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=003, loss=0.1669, train_accuracy=50.43,train_auc=77.62
test_accuracy=53.96, test_auc=67.15



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=004, loss=0.1566, train_accuracy=72.61,train_auc=84.97
test_accuracy=60.40, test_auc=68.13



100%|██████████| 202/202 [00:11<00:00, 16.88it/s]


(Train) | Epoch=005, loss=0.1391, train_accuracy=61.21,train_auc=84.04
test_accuracy=58.91, test_auc=69.11



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=006, loss=0.1290, train_accuracy=85.01,train_auc=92.96
test_accuracy=63.86, test_auc=70.77



100%|██████████| 202/202 [00:11<00:00, 16.97it/s]


(Train) | Epoch=007, loss=0.1042, train_accuracy=88.48,train_auc=95.50
test_accuracy=64.85, test_auc=70.56



100%|██████████| 202/202 [00:11<00:00, 16.96it/s]


(Train) | Epoch=008, loss=0.0879, train_accuracy=93.31,train_auc=98.21
test_accuracy=67.33, test_auc=68.99



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=009, loss=0.0718, train_accuracy=91.57,train_auc=98.21
test_accuracy=62.38, test_auc=65.27



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=000, loss=0.1929, train_accuracy=54.65,train_auc=62.09
test_accuracy=49.01, test_auc=52.88



100%|██████████| 202/202 [00:11<00:00, 17.04it/s]


(Train) | Epoch=001, loss=0.1781, train_accuracy=57.50,train_auc=69.35
test_accuracy=51.49, test_auc=59.44



100%|██████████| 202/202 [00:11<00:00, 17.06it/s]


(Train) | Epoch=002, loss=0.1782, train_accuracy=57.74,train_auc=74.23
test_accuracy=50.50, test_auc=60.99



100%|██████████| 202/202 [00:11<00:00, 16.88it/s]


(Train) | Epoch=003, loss=0.1657, train_accuracy=72.00,train_auc=80.50
test_accuracy=60.40, test_auc=68.15



100%|██████████| 202/202 [00:12<00:00, 16.83it/s]


(Train) | Epoch=004, loss=0.1493, train_accuracy=68.28,train_auc=82.70
test_accuracy=54.95, test_auc=64.70



100%|██████████| 202/202 [00:11<00:00, 16.98it/s]


(Train) | Epoch=005, loss=0.1446, train_accuracy=79.31,train_auc=88.92
test_accuracy=61.88, test_auc=70.69



100%|██████████| 202/202 [00:11<00:00, 16.84it/s]


(Train) | Epoch=006, loss=0.1255, train_accuracy=85.63,train_auc=94.57
test_accuracy=59.90, test_auc=69.74



100%|██████████| 202/202 [00:12<00:00, 16.81it/s]


(Train) | Epoch=007, loss=0.1068, train_accuracy=80.30,train_auc=97.42
test_accuracy=64.36, test_auc=69.50



100%|██████████| 202/202 [00:11<00:00, 16.85it/s]


(Train) | Epoch=008, loss=0.0763, train_accuracy=94.55,train_auc=98.70
test_accuracy=61.88, test_auc=69.41



100%|██████████| 202/202 [00:12<00:00, 16.79it/s]


(Train) | Epoch=009, loss=0.0677, train_accuracy=93.68,train_auc=98.51
test_accuracy=60.89, test_auc=67.48



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=000, loss=0.1955, train_accuracy=52.66,train_auc=69.59
test_accuracy=45.05, test_auc=63.25



100%|██████████| 202/202 [00:11<00:00, 16.95it/s]


(Train) | Epoch=001, loss=0.1796, train_accuracy=47.34,train_auc=73.02
test_accuracy=54.95, test_auc=65.43



100%|██████████| 202/202 [00:11<00:00, 17.00it/s]


(Train) | Epoch=002, loss=0.1785, train_accuracy=47.58,train_auc=72.80
test_accuracy=54.95, test_auc=63.77



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=003, loss=0.1746, train_accuracy=59.60,train_auc=77.97
test_accuracy=60.89, test_auc=69.95



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=004, loss=0.1681, train_accuracy=58.49,train_auc=79.87
test_accuracy=48.51, test_auc=69.60



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=005, loss=0.1562, train_accuracy=75.84,train_auc=83.32
test_accuracy=62.38, test_auc=70.36



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=006, loss=0.1454, train_accuracy=79.18,train_auc=89.57
test_accuracy=58.91, test_auc=69.84



100%|██████████| 202/202 [00:12<00:00, 16.81it/s]


(Train) | Epoch=007, loss=0.1257, train_accuracy=80.55,train_auc=89.82
test_accuracy=56.44, test_auc=68.24



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=008, loss=0.1125, train_accuracy=66.42,train_auc=92.96
test_accuracy=52.48, test_auc=74.15



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=009, loss=0.0843, train_accuracy=93.18,train_auc=98.00
test_accuracy=68.32, test_auc=71.55



100%|██████████| 202/202 [00:11<00:00, 17.08it/s]


(Train) | Epoch=000, loss=0.1976, train_accuracy=48.08,train_auc=62.95
test_accuracy=51.98, test_auc=54.58



100%|██████████| 202/202 [00:11<00:00, 17.08it/s]


(Train) | Epoch=001, loss=0.1793, train_accuracy=56.63,train_auc=66.14
test_accuracy=51.49, test_auc=57.12



100%|██████████| 202/202 [00:11<00:00, 17.14it/s]


(Train) | Epoch=002, loss=0.1776, train_accuracy=48.08,train_auc=75.88
test_accuracy=51.98, test_auc=71.52



100%|██████████| 202/202 [00:11<00:00, 17.14it/s]


(Train) | Epoch=003, loss=0.1664, train_accuracy=53.04,train_auc=76.75
test_accuracy=49.01, test_auc=70.39



100%|██████████| 202/202 [00:11<00:00, 16.84it/s]


(Train) | Epoch=004, loss=0.1638, train_accuracy=51.92,train_auc=78.54
test_accuracy=54.95, test_auc=71.73



100%|██████████| 202/202 [00:11<00:00, 17.18it/s]


(Train) | Epoch=005, loss=0.1511, train_accuracy=78.31,train_auc=88.46
test_accuracy=70.79, test_auc=75.96



100%|██████████| 202/202 [00:11<00:00, 17.23it/s]


(Train) | Epoch=006, loss=0.1479, train_accuracy=77.70,train_auc=89.05
test_accuracy=64.85, test_auc=71.10



100%|██████████| 202/202 [00:11<00:00, 17.17it/s]


(Train) | Epoch=007, loss=0.1223, train_accuracy=79.68,train_auc=94.57
test_accuracy=63.37, test_auc=75.19



100%|██████████| 202/202 [00:11<00:00, 17.12it/s]


(Train) | Epoch=008, loss=0.1063, train_accuracy=89.10,train_auc=96.91
test_accuracy=66.83, test_auc=75.00



100%|██████████| 202/202 [00:11<00:00, 17.12it/s]


(Train) | Epoch=009, loss=0.0791, train_accuracy=92.81,train_auc=97.72
test_accuracy=66.83, test_auc=74.67



100%|██████████| 202/202 [00:11<00:00, 17.24it/s]


(Train) | Epoch=000, loss=0.1885, train_accuracy=49.32,train_auc=67.92
test_accuracy=47.03, test_auc=55.61



100%|██████████| 202/202 [00:11<00:00, 17.15it/s]


(Train) | Epoch=001, loss=0.1744, train_accuracy=50.68,train_auc=68.91
test_accuracy=52.97, test_auc=56.96



100%|██████████| 202/202 [00:11<00:00, 16.97it/s]


(Train) | Epoch=002, loss=0.1719, train_accuracy=68.65,train_auc=79.12
test_accuracy=56.44, test_auc=65.46



100%|██████████| 202/202 [00:11<00:00, 17.22it/s]


(Train) | Epoch=003, loss=0.1689, train_accuracy=62.21,train_auc=81.25
test_accuracy=55.45, test_auc=66.20



100%|██████████| 202/202 [00:11<00:00, 17.07it/s]


(Train) | Epoch=004, loss=0.1502, train_accuracy=76.08,train_auc=86.10
test_accuracy=59.41, test_auc=65.29



100%|██████████| 202/202 [00:11<00:00, 17.21it/s]


(Train) | Epoch=005, loss=0.1393, train_accuracy=67.04,train_auc=89.28
test_accuracy=51.49, test_auc=65.20



100%|██████████| 202/202 [00:11<00:00, 17.17it/s]


(Train) | Epoch=006, loss=0.1238, train_accuracy=85.25,train_auc=93.74
test_accuracy=61.39, test_auc=65.51



100%|██████████| 202/202 [00:11<00:00, 17.00it/s]


(Train) | Epoch=007, loss=0.0943, train_accuracy=92.57,train_auc=97.75
test_accuracy=63.37, test_auc=66.76



100%|██████████| 202/202 [00:12<00:00, 16.70it/s]


(Train) | Epoch=008, loss=0.0840, train_accuracy=90.95,train_auc=98.57
test_accuracy=61.39, test_auc=67.03



100%|██████████| 202/202 [00:12<00:00, 16.79it/s]


(Train) | Epoch=009, loss=0.0593, train_accuracy=93.56,train_auc=99.40
test_accuracy=63.37, test_auc=68.28



100%|██████████| 202/202 [00:11<00:00, 17.01it/s]


(Train) | Epoch=000, loss=0.1856, train_accuracy=51.67,train_auc=69.84
test_accuracy=49.01, test_auc=67.05



100%|██████████| 202/202 [00:12<00:00, 16.49it/s]


(Train) | Epoch=001, loss=0.1826, train_accuracy=51.67,train_auc=71.74
test_accuracy=49.01, test_auc=64.37



100%|██████████| 202/202 [00:12<00:00, 16.82it/s]


(Train) | Epoch=002, loss=0.1738, train_accuracy=67.78,train_auc=76.50
test_accuracy=61.39, test_auc=68.13



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=003, loss=0.1703, train_accuracy=58.74,train_auc=79.87
test_accuracy=53.96, test_auc=68.04



100%|██████████| 202/202 [00:11<00:00, 16.83it/s]


(Train) | Epoch=004, loss=0.1530, train_accuracy=77.82,train_auc=85.66
test_accuracy=64.36, test_auc=71.09



100%|██████████| 202/202 [00:11<00:00, 16.89it/s]


(Train) | Epoch=005, loss=0.1440, train_accuracy=79.43,train_auc=90.82
test_accuracy=62.87, test_auc=69.08



100%|██████████| 202/202 [00:11<00:00, 16.94it/s]


(Train) | Epoch=006, loss=0.1176, train_accuracy=80.55,train_auc=94.04
test_accuracy=62.38, test_auc=70.69



100%|██████████| 202/202 [00:11<00:00, 16.88it/s]


(Train) | Epoch=007, loss=0.1049, train_accuracy=91.82,train_auc=97.83
test_accuracy=63.37, test_auc=68.08



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=008, loss=0.0761, train_accuracy=88.85,train_auc=98.22
test_accuracy=60.40, test_auc=67.38



100%|██████████| 202/202 [00:11<00:00, 16.85it/s]


(Train) | Epoch=009, loss=0.0558, train_accuracy=97.15,train_auc=99.74
test_accuracy=66.34, test_auc=70.81



100%|██████████| 202/202 [00:11<00:00, 16.92it/s]


(Train) | Epoch=000, loss=0.1883, train_accuracy=51.67,train_auc=70.31
test_accuracy=49.01, test_auc=69.38



100%|██████████| 202/202 [00:11<00:00, 16.85it/s]


(Train) | Epoch=001, loss=0.1871, train_accuracy=55.27,train_auc=71.65
test_accuracy=52.48, test_auc=67.78



100%|██████████| 202/202 [00:12<00:00, 16.80it/s]


(Train) | Epoch=002, loss=0.1753, train_accuracy=60.84,train_auc=73.39
test_accuracy=55.94, test_auc=68.12



100%|██████████| 202/202 [00:11<00:00, 16.86it/s]


(Train) | Epoch=003, loss=0.1747, train_accuracy=60.84,train_auc=74.91
test_accuracy=56.44, test_auc=68.06



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=004, loss=0.1604, train_accuracy=67.66,train_auc=77.82
test_accuracy=62.38, test_auc=67.98



100%|██████████| 202/202 [00:11<00:00, 16.93it/s]


(Train) | Epoch=005, loss=0.1550, train_accuracy=77.70,train_auc=87.32
test_accuracy=68.32, test_auc=74.16



100%|██████████| 202/202 [00:11<00:00, 16.89it/s]


(Train) | Epoch=006, loss=0.1395, train_accuracy=75.84,train_auc=87.29
test_accuracy=65.84, test_auc=73.13



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=007, loss=0.1190, train_accuracy=83.89,train_auc=96.41
test_accuracy=65.35, test_auc=71.84



100%|██████████| 202/202 [00:12<00:00, 16.55it/s]


(Train) | Epoch=008, loss=0.0931, train_accuracy=88.97,train_auc=98.44
test_accuracy=63.86, test_auc=71.57



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=009, loss=0.0707, train_accuracy=93.43,train_auc=99.03
test_accuracy=61.88, test_auc=69.66



100%|██████████| 202/202 [00:11<00:00, 16.98it/s]


(Train) | Epoch=000, loss=0.2087, train_accuracy=49.32,train_auc=63.26
test_accuracy=47.03, test_auc=54.24



100%|██████████| 202/202 [00:11<00:00, 16.97it/s]


(Train) | Epoch=001, loss=0.1784, train_accuracy=49.32,train_auc=74.28
test_accuracy=47.03, test_auc=61.88



100%|██████████| 202/202 [00:11<00:00, 16.90it/s]


(Train) | Epoch=002, loss=0.1779, train_accuracy=50.68,train_auc=74.68
test_accuracy=52.97, test_auc=64.54



100%|██████████| 202/202 [00:11<00:00, 16.89it/s]


(Train) | Epoch=003, loss=0.1700, train_accuracy=67.04,train_auc=80.03
test_accuracy=60.89, test_auc=67.11



100%|██████████| 202/202 [00:11<00:00, 16.89it/s]


(Train) | Epoch=004, loss=0.1565, train_accuracy=68.28,train_auc=86.68
test_accuracy=62.38, test_auc=68.93



100%|██████████| 202/202 [00:11<00:00, 17.02it/s]


(Train) | Epoch=005, loss=0.1419, train_accuracy=80.67,train_auc=90.48
test_accuracy=60.40, test_auc=68.93



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=006, loss=0.1248, train_accuracy=76.95,train_auc=90.64
test_accuracy=60.40, test_auc=67.23



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=007, loss=0.1124, train_accuracy=85.01,train_auc=95.56
test_accuracy=61.39, test_auc=70.06



100%|██████████| 202/202 [00:11<00:00, 16.86it/s]


(Train) | Epoch=008, loss=0.0868, train_accuracy=92.81,train_auc=98.04
test_accuracy=61.39, test_auc=67.82



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=009, loss=0.0653, train_accuracy=90.46,train_auc=98.43
test_accuracy=59.41, test_auc=67.20

0.63490099009901 0.033825867249298154
0.6995713069017249 0.03316233758192342


# **L1 Decay Experiment**
Attempt to limit the number of nodes being used to make the final prediction my multiplying all of the features of each graph node by a different weight for each node.  These weights would be decayed using l1 weight decay to try to force some to 0.

In [None]:
# cls tokens comes from: https://github.com/yandex-research/rtdl
class TransformerModel(torch.nn.Module):
    def __init__(self, num_nodes=200, num_classes=2):
        super(TransformerModel, self).__init__()
        self.num_nodes = num_nodes
        self.node_fcn = nn.Linear(256, 8)

        self.fcn = nn.Sequential(
            nn.Linear(1600, 256),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(256, 32),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(32, num_classes)
        )

        self.input_embed = nn.Sequential(
            nn.Linear(200, 200),
        )


        d_token = 512
        heads = 8
        initialization = _TokenInitialization.from_str("uniform")
        self.transformer = nn.Sequential(
            nn.Linear(200, d_token),
            CLSToken(
                d_token, initialization
            ),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0)
        )

        self.out = nn.Linear(d_token,num_classes)
        self.filter = nn.Parameter(torch.ones(200).reshape((1,-1)))


    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = data.x.reshape((data.x.shape[0]//200, 200, 200))
        x = x * self.filter.unsqueeze(1)* self.filter.unsqueeze(-1)
        x = self.transformer(x)
        cls_token = x[:, -1]
        return self.out(cls_token)

eval = True
eval_runs = 20
hyperparameter_tuning_split_rand_seed = 0

skf = StratifiedKFold(n_splits=5, shuffle=True)
lr = 5e-5
weight_decay = 1e-4
train_batch_size = 4
test_batch_size = 4
epochs = 10
device = "cuda:0"


# We don't want to use the test_indexes during evaluation since those were used during hyperparameter tuning
if eval:
    train_index, test_index = train_test_split(np.arange(len(y)), test_size=0.2, random_state=hyperparameter_tuning_split_rand_seed)
    dataset_used = dataset[train_index]
else:
    dataset_used = dataset

accuracy = []
auc = []

if eval:
    train_index, test_index = train_test_split(np.arange(len(y)), test_size=0.2)
else:
    train_index, test_index = train_test_split(np.arange(len(y)), test_size=0.2, random_state=hyperparameter_tuning_split_rand_seed)

model = TransformerModel().to(device)
#model.transformer.load_state_dict(stored_model.transformer.state_dict()) -> used previously for fine tuning the self supervised pretrained model
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

train_set, test_set = dataset[train_index], dataset[test_index]

train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=test_batch_size, shuffle=False)

# train
model.eval()
epoch_num = epochs
for i in range(epoch_num):
    loss_all = 0
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y) + 0.001 * torch.norm(model.filter, 1)**2
        loss.backward()
        optimizer.step()
        loss_all += loss.item()
    epoch_loss = loss_all / len(train_loader.dataset)

    train_accuracy, train_auc = evaluate(model, device, train_loader)
    print(f'(Train) | Epoch={i:03d}, loss={epoch_loss:.4f}, '
                f'train_accuracy={(train_accuracy * 100):.2f},'
                f'train_auc={(train_auc * 100):.2f}')
    run["train/accuracy"].append(train_accuracy)
    run["train/auc"].append(train_auc)
    run["train/loss"].append(epoch_loss)

    model.eval()
    if (i + 1) % 1 == 0:
        test_accuracy, test_auc = evaluate(model, device, test_loader)
        print(f'test_accuracy={(test_accuracy * 100):.2f}, ' \
              f'test_auc={(test_auc * 100):.2f}\n')

        run["test/accuracy"].append(test_accuracy)
        run["test/auc"].append(test_auc)


# Rerun evaluation just to make 100% sure the results are accurate
model.eval()
test_accuracy, test_auc = evaluate(model, device, test_loader)
accuracy.append(test_accuracy)
auc.append(test_auc)


accuracy = np.array(accuracy)
auc = np.array(auc)

print(accuracy.mean(), accuracy.std())
print(auc.mean(), auc.std())

100%|██████████| 202/202 [00:12<00:00, 16.42it/s]


(Train) | Epoch=000, loss=10.1129, train_accuracy=48.82,train_auc=68.53
test_accuracy=49.01, test_auc=60.26



100%|██████████| 202/202 [00:12<00:00, 16.45it/s]


(Train) | Epoch=001, loss=9.8948, train_accuracy=51.18,train_auc=70.44
test_accuracy=50.99, test_auc=62.19



100%|██████████| 202/202 [00:12<00:00, 16.80it/s]


(Train) | Epoch=002, loss=9.6946, train_accuracy=51.18,train_auc=68.38
test_accuracy=50.99, test_auc=60.19



100%|██████████| 202/202 [00:11<00:00, 17.04it/s]


(Train) | Epoch=003, loss=9.5012, train_accuracy=66.54,train_auc=74.16
test_accuracy=61.88, test_auc=68.14



100%|██████████| 202/202 [00:11<00:00, 17.00it/s]


(Train) | Epoch=004, loss=9.3045, train_accuracy=55.27,train_auc=73.98
test_accuracy=55.94, test_auc=69.22



100%|██████████| 202/202 [00:11<00:00, 16.91it/s]


(Train) | Epoch=005, loss=9.1095, train_accuracy=69.76,train_auc=78.51
test_accuracy=62.87, test_auc=71.20



100%|██████████| 202/202 [00:12<00:00, 16.83it/s]


(Train) | Epoch=006, loss=8.9120, train_accuracy=68.03,train_auc=83.85
test_accuracy=60.40, test_auc=68.75



100%|██████████| 202/202 [00:11<00:00, 16.87it/s]


(Train) | Epoch=007, loss=8.7130, train_accuracy=78.19,train_auc=91.11
test_accuracy=59.90, test_auc=69.05



100%|██████████| 202/202 [00:11<00:00, 16.88it/s]


(Train) | Epoch=008, loss=8.5176, train_accuracy=83.89,train_auc=94.49
test_accuracy=59.41, test_auc=68.64



100%|██████████| 202/202 [00:11<00:00, 16.90it/s]


(Train) | Epoch=009, loss=8.3106, train_accuracy=85.50,train_auc=97.18
test_accuracy=63.37, test_auc=70.92

0.6336633663366337 0.0
0.7092282043738355 0.0


In [None]:
model.filter.max() - model.filter.min()

tensor(0.0008, device='cuda:0', grad_fn=<SubBackward0>)

In [None]:
model.filter

Parameter containing:
tensor([[0.9010, 0.9010, 0.9010, 0.9010, 0.9012, 0.9010, 0.9010, 0.9010, 0.9010,
         0.9011, 0.9010, 0.9010, 0.9010, 0.9010, 0.9011, 0.9010, 0.9010, 0.9010,
         0.9011, 0.9011, 0.9010, 0.9010, 0.9011, 0.9011, 0.9010, 0.9010, 0.9009,
         0.9011, 0.9010, 0.9012, 0.9011, 0.9011, 0.9010, 0.9010, 0.9011, 0.9010,
         0.9010, 0.9010, 0.9010, 0.9011, 0.9013, 0.9011, 0.9012, 0.9011, 0.9010,
         0.9011, 0.9010, 0.9011, 0.9010, 0.9010, 0.9011, 0.9011, 0.9011, 0.9010,
         0.9010, 0.9011, 0.9010, 0.9010, 0.9011, 0.9011, 0.9011, 0.9010, 0.9011,
         0.9010, 0.9011, 0.9010, 0.9010, 0.9009, 0.9010, 0.9010, 0.9012, 0.9010,
         0.9011, 0.9010, 0.9010, 0.9011, 0.9009, 0.9011, 0.9010, 0.9012, 0.9010,
         0.9011, 0.9011, 0.9011, 0.9010, 0.9009, 0.9010, 0.9010, 0.9010, 0.9011,
         0.9010, 0.9010, 0.9010, 0.9011, 0.9010, 0.9011, 0.9010, 0.9010, 0.9011,
         0.9011, 0.9011, 0.9010, 0.9010, 0.9010, 0.9010, 0.9010, 0.9011, 0.9011,
      

# **Self supervision**
We tried to pretrain on an auxillary tasks and then fine tune on the target task, but the results were worse when pretraining.

In [None]:
all_data = []
for i in range(len(dataset)):
  all_data.append(dataset[i].x.numpy())

all_data = np.concatenate(all_data,axis=0)
mean, std = all_data.mean(), all_data.std()

In [None]:
mean = all_data.mean(axis=0)
cov = np.cov(all_data, rowvar=False)
def sample_vals(n = 16):
    # Generate samples
    return np.random.multivariate_normal(mean, cov, n*200).reshape((n, 200, 200))

labels = []
graphs = []

for _ in tqdm(range(10000)):
    sample = sample_vals(1)[0]
    upper_triangle_indices = np.triu_indices(200, k=1)
    lower_triangle_indices = upper_triangle_indices[1],upper_triangle_indices[0]
    sample[lower_triangle_indices] = sample[upper_triangle_indices]
    np.fill_diagonal(sample, 0)

    graphs.append(sample)
    G = nx.from_numpy_matrix(np.matrix(sample), create_using=nx.Graph)
    val1 = np.array(list(nx.eigenvector_centrality_numpy(G, weight="weight").values()))
    val2 = np.array(list(nx.katz_centrality_numpy(G, weight="weight").values()))
    val3 = np.array(list(nx.current_flow_closeness_centrality(G, weight="weight").values()))
    labels.append([val1, val2, val3])


np.save("/content/drive/MyDrive/comp_bio/random_graphs.npy", np.stack(graphs,axis=0))
np.save("/content/drive/MyDrive/comp_bio/random_labels.npy", np.stack(labels,axis=0))
graphs = np.load("/content/drive/MyDrive/comp_bio/random_graphs.npy")
labels = np.load("/content/drive/MyDrive/comp_bio/random_labels.npy")
labels = (labels - labels.mean(axis=0)) / labels.std(axis=0)
max_val = torch.Tensor(labels).quantile(.99).item()
min_val = torch.Tensor(labels).quantile(.01).item()
labels[labels>max_val] = max_val
labels[labels<min_val] = min_val

class CustomDataset(Dataset):
    def __init__(self, graphs, labels):
        super(CustomDataset, self).__init__()
        self.graphs = graphs
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        adj = torch.Tensor(self.graphs[idx]).float()
        y = torch.Tensor(self.labels[idx]).float()
        num_nodes = 200
        return adj,y

num_test = int(len(labels) * 0.05)
dataset1 = CustomDataset(graphs[num_test:], labels[num_test:])
dataset2 = CustomDataset(graphs[:num_test], labels[:num_test])

100%|██████████| 10000/10000 [1:50:07<00:00,  1.51it/s]


In [None]:

class CustomNetwork(torch.nn.Module):
    def __init__(self, gnn, num_nodes=200, num_classes=2):
        super(CustomNetwork, self).__init__()
        self.gnn = gnn

        d_token = 512
        heads = 8
        initialization = _TokenInitialization.from_str("uniform")
        self.transformer = nn.Sequential(
            nn.Linear(200, d_token),
            CLSToken(
                d_token, initialization
            ),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0),
            TransformerBlock(d_token, heads, 0, 0)
        )

        self.num_nodes = num_nodes

        node_out_size = d_token
        self.node_output = nn.ModuleList(
            [nn.Sequential(
                nn.Linear(node_out_size, node_out_size),
                nn.ReLU(),
                nn.Linear(node_out_size, node_out_size),
                nn.ReLU(),
                nn.Linear(node_out_size, 1)
            ),
            nn.Sequential(
                nn.Linear(node_out_size, node_out_size),
                nn.ReLU(),
                nn.Linear(node_out_size, node_out_size),
                nn.ReLU(),
                nn.Linear(node_out_size, 1)
            ),
            nn.Sequential(
                nn.Linear(node_out_size, node_out_size),
                nn.ReLU(),
                nn.Linear(node_out_size, node_out_size),
                nn.ReLU(),
                nn.Linear(node_out_size, 1)
            )]
        )

        self.input_embed = nn.Sequential(
            nn.Linear(200, 200),
        )


    def forward(self, x):
        x = self.transformer(x)
        node_tokens = x[:, :-1]
        out_tensors = []
        for out in self.node_output:
            out_val = out(node_tokens)
            out_tensors.append(out_val)

        return out_tensors


run = neptune.init_run(project="LimitedSupervision/CompBio",
    api_token=userdata.get('NEPTUNE_SECRET'))

lr = 5e-5
weight_decay = 0
train_batch_size = 32
test_batch_size = 32
epochs = 100

train_dataloader = torch_geometric.data.DataLoader(dataset1, batch_size=train_batch_size, shuffle=True)
test_dataloader = torch_geometric.data.DataLoader(dataset2, batch_size=test_batch_size, shuffle=False)

gcn = GCN(num_features, nodes_num, num_classes=2, pooling="sum", num_layers = 8)
model = CustomNetwork(gcn).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.MSELoss()

device="cuda:0"
accs, aucs, macros = [], [], []
epoch_num = epochs
for i in range(epoch_num):
    loss_all = 0
    model.eval()
    for j, (x, y) in enumerate(tqdm(train_dataloader)):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        out = model(x)

        loss = 0
        for j in range(len(y[0])):
            loss += criterion(out[j].reshape(-1), y[:, j].reshape(-1))/3

        loss.backward()

        optimizer.step()
        loss_all += loss.item()

        #if j>1:
        #    break

    epoch_loss = loss_all / len(train_dataloader.dataset)

    print(f'(Train) | Epoch={i:03d}, loss={epoch_loss:.4f}')
    run["train/loss"].append(epoch_loss)

    loss_all = 0
    model.eval()
    if (i + 1) % 1 == 0:
        with torch.no_grad():
            for j, (x, y) in enumerate(test_dataloader):
                x = x.to(device)
                y = y.to(device)
                out = model(x)
                loss = 0
                for j in range(len(y[0])):
                    loss += criterion(out[j].reshape(-1), y[:, j].reshape(-1))
                    break
                loss_all += loss.item()
                #if j>1:
                #    break
        epoch_loss = loss_all / len(test_dataloader.dataset)

        print(f'(Test) | Epoch={i:03d}, loss={epoch_loss:.4f}')
        run["test/loss"].append(epoch_loss)

torch.save(model.state_dict(), "model.pth")
stored_model = model

SecretNotFoundError: Secret NEPTUNE_SECRET does not exist.