In [None]:
import pandas as pd
import numpy as np
import random
import re
import pickle 
from collections import defaultdict, Counter
from scipy.linalg import fractional_matrix_power
import warnings
import json
import math
from scipy import stats

import sklearn
from sklearn import metrics, naive_bayes
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, precision_score, recall_score, accuracy_score, f1_score, confusion_matrix, roc_curve, balanced_accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression, SGDClassifier
from xgboost import XGBClassifier

import importlib

import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CyclicLR
from torch.utils.data import Dataset, DataLoader, RandomSampler

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from matplotlib.cbook import boxplot_stats
from matplotlib.ticker import PercentFormatter
from matplotlib.font_manager import FontProperties
import shap

from base_utils import get_new_cols, check_filename, drop_useless_cols, load_data, write_file
from preprocessing_utils import remove_rarely_mutated_genes, log1p_mutations, remove_nearly_constant, apply_scaler, COMBAT, CorrelationAnalysis, ExpressionStabilitySelector

paths = {
    'gbm': {'cbio': {'firehose': 'data/tcga/gbm/cbio_firehose/', 'pancan': 'data/tcga/gbm/cbio_pancan/'}, 'xena': 'data/tcga/gbm/xena/'},
    'lgg': {'cbio': {'firehose': 'data/tcga/lgg/cbio_firehose/', 'pancan': 'data/tcga/lgg/cbio_pancan/'}, 'xena': 'data/tcga/lgg/xena/'},
    'gbmlgg': {'cbio': 'data/tcga/gbmlgg/cbio/', 'xena': 'data/tcga/gbmlgg/xena/'},
    'gdc': 'data/tcga/gdc/',
    'glass': 'data/glass/'
}

**1. RETRIEVE, CLEAN, AND ENCODE DATA**

In [None]:
ignore_xena = ['radiation_therapy', 'targeted_molecular_therapy', 'primary_therapy_outcome_success', 'chemo_therapy', '_PANCAN_CNA_PANCAN_K8', '_PANCAN_Cluster_Cluster_PANCAN', '_PANCAN_DNAMethyl_PANCAN', '_PANCAN_RPPA_PANCAN_K8', '_PANCAN_mutation_PANCAN', '_PATIENT', 'animal_insect_allergy_history', 'animal_insect_allergy_types', 'asthma_history', 'bcr_followup_barcode', 'bcr_patient_barcode', 'bcr_sample_barcode', 'bcr_patient_uuid', 'eastern_cancer_oncology_group', 'eczema_history', 'first_diagnosis_age_asth_ecz_hay_fev_mold_dust', 'first_diagnosis_age_of_animal_insect_allergy', 'first_diagnosis_age_of_food_allergy', 'followup_case_report_form_submission_reason', 'food_allergy_history', 'food_allergy_types', 'form_completion_date', 'hay_fever_history', 'headache_history', 'initial_pathologic_diagnosis_method', 'initial_weight', 'mental_status_changes', 'mold_or_dust_allergy_history', 'motor_movement_changes', 'oct_embedded', 'person_neoplasm_cancer_status', 'postoperative_rx_tx', 'preoperative_antiseizure_meds', 'preoperative_corticosteroids', 'seizure_history', 'sensory_changes', 'shortest_dimension', 'tissue_prospective_collection_indicator', 'tissue_retrospective_collection_indicator', 'vial_number', 'visual_changes', 'year_of_initial_pathologic_diagnosis', 'In_Cancer_Cell_Paper', 'CDE_sourcesite', 'CDE_suspect', '_PANCAN_DNAMethyl_GBM', 'additional_chemo_therapy', 'additional_drug_therapy', 'additional_immuno_therapy', 'additional_pharmaceutical_therapy', 'additional_radiation_therapy', 'hormonal_therapy', 'immuno_therapy', 'days_to_last_followup', 'GeneExp_Subtype', 'pathology_report_file_name','days_to_collection', 'sample_type_id', '_INTEGRATION', 'CDE_missing', 'CDE_missingflag', 'CDE_previously_treated', 'G_CIMP_STATUS', 'first_presenting_symptom', 'first_presenting_symptom_longest_duration', 'days_to_performance_status_assessment', 'family_history_of_cancer', 'family_history_of_primary_brain_tumor', 'history_ionizing_rt_to_head', 'history_of_neoadjuvant_treatment', 'inherited_genetic_syndrome_result', 'inherited_genetic_syndrome_found', 'intermediate_dimension', 'karnofsky_performance_score', 'ldh1_mutation_found', 'ldh1_mutation_test_method', 'ldh1_mutation_tested', 'longest_dimension', 'patient_id', 'performance_status_scale_timing', 'pretreatment_history', 'prior_glioma', 'supratentorial_localization', 'days_to_birth', 'days_to_death', 'icd_o_3_site', 'other_dx', 'vital_status', '_primary_disease', 'CDE_vital_status', 'CDE_survival_time', 'CDE_DxAge', 'followup_treatment_success', 'sample_type', 'tumor_tissue_site', 'icd_10', 'estimate_stromal_score', 'estimate_immune_score', 'estimate_combined_score']
ignore_cbio = ['radiation_treatment_adjuvant', 'targeted_molecular_therapy', 'treatment_outcome_first_course', 'pharmaceutical_tx_adjuvant', 'other_patient_id', 'form_completion_date', 'prospective_collection', 'retrospective_collection', 'other_sample_id', 'days_to_collection', 'sample_initial_weight', 'oct_embedded', 'pathology_report_file_name', 'pathology_report_uuid', 'sample_type_id', 'shortest_dimension', 'vial_number', 'tumor_status', 'whole_exome_sequencing', 'absolute_extract_ploidy', 'animal_insect_allergy_age', 'animal_insect_allergy_hist', 'asthma_eczema_allergy_first_diagnosis', 'asthma_history', 'atrx_status', 'bcr_status', 'braf_kiaa1549_fusion', 'braf_v600e_status', 'chr_19_20_co_gain', 'daxx_status', 'ecog_score', 'eczema_history', 'family_history_of_cancer', 'family_history_of_primary_brain_tumor', 'first_symptom_longest_duration', 'food_allergy_age', 'food_allergy_history', 'food_allergy_types', 'hay_fever_history', 'headache_history', 'history_ionizing_rt_to_head', 'history_lgg_dx_of_brain_tissue', 'history_neoadjuvant_medication', 'history_neoadjuvant_steroid_tx', 'history_neoadjuvant_trtyn', 'history_other_malignancy', 'hm27', 'hm450', 'idh1_mutation_test_indicator', 'idh1_mutation_test_method', 'tissue_source_site','idh_specific_dna_methylation_cluster', 'idh_specific_rna_expression_cluster', 'inherited_genetic_syndrome_indicator', 'inherited_genetic_syndrome_specified', 'initial_pathologic_dx_year', 'karnofsky_performance_score', 'longest_dimension', 'method_of_initial_sample_procurement', 'mold_or_dust_allergy_history', 'pan_glioma_dna_methylation_cluster', 'pan_glioma_rna_expression_cluster', 'original_subtype', 'percent_aneuploidy', 'performance_status_days_to', 'performance_status_timing', 'purity_absolute', 'random_forest_sturm_cluster', 'related_symptom_first_present', 'rnaseq_data', 'rppa', 'rppa_cluster', 'seizure_history', 'snp6', 'specimen_second_longest_dimension', 'study', 'supervised_dna_methylation_cluster', 'supratentorial_localization', 'symp_changes_mental_status', 'symp_changes_motor_movement', 'symp_changes_sensory', 'symp_changes_visual', 'telomere_length_estimate_in_blood_normal_kb', 'telomere_length_estimate_in_tumor_kb', 'telomere_maintenance', 'tert_expression_log2', 'tert_expression_status', 'tmb_nonsynonymous', 'transcriptome_subtype', 'u133a', 'whole_genome_sequencing', 'icd_o_3_site', 'mgmt_promoter_status', 'os_months', 'sample_type', 'icd_10', 'stromal_score', 'estimate_score', 'immune_score', 'estimate_stromal_score', 'estimate_immune_score', 'estimate_combined_score']
ignore_cdr = ['treatment_outcome_first_course', 'initial_pathologic_dx_year', 'last_contact_days_to', 'tumor_status', 'type']

