© 2024 Nokia
Licensed under the BSD 3 Clause Clear License  
SPDX-License-Identifier: BSD-3-Clause-Clear

In [63]:
import math
import os
import re

import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from sklearn_extra.cluster import KMedoids
import sklearn
import sklearn.manifold
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise_distances_argmin_min
from scipy.spatial.distance import cdist, cityblock
from statsmodels.compat import scipy

num_classes = 2
intermediate_layer = 7  # last CNN layer
sns.set(style='white', context="poster", font='Work Sans Regural', palette=sns.color_palette(["#1F968BFF", "#FDE725FF", "#d1d3d4"], num_classes))

In [64]:
dataset = 'MIMIC'

models = [
    os.path.join('SimCLR', dataset, '20230403-120645_e100_esFalse_bs128_wTrue_f2_fl', 'simclr.frozen.0.75.hdf5'),
    os.path.join('SimCLR', dataset, '20230403-120904_e100_esFalse_bs128_wTrue_f1_m', 'simclr.frozen.0.76.hdf5'),
    os.path.join('SimCLR', dataset, '20230403-121322_e100_esFalse_bs128_wTrue_f0', 'simclr.frozen.0.77.hdf5'),
    os.path.join('SimCLR', dataset, '20230403-121747_e100_esFalse_bs128_wTrue_f1', 'simclr.frozen.0.79.hdf5'),
    os.path.join('SimCLR', dataset, '20230403-122120_e100_esFalse_bs128_wTrue_f2', 'simclr.frozen.0.69.hdf5'),
    os.path.join('SimCLR', dataset, '20230403-122444_e100_esFalse_bs128_wTrue_f3', 'simclr.frozen.0.78.hdf5'),
    os.path.join('Supervised', dataset, '20230403-133103_l2_e100_esFalse_bs128_wTrue', 'supervised.finetuned.0.81.hdf5')
]

columns = ['INSURANCE', 'ETHNICITY', 'GENDER', 'AGE', 'LANGUAGE']
best_groups = ['Self Pay', 'HISPANIC', 'M', '1', 'ENGL']
worst_groups = ['Medicaid', 'BLACK', 'F', '0', 'OTHER']
# less than 65: 1

In [65]:
dataset = 'MESA'

models = [
    os.path.join('SimCLR', dataset, '20231201-084504_e200_esTrue_bs128_wTrue_f2_fl', 'simclr.frozen.16.12.hdf5'),
    os.path.join('SimCLR', dataset, '20231201-085552_e200_esTrue_bs128_wTrue_f1_m', 'simclr.frozen.16.93.hdf5'),
    os.path.join('SimCLR', dataset, '20231201-090723_e200_esTrue_bs128_wTrue_f0', 'simclr.frozen.6.26.hdf5'),
    os.path.join('SimCLR', dataset, '20231201-091643_e200_esTrue_bs128_wTrue_f1', 'simclr.frozen.9.05.hdf5'),
    os.path.join('SimCLR', dataset, '20231201-092528_e200_esTrue_bs128_wTrue_f2', 'simclr.frozen.11.65.hdf5'),
    os.path.join('SimCLR', dataset, '20231201-093337_e200_esTrue_bs128_wTrue_f3', 'simclr.frozen.25.36.hdf5'),
    os.path.join('Supervised', dataset, '20231204-074617_l2_e200_esTrue_bs64_wTrue', 'supervised.finetuned.0.59.hdf5'),
]

columns = ['nsrr_sex', 'nsrr_age_gt65', 'nsrr_race']
best_groups = ['male', 'no', 'white']
worst_groups = ['female', 'yes', 'asian']

In [66]:
dataset = 'GLOBEM'

models = [
    os.path.join('SimCLR', dataset, '20231128-152529_e200_esTrue_bs128_wTrue_f2_fl', 'simclr.frozen.1.41.hdf5'),
    os.path.join('SimCLR', dataset, '20231128-152400_e200_esTrue_bs128_wTrue_f1_m', 'simclr.frozen.1.58.hdf5'),
    os.path.join('SimCLR', dataset, '20231128-152548_e200_esTrue_bs128_wTrue_f0', 'simclr.frozen.1.56.hdf5'),
    os.path.join('SimCLR', dataset, '20231128-152601_e200_esTrue_bs128_wTrue_f1', 'simclr.frozen.1.08.hdf5'),
    os.path.join('SimCLR', dataset, '20231128-152630_e200_esTrue_bs128_wTrue_f2', 'simclr.frozen.1.19.hdf5'),
    os.path.join('SimCLR', dataset, '20231128-152654_e200_esTrue_bs128_wTrue_f3', 'simclr.frozen.1.71.hdf5'),
    os.path.join('Supervised', dataset, '20231115-091914_l2_e200_esTrue_bs64_wTrue', 'supervised.finetuned.hdf5')
]

