In [40]:
import pandas as pd
import os

In [1]:
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# -------------------- DEFINE CUSTOM TEMPLATE -------------------- #
pio.templates['draft'] = go.layout.Template(layout=dict(
    margin=dict(l=50, r=50, b=50, t=50),
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
)
))
pio.templates.default = "plotly+draft"

### Fragment Results

In [44]:
fragment_result = pd.read_csv('./Pos_training/fragment_results.csv')
fragment_result.head()

Unnamed: 0,ssl_model,head_model,calibration,scenario,dataset,avg_loss,acc,acc_std,f1,f1_std,recall,recall_std,precision,precision_std
0,TSTCC,MLP,linearprobing,one_domain,Fragment,0.5509,0.5317,0.0063,0.4072,0.0076,0.5509,0.008,1.1092,0.0158
1,TSTCC,MLP,linearprobing,four_domains,Fragment,0.5043,0.4837,0.0066,0.3457,0.0084,0.5043,0.0051,1.0121,0.03
2,TSTCC,FCN,finetuning,four_domains,Fragment,0.281,0.168,0.1165,0.1032,0.0564,0.281,0.1058,1.8616,0.1345
3,TSTCC,FCN,linearprobing,one_domain,Fragment,0.1951,0.1054,0.0496,0.0825,0.0337,0.1951,0.057,1.6465,0.0328
4,TSTCC,FCN,linearprobing,four_domains,Fragment,0.1988,0.0939,0.0192,0.1003,0.0306,0.1988,0.0192,1.6779,0.0777


In [42]:
fig = px.bar(fragment_result, x='ssl_model', y='f1', color='scenario', 
             facet_row='calibration', facet_col='head_model', barmode='group',
             labels={'ssl_model': 'SSL Model', 'acc': 'Accuracy', 'scenario': 'Scenario'},)

fig.update_yaxes(range=[0, 1])
fig.update_layout(height=640, width=1000, font_size=16, 
                  legend_title_text="",
                  margin=dict(l=40, r=40, b=30, t=80),
                  showlegend=True,
                  legend=dict(orientation="h", yanchor="bottom", y=1.07, xanchor="center", x=0.5)
)

### IEEEPPG Results

In [43]:
ieeeppg_result = pd.read_csv('./Pos_training/ieeeppg_results.csv')
ieeeppg_result.head()

Unnamed: 0,ssl_model,head_model,calibration,scenario,dataset,avg_loss,mse,mse_std,r2_score,r2_score_std,rmse,rmse_std,mae,mae_std
0,TSTCC,MLP,linearprobing,one_domain,IEEEPPG,517.3157,0.0101,0.0043,22.7445,0.0498,18.398,0.1768,529.408,1.4312
1,TSTCC,MLP,linearprobing,four_domains,IEEEPPG,518.3404,0.0082,0.003,22.7671,0.0346,18.4338,0.1247,529.0043,1.1918
2,TSTCC,MLP,finetuning,one_domain,IEEEPPG,535.3663,-0.0244,0.0264,23.1364,0.2993,19.1758,0.5177,543.2826,11.4943
3,TSTCC,MLP,finetuning,four_domains,IEEEPPG,604.888,-0.1574,0.0264,24.5932,0.2804,21.0037,0.3759,591.9205,15.9592
4,TSTCC,FCN,linearprobing,one_domain,IEEEPPG,1190.1912,-1.2774,0.231,34.4638,1.7447,30.7696,1.6269,1177.3956,119.0651


In [37]:
fig = px.bar(ieeeppg_result, x='ssl_model', y='rmse', color='scenario', 
             facet_row='calibration', facet_col='head_model', barmode='group')

fig.update_layout(height=640, width=1000, font_size=16, 
                  legend_title_text="",
                  margin=dict(l=40, r=40, b=30, t=80),
                  showlegend=True,
                  legend=dict(orientation="h", yanchor="bottom", y=1.07, xanchor="center", x=0.5)
)