class GLIOMA:
    def __init__(self, paths):
        self.paths = paths
        self.explicit_cases = set()
        self.fig_tracker = {'tcga': {}, 'glass': {}}
        self.get_hgnc()
        self.get_genomic()
        self.get_clinical()
        self.build_outputs_table()
        self.relabel_WHO_CNS5()
        self.finalize_data()

    def get_hgnc(self):
        hgnc = self._get_data('HGNC_all.txt', dp='data/other/', clean=False).query("Locus_group == 'protein-coding gene'").drop(columns=['Locus_group'])
        hgnc['NCBI_Gene_ID'] = hgnc['NCBI_Gene_ID'].fillna(hgnc['HGNC_NCBI_Gene_ID'])
        hgnc['Ensembl_Gene_ID'] = hgnc['Ensembl_Gene_ID'].fillna(hgnc['Ensembl_Gene_ID_secondary'])
        hgnc['HGNC_ID'] = hgnc['HGNC_ID'].str.replace('HGNC:', '', regex=False)
        self.ncbi_id_dict = dict(zip(hgnc['NCBI_Gene_ID'], hgnc['Approved_symbol']))
        self.hgnc_id_dict = dict(zip(hgnc['HGNC_ID'], hgnc['Approved_symbol']))
        self.ensembl_id_dict = dict(zip(hgnc['Ensembl_Gene_ID'], hgnc['Approved_symbol']))
        self.symbol_to_approved = {}
        for _, row in hgnc.iterrows():
            approved = row['Approved_symbol']
            self.symbol_to_approved[approved] = approved
            for col in ['Previous_symbols', 'Alias_symbols']:
                if pd.notna(row[col]):
                    for s in row[col].split(','):
                        self.symbol_to_approved[s.strip()] = approved

    def get_genomic(self):
        # TCGA ('sample' is the gene column in xena expression data and the sample ID column in xena mutation data)
        self.tcga_expr = self._process_expression_data('xena_mrna_normalized.txt', dp=self.paths['gbmlgg']['xena'], gene_col='sample', by_id=None) 
        self.tcga_mut = self._process_mutation_data('data_mutations.txt', dp=self.paths['gbmlgg']['cbio'], gene_col='Hugo_Symbol', id_col='Tumor_Sample_Barcode', by_id='HGNC_ID')
        self.tcga_mut.index = self.tcga_mut.index.astype(str) + '-01'  # GBMLGG is all primary; I verified on cbioportal, idk why they change upon download
        self.tcga_expr, self.tcga_mut = [self._primary_genomic_only(df, '-01') for df in [self.tcga_expr, self.tcga_mut]]
        self.tcga_pts_with_genomic = list(set(self.tcga_expr.index.tolist()) & set(self.tcga_mut.index.tolist()))
        self.fig_tracker['tcga'].update({'Has primary sample': set(self.tcga_expr.index.tolist() + self.tcga_mut.index.tolist()), 'Has expr': self.tcga_expr.index.tolist(), 'Has mut': self.tcga_mut.index.tolist(), 'Has genomic': self.tcga_pts_with_genomic})
        # GLASS
        self.glass_expr = self._process_expression_data('data_mrna_seq_tpm_zscores_ref_all_samples.txt', dp=self.paths['glass'], gene_col='Hugo_Symbol', by_id=None)
        self.glass_mut = self._process_mutation_data('data_mutations.txt', dp=self.paths['glass'], gene_col='Hugo_Symbol', id_col='Tumor_Sample_Barcode', by_id='Entrez_Gene_Id')
        self.glass_expr, self.glass_mut = [self._primary_genomic_only(df, '-TP') for df in [self.glass_expr, self.glass_mut]]
        self.glass_pts_with_genomic = list(set(self.glass_expr.index.tolist()) & set(self.glass_mut.index.tolist()))
        self.fig_tracker['glass'].update({'Has primary sample': set(self.glass_expr.index.tolist() + self.glass_mut.index.tolist()), 'Has expr': self.glass_expr.index.tolist(), 'Has mut': self.glass_mut.index.tolist(), 'Has genomic': self.glass_pts_with_genomic})
    
    def _primary_genomic_only(self, df, primary_flag):
        return df.loc[df.index.str.endswith(primary_flag)].copy()

    def _process_expression_data(self, filename, dp, gene_col, by_id=None):
        expr_df = self._get_data(filename, dp=dp).dropna(subset=[gene_col])
        expr_df = self._update_gene_symbols(expr_df, gene_col, id_type=by_id)
        expr_df = expr_df.groupby('approved_symbol').mean()
        expr_df.index = list(expr_df.index) # get rid of "approved symbol" label
        expr_df = expr_df.T
        return expr_df.add_suffix('_expr')

    def _process_mutation_data(self, filename, dp, gene_col, id_col, xena_flag=False, by_id=None):
        q = f"{'chr' if xena_flag else 'Chromosome'} not in ['X', 'Y'] and {'effect' if xena_flag else 'Variant_Classification'} not in ['Silent', 'RNA']"
        df = self._get_data(filename, dp=dp, low_memory=False).query(q).dropna(subset=[gene_col])
        df = self._update_gene_symbols(df, gene_col, id_type=by_id)
        if dp == self.paths['gbmlgg']['cbio']: self.tcga_raw_mut = df.copy()
        elif dp == self.paths['glass']: self.glass_raw_mut = df.copy()
        mut_df = pd.crosstab(df[id_col], df['approved_symbol']).astype(float)
        mut_df.columns = [f"{col}_mut" for col in mut_df.columns]
        return mut_df

    def _update_gene_symbols(self, df, gene_col, id_type=None):
        df['approved_symbol'] = df[gene_col].map(self.symbol_to_approved)
        if id_type == 'Entrez_Gene_Id':
            df['approved_symbol'] = df.apply(lambda row: self.ncbi_id_dict.get(row[id_type], row['approved_symbol']) if row[id_type]!=0 else row['approved_symbol'], axis=1)
        elif id_type == 'HGNC_ID':
            df['approved_symbol'] = df['Entrez_Gene_Id'].map(self.ncbi_id_dict).fillna(df['approved_symbol'])
        return df.dropna(subset=['approved_symbol']).drop(columns=[gene_col])

    def get_clinical(self):
        """ TCGA sources - Xena (GBMLGG, GBM, LGG), cBioPortal clinical (GBMLGG, GBM, LGG), GDC (already all combined) """
        self.gbmlgg_xena, self.gbm_xena, self.lgg_xena = [self._load_and_prep_xena(*i) for i in (['gbmlgg','_1'], ['gbm','_1GBM'], ['lgg','_1LGG'])]
        self.tcga_xena = pd.concat([self.gbmlgg_xena, self.gbm_xena, self.lgg_xena], axis=1) # outer by default
        self.gbmlgg_cbio, self.gbm_cbio, self.lgg_cbio = [self._load_and_prep_cbio(*i) for i in (['gbmlgg','_2'], ['gbm','_2GBM'], ['lgg','_2LGG'])]
        self.tcga_cbio = pd.concat([self.gbmlgg_cbio, self.gbm_cbio, self.lgg_cbio], axis=1)
        self.gdc_cdr, self.gdc_followup = [self._load_and_prep_gdc(*i) for i in (['cdr','_3'], ['followup','_4'])]
        self.tcga_clin = pd.concat([self.tcga_xena, self.tcga_cbio, self.gdc_cdr, self.gdc_followup], axis=1, join='outer').dropna(axis=1, how='all')
        self.tcga_clin = self.tcga_clin[sorted(self.tcga_clin.columns)]
        self.tcga_tss = self.tcga_clin['tissue_source_site_1'] # batch labels
        self.fig_tracker['tcga']['Has clin'] = self.tcga_clin.index.tolist()
        self.fig_tracker['tcga']['Has primary sample'].update(self.tcga_clin.index.tolist())
        """ GLASS source - cBioPortal clinical """
        self.glass_clin = self._load_and_prep_cbio('glass', '')
        self.glass_tss = self.glass_clin['tissue_source'] # batch labels
        self.fig_tracker['glass']['Has clin'] = self.glass_clin.index.tolist()
        self.fig_tracker['glass']['Has primary sample'].update(self.glass_clin.index.tolist())
        """ Timelines (LGG doesn't have any relevant status instances) """
        self.gbm_timeline_dx, self.gbm_timeline_rx = [self._get_data(f"data_timeline_{i}.txt", dp=self.paths['gbm']['cbio']['pancan'], lower=True) for i in ['status', 'treatment']]
        self.lgg_timeline_rx = self._get_data('data_timeline_treatment.txt', dp=self.paths['lgg']['cbio']['pancan'], lower=True)
        self.glass_timeline_sx = self._get_data('data_timeline_surgery.txt', dp=self.paths['glass'], lower=True)
        self._consolidate_columns()
        self.tcga_clin = self._limit_df(self.tcga_clin, self.tcga_pts_with_genomic)
        self.glass_clin = self._limit_df(self.glass_clin, self.glass_pts_with_genomic)
        self.tcga_tss, self.glass_tss = self.tcga_tss.loc[self.tcga_clin.index], self.glass_tss.loc[self.glass_clin.index]
        self.fig_tracker['tcga']['Has all 3'], self.fig_tracker['glass']['Has all 3'] = self.tcga_clin.index.tolist(), self.glass_clin.index.tolist()
        # Remove any TCGA patients from the GLASS dataset
        self.fig_tracker['glass']['TCGA-sourced pts removed'] = [i for i in list(self.glass_clin.index) if i.startswith('TCGA')]
        self.glass_clin.drop(index = self.fig_tracker['glass']['TCGA-sourced pts removed'], inplace=True)
        # Coerce (non-timeline) TCGA TTR columns to numerical format (days to additional surgery has a few ints as strings)
        ttr_cols = ['days_to_tumor_recurrence_1GBM', 'DFI.time_3', 'nte_dx_days_to_3', 'days_to_nte_after_initial_treatment_1', 'days_to_nte_after_initial_treatment_4', 'days_to_nte_after_initial_treatment_1GBM', 'days_to_nte_after_initial_treatment_1LGG', 'days_to_nte_additional_surgery_procedure_1', 'days_to_additional_surgery_locoregional_procedure_1', 'days_to_additional_surgery_metastatic_procedure_1']
        self.tcga_clin[ttr_cols] = self.tcga_clin[ttr_cols].apply(pd.to_numeric, errors='coerce')
        # Confirmed no cases of IDHwt & 1p19q co-deletion & filled in combined status for 4 patients using idh_status_2
        self.tcga_clin.loc[['TCGA-12-1597-01', 'TCGA-28-2499-01', 'TCGA-76-4927-01', 'TCGA-76-4932-01'], 'idh_codel_subtype_2'] = 'IDHwt'
        self.tcga_clin.loc[self.tcga_clin['idh_codel_subtype_2']=='IDHmut-non-codel', 'idh_codel_subtype_2'] = 'IDHmut-noncodel'
        self.tcga_clin.drop(columns =['idh_status_2', 'idh_1p19q_subtype_2'], inplace=True)
        self.fig_tracker['tcga']['Missing IDH mutation/codel status (exclude)'] = ['TCGA-06-5417-01']
        self.tcga_clin.drop(index=['TCGA-06-5417-01'], inplace=True)
        self.tcga_clin.loc[self.tcga_clin['histological_type_1'].isin(['Untreated primary (de novo) GBM', 'Glioblastoma Multiforme (GBM)']), 'histological_type_1'] = 'Glioblastoma'

    def _load_and_prep_xena(self, cohort, suffix):
        df = self._get_data('xena_phenotypes.txt', dp=self.paths[cohort]['xena'])
        df = self._set_idx(df.dropna(subset=['sampleID']), idx_col='sampleID')
        df = self._limit_to_primary(df, key=f'{cohort}_xena')
        return self._clean_clinical_dfs(df, suffix, source='xena')

    def _load_and_prep_cbio(self, cohort, suffix):
        path = self.paths[cohort] if cohort == 'glass' else (self.paths[cohort]['cbio'] if cohort == 'gbmlgg' else self.paths[cohort]['cbio']['firehose'])
        patient = self._get_data('data_clinical_patient.txt', dp=path, lower=True)
        sample = self._get_data('data_clinical_sample.txt', dp=path, lower=True)
        if cohort == 'gbmlgg': sample['sample_id'] = sample['patient_id'] + '-01'
        df = patient.merge(sample, on='patient_id')
        df = self._set_idx(df.dropna(subset=['sample_id']), idx_col='sample_id')
        if cohort != 'gbmlgg': 
            df = self._limit_to_primary(df, key=f'{cohort}_cbio') # GBMLGG is all primary; I verified on cbioportal, idk why they change upon download
        return self._clean_clinical_dfs(df, suffix, source='glass') if cohort == 'glass' else self._clean_clinical_dfs(df, suffix, source='cbio')

    def _load_and_prep_gdc(self, kind, suffix):
        if kind == 'cdr': 
            df = load_data('csv', 'tcga_cdr.csv', data_path=paths['gdc'], clean=True, index_col=0).query("type in ['LGG', 'GBM']").dropna(axis=1, how='all')
        else: 
            df = load_data('tsv', 'clinical_PANCAN_patient_with_followup.tsv', data_path=paths['gdc'], clean=True, encoding='latin1', low_memory=False).query("acronym in ['LGG', 'GBM']")
            df = df[['bcr_patient_barcode', 'ethnicity', 'laterality', 'days_to_new_tumor_event_after_initial_treatment', 'prior_glioma']].copy()
        return self._clean_clinical_dfs(df, suffix=suffix, source=kind)

    def _limit_to_primary(self, df, key, id_col='index'):
        # don't need to collect for glass because glass TTR is exclusively based on this approach (we actually don't gain any for TCGA doing this either but it doesn't hurt)
        query_val = 'Recurrent Tumor' if key in ['gbm_xena', 'lgg_xena', 'gbmlgg_xena'] else ('Recurrence' if key in ['gbm_cbio', 'lgg_cbio'] else None)
        if query_val is not None:
            self.explicit_cases.update(list(df.query(f"sample_type == @query_val").index.to_series().str[:-2] + '01'))
        primary_flag = '-TP' if key.startswith('glass') else '-01'
        return df.loc[df.index.str.endswith(primary_flag)].copy() if id_col == 'index' else df.loc[df[id_col].str.endswith(primary_flag)].copy()

    def _clean_clinical_dfs(self, df, suffix, source):
        ignore_cols = {'cbio': ignore_cbio+['patient_id', 'os_status'], 'xena': ignore_xena, 'cdr': ignore_cdr, 'followup': [], 'glass': ignore_cbio}[source]
        new_df = df.drop(columns=[c for c in df.columns if c in ignore_cols or c.startswith('_GENOMIC_ID') or c.startswith('CDE_')]) if source != 'followup' else df.copy()
        if source in ['cdr', 'followup']:
            new_df['sample_id'] = new_df['bcr_patient_barcode'] + '-01' # so we can merge (this data isn't sample specific, it's static to the patient)
            new_df = self._set_idx(new_df, idx_col='sample_id', also_drop=['bcr_patient_barcode'])
        new_df.columns = [f"{col.replace('new_tumor_event', 'nte')}{suffix}" for col in new_df.columns]
        return new_df

    def _combine_column_values(self, df, target_col, fallback_cols):
        return df[[target_col] + fallback_cols].bfill(axis=1).iloc[:, 0]

    def _consolidate_columns(self):
        fill_from = {'grade_2': ['neoplasm_histologic_grade_1'], 'ethnicity_4': ['ethnicity_2LGG','ethnicity_2GBM'], 'sex_2': ['sex_2GBM','sex_2LGG']}
        for target_col, fallback_cols in fill_from.items():
            self.tcga_clin[target_col] = self._combine_column_values(self.tcga_clin, target_col, fallback_cols)
        """ Enable word wrap for details - (A): No missing values for age_at_initial_pathologic_diagnosis (xena == cdr); No missing values for gender (xena == cdr); No missing values for icd_o_3_histology; histological_type_1 = the individuals combined; (B): All the grades missing from grade_2 are in the other grade columns (xena, cdr, and cbio_lgg, which all match); race_3 = the individuals combined; (C): Exact duplicates: histological_type_1LGG/histological_type_2LGG, histological_diagnosis_2GBM/histological_type_1GBM, histological_type_3/histological_type_1, CDE_tmz_chemoradiation_standard_1GBM/CDE_alk_chemoradiation_standard_1GBM, OS.time_3/DSS.time_3, additional_surgery_locoregional_procedure_1LGG/additional_surgery_locoregional_procedure_1, additional_surgery_metastatic_procedure_1LGG/additional_surgery_metastatic_procedure_1, days_to_additional_surgery_locoregional_procedure_1LGG/days_to_additional_surgery_locoregional_procedure_1, days_to_additional_surgery_metastatic_procedure_1LGG/days_to_additional_surgery_metastatic_procedure_1, days_to_nte_additional_surgery_procedure_1GBM/days_to_nte_additional_surgery_procedure_1, laterality_1LGG/laterality_1, laterality_2LGG/laterality_1LGG, new_neoplasm_event_type_1GBM/new_neoplasm_event_type_1, targeted_molecular_therapy_1LGG/targeted_molecular_therapy_1, tumor_location_1LGG/tumor_location_1, tumor_site_2LGG/tumor_location_1LGG """
        to_dropA = ['age_2GBM', 'age_at_initial_pathologic_diagnosis_1GBM', 'age_at_initial_pathologic_diagnosis_1LGG', 'age_2LGG', 'age_2', 'age_at_initial_pathologic_diagnosis_3', 'birth_days_to_3', 'gender_1GBM', 'gender_1LGG', 'gender_3', 'icd_o_3_histology_2LGG', 'icd_o_3_histology_1LGG', 'histological_type_1GBM', 'histological_type_1LGG', 'sex_2GBM', 'sex_2LGG', 'gender_1']
        to_dropB = ['histological_grade_3', 'neoplasm_histologic_grade_1', 'neoplasm_histologic_grade_1LGG', 'grade_2LGG', 'race_2GBM', 'race_2LGG', 'ethnicity_2LGG', 'ethnicity_2GBM']
        to_dropC = ['histological_diagnosis_2GBM', 'histological_diagnosis_2LGG', 'histological_diagnosis_2', 'icd_o_3_histology_1', 'histological_type_3', 'DSS.time_3', 'additional_surgery_locoregional_procedure_1LGG', 'additional_surgery_metastatic_procedure_1LGG', 'days_to_additional_surgery_locoregional_procedure_1LGG', 'days_to_additional_surgery_metastatic_procedure_1LGG', 'days_to_nte_additional_surgery_procedure_1GBM', 'laterality_1LGG', 'laterality_2LGG', 'new_neoplasm_event_type_1GBM', 'tumor_location_1LGG', 'tumor_site_2LGG']
        other = ['PFI.time_3', 'PFI_3', 'additional_surgery_locoregional_procedure_1', 'additional_surgery_metastatic_procedure_1', 'tissue_source_site_1', 'tissue_source_site_1GBM', 'tissue_source_site_1LGG']
        to_dropG = ['case_project', 'mgmt_methylation', 'tissue_source', 'surgery_indication', 'mgmt_methylation_method', 'dna_aliquot_barcode', 'rna_aliquot_barcode', 'aliquot_analysis_type', 'patient_id', 'surgery_type', 'surgery_extent_of_resection', 'purity', 'alkylating_agent_tx']
        self.tcga_clin.drop(columns = [*to_dropA, *to_dropB, *to_dropC, *other], inplace=True)
        self.glass_clin.drop(columns = to_dropG + [g for g in self.glass_clin.columns.tolist() if g.startswith('treatment_')], inplace=True) 

    def build_outputs_table(self):
        """ TCGA """
        # 1. Extract TCGA outcomes from the diagnosis and treatment timelines
        gbm_dx = self._extract_timeline_outcomes(self.gbm_timeline_dx, 'status', {'Locoregional Disease': 'locoreg_disease_dx_start', 'Recurrence': 'recurrence_dx_start'})
        gbm_rx = self._extract_timeline_outcomes(self.gbm_timeline_rx, 'anatomic_treatment_site', {'Distant Recurrence': 'distant_recurrence_rx_start', 'Distant Site': 'distant_site_rx_start', 'Local Recurrence': 'local_recurrence_rx_start'})
        gbm_rx = gbm_rx.merge(self._extract_timeline_outcomes(self.gbm_timeline_rx, 'regimen_indication', {'Recurrence': 'recurrence_regimen_start'}), how='outer', left_index=True, right_index=True)
        lgg_rx = self._extract_timeline_outcomes(self.lgg_timeline_rx, 'anatomic_treatment_site', {'Local Recurrence': 'local_recurrence_rx_start'})
        lgg_rx = lgg_rx.merge(self._extract_timeline_outcomes(self.lgg_timeline_rx, 'regimen_indication', {'Recurrence': 'recurrence_regimen_start'}), how='outer', left_index=True, right_index=True)
        timelines = gbm_dx.merge(pd.concat([gbm_rx, lgg_rx]), how='outer', left_index=True, right_index=True)
        self.timelines = timelines
        tcga_ttr = self.tcga_clin.merge(timelines, how='left', left_index=True, right_index=True)
        # 2. Exclude progression-only cases
        self.fig_tracker['tcga']['Has rec sample'] = list(self.explicit_cases)
        self.fig_tracker['tcga']['Exclude progression'] = defaultdict(list)
        tcga_ttr = tcga_ttr[tcga_ttr.apply(self._exclude_progression_only, axis=1)]
        # 3. Calculate TTR
        self.fig_tracker['tcga']['Get ttr'] = defaultdict(lambda: defaultdict(list))
        tcga_ttr['ttr'] = tcga_ttr.apply(self._get_tcga_ttr, axis=1)
        # 4. Fill DFS months (converted to days) where TTR is missing for explicit recurrence cases
        tcga_ttr['dfs_months'] = tcga_ttr['dfs_months_2GBM'].fillna(tcga_ttr['dfs_months_2LGG'])
        self.fig_tracker['tcga']['Explicit cases'] = self.explicit_cases
        mask = tcga_ttr['ttr'].isna() & tcga_ttr.index.isin(self.explicit_cases) # none
        self.fig_tracker['tcga']['Get ttr']['Explicit w/o ttr'] = tcga_ttr.loc[mask].index.tolist() # none
        tcga_ttr.loc[mask, 'ttr'] = tcga_ttr.loc[mask, 'dfs_months'] * 30.44 # none
        # 5. Final cleanup
        self.fig_tracker['tcga']['No ttr'] = tcga_ttr[tcga_ttr['ttr'].isna()].index.tolist()
        self.fig_tracker['tcga']['Get ttr']['Explicit w/o ttr filled in'] = [i for i in self.fig_tracker['tcga']['Get ttr']['Explicit w/o ttr'] if i not in self.fig_tracker['tcga']['No ttr']]
        tcga_ttr.dropna(subset=['ttr'], inplace=True)
        no_longer_needed = ['days_to_tumor_recurrence_1GBM', 'nte_dx_days_to_3', 'days_to_nte_after_initial_treatment_1', 'days_to_nte_after_initial_treatment_1GBM', 'days_to_nte_after_initial_treatment_1LGG', 'days_to_nte_additional_surgery_procedure_1', 'days_to_additional_surgery_locoregional_procedure_1', 'days_to_additional_surgery_metastatic_procedure_1', 'DFI.time_3', 'dfs_status_2GBM', 'dfs_months_2GBM', 'DFI_3', 'nte_type_3', 'new_neoplasm_event_type_1', 'days_to_tumor_progression_1GBM', 'nte_after_initial_treatment_1GBM', 'nte_after_initial_treatment_2LGG', 'lost_follow_up_1GBM', 'lost_follow_up_1LGG', 'lost_follow_up_1', 'nte_after_initial_treatment_1LGG', 'nte_after_initial_treatment_1', 'dfs_months_2LGG', 'dfs_status_2LGG', 'days_to_nte_after_initial_treatment_4']
        self.tcga_clin = pd.concat([self.tcga_clin, tcga_ttr[['ttr']]], axis=1, join='inner').drop(columns=no_longer_needed)
        """ GLASS """
        # 1. Extract GLASS TTR using the surgical timeline (this is the only way)
        glass_ttr = self.glass_timeline_sx[self.glass_timeline_sx['sample_id'].str.endswith('-R1')].copy()
        glass_ttr = glass_ttr.sort_values('start_date').drop_duplicates('patient_id').rename(columns={'start_date': 'ttr'}).dropna(subset=['ttr'])
        glass_ttr.index = list(glass_ttr['patient_id'] + '-TP') # so we can match TTR back to the patients
        self.fig_tracker['glass']['Has ttr'] = [i for i in self.glass_clin.index.tolist() if i in glass_ttr.index.tolist()]
        self.fig_tracker['glass']['No ttr'] = [i for i in self.glass_clin.index.tolist() if i not in self.fig_tracker['glass']['Has ttr']]
        # 2. Merge GLASS TTR back to clinical
        self.glass_clin = pd.concat([self.glass_clin, glass_ttr[['ttr']]], axis=1, join='inner')

    def _extract_timeline_outcomes(self, df, pivot_col, col_map):
        df_filtered = df[df[pivot_col].isin(col_map.keys())]
        self.explicit_cases.update(list(df_filtered['patient_id'] + '-01'))
        pivot_df = df_filtered.pivot_table(index='patient_id', columns=pivot_col, values='start_date', aggfunc='min')
        pivot_df.index = pivot_df.index + '-01' # this data is static to the patient
        pivot_df.rename(columns=col_map, inplace=True)
        return pivot_df

    def _exclude_progression_only(self, row):
        def track(reason, keep, add_to_explicit=False):
            self.fig_tracker['tcga']['Exclude progression'][reason].append(row.name)
            if add_to_explicit: self.explicit_cases.add(row.name)
            return keep
        recurrence_cols = ['locoreg_disease_dx_start', 'recurrence_dx_start', 'distant_recurrence_rx_start', 'distant_site_rx_start', 'local_recurrence_rx_start', 'recurrence_regimen_start', 'days_to_tumor_recurrence_1GBM']
        new_tumor_events = row[['nte_type_3', 'new_neoplasm_event_type_1']].dropna()
        if row.name in self.explicit_cases: 
            return True
        if row[recurrence_cols].notna().any(): # all these patients are already in explicit_cases, but for consistency's sake
            return track('Has non-null recurrence col (explicit)', True, add_to_explicit=True)
        if any(nte in ['Recurrence', 'Locoregional Disease'] for nte in new_tumor_events): 
            return track('Has recurrence NTE (explicit)', True, add_to_explicit=True)
        if row.get('DFI_3') == 1:
            return track('Has DFI=1 (explicit)', True, add_to_explicit=True)
        if row.get('DFI_3') == 0: 
            return track('DFI=0 (exclude)', False)
        if new_tumor_events.empty: # otherwise the next condition returns False either way
            return True
        if all(nte == 'Progression of Disease' for nte in new_tumor_events): 
            return track('Progression events only (exclude)', False)
        return True
        
    def _get_tcga_ttr(self, row):
        tiers = [
            ['recurrence_dx_start', 'distant_recurrence_rx_start', 'local_recurrence_rx_start', 'days_to_tumor_recurrence_1GBM', 'recurrence_regimen_start'],
            ['DFI.time_3'],
            ['locoreg_disease_dx_start', 'distant_site_rx_start', 'nte_dx_days_to_3', 'days_to_nte_after_initial_treatment_1', 'days_to_nte_after_initial_treatment_4', 'days_to_nte_after_initial_treatment_1GBM', 'days_to_nte_after_initial_treatment_1LGG'],
            ['days_to_nte_additional_surgery_procedure_1', 'days_to_additional_surgery_locoregional_procedure_1']
        ]
        for i, tier in enumerate(tiers):
            valid = [(col, row[col]) for col in tier if pd.notna(row[col])]
            if valid:
                col_used, val_used = min(valid, key=lambda x: x[1])
                self.fig_tracker['tcga']['Get ttr'][f'Tier {i+1}'][col_used].append(row.name)
                return val_used
        return np.nan

    def relabel_WHO_CNS5(self): 
        # https://doi-org.revproxy.brown.edu/10.1002/cncr.33918; https://pmc.ncbi.nlm.nih.gov/articles/PMC9427889/; (1) ASTR = IDHmut, no codel; (2) ODG = IDHmut, codel; (3) GBM = IDHwt; Glioblastomas that are IDH mutant w/o codeletion are now classified as astrocytomas & 'anaplastic' is omitted; "Diffuse astrocytoma, IDHwt without molecular features of glioblastoma, is a rare entity" <- FROM XENA, WE CAN CHECK +7/-10 AND TERT PROMOTOR STATUS
        self.new_mappings  = {
            # IDH-mut glioblastomas without codeletion are reclassified as astrocytomas; guidelines do not specify how to classify glioblastomas that are co-deleted (aka cannot be astrocytomas)
            'Glioblastoma': {'IDHwt': 'Glioblastoma', 'IDHmut-noncodel': 'Astrocytoma', 'IDHmut-codel': 'EXCLUDE'}, # only 1 mut-codel for glass, 0 for tcga
            # Otherwise, if IDHwt (and in the TCGA gbmlgg cbio df), classified as glioblastoma if one of the following is present: chr7 gain and chr10 loss or TERTp mutation
            'Astrocytoma': {'IDHwt': 'CHECK', 'IDHmut-noncodel': 'Astrocytoma', 'IDHmut-codel': 'Oligodendroglioma'},
            'Oligodendroglioma': {'IDHwt': 'CHECK', 'IDHmut-noncodel': 'Astrocytoma', 'IDHmut-codel': 'Oligodendroglioma'},
            'Oligoastrocytoma': {'IDHwt': 'CHECK', 'IDHmut-noncodel': 'Astrocytoma', 'IDHmut-codel': 'Oligodendroglioma'}
        }
        self.cns5_tracker = {ds: {hist: {'IDHwt': 0, 'IDHmut-noncodel': 0, 'IDHmut-codel': 0} for hist in ['Glioblastoma', 'Astrocytoma', 'Oligodendroglioma', 'Oligoastrocytoma']} for ds in ['tcga', 'glass']}
        self.fig_tracker['tcga']['CNS5 check'] = defaultdict(lambda: defaultdict(list)) # subdictionaries for evaluation method per hist
        self.fig_tracker['glass']['CNS5 check'] = defaultdict(lambda: defaultdict(list)) # can't check for gbm in glass so they all default to astr wt, but kept as a nested dict for consistency
        self.fig_tracker['tcga']['CNS5 exclude'] = defaultdict(list)
        self.fig_tracker['glass']['CNS5 exclude'] = defaultdict(list)
        self.tcga_clin['histology_CNS5'] = self._relabel_tumors('tcga', self.tcga_clin, 'histological_type_1', 'idh_codel_subtype_2')
        self.glass_clin['histology_CNS5'] = self._relabel_tumors('glass', self.glass_clin, 'histology', 'idh_codel_status')
        self.tcga_clin = self.tcga_clin.query("histology_CNS5 != 'EXCLUDE'").copy()
        self.glass_clin = self.glass_clin.query("histology_CNS5 != 'EXCLUDE'").copy()

    def _relabel_tumors(self, cohort, df, histology_col, idh_codel_col):
        new_labels = []
        for pt in df.index.tolist():
            histology, idh_codel = df.loc[pt, histology_col], df.loc[pt, idh_codel_col]
            self.cns5_tracker[cohort][histology][idh_codel] += 1
            new = self.new_mappings[histology][idh_codel]
            if new == 'EXCLUDE': 
                self.fig_tracker[cohort]['CNS5 exclude'][histology].append(pt)
            elif new == 'CHECK':
                if cohort == 'glass':
                    self.fig_tracker['glass']['CNS5 check'][histology]['Astrocytoma WT'].append(pt)
                    new = 'Astrocytoma WT' # We can't check for GBM in glass so they all default to astrocytoma wt
                elif cohort == 'tcga':
                    if self.tcga_clin.loc[pt, 'chr_7_gain_chr_10_loss_2'] == "Gain chr 7 & loss chr 10": 
                        self.fig_tracker['tcga']['CNS5 check'][histology]['7+/10-'].append(pt)
                        new = 'Glioblastoma'
                    elif self.tcga_clin.loc[pt, 'tert_promoter_status_2'] == "Mutant": 
                        self.fig_tracker['tcga']['CNS5 check'][histology]['TERTp mutant'].append(pt)
                        new = 'Glioblastoma'
                    else: 
                        self.fig_tracker['tcga']['CNS5 check'][histology]['Astrocytoma WT'].append(pt)
                        new = 'Astrocytoma WT'
            new_labels.append(new)
        return new_labels

    def finalize_data(self):
        self.fig_tracker['tcga']['Prior glioma (exclude)'] = self.tcga_clin.query("prior_glioma_4 == 'YES'").index.tolist()
        self.tcga_clin.drop(index = self.fig_tracker['tcga']['Prior glioma (exclude)'], inplace=True)
        self.tcga_tss = self.tcga_tss.loc[self.tcga_clin.index].copy()
        batch_label_counts_tcga = self.tcga_tss.value_counts()
        solo_samples_tcga = self.tcga_tss[self.tcga_tss.isin(batch_label_counts_tcga[batch_label_counts_tcga==1].index.tolist())].index.tolist()
        self.fig_tracker['tcga']['Samples with unique tissue source sites (batch of 1; exclude)'] = solo_samples_tcga # have unique batch labels
        self.tcga_clin.drop(index = solo_samples_tcga, inplace=True)
        self.glass_tss = self.glass_tss.loc[self.glass_clin.index].copy()
        batch_label_counts_glass = self.glass_tss.value_counts()
        solo_samples_glass = self.glass_tss[self.glass_tss.isin(batch_label_counts_glass[batch_label_counts_glass==1].index.tolist())].index.tolist()
        self.fig_tracker['glass']['Samples with unique tissue source sites (batch of 1; exclude)'] = solo_samples_glass
        if len(solo_samples_glass) != 0: self.glass_clin.drop(index = solo_samples_glass, inplace=True)
        # Columns with too many missing values or an equivalent with fewer missing values + columns replaced by histology_CNS5 + columns no longer needed
        self.tcga_clin.drop(columns = ['idh1_mutation_2LGG', 'laterality_1', 'laterality_4', 'OS.time_3', 'OS_3', 'DSS_3', 'histological_type_1', 'chr_7_gain_chr_10_loss_2', 'tert_promoter_status_2', 'oncotree_code_2LGG', 'cancer_type_detailed_2LGG', 'prior_glioma_4', 'death_days_to_3', 'tumor_location_1'], inplace=True)
        self.glass_clin.drop(columns = ['cancer_type_detailed', 'tumor_classification', 'histology', 'idh_status', 'codel_status', 'oncotree_code', 'surgery_laterality', 'surgery_location'], inplace=True)
        self.tcga_all3 = list(set(self.tcga_clin.index.tolist()) & set(self.tcga_pts_with_genomic))
        self.glass_all3 = list(set(self.glass_clin.index.tolist()) & set(self.glass_pts_with_genomic))
        self.tcga_clin, self.tcga_expr, self.tcga_mut = [df.loc[self.tcga_all3].dropna(axis=1, how='all').copy() for df in [self.tcga_clin, self.tcga_expr, self.tcga_mut]]
        self.glass_clin, self.glass_expr, self.glass_mut = [df.loc[self.glass_all3].dropna(axis=1, how='all').copy() for df in [self.glass_clin, self.glass_expr, self.glass_mut]]
        self.tcga_tss = self.tcga_tss.loc[self.tcga_all3].copy()
        self.glass_tss = self.glass_tss.loc[self.glass_all3].copy()
        """ Label the data and find the table1 patient ids per outcome and grade """
        tcga_median, glass_median = [df['ttr'].median() for df in [self.tcga_clin, self.glass_clin]]
        threshold = (tcga_median + glass_median)/2
        print(f"Unweighted averaged median: (tcga_median + glass_median)/2 = ({tcga_median} + {glass_median})/2 = {threshold}")
        self.tcga_clin['label'] = (self.tcga_clin['ttr'] < threshold).astype(int)
        self.glass_clin['label'] = (self.glass_clin['ttr'] < threshold).astype(int)
        for i, ds in enumerate([self.tcga_clin, self.glass_clin]):
            print(f"{'TCGA' if i==0 else 'GLASS'}: {len(ds)} patients ({len(ds)-ds['label'].sum()} late, {ds['label'].sum()} early)")
        self.tcga_clin.columns = ['age', 'ethnicity', 'grade', 'idh_codel', 'race', 'sex', 'vital_status', 'ttr', 'histology_CNS5', 'label']
        self.glass_clin.columns = ['sex', 'age', 'os_status', 'grade', 'idh_codel', 'ttr', 'histology_CNS5', 'label']
        self.glass_clin['grade'] = self.glass_clin['grade'].map({'II': 'G2', 'III': 'G3', 'IV': 'G4'})
        """ Limit genomic data to genes shared between TCGA and GLASS (dropping constants)"""
        self.mut_cols = list(set(self.tcga_mut.columns[self.tcga_mut.nunique() > 1]) & set(self.glass_mut.columns[self.glass_mut.nunique() > 1]))
        self.expr_cols = list(set(self.tcga_expr.columns[self.tcga_expr.nunique() > 1]) & set(self.glass_expr.columns[self.glass_expr.nunique() > 1]))
        self.tcga_mut, self.glass_mut = self.tcga_mut[self.mut_cols].copy(), self.glass_mut[self.mut_cols].copy()
        self.tcga_expr, self.glass_expr = self.tcga_expr[self.expr_cols].copy(), self.glass_expr[self.expr_cols].copy()
        self._get_mut_distribution()

    def _get_mut_distribution(self): # for plotting variant distribution
        mut_incl = [gene.replace('_mut','') for gene in self.mut_cols]
        self.tcga_raw_mut = self.tcga_raw_mut.query("approved_symbol in @mut_incl")
        self.tcga_raw_mut['Tumor_Sample_Barcode'] = self.tcga_raw_mut['Tumor_Sample_Barcode'] + '-01' # GBMLGG is all primary, verified
        self.tcga_mut_dist = self.tcga_raw_mut[['Tumor_Sample_Barcode','Variant_Classification']].merge(self.tcga_clin['label'], left_on='Tumor_Sample_Barcode', right_index=True).groupby(['Variant_Classification','label']).size().unstack(fill_value=0)
        self.glass_raw_mut = self.glass_raw_mut.query("approved_symbol in @mut_incl")
        self.glass_mut_dist = self.glass_raw_mut[['Tumor_Sample_Barcode','Variant_Classification']].merge(self.glass_clin['label'], left_on='Tumor_Sample_Barcode', right_index=True).groupby(['Variant_Classification','label']).size().unstack(fill_value=0)

    def _get_data(self, filename, dp, lower=False, clean=True, **kwargs):
        df = load_data('raw txt', filename, data_path=dp, clean=clean, comment='#', **kwargs)
        if lower: df.columns = [col.lower() for col in df.columns]
        return df         

    def _set_idx(self, df, idx_col='sample_id', also_drop=[]):
        assert len(df) == df[idx_col].nunique(), f"Duplicate index values in {idx_col}"
        df.index = df[idx_col].tolist()
        return df.drop(columns = [idx_col] + also_drop)
    
    def _limit_df(self, df, included):
        return df.query("index in @included").dropna(axis=1, how='all').copy()

    def compare_training_and_testing(self, action, source, group1, group2, label1='Train', label2='Test', save_table=False, table_fn=''):
        if action == 'load':
            return pd.read_csv(f"table1s/{table_fn}").rename(columns={'Unnamed: 0': 'cat', 'Unnamed: 1': 'subcat'})  
        xclin = self.tcga_clin.copy() if source == 'tcga' else self.glass_clin.copy()
        pts = {label1: group1, label2: group2}
        categorical_clin = ['label', 'grade', 'idh_codel', 'sex', 'histology_CNS5'] + (['race', 'ethnicity', 'vital_status'] if source == 'tcga' else ['os_status'])
        xclin_filled = xclin.fillna('NAN')
        mi = [(col, val) for col in categorical_clin for val in set(xclin_filled[col].tolist())]
        clin_counts = pd.DataFrame(index=pd.MultiIndex.from_tuples(mi), columns=pts.keys())
        # Fill counts
        for pts_col, patients in pts.items():
            subdf = xclin_filled.loc[xclin_filled.index.isin(patients)].copy()
            for combo in mi:
                clin_counts.loc[combo, pts_col] = (subdf[combo[0]] == combo[1]).sum()
        # Add p-values for categorical variables
        for col in categorical_clin:
            table = pd.DataFrame({label: (xclin_filled.loc[xclin_filled.index.isin(group), col].value_counts()) for label, group in pts.items()}).fillna(0)
            if table.shape == (2, 2):  # Use Fisher's exact test for 2x2
                _, p = stats.fisher_exact(table)
            else:
                try: _, p, _, _ = stats.chi2_contingency(table)
                except: p = np.nan
            clin_counts.loc[(col, 'p-value')] = f"{p:.3g}"
        # Numerical variables
        numerical_clin = ['age', 'ttr']
        mi = [(col, stat) for col in numerical_clin for stat in ['mean (std)', 'median', 'IQR (q1,q3)', 'p-value']]
        clin_stats = pd.DataFrame(index=pd.MultiIndex.from_tuples(mi), columns=pts.keys())
        for pts_col, patients in pts.items():
            subdf = xclin[xclin.index.isin(patients)].copy()
            for col in numerical_clin:
                clin_stats.loc[(col, 'mean (std)'), pts_col] = f"{round(subdf[col].mean(), 2)} ({round(subdf[col].std(), 2)})"
                clin_stats.loc[(col, 'median'), pts_col] = f"{round(subdf[col].median(), 2)}"
                q25, q75 = subdf[col].quantile(0.25), subdf[col].quantile(0.75)
                clin_stats.loc[(col, 'IQR (q1,q3)'), pts_col] = f"{round(q75 - q25, 2)} ({round(q25, 2)}, {round(q75, 2)})"
        # Compute Mann–Whitney U test for numerical variables
        for col in numerical_clin:
            x1 = xclin.loc[xclin.index.isin(group1), col].dropna()
            x2 = xclin.loc[xclin.index.isin(group2), col].dropna()
            try: _, p = stats.mannwhitneyu(x1, x2, alternative='two-sided')
            except: p = np.nan
            clin_stats.loc[(col, 'p-value'), :] = [f"{p:.3g}"] * len(pts)
        out = pd.concat([clin_counts, clin_stats], axis=0)
        if save_table:
            out.to_csv(f"table1s/{table_fn}")
        return out

    def get_table1(self, action, source, incl_pts, save_table=False, table_fn=''):
        if action == 'load':
            return pd.read_csv(f"table1s/{table_fn}").rename(columns={'Unnamed: 0': 'cat', 'Unnamed: 1': 'subcat'})
        pts = {}
        xclin, xmut = [self.tcga_clin.loc[incl_pts], self.tcga_mut.loc[incl_pts]] if source == 'tcga' else [self.glass_clin.loc[incl_pts], self.glass_mut.loc[incl_pts]]
        for label in [0, 1]:
            for grade in ['All', 'G2', 'G3', 'G4']:
                ids = xclin.query("label == @label").index.tolist() if grade == 'All' else xclin.query("label == @label and grade == @grade").index.tolist()
                pts[f'{label}: {grade}'] = ids
        categorical_clin = ['label', 'grade', 'idh_codel', 'sex', 'histology_CNS5'] + (['race', 'ethnicity', 'vital_status'] if source == 'tcga' else ['os_status'])
        xclin_filled = xclin.fillna('NAN')
        mi = [(col, val) for col in categorical_clin for val in set(xclin_filled[col].tolist())]
        clin_counts = pd.DataFrame(index=pd.MultiIndex.from_tuples(mi), columns=pts.keys())
        for pts_col, patients in pts.items():
            subdf = xclin_filled.loc[xclin_filled.index.isin(patients)].copy()
            for combo in mi: clin_counts.loc[combo, pts_col] = (subdf[combo[0]] == combo[1]).sum()
        numerical_clin = ['age', 'ttr']
        mi = [(col, cat) for col in numerical_clin for cat in ['mean (std)', 'median', 'IQR (q1,q3)']]
        clin_stats = pd.DataFrame(index=pd.MultiIndex.from_tuples(mi), columns=pts.keys())
        for pts_col, patients in pts.items():
            subdf = xclin[xclin.index.isin(patients)].copy()
            for col in numerical_clin:
                clin_stats.loc[(col, 'mean (std)'), pts_col] = f"{round(subdf[col].mean(), 2)} ({round(subdf[col].std(), 2)})"
                clin_stats.loc[(col, 'median'), pts_col] = f"{round(subdf[col].median(), 2)}"
                q25, q75 = subdf[col].quantile(0.25), subdf[col].quantile(0.75)
                clin_stats.loc[(col, 'IQR (q1,q3)'), pts_col] = f"{round(q75 - q25, 2)} ({round(q25, 2)}, {round(q75, 2)})"
        mi = [('TOTAL MUTS', 'median'), ('TOTAL MUTS', 'IQR (q1,q3)'), ('TOP GENE', 'most frequent')]
        mut_stats = pd.DataFrame(index=pd.MultiIndex.from_tuples(mi), columns=pts.keys()) # total mutation stats and the most freq mutated gene per group
        for pts_col, patients in pts.items():
            subdf = xmut.loc[xmut.index.isin(patients)].copy()
            sums = subdf.sum(axis=1)
            mut_stats.loc[('TOTAL MUTS', 'median'), pts_col] = f"{round(sums.median(), 2)}"
            q25, q75 = sums.quantile(0.25), sums.quantile(0.75)
            mut_stats.loc[('TOTAL MUTS', 'IQR (q1,q3)'), pts_col] = f"{round(q75 - q25, 2)} ({round(q25, 2)}, {round(q75, 2)})"
            gene_counts = (subdf > 0).sum()
            if not gene_counts.empty:
                top_gene = gene_counts.idxmax()
                top_gene_n = gene_counts.max()
                mut_stats.loc[('TOP GENE', 'most frequent'), pts_col] = f"{top_gene} (N={top_gene_n})"
            else: mut_stats.loc[('TOP GENE', 'most frequent'), pts_col] = "None"
        table1 = pd.concat([clin_counts, clin_stats, mut_stats], axis=0)
        if save_table: table1.to_csv(f"table1s/{table_fn}")
        return table1
    
    def standardize_data(self):
        self.tcga_clin.drop(columns=['ethnicity', 'race', 'vital_status', 'ttr'], inplace=True)
        self.glass_clin.drop(columns=['os_status', 'ttr'], inplace=True)
        encodings = {'sex': {'Female': 0, 'Male': 1}, 'grade': {'G2': 0, 'G3': 1, 'G4': 2}}
        self.tcga_clin, self.glass_clin = [self._encode_clin(df, encodings) for df in [self.tcga_clin, self.glass_clin]]
        self.tcga_clin, self.glass_clin = [self._apply_dummies(df) for df in [self.tcga_clin, self.glass_clin]]
        clin_order = ['age', 'grade', 'sex', 'IDHmut-codel', 'IDHmut-noncodel', 'IDHwt', 'astrocytoma', 'astrocytoma_wt', 'oligodendroglioma', 'glioblastoma', 'label']
        self.tcga_clin, self.glass_clin = [df[clin_order].copy() for df in [self.tcga_clin, self.glass_clin]]
        self.tcga_expr, self.glass_expr = [df[sorted(df.columns)].copy() for df in [self.tcga_expr, self.glass_expr]]
        self.tcga_mut, self.glass_mut = [df[sorted(df.columns)].copy() for df in [self.tcga_mut, self.glass_mut]]
        for c1, c2 in zip([self.tcga_clin, self.tcga_expr, self.tcga_mut], [self.glass_clin, self.glass_expr, self.glass_mut]): 
            if c1.columns.tolist() != c2.columns.tolist(): print("ERROR: COLUMN MISMATCH BETWEEN TCGA AND GLASS")

    def _encode_clin(self, df, encodings):
        encoded_df = df.copy()
        for col, encoding_map in encodings.items():
            encoded_df[col] = encoded_df[col].map(encoding_map)
        return encoded_df

    def _apply_dummies(self, df):
        fixed_names = {'Astrocytoma': 'astrocytoma', 'Astrocytoma WT': 'astrocytoma_wt', 'Glioblastoma': 'glioblastoma', 'Oligodendroglioma': 'oligodendroglioma'}
        return pd.get_dummies(df, columns=['idh_codel', 'histology_CNS5'], dtype=int, prefix='', prefix_sep='').rename(columns=fixed_names)

    def preprocess_data(self, tcga_presplit=True, glass_presplit=True, tcga_fn='', glass_fn='', dp='data/other/', remove_constants=True):
        self.og_cols = {'clin': [c for c in self.tcga_clin.columns.tolist() if c != 'label'], 'expr': self.tcga_expr.columns.tolist(), 'mut': self.tcga_mut.columns.tolist()}
        # Everything is fit on TCGA, however we set aside training GLASS sets so that we can use them for CORAL loss during model training
        self._process_dataset('tcga', self.tcga_clin, self.tcga_expr, self.tcga_mut, tcga_presplit, tcga_fn, dp)
        self._process_dataset('glass', self.glass_clin, self.glass_expr, self.glass_mut, glass_presplit, glass_fn, dp)
        if remove_constants: self.remove_training_constants() # Based on TCGA train only
        else: self.cols = {kind: self.og_cols[kind] for kind in ['clin', 'expr', 'mut']}

    def _process_dataset(self, source, clin_df, expr_df, mut_df, presplit, fn, dp):
        df_all = clin_df.join(expr_df, how='inner').join(mut_df, how='inner')
        x = df_all.drop(columns=['label']).copy()
        y = df_all['label'].copy()
        setattr(self, f"{source}_x", x)
        setattr(self, f"{source}_y", y)
        splits = self.load_splits(x, y, fn, dp) if presplit else self.split_data(x, y, fn, dp)
        for sname, split_data in zip(['x_train', 'x_val', 'x_test', 'y_train', 'y_val', 'y_test'], splits):
            setattr(self, f"{source}_{sname}", split_data)
    
    def split_data(self, x, y, fn, dp, size1=0.15, size2=0.176): # this gives 70/15/15 split (0.15/(1.0-0.15)=0.176)
        pseudo_y = pd.Series([f"{y.loc[pt]}_{x.loc[pt, 'grade']}" for pt in y.index], index=y.index) # stratify by Y and grade
        x_temp, x_test, y_temp_pseudo, y_test_pseudo = train_test_split(x, pseudo_y, test_size=size1, stratify=pseudo_y)
        x_train, x_val, y_train_pseudo, y_val_pseudo = train_test_split(x_temp, y_temp_pseudo, test_size=size2, stratify=y_temp_pseudo) 
        y_train, y_val, y_test = [y.loc[pseudo.index].copy() for pseudo in [y_train_pseudo, y_val_pseudo, y_test_pseudo]]
        if fn != '': 
            write_file('json', fn, {label: split.index.tolist() for label, split in zip(['train', 'val', 'test'], [x_train, x_val, x_test])}, dp)
        return x_train, x_val, x_test, y_train, y_val, y_test

    def load_splits(self, x, y, fn, dp):
        idxs = load_data('json', fn, dp)
        x_train, x_val, x_test = [x.loc[idxs[split]].copy() for split in ['train', 'val', 'test']]
        y_train, y_val, y_test = [y.loc[idxs[split]].copy() for split in ['train', 'val', 'test']]
        return x_train, x_val, x_test, y_train, y_val, y_test

    def remove_training_constants(self):
        keep = self.tcga_x_train.columns[self.tcga_x_train.nunique() > 1] # based only on tcga train
        self.tcga_x_train, self.tcga_x_val, self.tcga_x_test = [df[keep].copy() for df in [self.tcga_x_train, self.tcga_x_val, self.tcga_x_test]]
        self.glass_x_train, self.glass_x_val, self.glass_x_test = [df[keep].copy() for df in [self.glass_x_train, self.glass_x_val, self.glass_x_test]]
        self.cols = {kind: get_new_cols(keep, self.og_cols[kind]) for kind in ['clin', 'expr', 'mut']}
        print(f"features removed: {len(self.og_cols['clin'])-len(self.cols['clin'])} clin, {len(self.og_cols['expr'])-len(self.cols['expr'])} expr, {len(self.og_cols['mut'])-len(self.cols['mut'])} mut")

    def print_figure_tracker(self, d, path=(), final_cohort_only=False):
        if isinstance(d, dict):
            for k, v in d.items():
                self.print_figure_tracker(v, path + (k,))
        elif isinstance(d, list) or isinstance(d, set):
            if final_cohort_only: 
                print(" -> ".join(path), ":", len([i for i in d if i in self.tcga_clin.index]))
            else: 
                print(" -> ".join(path), ":", len(d)) 

    def get_class_weights(self, device):
        self.tcga_weights = torch.tensor(1./np.bincount(self.tcga_y_train), dtype=torch.float, device=device)
        self.glass_weights = torch.tensor(1./np.bincount(self.glass_y_train), dtype=torch.float, device=device)