columns = ['gender', 'race', 'disability']
best_groups = ['Male', 'White', 'No']
worst_groups = ['Female', 'Biracial', 'Yes']

In [67]:
# read demographics
test = (np.load(os.path.join('SimCLR', dataset, 'test_x.npy')),
           np.load(os.path.join('SimCLR', dataset, 'test_y.npy')))
test_listfile = pd.read_csv(os.path.join('..', '..', 'datasets', dataset, 'test_listfile.csv'))
test_listfile.head()

Unnamed: 0,PID,gender,age,race,generation,disability,year
0,1221.0,1.0,21.0,Asian,1.0,1.0,4
1,1221.0,1.0,21.0,Asian,1.0,1.0,4
2,1221.0,1.0,21.0,Asian,1.0,1.0,4
3,1221.0,1.0,21.0,Asian,1.0,1.0,4
4,1221.0,1.0,21.0,Asian,1.0,1.0,4


In [68]:
if dataset == 'MIMIC':
    subjects = pd.read_csv(os.path.join('..', '..',  'datasets', dataset, 'demographics_rich.csv'))
    # extract subject_id from stay
    regex = r"(?:^\d+)"
    test_listfile.loc[:, "SUBJECT_ID"] = test_listfile.stay.apply(lambda stay: re.search(regex, stay).group(0))
    test_listfile.SUBJECT_ID = test_listfile.SUBJECT_ID.astype(int)
    test_listfile.drop(['stay'], axis=1, inplace=True)
    test_listfile = test_listfile.merge(subjects, how='left', on='SUBJECT_ID')
elif dataset == 'MESA':
        subjects = pd.read_csv(os.path.join('..', '..',  'datasets', dataset, 'demographics.csv'), delimiter=';')
        test_listfile = test_listfile.merge(subjects, how='left', on='mesaid')

test_listfile.head()

Unnamed: 0,PID,gender,age,race,generation,disability,year
0,1221.0,1.0,21.0,Asian,1.0,1.0,4
1,1221.0,1.0,21.0,Asian,1.0,1.0,4
2,1221.0,1.0,21.0,Asian,1.0,1.0,4
3,1221.0,1.0,21.0,Asian,1.0,1.0,4
4,1221.0,1.0,21.0,Asian,1.0,1.0,4


In [69]:
def transform_mimic(test_listfile):
    test_listfile.LANGUAGE = test_listfile.LANGUAGE.apply(lambda lang: lang if lang=='ENGL' else 'OTHER')
    test_listfile.ETHNICITY = test_listfile.ETHNICITY.apply(lambda v: "WHITE" if "WHITE" in v else ("BLACK" if "BLACK" in v else ("HISPANIC" if "HISPANIC" in v else ("ASIAN" if "ASIAN" in v else "OTHER"))))
    test_listfile.AGE = test_listfile.AGE.astype(str)
    return test_listfile

def transform_mesa(test_listfile):
    test_listfile.loc[:, 'nsrr_age_gt65'] = test_listfile.nsrr_age.map(lambda age: 'no' if age < 65 else 'yes')
    return test_listfile

def transform_globem(test_listfile):
    test_listfile.gender = test_listfile.gender.apply(lambda v: 'Male' if v==1 else ('Female' if v==2 else ('Transgender' if v==3 else 'Genderqueer ' if v==4 else 'Other')))
    test_listfile.disability = test_listfile.disability.apply(lambda v: 'Yes' if v==1 else 'No')
    test_listfile.race = test_listfile.race.apply(lambda v: 'Other' if pd.isna(v) else v)
    return test_listfile

if dataset == 'MIMIC':
    test_listfile = transform_mimic(test_listfile)
elif dataset == 'MESA':
    test_listfile = transform_mesa(test_listfile)
else:
    test_listfile = transform_globem(test_listfile)
    
test_listfile.head()

Unnamed: 0,PID,gender,age,race,generation,disability,year
0,1221.0,Male,21.0,Asian,1.0,Yes,4
1,1221.0,Male,21.0,Asian,1.0,Yes,4
2,1221.0,Male,21.0,Asian,1.0,Yes,4
3,1221.0,Male,21.0,Asian,1.0,Yes,4
4,1221.0,Male,21.0,Asian,1.0,Yes,4


