In [1]:
import pandas as pd 
import numpy as np 
from dash_website import MAIN_CATEGORIES_TO_CATEGORIES, RENAME_DIMENSIONS, ALGORITHMS_RENDERING

data_scores = pd.read_feather("../../data/xwas/multivariate_results/scores.feather")
algorithm = "elastic_net"
main_category = "All"

In [16]:
from dash_website.utils.graphs.colorscale import get_colorscale
import plotly.graph_objs as go

if algorithm == "best_algorithm":
    every_score = (
        pd.DataFrame(data_scores)
        .groupby(by=["category", "dimension"])
        .apply(
            lambda score_category_dimension: score_category_dimension.iloc[score_category_dimension["r2"].argmax()]
        )
        .reset_index(drop=True)
    )
else:
    every_score = pd.DataFrame(data_scores).set_index("algorithm").loc[algorithm].reset_index()
scores = every_score.loc[every_score["category"].isin(MAIN_CATEGORIES_TO_CATEGORIES[main_category])]

r2_2d = pd.pivot(scores, index="category", columns="dimension", values="r2").rename(columns=RENAME_DIMENSIONS)
r2_2d["average"] = r2_2d.T.mean()
r2_2d.loc["average"] = r2_2d.mean()
r2_2d = r2_2d.reindex(index=np.roll(r2_2d.index, 1), columns=np.roll(r2_2d.columns, 1))

std_2d = pd.pivot(scores, index="category", columns="dimension", values="std")
std_2d["average"] = r2_2d.T.std()
std_2d.loc["average"] = r2_2d.std()
std_2d = std_2d.reindex(index=np.roll(std_2d.index, 1), columns=np.roll(std_2d.columns, 1))

sample_size_2d = pd.pivot(scores, index="category", columns="dimension", values="sample_size")
sample_size_2d["average"] = sample_size_2d.T.sum()
sample_size_2d.loc["average"] = sample_size_2d.sum()
sample_size_2d = sample_size_2d.reindex(index=np.roll(sample_size_2d.index, 1), columns=np.roll(sample_size_2d.columns, 1))

algorithm_2d = pd.pivot(scores, index="category", columns="dimension", values="algorithm").replace(
    ALGORITHMS_RENDERING
)
algorithm_2d["average"] = "No algorithm"
algorithm_2d.loc["average"] = "No algorithm"
algorithm_2d = algorithm_2d.reindex(index=np.roll(algorithm_2d.index, 1), columns=np.roll(algorithm_2d.columns, 1))

customdata = np.dstack((std_2d, sample_size_2d, algorithm_2d))

hovertemplate = "Aging dimension: %{x} <br>X subcategory: %{y} <br>r²: %{z:.3f} <br>Standard deviation: %{customdata[0]:.3f} <br>Sample size: %{customdata[1]} <br>Algorithm: %{customdata[2]} <br><extra></extra>"

heatmap = go.Heatmap(
    x=r2_2d.columns,
    y=r2_2d.index,
    z=r2_2d,
    colorscale=get_colorscale(r2_2d),
    customdata=customdata,
    hovertemplate=hovertemplate,
)

fig = go.Figure(heatmap)

fig.update_layout(
    {
        "width": 1000,
        "height": int(1000 * max(1, r2_2d.shape[0] / r2_2d.shape[1])),
        "xaxis": {"title": "Aging dimension", "tickangle": 90, "showgrid": False},
        "yaxis": {"title": "X subcategory", "showgrid": False},
    }
)

In [7]:
r2_2d["average"] = r2_2d.T.mean()
r2_2d.loc["average"] =r2_2d.mean()

In [8]:
r2_2d

dimension,Abdomen,AbdomenLiver,AbdomenPancreas,Arterial,ArterialCarotids,ArterialPulseWaveAnalysis,Biochemistry,BiochemistryBlood,BiochemistryUrine,BloodCells,...,MusculoskeletalHips,MusculoskeletalKnees,MusculoskeletalScalars,MusculoskeletalSpine,PhysicalActivity,*,*instances01,*instances1.5x,*instances23,average
category,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Alcohol,0.015810,0.015784,0.008746,0.023028,0.019950,0.055598,0.040478,0.031560,0.218659,0.129791,...,0.016795,0.009108,0.062248,0.009951,0.051058,0.047176,0.026629,0.050885,0.031379,0.037235
Anthropometry,0.021865,0.019217,0.016015,0.024227,0.009548,0.101672,0.048326,0.047549,0.225601,0.134395,...,0.033926,0.051292,0.412845,0.018600,0.058358,0.053265,0.049909,0.058050,0.034012,0.062268
ArterialStiffness,0.011503,0.010805,0.006528,0.058374,0.013561,0.214711,0.049120,0.028162,0.277007,0.161318,...,0.014313,0.009109,0.070681,0.007416,0.062048,0.054198,0.035150,0.061920,0.028233,0.046186
BloodCount,0.021468,0.018266,0.013907,0.015727,0.009776,0.066681,0.056816,0.045099,0.216962,0.435308,...,0.016381,0.013732,0.066580,0.012363,0.051689,0.048623,0.032853,0.051353,0.017232,0.049075
BloodPressure,0.012589,0.013663,0.006291,0.128262,0.054220,0.122564,0.037794,0.023890,0.221641,0.127281,...,0.015461,0.010018,0.058588,0.008992,0.051646,0.048002,0.050453,0.051518,0.045617,0.045403
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
medical_diagnoses_W,0.009780,0.010290,0.004793,0.005108,0.004213,0.049563,0.038237,0.024489,0.217265,0.127603,...,0.013323,0.008317,0.057223,0.006900,0.044472,0.041213,0.014074,0.044525,0.017786,0.031223
medical_diagnoses_X,0.009998,0.010372,0.005568,0.005763,0.004933,0.049516,0.038024,0.024184,0.217119,0.127679,...,0.013292,0.008923,0.055114,0.006761,0.044045,0.040740,0.016916,0.043868,0.018274,0.031328
medical_diagnoses_Y,0.009723,0.010388,0.004816,0.005497,0.004021,0.049938,0.040074,0.025784,0.218107,0.128809,...,0.012980,0.008282,0.056019,0.006673,0.046358,0.042947,0.017887,0.046218,0.018076,0.031633
medical_diagnoses_Z,0.012800,0.012065,0.007540,0.008458,0.007058,0.054444,0.052750,0.037691,0.220935,0.134447,...,0.015227,0.009940,0.061380,0.009366,0.054771,0.050407,0.019436,0.054691,0.015814,0.036079