In [None]:
# Note: If using the provided data, glass/data_mutations_chunk1 and glass/data_mutations_chunk2 must first be concatenated (github file size restrictions)
g = GLIOMA(paths)
g.standardize_data()

In [None]:
# g.tcga_raw_mut.to_csv('data/Supplementary Data 3.zip', compression={'method': 'zip', 'archive_name': 'tcga_variant_dist.csv'})
# g.glass_raw_mut.to_csv('data/Supplementary Data 4.zip', compression={'method': 'zip', 'archive_name': 'glass_variant_dist.csv'})

**2. PREPROCESS DATA**
- Split TCGA and GLASS data 70/15/15 and GLASS data (GLASS training/val is only used for CORAL loss, training/val Ys are never seen and testing is completely withheld)  
    - Note: Nothing is ever "fit" on GLASS, only TCGA. GLASS is only split so that when we implement CORAL loss our model only "sees" the training set but never the training labels (this is what makes it different from ComBat, where TCGA doesn't see any GLASS data)  

- Batch correct the genomic data for TCGA (within) and GLASS (within)  

- Remove constants using tcga_x_train  

- Remove mutation genes that are mutated in less than 2% of the TCGA training patients (we do this before correlation analyses to cut down on computational expense)  

- Remove highly correlated genomic features using tcga_x_train
    - Note: Also dropping IDHmut-codel and IDHmut-noncodel (1.0 corr with ODG and ASTR, respectively)

In [None]:
g.preprocess_data(tcga_presplit=True, glass_presplit=True, tcga_fn='tcga_splits.json', glass_fn='glass_splits.json', remove_constants=True)

In [None]:
# mutation_freq = (g.tcga_x_train[g.cols['mut']] > 0).mean()
# sorted_freq = mutation_freq.sort_values()
# cumulative = np.arange(1, len(sorted_freq) + 1) / len(sorted_freq)
# sorted_freq.name = 'mutation_freq'
# sorted_freq.to_csv('data/Supplementary Data 5.zip', compression={'method': 'zip', 'archive_name': 'mutation_freqs.csv'})

In [None]:
remove_rarely_mutated_genes(g, f=0.02)
log1p_mutations(g)

In [None]:
combat = COMBAT(g)
combat.correct_modality(g, 'expr')
combat.correct_modality(g, 'mut')

In [None]:
remove_nearly_constant(g)
all_cols = g.tcga_x_train.columns.tolist()
print(f"features removed: {len(g.cols['clin'])-len(get_new_cols(all_cols, g.cols['clin']))} clin, {len(g.cols['expr'])-len(get_new_cols(all_cols, g.cols['expr']))} expr, {len(g.cols['mut'])-len(get_new_cols(all_cols, g.cols['mut']))} mut")
g.cols = {kind: get_new_cols(all_cols, g.cols[kind]) for kind in ['clin', 'expr', 'mut']}

In [None]:
# genes removed that were correlated with known genes: 42 (332 pairs resolved)
# Must re-run, file sizes exceeded GitHub allotment
ca = CorrelationAnalysis('load', g, 'corr_mx.pkl', data_path='data/other/', threshold=0.90, corr_pairs_fn='corr_pairs.pkl', calculate_pairs=False)
ca.remove_correlated_genes('load', g, fn='genes_to_remove.pkl', primary_method='corr count', secondary_method='variance', clin_features_to_drop=['IDHmut-codel', 'IDHmut-noncodel'])

**3. FEATURE SELECTION FOR EXPRESSION DATA**
- Conduct stability selection on the expression data  

    - Note: Mutation features were reduced enough when we removed rarely mutated genes. We just need to reduce enough that the number of features is manageable, as I have a feature selection layer within the model itself.

In [None]:
expr_fpr = ExpressionStabilitySelector('load', x=g.tcga_x_train[g.cols['expr']], y=g.tcga_y_train, fn='expression_fs_fpr.pkl', dp='data/other/', method='fpr')
expr_fpr.select_by_threshold()
expr_fpr.apply_to_dataset(g)

**4. SCALE DATA USING TCGA**

In [None]:
apply_scaler(g, ctype='standard', etype='standard', mtype='minmax')

**5. MODEL**
- Training uses CORAL loss (Frobenius norm between covariance matrices) to minimize feature drift and maximize generalizability; GLASS TRAINING/VALIDATION LABELS ARE NOT SEEN!!!  

In [None]:
from model_utils import set_seed, generate_seeds, cycle_seeds, calc_metrics, earlystop_checkpoint, ModelParams, TrainParams, get_gene_reg, save_model_weights, load_saved_model, coral_loss_fn
from model_utils import GeneSelector, SelfAttention, CrossAttention, LearnedQueryAttentionPooling, ModalityEncoder, OutputClassifier, FusionEmbedding

if 'device' not in globals():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

g.get_class_weights(device)

class MultiModalNN(nn.Module):
    def __init__(self, params):
        super(MultiModalNN, self).__init__()
        self.params = params
        # 1. Gene selection + modality-specific neural networks/encoders (project encoder output to final_dim[0] if needed)
        self.sel_e = GeneSelector(len(params.expr_features))
        self.sel_m = GeneSelector(len(params.mut_features))
        self.enc_c = ModalityEncoder(len(params.clin_features), params.clin['n'], params.clin['dropout'], params.clin['act'])
        self.enc_e = ModalityEncoder(len(params.expr_features), params.expr['n'], params.expr['dropout'], params.expr['act'])
        self.enc_m = ModalityEncoder(len(params.mut_features), params.mut['n'], params.mut['dropout'], params.mut['act'])
        if params.clin['n'][-1] != params.final['n'][0]: self.proj_c = nn.Linear(params.clin['n'][-1], params.final['n'][0])
        if params.expr['n'][-1] != params.final['n'][0]: self.proj_e = nn.Linear(params.expr['n'][-1], params.final['n'][0])
        if params.mut['n'][-1] != params.final['n'][0]: self.proj_m = nn.Linear(params.mut['n'][-1], params.final['n'][0])
        # 2. Attention
        if params.att['self']['use']:
            self.self_c = SelfAttention(params.final['n'][0], params.att['num_heads'], dropout=params.att['self']['dropout'][0])
            self.self_e = SelfAttention(params.final['n'][0], params.att['num_heads'], dropout=params.att['self']['dropout'][1])
            self.self_m = SelfAttention(params.final['n'][0], params.att['num_heads'], dropout=params.att['self']['dropout'][2])
            # self attention normalization for Add & Norm
            if params.norm['add_and_norm']:
                self.self_c_norm = nn.LayerNorm(params.final['n'][0])
                self.self_e_norm = nn.LayerNorm(params.final['n'][0])
                self.self_m_norm = nn.LayerNorm(params.final['n'][0])
        if params.att['cross']['use']:
            self.cross_ce = CrossAttention(params.final['n'][0], params.att['num_heads'], dropout=params.att['cross']['dropout'][0])
            self.cross_cm = CrossAttention(params.final['n'][0], params.att['num_heads'], dropout=params.att['cross']['dropout'][1])
            self.cross_em = CrossAttention(params.final['n'][0], params.att['num_heads'], dropout=params.att['cross']['dropout'][2])
        # 3. Pooling
        self.pool = FusionEmbedding(params.final['n'][0], params, params.fusion['type'], params.fusion['cross_embed_red'], params.fusion['dropout'])
        if params.norm['combined']:
            self.combined_norm = nn.LayerNorm(params.final['n'][0])
        # 4. Final classifier (raw logits)
        self.classifier = OutputClassifier(params.final['n'], params.final['dropout'], params.final['act'])
    def forward(self, x_c, x_e, x_m, return_attn_weights=False):
        z_c1 = self.enc_c(x_c)
        z_e1 = self.enc_e(self.sel_e(x_e))
        z_m1 = self.enc_m(self.sel_m(x_m))
        if self.params.clin['n'][-1] != self.params.final['n'][0]: z_c1 = self.proj_c(z_c1)
        if self.params.expr['n'][-1] != self.params.final['n'][0]: z_e1 = self.proj_e(z_e1)
        if self.params.mut['n'][-1] != self.params.final['n'][0]: z_m1 = self.proj_m(z_m1)
        att_maps = {}
        if self.params.att['self']['use']: # Self-attention within each modality
            z_c2, w_c = self.self_c(z_c1)
            z_e2, w_e = self.self_e(z_e1)
            z_m2, w_m = self.self_m(z_m1)
            att_maps.update({'self_c':w_c, 'self_e':w_e, 'self_m':w_m})
            if self.params.norm['add_and_norm']: # Layer normalization and residual connection
                z_c2 = self.self_c_norm(z_c1 + z_c2)
                z_e2 = self.self_e_norm(z_e1 + z_e2)
                z_m2 = self.self_m_norm(z_m1 + z_m2)
        else: z_c2, z_e2, z_m2 = z_c1, z_e1, z_m1
        if self.params.att['cross']['use']: # Cross-modal attention
            z_ce1, z_ec1, w_ce, w_ec = self.cross_ce(z_c2, z_e2)
            z_cm1, z_mc1, w_cm, w_mc = self.cross_cm(z_c2, z_m2)
            z_em1, z_me1, w_em, w_me = self.cross_em(z_e2, z_m2)
            att_maps.update({'w_ce':w_ce, 'w_ec':w_ec, 'w_cm':w_cm, 'w_mc':w_mc, 'w_em':w_em, 'w_me':w_me})
            embeddings = [z_ce1, z_ec1, z_cm1, z_mc1, z_em1, z_me1]
        else: embeddings = [z_c2, z_e2, z_m2]
        z_fused, w_fused = self.pool(embeddings, return_attn_weights) # Pooling and fusion
        att_maps['fusion'] = w_fused
        if self.params.norm['combined']: 
            z_fused = self.combined_norm(z_fused)
        logits = self.classifier(z_fused).view(-1)
        y_out = torch.sigmoid(logits) # only for metrics
        return {'y_logits': logits, 'y_out': y_out, 'att_maps': att_maps if return_attn_weights else None}
    def get_domain_embeddings(self, x_e, x_m):
        """Return projected (post-encoder) embeddings for CORAL."""
        z_e = self.enc_e(self.sel_e(x_e))
        z_m = self.enc_m(self.sel_m(x_m))
        if self.params.expr['n'][-1] != self.params.final['n'][0]: z_e = self.proj_e(z_e)
        if self.params.mut['n'][-1] != self.params.final['n'][0]: z_m = self.proj_m(z_m)
        return z_e, z_m

In [None]:
class Trainer:
    def __init__(self, model, params, tparams, device): # it's fine if class_weights is none bc the loss constructor defaults weight to None anyhow
        self.model = model.to(device)
        self.tparams, self.device = tparams, device
        self.optimizer = {'adam': optim.Adam, 'adamw': optim.AdamW}[params.config['opt']](self.model.parameters(), params.config['lr'], **params.config['lr_kwargs'])
        self.scheduler = CyclicLR(self.optimizer, base_lr=params.clr['base_lr'], max_lr=params.clr['max_lr'], **params.clr['kwargs'])
    
    def train(self, train_tcga_loader, train_glass_loader, val_tcga_loader, val_glass_loader, chkpt_fn='best_model'):
        best_val = 0 if self.tparams.es['metric'] == 'auc' else 1e10
        no_improve = 0
        history = defaultdict(lambda: defaultdict(list))
        for epoch in range(1, self.tparams.epochs+1):
            train_tracker = self.train_loop(train_tcga_loader, train_glass_loader)
            for tracker_dict in [train_tracker.avg_losses, train_tracker.metrics]:
                for item_type, item_value in tracker_dict.items():
                    history['train'][item_type].append(item_value)
            val_tracker, val_atts = self.validation_loop(val_tcga_loader, val_glass_loader)
            for tracker_dict in [val_tracker.avg_losses, val_tracker.metrics]:
                for item_type, item_value in tracker_dict.items():
                    history['val'][item_type].append(item_value)
            if self.tparams.get_val_att: 
                history['val']['att'].append(val_atts)
            for h, label in zip([history['train'], history['val']], ['Train', '  Val']): # print epoch metrics
                print(f"{epoch}/{self.tparams.epochs} {label} - Loss: {h['total_loss'][-1]:.4f} ({h['task_loss'][-1]:.4f} + {h['gene_reg'][-1]:.4f} + {h['coral_loss'][-1]:.4f})  AUC: {h['auc'][-1]:.4f}  AUPRC: {h['auprc'][-1]:.4f}  ACC: {h['acc'][-1]:.4f}  PPV: {h['prec'][-1]:.4f}  TPR: {h['rec'][-1]:.4f}  TNR: {h['spec'][-1]:.4f}")
            if self.tparams.es['use']:
                check_val = val_tracker.metrics['auc'] if self.tparams.es['metric'] == 'auc' else val_tracker.avg_losses['total_loss']
                best_val, no_improve = earlystop_checkpoint(self.model, self.tparams, check_val, best_val, no_improve, chkpt_fn)
                if no_improve >= self.tparams.es['patience']:
                    print(f"Early stopping at epoch {epoch}\n")
                    if self.tparams.restore_best: self.model.load_state_dict(torch.load(f"results/weights/{chkpt_fn}.pth"))
                    break
        return history

    def train_loop(self, train_tcga_loader, train_glass_loader):
        self.model.train()
        train_tracker = OutputTracker()
        for (tcga_batch, glass_batch) in zip(train_tcga_loader, train_glass_loader):
            tcga_c, tcga_e, tcga_m, labels = [t.to(torch.float32).to(self.device) for t in tcga_batch]
            glass_c, glass_e, glass_m, _ = [t.to(torch.float32).to(self.device) for t in glass_batch]
            self.optimizer.zero_grad()
            out = self.model(tcga_c, tcga_e, tcga_m)
            task_loss = F.binary_cross_entropy_with_logits(out['y_logits'], labels, weight=self.tparams.class_weights[labels.long()])
            gene_reg = get_gene_reg(self.model, self.device, self.tparams.use_gene_reg, self.tparams.gene_reg_weight)
            if self.tparams.use_coral:
                z_e_tcga, z_m_tcga = self.model.get_domain_embeddings(tcga_e, tcga_m)
                z_e_glass, z_m_glass = self.model.get_domain_embeddings(glass_e, glass_m)
                coral_expr = coral_loss_fn(z_e_tcga, z_e_glass)
                coral_mut  = coral_loss_fn(z_m_tcga, z_m_glass)
                coral_loss = self.tparams.coral_weight * (coral_expr + coral_mut)
            else:
                coral_loss = torch.tensor(0.0, device=self.device)
            total_loss = task_loss + gene_reg + coral_loss
            total_loss.backward()
            if self.tparams.max_grad_norm:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.tparams.max_grad_norm)
            self.optimizer.step()
            self.scheduler.step()
            train_tracker(task_loss, gene_reg, coral_loss, total_loss, labels, out['y_logits'], out['y_out'])
        train_tracker.get_metrics(len(train_tcga_loader.dataset))
        return train_tracker

    def validation_loop(self, val_tcga_loader, val_glass_loader=None):
        self.model.eval()
        val_tracker = OutputTracker()
        with torch.no_grad():
            for clin, expr, mut, labels in val_tcga_loader:
                clin, expr, mut, labels = [t.to(torch.float32).to(self.device) for t in [clin, expr, mut, labels]]
                out = self.model(clin, expr, mut, self.tparams.get_val_att)
                task_loss = F.binary_cross_entropy_with_logits(out['y_logits'], labels, weight=self.tparams.class_weights[labels.long()])
                gene_reg = get_gene_reg(self.model, self.device, self.tparams.use_gene_reg, self.tparams.gene_reg_weight)
                coral_loss = torch.tensor(0.0, device=self.device)
                if self.tparams.use_coral and val_glass_loader is not None:
                    try: 
                        glass_clin, glass_expr, glass_mut, _ = next(glass_loader_iter)
                    except:
                        glass_loader_iter = iter(val_glass_loader)
                        glass_clin, glass_expr, glass_mut, _ = next(glass_loader_iter)
                    glass_clin, glass_expr, glass_mut = [t.to(torch.float32).to(self.device) for t in [glass_clin, glass_expr, glass_mut]]
                    z_e_tcga, z_m_tcga = self.model.get_domain_embeddings(expr, mut)
                    z_e_glass, z_m_glass = self.model.get_domain_embeddings(glass_expr, glass_mut)
                    coral_loss = self.tparams.coral_weight * (coral_loss_fn(z_e_tcga, z_e_glass) + coral_loss_fn(z_m_tcga, z_m_glass))
                total_loss = task_loss + gene_reg + coral_loss
                val_tracker(task_loss, gene_reg, coral_loss, total_loss, labels, out['y_logits'], out['y_out'])
        val_tracker.get_metrics(len(val_tcga_loader.dataset))
        return (val_tracker, out['att_maps']) if self.tparams.get_val_att else (val_tracker, None)

