In [None]:
import matplotlib.pyplot as plt
import os

from src.utils import init_notebook

init_notebook()

save_path = "data/results/diagrams/generation_plots.pdf"
legend_position = "right"   # e.g. "upper right", "lower center", "center right", etc.
legend_fontsize = 8              # You can change this
figsize = (6, 4)                # For half-page / 2-column layout


# Create directory if it doesn't exist
os.makedirs(os.path.dirname(save_path), exist_ok=True)

# Data
steps = [1, 2, 3, 4]
metrics = ["CosSim", "MRR", "MRR", "F1 Micro"]
titles = ["V1: Symptoms Extraction", "V2: Diagnoses Prediction", "V3: Lab-based Reranking", "V4: ICD Codes Prediction"]

data = {
    "Qwen3B-32B": [
        [0.680380282, 0.700690141, 0.721, 0.721],
        [0.652444444, 0.709533333, 0.725844444, 0.734],
        [0.819000106, 0.888072404, 0.907807346, 0.927542289],
        [0.46746, 0.477, 0.477, 0.477]
    ],
    "MedGemma-27B": [
        [0.656238806, 0.666492537, 0.676746269, 0.687],
        [0.560625, 0.6825, 0.706875, 0.715],
        [0.671311249, 0.789244306, 0.816459627, 0.825531401],
        [0.328930233, 0.386976744, 0.406325581, 0.416]
    ],
    "Llama-3.1-70B-Instruct": [
        [0.681493151, 0.711123288, 0.711123288, 0.721],
        [0.661333333, 0.7192, 0.735733333, 0.744],
        [0.809675889, 0.857870883, 0.87714888, 0.886787879],
        [0.296613636, 0.373159091, 0.401863636, 0.421]
    ]
}

In [None]:
# Plot
fig, axes = plt.subplots(2, 2, figsize=figsize)
fig.subplots_adjust(hspace=0.3, wspace=0.2)

# Collect y-values for plots 2 and 3 to sync their scales
y_values_2_3 = []

for idx in [1, 2]:  # subplot 2 and 3
    for model in data:
        y_values_2_3.extend(data[model][idx])

y_min = min(y_values_2_3)
y_max = max(y_values_2_3)
# Add a small margin for visual comfort
margin = (y_max - y_min) * 0.05
y_limits = (y_min - margin, y_max + margin)


for idx, ax in enumerate(axes.flat):
    metric = metrics[idx]
    title = titles[idx]

    for model in data:
        ax.plot(steps, data[model][idx], marker='o', label=model)

    # Titles and labels
    ax.set_title(title, fontsize=10)
    ax.set_ylabel(metric)
    ax.set_xticks(steps)
    ax.grid(True, linestyle='--', linewidth=0.5)

    # Apply shared y-axis for subplots 2 and 3
    if idx in [1, 2]:
        ax.set_ylim(y_limits)

    # Add legend only to second subplot (idx == 1)
    if idx == 1:
        ax.legend(loc=[0.2, 0.58], fontsize=legend_fontsize)

# Save as PDF
plt.tight_layout()
plt.savefig(save_path, format='pdf', bbox_inches="tight")
print(f"Saved plot to {save_path}")
plt.show()
