In [5]:
import plotly.express as px
import pandas as pd

history_res = pd.read_csv("./history_resnet.csv")
history_vit = pd.read_csv("./history_vit.csv")

history_res["model"] = "ResNet"
history_vit["model"] = "ViT"

history_all = pd.concat([history_res, history_vit])

In [10]:
loss_df = history_all.melt(
    id_vars=["epoch", "model"],
    value_vars=["train_loss", "test_loss"],
    var_name="dataset",
    value_name="loss"
)

fig_loss = px.line(
    loss_df,
    x="epoch",
    y="loss",
    color="model",
    line_dash="dataset",  # dashed = test, solid = train
    title="Loss Comparison (ResNet vs ViT)",
    markers=True,
    color_discrete_map={"ResNet": "blue", "ViT": "red"}
)
fig_loss.update_layout(
    template="plotly_white",
    legend_title_text="Model / Dataset",
    yaxis_title="Loss",
    xaxis_title="Epoch",
)
fig_loss.show()

# --- ACCURACY COMPARISON (train + test) ---
acc_df = history_all.melt(
    id_vars=["epoch", "model"],
    value_vars=["train_acc", "test_acc"],
    var_name="dataset",
    value_name="accuracy"
)

fig_acc = px.line(
    acc_df,
    x="epoch",
    y="accuracy",
    color="model",
    line_dash="dataset",
    title="Accuracy Comparison (ResNet vs ViT)",
    markers=True,
    color_discrete_map={"ResNet": "blue", "ViT": "red"}
)
fig_acc.update_layout(
    template="plotly_white",
    legend_title_text="Model / Dataset",
    yaxis_title="Accuracy",
    xaxis_title="Epoch",
)
fig_acc.update_yaxes(range=[0, 1])
fig_acc.show()