In [None]:
import glob
import itertools
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import os
import pandas as pd
import pickle
import seaborn as sns
import shap
import statsmodels.api as sm
import torch
import warnings
from confidenceinterval import (
    tpr_score, 
    tnr_score,
)
from scipy.stats import spearmanr, pearsonr, skew, skewtest
from sklearn.metrics import mean_absolute_error

import sys
sys.path.append('..')
from libs.eval import opt_youden_j_binary
from libs.compare_roc_auc import delong_roc_test

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

In [None]:
%matplotlib inline

plt.style.use('default')
  
DPI = 600
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = DPI
plt.rcParams['figure.autolayout'] = False
plt.rcParams['font.size'] = 8
plt.rcParams['legend.fontsize'] = 9
sns.set_context('paper', font_scale=0.9)

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
ukb_results = {}
ship_results = {}
for res in glob.glob('../*.csv'):
    if 'SHIP' in res and not '202506' in res:
        continue
    fname = '_'.join(res.split('/')[-1].split('.')[0].split('_')[1:])
    if 'SHIP' in res:
        ship_results[fname] = pd.read_csv(res)
    else:
        ukb_results[fname] = pd.read_csv(res)

In [None]:
UKB_KEYS = sorted(ukb_results.keys())
SHIP_KEYS = sorted(ship_results.keys())

In [None]:
def _str_rounded_err(x_xs, n=3):
    x, (xs_l, xs_u) = x_xs
    return f"{np.round(x, n)} ({np.round(xs_l, n)},{np.round(xs_u, n)})"
def _print_rounded_err(x_xs, n=3):
    print(_str_rounded_err(x_xs, n))

In [None]:
def _clf_perf(y_true, y_proba):
    return opt_youden_j_binary(y_true, y_proba, avg_method="macro")

def _regr_perf(y_true, y_pred):
    pears = pearsonr(y_pred, y_true)
    ols_model = sm.OLS(y_true, sm.add_constant(y_pred))
    ols_results = ols_model.fit()
    adj_rsq = ols_results.rsquared_adj
    ols_fpvalue = ols_results.f_pvalue
    maes = np.abs(y_true - y_pred)
    bootstrap_mae = np.array(
        [
            np.mean(np.random.choice(maes, size=len(maes), replace=True))
            for _ in range(5000)
        ]
    )
    mae_lower, mae_upper = np.percentile(bootstrap_mae, [2.5, 97.5])
    mae_mean = mean_absolute_error(y_pred, y_true)
    mae = mae_mean, (mae_lower, mae_upper)


    mes = y_true - y_pred
    bootstrap_me = np.array(
        [
            np.mean(np.random.choice(mes, size=len(maes), replace=True))
            for _ in range(5000)
        ]
    )
    me_lower, me_upper = np.percentile(bootstrap_me, [2.5, 97.5])
    print(np.mean(mes), me_lower, me_upper)
    print(skew(mes))

    
    return pears, (adj_rsq, ols_fpvalue), mae

In [None]:
def _flatten_tuple(t):
    for item in t:
        if isinstance(item, tuple):
            yield from _flatten_tuple(item)
        else:
            yield item
def _flatten(t):
    return [float(x) for x in _flatten_tuple(t)]

In [None]:
HEADER_CAT = ['thresh', 'sens', 'sens_l', 'sens_u', 'spec', 'spec_l', 'spec_u', 'f1', 'f1_l', 'f1_u', 'auc', 'auc_l', 'auc_u']
HEADER_REG = ['pears', 'pears_pval', 'adj_rsq', 'f_pval', 'mae', 'mae_l', 'mae_u']
HEADER = ['model'] + HEADER_CAT + HEADER_REG

# Loaders

In [None]:
from libs.ecg_ukb import ECGMode, ECGTarget, N_STEPS, N_CHANNELS, LEADS
from libs.model import FCN1D, FCN1DConfig
from train import load_data, SPLITS, get_model

In [None]:
SEED = 240302
torch.manual_seed(SEED)
TORCH_SEED = torch.Generator().manual_seed(SEED)
np.random.seed(SEED)

In [None]:
TARGET = ECGTarget.INDEXED_MASS
config = {
    "batch_size": 32,
    #"optimizer_cls": optim.Adam,
    "learning_rate": 0.0005,
    "fcn_config": FCN1DConfig.WANG2016,
    "fcn_batch_norm": True,
    "fcn_max_pool": True,
    "fcn_conv_dropout": 0.4,
    "fcn_linear_dropout": 0.6,
    "excl_meta": False,
    #"soto2022": "resnet34",
}
checkpoint_path = '../UKB_iLVM.202503.pth'
config['target'] = TARGET

_, _, test_loader, _ = load_data(config['batch_size'], TARGET, SPLITS)
model = get_model(config)
model.load_state_dict(torch.load(checkpoint_path))
model.eval()

In [None]:
indices = []
for _, _, _, inds in test_loader:
    indices.append(inds.detach().numpy())
indices = list(itertools.chain(*indices))

inds_is_m = [test_loader.dataset.get_record(i)['Sex'] == 1 for i in indices]
inds_lvh = [test_loader.dataset.get_record(i)['LVM.group'] for i in indices]
inds_ilvm = [test_loader.dataset.get_record(i)['indexed.LVM'] for i in indices]
inds_eid = [test_loader.dataset.get_record(i)['f.eid'] for i in indices]