def evaluate(model, test_loader, tparams, device, pred_thresh=0.5, return_attn_weights=False, plot_metrics=False):
    model.eval()
    test_tracker = OutputTracker(thresh=pred_thresh, plot_metrics=plot_metrics)
    with torch.no_grad():
        for clin, expr, mut, labels in test_loader:
            clin, expr, mut, labels = [t.to(torch.float32).to(device) for t in [clin, expr, mut, labels]]
            out = model(clin, expr, mut, return_attn_weights)
            task_loss = F.binary_cross_entropy_with_logits(out['y_logits'], labels, weight=tparams.class_weights[labels.long()])
            gene_reg = get_gene_reg(model, device, tparams.use_gene_reg, tparams.gene_reg_weight)
            total_loss = task_loss + gene_reg
            coral_loss = torch.tensor(0.0, device=device)
            test_tracker(task_loss, gene_reg, coral_loss, total_loss, labels, out['y_logits'], out['y_out'])
    test_tracker.get_metrics(len(test_loader.dataset))
    print(f"Loss: {test_tracker.avg_losses['total_loss']:.3f} ({test_tracker.avg_losses['task_loss']:.3f} + {test_tracker.avg_losses['gene_reg']:.3f})  AUC: {test_tracker.metrics['auc']:.4f}  AUPRC: {test_tracker.metrics['auprc']:.4f}  ACC: {test_tracker.metrics['acc']:.4f}  PPV: {test_tracker.metrics['prec']:.4f}  TPR: {test_tracker.metrics['rec']:.4f}  TNR: {test_tracker.metrics['spec']:.4f}  CM: [tn={test_tracker.metrics['tn']}, fp={test_tracker.metrics['fp']}, fn={test_tracker.metrics['fn']}, tp={test_tracker.metrics['tp']}]")
    return (test_tracker, out['att_maps']) if return_attn_weights else (test_tracker, None)

