In [None]:
import pandas as pd

import plotly.express as px
import plotly.io as pio
pio.templates.default = "simple_white"

from aim import Repo

In [None]:
repo = Repo.from_path('../.aim/')

In [None]:
metrics = repo.query_metrics(
    query="run.hparams.training.checkpoint in ['data/06_models/model_dmd_v2.pt', 'data/06_models/model_cine_v8_simtag_v1_dmd_v1.pt', 'data/06_models/model_cine_v8_tag_v1_dmd_v2.pt', 'data/06_models/model_cine_v8_dmd_v0.pt', \
            'data/06_models/model_dmd_v1.pt', 'data/06_models/model_cine_v4_dmd_v0.pt', 'data/06_models/model_cine_v6_simtag_v1_dmd_v1.pt', 'data/06_models/model_cine_v6_tag_v1_dmd_v2.pt']"
)

In [None]:
df = metrics.dataframe(include_run=True)

In [None]:
def parse_strategy(checkpoint_path: str):
    if 'simtag' in checkpoint_path:
        return 'Physics-driven'
    elif 'tag' in checkpoint_path:
        return 'CycleGAN'
    elif 'cine' in checkpoint_path:
        return 'Cine'
    else:
        return 'Scratch'

In [None]:
df = df.astype({'epoch': int}, errors='ignore')

In [None]:
df['strategy'] = df['run.hparams.training.checkpoint'].apply(parse_strategy)

In [None]:
df.rename(columns={
    'metric.name': 'metric',
    'run.hparams.training.model_type': 'architecture',
    'metric.context.subset': 'split',
}, inplace=True)

In [None]:
df['architecture'].replace('DynUNet', 'nnUnet', inplace=True)
df['architecture'].replace('SegResNetVAE', 'ResNetVAE', inplace=True)

In [None]:
colors = ['#023047', '#219EBC', '#FB8500', '#C44536']

sort_kwargs = dict(
    by=['split', 'strategy', 'architecture', 'metric', 'epoch'], 
    ascending=[True, False, False, True, False]
)

fig = px.line(
    df[(df['metric'].isin(['dice', 'loss']))].sort_values(**sort_kwargs),
    x='epoch', y='value', facet_row='metric', color='strategy', facet_col='architecture', line_dash='split',
    color_discrete_sequence=colors, facet_col_spacing=0.03
)

fig.update_yaxes(matches=None)
top_h_legend = dict(orientation='h', yanchor="bottom", y=1.1)
fig.update_layout(legend=top_h_legend)

fig.update_xaxes(range=[0, 150], dtick=25)

fig.update_yaxes(title_text='DSC (↑)', range=[0, 1], dtick=.1, row=2, col=1)
fig.update_yaxes(range=[0, 1], row=2, col=2)

fig.update_yaxes(title_text='Loss (↓)', range=[-3.5, -1.6], type="log", row=1, col=1)
fig.update_yaxes(range=[-3.5, -1.6], type="log", row=1, col=2)

fig.update_layout(height=800 / 1.62, width=800)
fig.show()

In [None]:
fig.write_image("../../figures/training-convergence.pdf")