In [None]:
TEST_FEIDS = pd.read_csv('../ukb_feids.test_split.list')['f.eid'].values
assert all(feid in inds_eid for feid in TEST_FEIDS)

# UKB

In [None]:
recs = []
recs.append(
    ['SVM LVH'] + 
    _flatten(_clf_perf(ukb_results['benchmarks']['lvh_true'].values, ukb_results['benchmarks']['lvh_proba'].values)) + 
    _flatten(_regr_perf(ukb_results['benchmarks']['ilvm_true'].values, ukb_results['iLVM']['ilvm_pred'].values))
)
recs.append(
    ['Cornell'] + 
    _flatten(_clf_perf(ukb_results['benchmarks']['lvh_true'].values, ukb_results['benchmarks']['lvh_pred_cornell_voltage'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['Sokolov-Lyon'] + 
    _flatten(_clf_perf(ukb_results['benchmarks']['lvh_true'].values, ukb_results['benchmarks']['lvh_pred_sokolov_lyon'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['FCN LVH'] + 
    _flatten(_clf_perf(ukb_results['LVH']['lvh_true'].values, ukb_results['LVH']['lvh_proba'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['FCN iLVM'] + 
    _flatten(_clf_perf(ukb_results['iLVM']['lvh_true'].values, ukb_results['iLVM']['lvh_pred_cutoff'].values)) + 
    _flatten(_regr_perf(ukb_results['iLVM']['ilvm_true'].values, ukb_results['iLVM']['ilvm_pred'].values))
)
recs.append(
    ['FCN iLVM + LR'] + 
    _flatten(_clf_perf(ukb_results['iLVM']['lvh_true'].values, ukb_results['iLVM']['lvh_proba_lr'].values)) + 
    _flatten(_regr_perf(ukb_results['iLVM']['ilvm_true'].values, ukb_results['iLVM']['ilvm_pred'].values))
)
recs.append(
    ['R34 LVH'] + 
    _flatten(_clf_perf(ukb_results['LVH_r34']['lvh_true'].values, ukb_results['LVH_r34']['lvh_proba'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['R34 iLVM'] + 
    _flatten(_clf_perf(ukb_results['iLVM_r34']['lvh_true'].values, ukb_results['iLVM_r34']['lvh_pred_cutoff'].values)) + 
    _flatten(_regr_perf(ukb_results['iLVM_r34']['ilvm_true'].values, ukb_results['iLVM_r34']['ilvm_pred'].values))
)
recs.append(
    ['R34 iLVM + LR'] + 
    _flatten(_clf_perf(ukb_results['iLVM_r34']['lvh_true'].values, ukb_results['iLVM_r34']['lvh_proba_lr'].values)) + 
    _flatten(_regr_perf(ukb_results['iLVM_r34']['ilvm_true'].values, ukb_results['iLVM_r34']['ilvm_pred'].values))
)
ukb_df = pd.DataFrame(recs, columns=HEADER)
ukb_df.round(4).T

In [None]:
assert (ukb_results['benchmarks']['f.eid'] ==  ukb_results['iLVM']['eid']).all()

print('SVM - FCN LVH',
10**delong_roc_test(
    ukb_results['benchmarks']['lvh_true'].values,
    ukb_results['benchmarks']['lvh_proba'].values,
    ukb_results['LVH']['lvh_proba'].values
))
print('SVM - FCN LVM',
10**delong_roc_test(
    ukb_results['benchmarks']['lvh_true'].values,
    ukb_results['benchmarks']['lvh_proba'].values,
    ukb_results['iLVM']['lvh_pred_cutoff'].values
))
print('SVM - FCN LVM+LR',
    10**delong_roc_test(
    ukb_results['benchmarks']['lvh_true'].values,
    ukb_results['benchmarks']['lvh_proba'].values,
    ukb_results['iLVM']['lvh_proba_lr'].values
))


In [None]:
recs = []
recs.append(
    ['SVM LVH (ECGo)'] + 
    _flatten(_clf_perf(ukb_results['benchmarks']['lvh_true'].values, ukb_results['benchmarks']['lvh_proba_ecgonly'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['FCN LVH (ECGo)'] + 
    _flatten(_clf_perf(ukb_results['LVH_ecgonly']['lvh_true'].values, ukb_results['LVH_ecgonly']['lvh_proba'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['FCN iLVM (ECGo)'] + 
    _flatten(_clf_perf(ukb_results['iLVM_ecgonly']['lvh_true'].values, ukb_results['iLVM_ecgonly']['lvh_pred_cutoff'].values)) + 
    _flatten(_regr_perf(ukb_results['iLVM_ecgonly']['ilvm_true'].values, ukb_results['iLVM_ecgonly']['ilvm_pred'].values))
)
recs.append(
    ['FCN iLVM + LR (ECGo)'] + 
    _flatten(_clf_perf(ukb_results['iLVM_ecgonly']['lvh_true'].values, ukb_results['iLVM_ecgonly']['lvh_proba_lr'].values)) + 
    _flatten(_regr_perf(ukb_results['iLVM_ecgonly']['ilvm_true'].values, ukb_results['iLVM_ecgonly']['ilvm_pred'].values))
)
recs.append(
    ['R34 LVH (ECGo)'] + 
    _flatten(_clf_perf(ukb_results['LVH_r34_ecgonly']['lvh_true'].values, ukb_results['LVH_r34_ecgonly']['lvh_proba'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['R34 iLVM (ECGo)'] + 
    _flatten(_clf_perf(ukb_results['iLVM_r34_ecgonly']['lvh_true'].values, ukb_results['iLVM_r34_ecgonly']['lvh_pred_cutoff'].values)) + 
    _flatten(_regr_perf(ukb_results['iLVM_r34_ecgonly']['ilvm_true'].values, ukb_results['iLVM_r34_ecgonly']['ilvm_pred'].values))
)
recs.append(
    ['R34 iLVM + LR (ECGo)'] + 
    _flatten(_clf_perf(ukb_results['iLVM_r34_ecgonly']['lvh_true'].values, ukb_results['iLVM_r34_ecgonly']['lvh_proba_lr'].values)) + 
    _flatten(_regr_perf(ukb_results['iLVM_r34_ecgonly']['ilvm_true'].values, ukb_results['iLVM_r34_ecgonly']['ilvm_pred'].values))
)
ukb_df = pd.DataFrame(recs, columns=HEADER)
ukb_df.round(4).T

# SHIP

In [None]:
recs = []
recs.append(
    ['FCN LVH'] + 
    _flatten(_clf_perf(ship_results['LVH']['lvh_true'].values, ship_results['LVH']['lvh_proba'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['FCN iLVM'] + 
    _flatten(_clf_perf(ship_results['iLVM']['lvh_true'].values, ship_results['iLVM']['lvh_pred_cutoff'].values)) + 
    _flatten(_regr_perf(ship_results['iLVM']['ilvm_true'].values, ship_results['iLVM']['ilvm_pred'].values))
)
recs.append(
    ['FCN iLVM + LR'] + 
    _flatten(_clf_perf(ship_results['iLVM']['lvh_true'].values, ship_results['iLVM']['lvh_proba_lr'].values)) + 
    _flatten(_regr_perf(ship_results['iLVM']['ilvm_true'].values, ship_results['iLVM']['ilvm_pred'].values))
)
recs.append(
    ['R34 LVH'] + 
    _flatten(_clf_perf(ship_results['LVH_r34']['lvh_true'].values, ship_results['LVH_r34']['lvh_proba'].values)) + 
    [None]*len(HEADER_REG)
)
recs.append(
    ['R34 iLVM'] + 
    _flatten(_clf_perf(ship_results['iLVM_r34']['lvh_true'].values, ship_results['iLVM_r34']['lvh_pred_cutoff'].values)) + 
    _flatten(_regr_perf(ship_results['iLVM_r34']['ilvm_true'].values, ship_results['iLVM_r34']['ilvm_pred'].values))
)
recs.append(
    ['R34 iLVM + LR'] + 
    _flatten(_clf_perf(ship_results['iLVM_r34']['lvh_true'].values, ship_results['iLVM_r34']['lvh_proba_lr'].values)) + 
    _flatten(_regr_perf(ship_results['iLVM_r34']['ilvm_true'].values, ship_results['iLVM_r34']['ilvm_pred'].values))
)
ship_df = pd.DataFrame(recs, columns=HEADER)
ship_df.round(4).T

# Figures

## LVM predictions

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)

cs_l = sns.color_palette('pastel', 2)
cs_d = sns.color_palette('bright', 2)

df_fig = ukb_results['iLVM']
is_m = df_fig['is_m'].values == 1
is_f = df_fig['is_m'].values == 0
is_lvh_m = is_m & df_fig['lvh_true']
is_lvh_f = is_f & df_fig['lvh_true']
y_true = df_fig['ilvm_true']
y_pred = df_fig['ilvm_pred']

with warnings.catch_warnings(action="ignore"):
    axs[0].scatter(y_pred[is_m], y_true[is_m], s=7, c=cs_l[1], alpha=.5, label='M, predicted LVM')
    axs[0].scatter(y_pred[is_f], y_true[is_f], s=7, c=cs_l[0], alpha=.5, label='F, predicted LVM')
    axs[0].scatter(y_pred[is_lvh_m], y_true[is_lvh_m], s=9, c=cs_d[1], alpha=.9, label='M, CMR-derived LVH')
    axs[0].scatter(y_pred[is_lvh_f], y_true[is_lvh_f], s=9, c=cs_d[0], alpha=.9, label='F, CMR-derived LVH')
    sns.regplot(x=y_pred, y=y_true, scatter=False, ax=axs[0], color='grey', line_kws={'lw':2})
axs[0].set_xlabel(r'Predicted indexed LV mass (g/m$^2$)', fontsize=10)
axs[0].set_ylabel(r'CMR-derived indexed LV mass (g/m$^2$)', fontsize=10)
axs[0].legend(loc="best", fontsize=9)

df_fig = ship_results['iLVM']
is_m = df_fig['is_m'].values == 1
is_f = df_fig['is_m'].values == 0
is_lvh_m = is_m & df_fig['lvh_true']
is_lvh_f = is_f & df_fig['lvh_true']
y_true = df_fig['ilvm_true']
y_pred = df_fig['ilvm_pred']

with warnings.catch_warnings(action="ignore"):
    axs[1].scatter(y_pred[is_m], y_true[is_m], s=9, c=cs_l[1], alpha=.95, label='M, predicted LVM')
    axs[1].scatter(y_pred[is_f], y_true[is_f], s=9, c=cs_l[0], alpha=.95, label='F, predicted LVM')
    axs[1].scatter(y_pred[is_lvh_f], y_true[is_lvh_f], s=11, c=cs_d[0], alpha=.95, label='F, CMR-derived LVH')
    axs[1].scatter(y_pred[is_lvh_m], y_true[is_lvh_m], s=11, c=cs_d[1], alpha=.95, label='M, CMR-derived LVH')
    sns.regplot(x=y_pred, y=y_true, scatter=False, ax=axs[1], color='grey', line_kws={'lw':2})
axs[1].set_xlabel(r'Predicted indexed LV mass (g/m$^2$)', fontsize=10)
axs[1].set_ylabel(r'CMR-derived indexed LV mass (g/m$^2$)', fontsize=10)
axs[1].legend(loc="best", fontsize=9)

axs[0].set_title('UK Biobank (N=7,326)', fontsize=10)
axs[1].set_title('SHIP (N=285)', fontsize=10)

fig.tight_layout()

plt.savefig('../figures/ilvm.reg.202506.svg', bbox_inches='tight')

## AUC comparison

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharey=True, gridspec_kw={'width_ratios': [0.7, 0.3]})

ukb_models = ['Cornell', 'Sokolov-Lyon', 'SVM LVH', 'FCN LVH', 'FCN iLVM', 'FCN iLVM + LR']
ship_models = ['FCN LVH', 'FCN iLVM', 'FCN iLVM + LR']

LABEL_MAP = {
    'SVM LVH': 'SVM',
    'FCN LVH': r'FCN$_{LVH}$',
    'FCN iLVM': r'FCN$_{LVM}$',
    'FCN iLVM + LR': r'FCN$_{LVM}$+LR',
}

n_compared = len(ukb_models)
cs = sns.color_palette("hls", n_colors=n_compared)

ukb_df_comp = ukb_df[ukb_df['model'].isin(ukb_models)]
ukb_df_comp.model = ukb_df_comp.model.astype("category")
ukb_df_comp.model = ukb_df_comp.model.cat.set_categories(ukb_models)

ship_df_comp = ship_df[ship_df['model'].isin(ship_models)]
ship_df_comp.model = ship_df_comp.model.astype("category")
ship_df_comp.model = ship_df_comp.model.cat.set_categories(ship_models)

sns.barplot(data=ukb_df_comp, x='model', y='auc',
            palette=cs, width=0.7,
            ax=axs[0])
axs[0].tick_params(axis='x', rotation=45)
axs[0].set_xticklabels([LABEL_MAP.get(v, v) for v in ukb_models])
axs[0].set_ylim([0.4, 1.0])
axs[0].set_ylabel('AUROC')
axs[0].set_xlabel('Model')
axs[0].set_title('UK Biobank (N=7,326)')

n_compared = len(ship_models)
sns.barplot(data=ship_df_comp, x='model', y='auc',
            palette=cs[-3:], width=0.7,
            ax=axs[1])
axs[1].tick_params(axis='x', rotation=45)
axs[1].set_xticklabels([LABEL_MAP.get(v, v) for v in ship_df_comp['model'].values.tolist()])
axs[1].set_ylim([0.4, 1.0])
axs[1].set_xlabel('Model')
axs[1].set_title('SHIP (N=285)')

for ax in axs:
    ax.set_axisbelow(True)
    ax.yaxis.grid(True, alpha=0.5)

fig.tight_layout()
    
plt.savefig('../figures/auc_bars.202506.svg', bbox_inches='tight')

## ROCs

In [None]:
cs = sns.color_palette('bright', 4)
cs_p = sns.color_palette('pastel', 4)

fig, axs = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)

ax = axs[0]

p_thresh = np.linspace(0, 1, 100)

y_true = ukb_results['benchmarks']['lvh_true']
preds = ukb_results['benchmarks']['lvh_proba']
senss = np.array([tpr_score(y_true, preds >= t)[0] for t in p_thresh])
specs = np.array([tnr_score(y_true, preds >= t)[0] for t in p_thresh])
ax.plot(1-specs, senss, label='SVM', c=cs_p[0])

y_true = ukb_results['iLVM']['lvh_true']
preds = ukb_results['iLVM']['lvh_proba_lr']
senss = np.array([tpr_score(y_true, preds >= t)[0] for t in p_thresh])
specs = np.array([tnr_score(y_true, preds >= t)[0] for t in p_thresh])
ax.plot(1-specs, senss, label=r'FCN$_{iLVM}$ + LR', c=cs_p[1])

y_true = ukb_results['LVH']['lvh_true']
preds = ukb_results['LVH']['lvh_proba']
senss = np.array([tpr_score(y_true, preds >= t)[0] for t in p_thresh])
specs = np.array([tnr_score(y_true, preds >= t)[0] for t in p_thresh])
ax.plot(1-specs, senss, label=r'FCN$_{LVH}$', c=cs_p[2])

specs = [
    1 - ukb_df[ukb_df.model == 'SVM LVH'].iloc[0].spec,
    1 - ukb_df[ukb_df.model == 'FCN iLVM + LR'].iloc[0].spec,
    1 - ukb_df[ukb_df.model == 'FCN LVH'].iloc[0].spec,
]
sens = [
    ukb_df[ukb_df.model == 'SVM LVH'].iloc[0].sens,
    ukb_df[ukb_df.model == 'FCN iLVM + LR'].iloc[0].sens,
    ukb_df[ukb_df.model == 'FCN LVH'].iloc[0].sens,
]

ax.scatter(specs, sens, c=[cs[0], cs[1], cs[2]], marker='x', zorder=100)

ax = axs[1]

p_thresh = np.insert(p_thresh, len(p_thresh), ship_df[ship_df.model == 'FCN LVH'].iloc[0].thresh)
p_thresh = np.sort(p_thresh)
p_thresh = np.insert(p_thresh, len(p_thresh), ship_df[ship_df.model == 'FCN iLVM + LR'].iloc[0].thresh)
p_thresh = np.sort(p_thresh)

y_true = ship_results['iLVM']['lvh_true']
preds = ship_results['iLVM']['lvh_proba_lr']
senss = np.array([tpr_score(y_true, preds >= t)[0] for t in p_thresh])
specs = np.array([tnr_score(y_true, preds >= t)[0] for t in p_thresh])
ax.plot(1-specs, senss, label=r'FCN$_{iLVM}$ + LR', c=cs_p[1])

y_true = ship_results['LVH']['lvh_true']
preds = ship_results['LVH']['lvh_proba']
senss = np.array([tpr_score(y_true, preds >= t)[0] for t in p_thresh])
specs = np.array([tnr_score(y_true, preds >= t)[0] for t in p_thresh])
ax.plot(1-specs, senss, label=r'FCN$_{LVH}$', c=cs_p[2])

specs = [
    1 - ship_df[ship_df.model == 'FCN iLVM + LR'].iloc[0].spec,
    1 - ship_df[ship_df.model == 'FCN LVH'].iloc[0].spec,
]
sens = [
    ship_df[ship_df.model == 'FCN iLVM + LR'].iloc[0].sens,
    ship_df[ship_df.model == 'FCN LVH'].iloc[0].sens,
]

ax.scatter(specs, sens, c=[cs[1], cs[2]], marker='x', zorder=100)

axs[0].set_title('UK Biobank (N=7,326)', fontsize=10)
axs[1].set_title('SHIP (N=285)', fontsize=10)

axs[0].set_ylabel('Sensitivity', fontsize=10)

for ax in axs:
    ax.plot([0, 1], [0, 1], ls='--', c='darkgray')
    ax.set_xlabel('1 - Specificity', fontsize=10)
    #ax.grid(ls='--', alpha=0.5)
    ax.set_xlim([-0.005, 1.005])
    ax.set_ylim([-0.005, 1.005])
    ax.legend(fontsize=10, loc='lower right')
    
plt.savefig('../figures/rocs.202506.svg', bbox_inches='tight')

## LR thresholding

In [None]:
with open('../UKB_iLVM.202503.LR.pkl', 'rb') as file:
    fcn_ilvm_lr = lr = pickle.load(file)

In [None]:
fcn_ilvm_lr

In [None]:
df_fig = ukb_results['iLVM']
is_m = df_fig['is_m'].values == 1
is_f = df_fig['is_m'].values == 0
is_lvh_m = is_m & df_fig['lvh_true']
is_lvh_f = is_f & df_fig['lvh_true']
y_true = df_fig['ilvm_true']
y_pred = df_fig['ilvm_pred']

In [None]:
y_pred_lvh = (df_fig['lvh_proba_lr'] > ukb_df[ukb_df.model == 'FCN iLVM + LR'].iloc[0].thresh).values

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)

cs = sns.color_palette('Blues', 3)
cs_l = sns.color_palette('pastel', 2)
cs_d = sns.color_palette('bright', 2)

Xs_test = np.linspace(30, 150, 300)
Xs_test_m = np.vstack([np.ones(Xs_test.shape), Xs_test]).T
Xs_test_f = np.vstack([np.zeros(Xs_test.shape), Xs_test]).T

clf_loss_f = fcn_ilvm_lr.predict_proba(Xs_test_f)[:, 1]
clf_loss_m = fcn_ilvm_lr.predict_proba(Xs_test_m)[:, 1]

clf_bound_m = Xs_test[np.abs(clf_loss_m-0.5).argmin()]
clf_bound_f = Xs_test[np.abs(clf_loss_f-0.5).argmin()]

axs[0].plot(Xs_test, clf_loss_f, label="LR fit (F)", color=cs_l[0], alpha=0.8, linewidth=4)
axs[0].plot(Xs_test, clf_loss_m, label="LR fit (M)", color=cs_l[1], alpha=0.8, linewidth=4)
axs[0].axvline(clf_bound_f, color=cs_d[0], ls='--', label="F, LVH cut-off")
axs[0].axvline(clf_bound_m, color=cs_d[1], ls='--', label="M, LVH cut-off")
axs[0].set_xlim([30, 80])
axs[0].set_ylim([-.1, 1.1])
axs[0].set_yticks([0,1])
axs[0].set_xlabel(r'Predicted indexed LV mass (g/m$^2$)', fontsize=10)
axs[0].set_ylabel('LVH classification', fontsize=10)
axs[0].legend(loc="best", fontsize=10)

axs[1].scatter(y_pred, y_true, s=7, c='lightgrey', alpha=.25)
sns.regplot(x=y_pred, y=y_true, scatter=False, ax=axs[1], color='grey', line_kws={'lw':2})
axs[1].scatter(
    y_pred[is_f & y_pred_lvh],
    y_true[is_f & y_pred_lvh],
    s=7, c=cs_l[0], alpha=.9, label='F, predicted LVH'
)
axs[1].scatter(
    y_pred[is_m & y_pred_lvh],
    y_true[is_m & y_pred_lvh],
    s=7, c=cs_l[1], alpha=.9, label='M, predicted LVH'
)

axs[1].axvline(clf_bound_f, color=cs_d[0], ls='--', label="F, LVH cut-off")
axs[1].axvline(clf_bound_m, color=cs_d[1], ls='--', label="M, LVH cut-off")

axs[1].set_xlabel(r'Predicted indexed LV mass (g/m$^2$)', fontsize=10)
axs[1].set_ylabel(r'CMR-derived indexed LV mass (g/m$^2$)', fontsize=10)
axs[1].legend(loc="best", fontsize=10)
    
plt.savefig('../figures/lr.202506.svg', bbox_inches='tight')

In [None]:
clf_bound_f, clf_bound_m

# SHAP UKB

In [None]:
to_shap_input = lambda batches: [torch.concatenate([b[i] for b in batches]) for i in (0, 1)]

In [None]:
df_fig = ukb_results['iLVM']
is_m = df_fig['is_m'].values == 1
is_f = df_fig['is_m'].values == 0
is_lvh_m = is_m & df_fig['lvh_true']
is_lvh_f = is_f & df_fig['lvh_true']
y_true = df_fig['ilvm_true']
y_pred = df_fig['ilvm_pred']

In [None]:
p5_lvm_m, p95_lvm_m = np.percentile(y_pred[is_m], 5), np.percentile(y_pred[is_m], 95)
p5_lvm_f, p95_lvm_f = np.percentile(y_pred[is_f], 5), np.percentile(y_pred[is_f], 95)
low_mask = (is_m & (y_pred <= p5_lvm_m)) | (is_f & (y_pred <= p5_lvm_f))
high_mask = (is_m & (y_pred >= p95_lvm_m)) | (is_f & (y_pred >= p95_lvm_f))
low_high_mask = low_mask | high_mask
all_indices = np.arange(0, len(y_pred))
sample_inds = all_indices[low_high_mask]
background_inds = np.random.default_rng(seed=SEED).choice(all_indices[~low_high_mask], size=1000, replace=False)
assert not set(sample_inds).intersection(set(background_inds))

In [None]:
shap_background = []
shap_test = []
bb, tt = 0, 0
test_iis = []
for i, (X_batch, X_meta_batch, y_batch, inds) in enumerate(test_loader):
    control_inds = np.argwhere(np.isin(inds, background_inds)).flatten()
    test_inds = np.argwhere(np.isin(inds, sample_inds)).flatten()
    if control_inds.shape[0] != 0:
        bb += control_inds.shape[0]
        shap_background.append([np.squeeze(X_batch[control_inds], axis=1), np.squeeze(X_meta_batch[control_inds], axis=1)])
    if test_inds.shape[0] != 0:
        tt += test_inds.shape[0]
        test_iis.extend(inds[test_inds.tolist()].numpy())
        shap_test.append([np.squeeze(X_batch[test_inds], axis=1), np.squeeze(X_meta_batch[test_inds], axis=1)]) 
shap_background = to_shap_input(shap_background)
shap_test = to_shap_input(shap_test)

In [None]:
e = shap.GradientExplainer(model, shap_background)
shap_values = e.shap_values(shap_test)

## SHAP - Clinical variables

In [None]:
meta_features = test_loader.dataset._df.columns.values[-16:]
meta_features

In [None]:
meta_features_tidy = [
    'Hypertension',
    'Sex (M=1)',
    'Diabetes',
    'Hypercholesterolaemia',
    'Alcohol intake',
    'Ethnicity (W.EU.=0, Other=1)',
    'Smoking (Never=1)',
    'Smoking (Previous=1)',
    'Smoking (Current=1)',
    'Avg. Systolic BP',
    'Avg. Diastolic BP',
    'Age',
    'BMI',
    'Non-HDL cholesterol',
    'Total cholesterol',
    'Ventricular rate',
]
    

In [None]:
shap_values[1].shape, shap_test[1].shape, meta_features_tidy[:8]

In [None]:
shap.summary_plot(
    shap_values[1].squeeze(), shap_test[1],
    feature_names=meta_features_tidy,
    plot_type="dot", show=False,
    plot_size=[9, 3.5], use_log_scale=False,
    max_display=8,
)
ax = plt.gca()
ax.set_xlabel("SHAP value (signed impact on indexed LVM)", fontsize=10)
ax.tick_params(axis='both', which='major', labelsize=10)
cb_ax = plt.gcf().axes[1] 
cb_ax.tick_params(labelsize=9)
cb_ax.set_ylabel("Feature value", fontsize=11)  

plt.savefig('../figures/shap_clin.small.svg', bbox_inches='tight')

## SHAP - ECG

In [None]:
from scipy.signal import savgol_filter
from scipy.ndimage import gaussian_filter1d
from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize, PowerNorm, SymLogNorm

In [None]:
cs_ext = sns.color_palette("coolwarm", as_cmap=True)
cs_ext

In [None]:
def smooth_savgol(xs, window_length=5, polyorder=2):
    return savgol_filter(xs, window_length=window_length, polyorder=polyorder)
def smooth_gauss(xs, sigma=2):
    return gaussian_filter1d(xs, sigma=sigma)
def smooth_moving_average(xs, window_size=5):
    return np.convolve(xs, np.ones(window_size)/window_size, mode='same')

In [None]:
shap_values[0][1].shape

In [None]:
def _plot_example(ind, shap_vals, ecgs, title="", norm_range=(-0.2, 0.2), cs=cs_ext, sigma=4, ylims=None):
    shap_values_eg = shap_vals[ind]
    ecg_eg = ecgs[ind]
    
    fig, axs_col = plt.subplots(N_CHANNELS, 1, figsize=(6, 12))
    norm = Normalize(*norm_range, clip=True)
    formatter = mticker.FormatStrFormatter('%.1f')  # 2 decimal places
    for i in range(N_CHANNELS):
        xs = list(range(N_STEPS))

        points = np.array([xs, ecg_eg[i]]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        agg_channel = smooth_gauss(shap_values_eg[i].reshape(-1), sigma=sigma)
        if norm_range:
            lc = LineCollection(segments, cmap=cs, norm=norm)
        else:
            lc = LineCollection(segments, cmap=cs)
        lc.set_array(agg_channel) 
        lc.set_linewidth(3)

        line = axs_col[i].add_collection(lc)
        axs_col[i].autoscale()

        if i == 0:
            axs_col[i].set_title(title, fontsize=11, pad=10)

        #axs_col[i].set_ylabel(f"Lead {LEADS[i]}", fontsize=11, rotation=0, labelpad=30)
        axs_col[i].set_ylabel('Amplitude (mV)', fontsize=9)
        axs_col[i].yaxis.set_major_formatter(formatter)
        
        if ylims:
            axs_col[i].set_ylim(*ylims)
        axs_col[i].set_xlim(0, 400)
       
        axs_col[i].set_xticks([0, 100, 200, 300, 400], [0, 200, 400, 600, 800])


    axs_col[-1].set_xlabel('Time (ms)')
    
    cbar_ax = fig.add_axes([1.025, 0.725, 0.025, 0.2])
    fig.colorbar(line, cax=cbar_ax, label='Integrated gradient (approx. SHAP)')
    
    fig.tight_layout()

    for i, label in enumerate(LEADS):
        y_mid = (axs_col[i].get_position().ymin + axs_col[i].get_position().ymax)/2
        x_left = axs_col[i].get_position().xmin
        fig.text(x_left  - 0.12, y_mid, f"Lead {label}", va='center', rotation='vertical', fontsize=11)

    #plt.savefig('figures/ilvm+ilvedv.shap_ecg.val.example.low.svg', bbox_inches='tight')

In [None]:
shap_lvm_df = pd.DataFrame([{'feid_ind': i, 'indexed.LVM': test_loader.dataset.get_record(i)['indexed.LVM']} for i in sample_inds])

In [None]:
shap_lvm_df = shap_lvm_df.sort_values(by='indexed.LVM', ascending=False)

In [None]:
shap_lvm_df.head()

In [None]:
shap_lvm_df.tail()

In [None]:
np.argwhere(sample_inds == 1882), np.argwhere(sample_inds == 6500)

In [None]:
_plot_example(213, shap_values[0], shap_test[0], title=r"Low predicted indexed LVM, 25.81g/m$^2$", norm_range=(-.3, .3), sigma=4)# ylims=(-1.0, 1.0))

In [None]:
_plot_example(659, shap_values[0], shap_test[0], title=r"High predicted indexed LVM, 100.51g/m$^2$", norm_range=(-.3, .3), sigma=4)# ylims=(-1.0, 1.0))

In [None]:
def _plot_example_multi(axs_col, ind, shap_vals, ecgs, title="", norm_range=(-0.2, 0.2), cs=cs_ext, sigma=4, ylabel=None, ylims=None):
    shap_values_eg = shap_vals[ind]
    ecg_eg = ecgs[ind]
    
    norm = Normalize(*norm_range, clip=True)
    formatter = mticker.FormatStrFormatter('%.1f')  # 2 decimal places
    for i in range(N_CHANNELS):
        xs = list(range(N_STEPS))

        points = np.array([xs, ecg_eg[i]]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        agg_channel = smooth_gauss(shap_values_eg[i].reshape(-1), sigma=sigma)
        if norm_range:
            lc = LineCollection(segments, cmap=cs, norm=norm)
        else:
            lc = LineCollection(segments, cmap=cs)
        lc.set_array(agg_channel) 
        lc.set_linewidth(2)

        line = axs_col[i].add_collection(lc)
        axs_col[i].autoscale()

        if i == 0:
            axs_col[i].set_title(title, pad=10) #fontsize=11,

        #axs_col[i].set_ylabel(f"Lead {LEADS[i]}", fontsize=11, rotation=0, labelpad=30)
        if ylabel:
            axs_col[i].set_ylabel(ylabel, fontsize=9)
        axs_col[i].yaxis.set_major_formatter(formatter)
        
        if ylims:
            axs_col[i].set_ylim(*ylims)
        axs_col[i].set_xlim(0, 400)
       
        axs_col[i].set_xticks([0, 100, 200, 300, 400], [0, 200, 400, 600, 800])

    axs_col[-1].set_xlabel('Time (ms)')
    return line

In [None]:
fig, axs = plt.subplots(N_CHANNELS, 2, figsize=(7, 10))

line1 = _plot_example_multi([ax[0] for ax in axs], 213, shap_values[0], shap_test[0], title=r"Low predicted indexed LVM, 25.81g/m$^2$",
                   norm_range=(-.3, .3), sigma=4, ylabel='Amplitude (mV)')
line2 = _plot_example_multi([ax[1] for ax in axs], 659, shap_values[0], shap_test[0], title=r"High predicted indexed LVM, 100.51g/m$^2$",
                   norm_range=(-.3, .3), sigma=4)

ymins = np.vstack([
    shap_test[0][213].numpy().min(axis=1),
    shap_test[0][659].numpy().min(axis=1)
]).min(axis=0)
ymaxs = np.vstack([
    shap_test[0][213].numpy().max(axis=1),
    shap_test[0][659].numpy().max(axis=1)
]).max(axis=0)

for ax_row, (min_y, max_y) in zip(axs, np.vstack([ymins, ymaxs]).T):
    for ax in ax_row:
        ax.set_ylim(min_y - 0.1, max_y + 0.1)

cbar_ax = fig.add_axes([1.025, 0.725, 0.025, 0.2])
fig.colorbar(line2, cax=cbar_ax, label='Integrated gradient (approx. SHAP)')

fig.tight_layout()

axs_col = [ax[0] for ax in axs]
for i, label in enumerate(LEADS):
    y_mid = (axs_col[i].get_position().ymin + axs_col[i].get_position().ymax)/2
    x_left = axs_col[i].get_position().xmin
    fig.text(x_left  - 0.12, y_mid, f"Lead {label}", va='center', rotation='vertical', fontsize=11)

plt.savefig('../figures/shap_ecg.svg', bbox_inches='tight')

## Averaged ECGs

In [None]:
low_ecg_df = shap_lvm_df[shap_lvm_df['feid_ind'].isin(np.argwhere(low_mask).flatten())]
high_ecg_df = shap_lvm_df[shap_lvm_df['feid_ind'].isin(np.argwhere(high_mask).flatten())]

In [None]:
low_inds = np.argwhere([s in low_ecg_df['feid_ind'].values for s in sample_inds]).flatten()
low_ecgs = shap_test[0][low_inds].reshape(-1, 8, 400).numpy()

high_inds = np.argwhere([s in high_ecg_df['feid_ind'].values for s in sample_inds]).flatten()
high_ecgs = shap_test[0][high_inds].reshape(-1, 8, 400).numpy()

assert not set(low_inds).intersection(set(high_inds))

In [None]:
cs_hl = sns.color_palette("Set1", 2)
cs_hl

In [None]:
n_samples, n_leads, n_steps = high_ecgs.shape
time = np.arange(n_steps)

fig, axs = plt.subplots(n_leads, 2, figsize=(7, 10), sharex=True)

formatter = mticker.FormatStrFormatter('%.1f')

for i in range(n_leads):
    mean_ecg_l = np.mean(low_ecgs[:, i, :], axis=0)
    std_ecg_l = np.std(low_ecgs[:, i, :], axis=0)
    
    axs[i][0].plot(time, mean_ecg_l, color=cs_hl[1], label='Mean ECG')
    axs[i][0].fill_between(time, mean_ecg_l - std_ecg_l, mean_ecg_l + std_ecg_l, color=cs_hl[1], alpha=0.2, label='±1 SD')

    mean_ecg_u = np.mean(high_ecgs[:, i, :], axis=0)
    std_ecg_u = np.std(high_ecgs[:, i, :], axis=0)
    
    axs[i][1].plot(time, mean_ecg_u, color=cs_hl[0], label='Mean ECG')
    axs[i][1].fill_between(time, mean_ecg_u - std_ecg_u, mean_ecg_u + std_ecg_u, color=cs_hl[0], alpha=0.2, label='±1 SD')


    min_y = min(min(mean_ecg_l - std_ecg_l), min(mean_ecg_u - std_ecg_u))
    max_y = max(max(mean_ecg_l + std_ecg_l), max(mean_ecg_u + std_ecg_u))

    for ax in axs[i]:
        ax.set_xlim([0, 400])
        ax.set_ylim([min_y - 0.1, max_y + 0.1])
        ax.axhline(0, c='gray', ls='--', alpha=0.9)
    
    axs[i][0].set_ylabel('Amplitude (mV)', fontsize=9)
    axs[i][0].yaxis.set_major_formatter(formatter)
    axs[i][1].yaxis.set_major_formatter(formatter)

    axs[i][0].set_xticks([0, 100, 200, 300, 400], [0, 200, 400, 600, 800])

axs[-1][0].set_xlabel("Time (ms)")
axs[-1][1].set_xlabel("Time (ms)")
plt.tight_layout()

for i, label in enumerate(LEADS):
    y_mid = (axs[i][0].get_position().ymin + axs[i][0].get_position().ymax)/2
    x_left = axs[i][0].get_position().xmin
    fig.text(x_left  - 0.1, y_mid, f"Lead {label}", va='center', rotation='vertical', fontsize=11)

axs[0][0].set_title('Low predicted indexed LVM')
axs[0][1].set_title('High predicted indexed LVM')

plt.savefig('../figures/mean_ecgs.svg', bbox_inches='tight')