def train(source, g, params, tparams, device, best_fn='best_model', weightedsampler=False, deterministic=False, seed=None):
    get_data = lambda x, y: MultiModalDataset(x[params.clin_features], x[params.expr_features], x[params.mut_features], y, device)
    tcga_train_loader = DataLoader(get_data(g.tcga_x_train, g.tcga_y_train), batch_size=tparams.bs, shuffle=True)
    tcga_val_loader = DataLoader(get_data(g.tcga_x_val, g.tcga_y_val), batch_size=tparams.bs, shuffle=False)
    glass_train_loader = DataLoader(get_data(g.glass_x_train, g.glass_y_train), batch_size=tparams.bs, shuffle=True)
    glass_val_loader = DataLoader(get_data(g.glass_x_val, g.glass_y_val), batch_size=tparams.bs, shuffle=False)
    if seed is not None: set_seed(seed, deterministic=deterministic)
    model = MultiModalNN(params).to(device)
    trainer = Trainer(model, params, tparams, device)
    train_hist = trainer.train(tcga_train_loader, glass_train_loader, tcga_val_loader, glass_val_loader, best_fn)
    return model, train_hist

def test(source, g, model, params, tparams, device, pred_thresh=0.5, return_attn_weights=False, plot_metrics=False):
    get_data = lambda x, y: MultiModalDataset(x[params.clin_features], x[params.expr_features], x[params.mut_features], y, device)
    xtest, ytest = {'tcga': [g.tcga_x_test, g.tcga_y_test], 'glass': [g.glass_x_test, g.glass_y_test]}[source]
    test_loader = DataLoader(get_data(xtest, ytest), batch_size=len(xtest), shuffle=False)
    test_tracker, att_maps = evaluate(model, test_loader, tparams, device, pred_thresh, return_attn_weights, plot_metrics)
    return test_tracker, att_maps
    

