In [1]:
import pickle
import torch
import numpy as np


def calculate_vita_maerror(test_results_path, feature_names):
    maerror_dict = {}
    with open(test_results_path, "rb") as file:
        test_results = pickle.load(file)
    gts = test_results["gts"]
    preds = test_results["preds"]

    preds = torch.where(torch.isnan(gts), torch.full_like(preds, float('nan')), preds)
    N = gts.shape[1]
    for i in range(N):
        name = feature_names[i]
        lowercase_n = name.lower().split('-')[0]
        lowercase_n = lowercase_n.split('(')[0]
        lowercase_n = lowercase_n.replace(" ", "")
        mae = np.abs(gts[:, i] - preds[:, i])
        maerror_dict[lowercase_n] = mae
    maerror_dict["subjects_ids"] = test_results["subjects_ids"]
    return maerror_dict

def calculate_baseline_maerror(model, feature_category, target_names):
    maerror_dict = {}
    N = len(target_names)
    for i in range(N):
        target_name = target_names[i]
        short_key = target_name.lower().split('-')[0]
        if feature_category == "features":
            short_key = short_key.replace('(', "")
            short_key = short_key.replace(')', "")
        else:
            short_key = short_key.split('(')[0]
        short_key = short_key.replace(" ", "")
        test_results_path = f"results/tabular_features/{model}_{feature_category}_{short_key}.pkl"

        with open(test_results_path, "rb") as file:
            test_results = pickle.load(file)
            
        gts = test_results["gts"]
        preds = test_results["preds"]
        if i > 0:
            assert (subj_ids == test_results["subjects_ids"]).all()
        subj_ids = test_results["subjects_ids"]

        preds = torch.where(torch.isnan(gts), torch.full_like(preds, float('nan')), preds)
        mae = np.abs(gts - preds)
        maerror_dict[short_key] = mae
    maerror_dict["subjects_ids"] = test_results["subjects_ids"]
    return maerror_dict


In [2]:
# ViTa mean absolute errors for comparison
sax_all_vita = ["LVEDV (mL)", "LVESV (mL)", "LVSV (mL)", "LVEF (%)", "LVCO (L/min)", "LVM (g)", "RVEDV (mL)", "RVESV (mL)", "RVSV (mL)", "RVEF (%)"]
lax_all_vita = ["LAV max (mL)", "LAV min (mL)", "LASV (mL)", "LAEF (%)", "RAV max (mL)", "RAV min (mL)", "RASV (mL)", "RAEF (%)"]
features_all_vita = [
            "Systolic blood pressure-2.mean",
            "Diastolic blood pressure-2.mean",
            "Pulse rate-2.mean",
            "Body fat percentage-2.0",
            "Whole body water mass-2.0",
            "Body mass index (BMI)-2.0",
            "Waist circumference-2.0",
            "Height-2.0",
            "Weight-2.0",
            "Cardiac index-2.0",
            "Average heart rate-2.0",
            "Systolic brachial blood pressure during PWA-2.0",
            "End systolic pressure during PWA-2.0",
            "Stroke volume during PWA-2.0",
            "Mean arterial pressure during PWA-2.0",
            "Sleep duration-2.0"
        ]
features_age = ["Age when attended assessment centre-2.0"]

sax_all = ["LVEDV (mL)", "LVSV (mL)", "LVEF (%)", "LVCO (L/min)", "LVM (g)", "RVEDV (mL)", "RVESV (mL)", "RVSV (mL)", "RVEF (%)"]
lax_all = ["LAV max (mL)", "LAV min (mL)", "LASV (mL)", "LAEF (%)", "RAV max (mL)", "RAV min (mL)", "RASV (mL)", "RAEF (%)"]
features_all = [
            "Systolic blood pressure-2.mean",
            "Pulse rate-2.mean",
            "Body fat percentage-2.0",
            "Body mass index (BMI)-2.0",
            "Height-2.0",
            "Weight-2.0",
            "Average heart rate-2.0",
            "Systolic brachial blood pressure during PWA-2.0",
            "End systolic pressure during PWA-2.0",
            "Stroke volume during PWA-2.0",
            "Mean arterial pressure during PWA-2.0",
            "Age when attended assessment centre-2.0"
        ]