## Get predictions and representations

In [70]:
model_file_name = models[1]
model_name = 'ssl'

tf.keras.backend.clear_session()
model = tf.keras.models.load_model(model_file_name, compile=False)
# predictions
np_test = (np.load(os.path.join(os.path.join('SimCLR', dataset), 'test_x.npy')),
           np.load(os.path.join(os.path.join('SimCLR', dataset), 'test_y.npy')))
probs = model.predict(np_test[0])
predictions = np.argmax(probs, axis=1)
test_listfile.loc[:, "y_pred"] = predictions

# representations
layer_model = model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[intermediate_layer].output)
intermediate_representation = layer_model.predict(test[0], batch_size=600)
print(intermediate_representation.shape)
layer_model.summary()

(1083, 96)
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           [(None, 28, 1390)]        0         
_________________________________________________________________
conv1d (Conv1D)              (None, 28, 32)            1067552   
_________________________________________________________________
dropout (Dropout)            (None, 28, 32)            0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 28, 64)            32832     
_________________________________________________________________
dropout_1 (Dropout)          (None, 28, 64)            0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 28, 96)            49248     
_________________________________________________________________
dropout_2 (Dropout)          (None, 28, 96)

## Get distances

In [71]:
def get_distance(test_listfile, intermediate_representation, column, worst_group, rest_groups=None, best_group=None, intra_group=False):
    GetMedoid = lambda vX: KMedoids(n_clusters=1, metric='manhattan').fit(vX).cluster_centers_
    worst_rows = test_listfile.index[test_listfile[column].str.contains(worst_group)].tolist()
    worst_group_medoid = GetMedoid(intermediate_representation[worst_rows, :])
    
    if intra_group:
        intra_dist_total = 0
        for group in test_listfile[col].value_counts().keys():
            group_rows = test_listfile.index[test_listfile[column].str.contains(group)].tolist()   
            group_reps = intermediate_representation[group_rows,:]
            intra_dist = cdist(group_reps, group_reps, metric='cityblock').mean()
            intra_dist_total += intra_dist
            print("Intra-group ({}) distance: {}".format(group, intra_dist))
        intra_dist_avg = intra_dist_total / len(test_listfile[col].value_counts().keys().tolist())
        print("Average Intra-group ({}) distance: {}".format(column, intra_dist_avg))
    
    if best_group:
        best_rows = test_listfile.index[test_listfile[column].str.contains(best_group)].tolist()
        best_group_medoid = GetMedoid(intermediate_representation[best_rows, :])
        dist = cityblock(worst_group_medoid, best_group_medoid)
        print("Distance ({}) between {} and {}: {}".format(column, best_group, worst_group, dist))
        return dist
    
    total_dist = 0
    for group in rest_groups:
        group_rows = test_listfile.index[test_listfile[column].str.contains(group)].tolist()
        group_medoid = GetMedoid(intermediate_representation[group_rows, :])
        total_dist += cityblock(worst_group_medoid, group_medoid)
    avg_dist = total_dist / len(rest_groups)
    print("Total Distance ({}) between {} and rest: {}".format(column, worst_group, total_dist))
    print("Average Distance ({}) between {} and rest: {}".format(column, worst_group, avg_dist))
    return total_dist, avg_dist 

In [72]:
for idx, col in enumerate(columns):
    # worst vs rest groups
    rest_groups = test_listfile[col].value_counts().keys().tolist()
    rest_groups.remove(worst_groups[idx])
    rest = get_distance(test_listfile, intermediate_representation, col, worst_groups[idx], rest_groups=rest_groups)
    
    # worst vs best group
    best = get_distance(test_listfile, intermediate_representation, col, worst_groups[idx], best_group=best_groups[idx], intra_group=False)
    
    print("----------------------------------------------------")

Total Distance (gender) between Female and rest: 25.525755882263184
Average Distance (gender) between Female and rest: 6.381438970565796
Distance (gender) between Male and Female: 4.617440223693848
----------------------------------------------------
Total Distance (race) between Biracial and rest: 27.62401795387268
Average Distance (race) between Biracial and rest: 6.90600448846817
Distance (race) between White and Biracial: 3.9554522037506104
----------------------------------------------------
Total Distance (disability) between Yes and rest: 4.885066032409668
Average Distance (disability) between Yes and rest: 4.885066032409668
Distance (disability) between No and Yes: 4.885066032409668
----------------------------------------------------