In [None]:
def convert_numpy(obj):
    if isinstance(obj, defaultdict): return {k: convert_numpy(v) for k, v in obj.items()}
    elif isinstance(obj, dict): return {k: convert_numpy(v) for k, v in obj.items()}
    elif isinstance(obj, list): return [convert_numpy(i) for i in obj]
    elif isinstance(obj, np.integer): return int(obj)
    elif isinstance(obj, np.floating): return float(obj)
    elif isinstance(obj, np.ndarray): return obj.tolist()
    else: return obj

class LUNAResults:
    def __init__(self, g, device, save_results=False, fn='lunar_results_dict'):
        self.g = g
        self.device = device
        self._init_params()
        self.name_map = {'base': 'LUNAR', 'satt': 'LUNAR-SAtt', 'catt': 'LUNAR-CAtt', 'natt': 'LUNAR-NAtt'}
        self.all_trackers = {'tcga': {}, 'glass': {}}
        self.stats_dict = defaultdict(lambda: defaultdict(dict))
        self._test_models()
        if save_results and fn != '':
            write_file('json', fn, convert_numpy(self.stats_dict), 'data/out/')

    def _init_params(self):
        self.tparams1 = TrainParams(epochs=20, bs=16, class_weights=self.g.tcga_weights, use_gene_reg=True, gene_reg_weight=1e-4, max_grad_norm=5.0, restore_best=True, es={'use':True, 'patience':3, 'metric':'loss'}, use_coral=True, coral_weight=.5)
        self.tparams2 = TrainParams(epochs=20, bs=16, class_weights=self.g.glass_weights, use_gene_reg=True, gene_reg_weight=1e-4, max_grad_norm=5.0, restore_best=True, es={'use':True, 'patience':3, 'metric':'loss'}, use_coral=True, coral_weight=.5)
        self.params1 = ModelParams(
            clin = {'n': [128, 128, 64], 'dropout': [0, .2, .25], 'act': 'relu'},
            expr = {'n': [128, 128, 64], 'dropout': [0, .2, .5], 'act': 'relu'},
            mut = {'n': [128, 128, 64], 'dropout': [0, .2, .5], 'act': 'relu'},
            final = {'n': [64, 32], 'dropout': [.2, .3], 'act': 'relu'},
            att = {'self': {'use': True, 'dropout': [0, 0, 0]}, 'cross': {'use': True, 'dropout': [0, 0, 0]}, 'num_heads': 2},
            norm = {'add_and_norm': False, 'combined': True},
            config = {'opt': 'adam', 'lr': .0005, 'lr_kwargs': {}},
            clr = {'base_lr': .0003, 'max_lr': .0005, 'kwargs': {'step_size_up': 36, 'mode': 'triangular', 'base_momentum': 0.8, 'max_momentum': 0.9}},
            fusion = {'type': 'LQA', 'cross_embed_red': 'avg', 'dropout': 0}
        )
        self.params1.select_features(self.g)
    
    def _test_models(self):
        self._base_lunar() # Loss: 0.032 (0.009 + 0.023)  AUC: 0.8284  AUPRC: 0.7659  ACC: 0.7241  PPV: 0.7500  TPR: 0.5000  TNR: 0.8824  CM: [tn=15, fp=2, fn=6, tp=6]
        self._lunar_ablation('satt', 'may7_lunar_self') # Loss: 0.032 (0.009 + 0.023)  AUC: 0.7745  AUPRC: 0.6789  ACC: 0.6897  PPV: 0.7143  TPR: 0.4167  TNR: 0.8824  CM: [tn=15, fp=2, fn=7, tp=5]
        self._lunar_ablation('catt', 'may7_lunar_cross') # Loss: 0.032 (0.009 + 0.023)  AUC: 0.7745  AUPRC: 0.7266  ACC: 0.6897  PPV: 0.7143  TPR: 0.4167  TNR: 0.8824  CM: [tn=15, fp=2, fn=7, tp=5]
        self._lunar_ablation('natt', 'may7_lunar_none') # Loss: 0.032 (0.009 + 0.023)  AUC: 0.7892  AUPRC: 0.7620  ACC: 0.6897  PPV: 0.7143  TPR: 0.4167  TNR: 0.8824  CM: [tn=15, fp=2, fn=7, tp=5]
    
    def _base_lunar(self): 
        self.lunar = load_saved_model('may7_lunar', self.params1, self.device, folder='data/models')
        self.lunar_tracker, self.lunar_att, self.glass_tracker, self.glass_att = self._test('base', self.lunar, self.params1)
        self._update_tracker_dict('base', self.lunar_tracker, self.glass_tracker)

    def _lunar_ablation(self, kind, filename):
        abl_params = copy.deepcopy(self.params1)
        if kind in ['satt', 'natt']: abl_params.att['cross']['use'] = False
        if kind in ['catt', 'natt']: abl_params.att['self']['use'] = False
        abl_model = load_saved_model(filename, abl_params, self.device, folder='data/models')
        t_trkr, t_att, g_trkr, g_att = self._test(kind, abl_model, abl_params)
        self._update_tracker_dict(kind, t_trkr, g_trkr)
        new_attribs = {f'params1_{kind}':abl_params, f'lunar_{kind}':abl_model, f'lunar_{kind}_tracker':t_trkr, f'lunar_{kind}_att':t_att, f'glass_{kind}_tracker':g_trkr, f'glass_{kind}_att':g_att}
        for aname, aval in new_attribs.items():
            setattr(self, aname, aval)

    def _test(self, kind, model, params):
        print(f'**** {self.name_map[kind]} ****')
        tcga_trkr, tcga_att = test('tcga', self.g, model, params, self.tparams1, self.device, pred_thresh=0.5, return_attn_weights=True, plot_metrics=True)
        glass_trkr, glass_att = test('glass', self.g, model, params, self.tparams2, self.device, pred_thresh=0.5, return_attn_weights=True, plot_metrics=True)
        return tcga_trkr, tcga_att, glass_trkr, glass_att

    def _update_tracker_dict(self, kind, tcga_trkr, glass_trkr):
        for ds, trkr in zip(['tcga', 'glass'], [tcga_trkr, glass_trkr]):
            self.all_trackers[ds][self.name_map[kind]] = trkr
            self.stats_dict[ds][self.name_map[kind]] = trkr.metrics