In [3]:
# Our model
vita_sax_maerror = calculate_vita_maerror("results/tabular_features/vita_phenotype_sax_all.pkl", sax_all_vita)
vita_lax_maerror = calculate_vita_maerror("results/tabular_features/vita_phenotype_lax_all.pkl", lax_all_vita)
vita_features_maerror = calculate_vita_maerror("results/tabular_features/vita_features_all.pkl", features_all_vita)
vita_age_maerror = calculate_vita_maerror("results/tabular_features/vita_features_agewhenattendedassessmentcentre.pkl", ["Age when attended assessment centre-2.0"])

assert (vita_age_maerror["subjects_ids"] == vita_features_maerror["subjects_ids"]).all()
age_key = list(vita_age_maerror.keys())[0]
vita_features_maerror[age_key] = vita_age_maerror[age_key]

# Baselines
resnet50_sax_maerror = calculate_baseline_maerror("resnet50", "phenotype_sax", sax_all)
resnet50_lax_maerror = calculate_baseline_maerror("resnet50", "phenotype_lax", lax_all)
resnet50_features_maerror = calculate_baseline_maerror("resnet50", "features", features_all)

vit_sax_maerror = calculate_baseline_maerror("vit", "phenotype_sax", sax_all)
vit_lax_maerror = calculate_baseline_maerror("vit", "phenotype_lax", lax_all)

mae_sax_maerror = calculate_baseline_maerror("mae", "phenotype_sax", sax_all)
mae_lax_maerror = calculate_baseline_maerror("mae", "phenotype_lax", lax_all)

In [34]:
# Paried t-test
from scipy.stats import ttest_rel

# SAX
print("SAX")
mae = vita_sax_maerror
baseline_mae = resnet50_sax_maerror
for k, v in mae.items():
    if k in ["subjects_ids", "lvesv"]: continue
    mask = ~v.isnan()
    x = baseline_mae[k][mask].view(-1)
    y = v[mask].view(-1)

    t_stat, p_value = ttest_rel(x, y)

    if p_value < 0.05 and p_value >= 0.01:
        print(f"{k} significant!")
    if p_value < 0.01:
        print(f"{k} highly significant!")

# LAX
print("LAX")
mae = vita_lax_maerror
baseline_mae = resnet50_lax_maerror
for k, v in mae.items():
    if k in ["subjects_ids"]: continue
    mask = ~v.isnan()
    x = baseline_mae[k][mask].view(-1)
    y = v[mask].view(-1)

    t_stat, p_value = ttest_rel(x, y)

    if p_value < 0.05 and p_value >= 0.01:
        print(f"{k} significant!")
    if p_value < 0.01:
        print(f"{k} highly significant!")

# Features
print("Features")
mae = vita_features_maerror
baseline_mae = resnet50_features_maerror
for k, v in mae.items():
    if k in ["subjects_ids", "diastolicbloodpressure", "wholebodywatermass", "waistcircumference", "cardiacindex", "sleepduration"]: continue
    mask = ~v.isnan()
    if k == "bodymassindex":
        x = baseline_mae["bodymassindexbmi"][mask].view(-1) 
    else:
        x = baseline_mae[k][mask].view(-1)
    y = v[mask].view(-1)

    t_stat, p_value = ttest_rel(x, y)

    if p_value < 0.05 and p_value >= 0.01:
        print(f"{k} significant!")
    if p_value < 0.01:
        print(f"{k} highly significant!")

SAX
lvedv highly significant!
lvef highly significant!
lvco highly significant!
rvedv highly significant!
rvesv highly significant!
rvsv highly significant!
LAX
lavmax highly significant!
lavmin highly significant!
lasv highly significant!
laef highly significant!
ravmax highly significant!
ravmin highly significant!
rasv highly significant!
raef highly significant!
Features
systolicbloodpressure highly significant!
pulserate highly significant!
bodyfatpercentage highly significant!
height highly significant!
weight highly significant!
systolicbrachialbloodpressureduringpwa highly significant!
endsystolicpressureduringpwa highly significant!
strokevolumeduringpwa highly significant!
meanarterialpressureduringpwa highly significant!
agewhenattendedassessmentcentre highly significant!