lunaresults = LUNAResults(g, device, save_results=True, fn='lunar_results_dict')

In [None]:
class AttentionBoxPlots:
    def __init__(self, tcga_att_tensor, glass_att_tensor):
        self.mods = ['Clinical', 'Expression', 'Mutation']
        self.tcga = self._extract_data(tcga_att_tensor, 'TCGA')
        self.glass = self._extract_data(glass_att_tensor, 'GLASS')
        self.att_df = pd.concat([self.tcga['stats'], self.glass['stats']], ignore_index=True).sort_values('Label').reset_index(drop=True).drop(columns=['Label'])
        self._att_df_to_str()

    def _extract_data(self, tnsr, name):
        att_tnsr = tnsr.squeeze(2).mean(dim=1).cpu().numpy() # Squeeze singleton dim and average across heads
        att_data = [att_tnsr[:, i] for i in range(len(self.mods))]
        att_stats = self._compute_box_stats(att_data, [f'{mod} {name}' for mod in self.mods])
        return {'data': att_data, 'stats': att_stats}

    """ Compute boxplot statistics using the standard Tukey boxplot method (whis=1.5 corresponds to Tukey's original definition of boxplots) """
    def _compute_box_stats(self, data_list, labels, whis=1.5, bootstrap=100):
        bp_stats = boxplot_stats(data_list, whis=whis, labels=labels, bootstrap=bootstrap) # Lower whisker = Q1-whis*IQR; Upper whisker = Q3+whis*IQR; IQR = Q3-Q1
        rows = []
        for i, s in enumerate(bp_stats):
            row = {k: round(s[k], 3) if k != 'label' else s[k] for k in ['label', 'whislo', 'q1', 'med', 'q3', 'iqr', 'whishi']}
            row['n'] = len(data_list[i]) # add sample size
            rows.append(row)
        return pd.DataFrame(rows).rename(columns={'label':'Label', 'whislo':'Lower whisker', 'q1':'Q1', 'med':'Median', 'q3':'Q3', 'iqr':'IQR', 'whishi':'Upper whisker', 'n':'Sample size'})

    def _att_df_to_str(self):
        df_str = self.att_df.copy()
        for col in df_str.columns:
            if pd.api.types.is_numeric_dtype(df_str[col]):
                df_str[col] = df_str[col].map('{:.3f}'.format)
        self.att_df_str = df_str

    """ Plots modality-level fusion attention weights for TCGA and GLASS datasets """
    def plot_fusion_attention(self, fsize=(6, 5), offset=0.15, dpi=1200, spacing=1.5):
        fig = plt.figure(figsize=fsize, dpi=dpi)
        gs  = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.22)
        ax1 = fig.add_subplot(gs[0]) # box plot
        base_y = np.arange(1, len(self.mods)+1) * spacing
        bp_tcga = ax1.boxplot(self.tcga['data'], positions=base_y-offset, widths=0.35, vert=False, patch_artist=True, boxprops=dict(facecolor='#bd5090'), medianprops=dict(color='black'))
        bp_glass = ax1.boxplot(self.glass['data'], positions=base_y+offset, widths=0.35, vert=False, patch_artist=True, boxprops=dict(facecolor='#f96361'), medianprops=dict(color='black'))
        ax1.set_yticks(base_y)
        ax1.set_yticklabels(self.mods)
        ax1.set_xlabel("Attention Weight", fontsize=10)
        ax1.set_title('Modality-level Fusion Attention by Testing Set', fontweight='bold')
        ax1.legend([bp_tcga["boxes"][0], bp_glass["boxes"][0]], ["TCGA testing set", "GLASS testing set"], loc="best")
        ax2 = fig.add_subplot(gs[1]) # table 
        ax2.axis('off')
        table_labels = ['GLASS: Clinical', 'TCGA: Clinical', 'GLASS: Expression', 'TCGA: Expression', 'GLASS: Mutation', 'TCGA: Mutation']
        table = ax2.table(cellText=self.att_df_str.values, colLabels=self.att_df_str.columns, rowLabels=table_labels, rowLoc='right', loc='center', cellLoc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(7.9)
        table.scale(1, 1.1)
        for (row, col), cell in table.get_celld().items():
            if row == 0 or col == -1: cell.set_facecolor("gainsboro")
        plt.show()

att_bps = AttentionBoxPlots(lunaresults.lunar_att['fusion'], lunaresults.glass_att['fusion'])
att_bps.plot_fusion_attention(fsize=(7.5, 5), offset=0.22, spacing=1.0)

# att_bps.att_df.to_csv("data/attention_boxplot_table.csv", index=False)
# def save_attention_csv(att_tensor, fname, mods=['Clinical', 'Expression', 'Mutation']):
#     att_arr = att_tensor.squeeze(2).mean(dim=1).cpu().numpy()  # shape: [n, 3]
#     df = pd.DataFrame(att_arr, columns=mods)
#     df.to_csv(fname, index=False)
# save_attention_csv(lunaresults.lunar_att['fusion'], 'data/tcga_attention_weights.csv')
# save_attention_csv(lunaresults.glass_att['fusion'], 'data/glass_attention_weights.csv')
# ^^ all three zipped into Supplementary Data 6.zip

In [None]:
class SHAPWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.model.eval()
    def forward(self, *inputs):
        out = self.model(*inputs)
        return out['y_logits'].unsqueeze(1) # SHAP expects shape [N, 1]

class SHAPAnalysis:
    def __init__(self, action, g, model, params, device, fn=''):
        self.device = device
        self.params = params
        if action == 'load': self.shap_df = load_data('pd', fn, 'data/out/')
        else: self._run(g, model, fn)
        self._get_stats()

    def _run(self, g, model, fn):
        def get_tensor_data(df_x):
            return torch.tensor(df_x.values, dtype=torch.float32).to(self.device)
        def format_shap_vals(svals, idxs):
            return pd.concat([pd.DataFrame(svals[i], columns=self.params.features[i], index=idxs) for i in range(3)], axis=1)
        # since TCGA was the "source domain" (labeled) during training, it should be used as the background (not glass), that's why we can only explain TCGA
        background = [get_tensor_data(g.tcga_x_train[f]) for f in self.params.features] # for reference
        tcga_inputs = [get_tensor_data(g.tcga_x_test[f]) for f in self.params.features] # to explain (tcga)
        wrapper = SHAPWrapper(model)
        explainer = shap.DeepExplainer(wrapper, background)
        self.shap_df = format_shap_vals(explainer.shap_values(tcga_inputs), g.tcga_x_test.index)   
        if fn != '': write_file('pd', fn, self.shap_df, 'data/out/')    
    
    def get_stats(self):
        mod_imp, mod_means = {}, {}
        for i, modality in enumerate(['clin', 'expr', 'mut']):
            abs_svals = self.shap_df[self.params.features[i]].abs()
            feature_means = abs_svals.mean() # mean importance per feature in modality X (feature-wise stats across patients)
            row_sums = abs_svals.sum(axis=1) # total importance of modality X per sample/patient (patient-wise stats across features)
            mod_means[modality] = feature_means
            mod_imp[modality] = {
                'f_mean': feature_means.mean(), # mean feature importance for modality X
                'f_med': feature_means.median(), # median feature importance for modality X
                't_mean': row_sums.mean(), # average total importance of modality X
                't_med': row_sums.median() # median total importance of modality X
            }
        self.mod_means = pd.DataFrame(mod_means)
        self.mod_imp = pd.DataFrame.from_dict(mod_imp, orient='index')
        self.avgs = pd.DataFrame(self.shap_df.abs().mean(), columns=['Average']).sort_values(by='Average', ascending=False) # total averages
    
    def print_modality_importance(self):
        modality_names = {'clin': 'Clinical', 'expr': 'Expression', 'mut': 'Mutation'}
        for modality, row in self.mod_imp.iterrows():
            full_name = modality_names.get(modality, modality.capitalize())
            f_mean, f_med, t_mean, t_med = [f"{row[k]:.5f}" for k in ['f_mean', 'f_med', 't_mean', 't_med']]
            print(f"{full_name:<12} feature importance:    mean = {f_mean:<10} \tmedian = {f_med}")
            print(f"{'':<12} total importance:      mean = {t_mean:<10} \tmedian = {t_med}")

    def plot_results(self, max_display=20, dpi=1200, w=15, h=10, p_lower_margin_increase=0):
        def clean_clinical(clinf):
            if clinf == 'IDHwt': return 'IDH-wildtype'
            elif clinf == 'astrocytoma_wt': return 'Astrocytoma wildtype'
            else: return clinf.replace('_', ' ').capitalize()
        f = [self.params.clin_features, self.params.expr_features, self.params.mut_features]
        features = [*f[0], *f[1], *f[2]]
        clean = [clean_clinical(i) for i in f[0]] + [i.replace('_expr', ' expression') for i in f[1]] + [i.replace('_mut', ' mutations') for i in f[2]]
        fig = plt.figure(figsize=(w, h), dpi=dpi)
        shap.summary_plot(self.shap_df.values, features=g.tcga_x_test[features], feature_names=clean, max_display=max_display, show=False, plot_type='dot')
        plt.title(f'TCGA: SHAP Feature Importance', pad=10, color='black', fontsize='large', fontweight='bold')
        plt.xlabel(plt.gca().get_xlabel(), labelpad=10, color='black', fontsize='large')
        plt.xticks(color='black')
        ax = plt.gca()
        ax.tick_params(axis='y', pad=5, reset=True, color='black', labelsize=11, right=False)
        if p_lower_margin_increase != 0:
            y_min, y_max = ax.get_ylim()
            ax.set_ylim(y_min - (y_max - y_min) * p_lower_margin_increase, y_max)
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_color('black')
        for collection in ax.collections:
            collection.set_sizes([18.5])
        plt.gcf().set_size_inches(w, h)
        plt.show()

shap_analysis = SHAPAnalysis('load', g, lunaresults.lunar, lunaresults.params1, device, "tcga_shap")
shap_analysis.plot_results(w=9, h=7.25, max_display=20)

# shap_analysis.shap_df.to_csv('data/Supplementary Data 2.zip', compression={'method': 'zip', 'archive_name': 'tcga_shap.csv'})

In [None]:
class MLbaselines:
    def __init__(self, action, g, params, path='data/out/', fn='baselines_results_dict', save=False, decimalpts=3):
        self.decimalpts = decimalpts
        self.features = params.clin_features + params.expr_features + params.mut_features
        self.models = {
            'lr': sklearn.linear_model.LogisticRegression(max_iter=1000), 
            'lsvc': sklearn.svm.LinearSVC(penalty='l1', max_iter=1000, dual=False), 
            'knn': sklearn.neighbors.KNeighborsClassifier(),
            'xgb': XGBClassifier(),
            'mlp': sklearn.neural_network.MLPClassifier(hidden_layer_sizes=(256, 128, 64))
        }
        self.sd = load_data('json', fn, path) if action == 'load' else self._run_models(g, path, fn, save)
        self.tcga_df, self.glass_df = [self._dict_to_df(self.sd[ds]) for ds in ['tcga', 'glass']]
        self._order_results()
    
    def _run_models(self, g, path, fn, save):
        xtrain, ytrain = pd.concat([g.tcga_x_train, g.tcga_x_val]), pd.concat([g.tcga_y_train, g.tcga_y_val])
        stats_dict = defaultdict(lambda: defaultdict(dict))
        for name, model in self.models.items():
            model.fit(xtrain[self.features], ytrain)
            stats_dict['tcga'][name] = self._evaluate(model, g.tcga_x_test, g.tcga_y_test)
            stats_dict['glass'][name] = self._evaluate(model, g.glass_x_test, g.glass_y_test)
        if save: write_file('json', fn, convert_numpy(stats_dict), path)
        return stats_dict

    def _evaluate(self, model, xtest, ytest):
        y_pred = model.predict(xtest[self.features])
        y_out = model.predict_proba(xtest[self.features])[:, 1] if hasattr(model, 'predict_proba') else model.decision_function(xtest[self.features])
        results = calc_metrics(ytest, y_out, y_pred, plot_metrics=True)
        if len(np.unique(y_pred)) == 1 or np.allclose(y_out, y_out[0]):
            print(f"!! {model.__class__.__name__} degenerate output. Number of unique y_out = {len(np.unique(y_out))}, Number of unique y_pred = {len(np.unique(y_pred))}")
            results['fallback'] = True
        else: results['fallback'] = False
        return results

    def _dict_to_df(self, sd):
        df_metrics = [i for i in sd['lr'].keys() if i not in ['prc_prec', 'prc_rec', 'fpr', 'tpr']]
        return pd.DataFrame({m: [round(sd[mdl][m], self.decimalpts) for mdl in sd.keys()] for m in df_metrics}, index=sd.keys())

    def _order_results(self):
        self.tcga_df = self.tcga_df.sort_values(by=['auc','auprc'], ascending=False)
        self.sd['tcga'] = {key: self.sd['tcga'][key] for key in self.tcga_df.index}
        self.glass_df = self.glass_df.sort_values(by=['auc','auprc'], ascending=False)
        self.sd['glass'] = {key: self.sd['glass'][key] for key in self.glass_df.index}

    def add_lunar(self, all_trackers):
        for ds, subdict in all_trackers.items():
            for lunar_name, trkr in subdict.items():
                new_row = {k: round(v, self.decimalpts) for k, v in trkr.metrics.items() if k not in ['prc_prec', 'prc_rec', 'fpr', 'tpr']}
                if ds == 'tcga': self.tcga_df.loc[lunar_name] = new_row
                else: self.glass_df.loc[lunar_name] = new_row
                self.sd[ds][lunar_name] = trkr.metrics
        self.order_results()

baselines = MLbaselines('load', g, lunaresults.params1, fn='baselines_results_dict', save=True, decimalpts=4)
baselines.add_lunar(lunaresults.all_trackers)

In [None]:
names = {'LUNAR': 'LUNAR', 'LUNAR-CAtt': 'LUNAR-CAtt', 'LUNAR-SAtt': 'LUNAR-SAtt', 'LUNAR-NAtt': 'LUNAR-NAtt', 'xgb': 'XGBoost', 'lsvc': 'Linear SVC', 'lr': 'Log Reg', 'mlp': 'MLP', 'knn': 'K-Nearest'}

style1 = {'lw':2, 'solid_joinstyle':'round'}
style2 = {'lw':2, 'dash_joinstyle':'round', 'dash_capstyle':'round'}
line_props = {
    'LUNAR': {'color':'#FC1C7D', **style1}, 'LUNAR-NAtt': {'color':'#FC3D9E', 'dashes':[3,1.7], **style2}, 'LUNAR-CAtt': {'color':'#FD6BCB', **style1}, 'LUNAR-SAtt': {'color':'#fda0ff', 'dashes':[2,2], **style2},
    'xgb': {'color':'#FFD700', **style1}, 'lsvc': {'color':'#6600FF', **style1}, 'lr': {'color':'#02B902', **style1}, 'mlp': {'color':'#0073E6', **style1}, 'knn': {'color':'#686868', **style1}
}

def plot_roc_prc_stacked(sd, dpi=1200, fsize=(11, 11), wspace=0.23):
    fig = plt.figure(figsize=fsize, dpi=dpi)
    gs = GridSpec(4, 2, height_ratios=[4, .9, 4, .9], wspace=wspace)
    roc_axT, prc_axT, roc_leg_axT, prc_leg_axT = [fig.add_subplot(gs[i]) for i in range(4)]
    roc_axG, prc_axG, roc_leg_axG, prc_leg_axG = [fig.add_subplot(gs[i]) for i in range(4, 8)]
    def sort_by_metric(stats_dict, metric):
        return dict(sorted(stats_dict.items(), key=lambda item: item[1][metric], reverse=False))
    for roc_ax, prc_ax, roc_leg, prc_leg, source in zip([roc_axT, roc_axG], [prc_axT, prc_axG], [roc_leg_axT, roc_leg_axG], [prc_leg_axT, prc_leg_axG], ['tcga', 'glass']):
        line = roc_ax.plot([0, 1], [0, 1], color='#B2B2B2', lw=1.75, zorder=1)[0] # random chance
        roc_ax.legend([line], ['Random chance'], loc='lower right', fontsize=7, frameon=False)
        for m, v in sort_by_metric(sd[source], 'auc').items():
            fallback = v.get('fallback', False)
            style = {'color':'#CCCCCC', 'lw':1.5, 'linestyle':'--'} if fallback else line_props[m]
            lbl = names[m] + ("*" if fallback else "") + f": {round(v['auc']*100,2)}%"
            roc_ax.plot(v['fpr'], v['tpr'], **style, drawstyle=None if fallback else 'steps', label=lbl)
        for m, v in sort_by_metric(sd[source], 'auprc').items():
            fallback = v.get('fallback', False)
            style = {'color':'#CCCCCC', 'lw':1.5, 'linestyle':'--'} if fallback else line_props[m]
            lbl = names[m] + ("*" if fallback else "") + f": {round(v['auprc']*100,2)}%"
            prc_ax.plot(v['prc_rec'], v['prc_prec'], **style, drawstyle='steps', label=lbl)
        for i, (ax, leg) in enumerate(zip([roc_ax, prc_ax], [roc_leg, prc_leg])):
            ax.set_xlabel(['False Positive Rate', 'Recall'][i], fontsize=9, fontweight=540)
            ax.set_ylabel(['True Positive Rate', 'Precision'][i], fontsize=9, fontweight=540)
            ax.set_title(f"LUNAR: {source.upper()} {['ROC', 'Precision-Recall'][i]} Curve", fontsize=10, fontweight='bold')
            ax.tick_params(axis='both', labelsize=7.5)
            handles, labels = ax.get_legend_handles_labels()
            leg.legend(handles, labels, loc='lower center', reverse=True, alignment='left', fontsize=7, handlelength=1.3, handletextpad=.7, labelspacing=.35, columnspacing=1.2, ncols=3, borderaxespad=0, edgecolor='gray', prop={'weight':530, 'size':7})
            leg.set_xticks([])
            leg.set_yticks([])
            leg.set_frame_on(False)
    plt.show()