In [None]:
import os 
import json

BASE_OUTPUT_DIR="/mnt/task_runtime/results"

models = [
    "Qwen3-VL-2B-GRPO-MRI-600-think",
    "Qwen3-VL-2B-KDPO-MRI-600-think-tau-0",
    "Qwen3-VL-2B-KDPO-MRI-600-think-tau-1",
]

CHECKPOINTS_STEP = [
    10, 20, 30, 40, 50, 60, 70, 80, 90, 100
]

CHECKPOINTS_NAME = [f"checkpoint-{step}" for step in CHECKPOINTS_STEP]

MODALITIES = ["CT", "Dermoscopy", "Fundus", "MRI", "Microscopy", "OCT", "Ultrasound", "XRay"]
METRICS = ["accuracy"]

def collect_results(model, checkpoint_name, modality, metrics=METRICS):
    model_path = os.path.join(BASE_OUTPUT_DIR, model, checkpoint_name)
    # find the json in which it contains the modality
    for file in os.listdir(model_path):
        print(file)
        print(model_path)
        full_file_path = None
        if modality in file:
            full_file_path = os.path.join(model_path, file)
            break

    if full_file_path is None:
        return {}
    # read the json file
    with open(full_file_path, "r") as f:
        data = json.load(f)

    # get the metrics
    metrics_data = {}
    for metric in metrics:
        metrics_data[metric] = data[metric]
    
    
    return metrics_data

def collect_all_results(model, checkpoint_name, modalities=MODALITIES, metrics=METRICS):
    all_results = {}
    for modality in modalities:
        results = collect_results(model, checkpoint_name, modality, metrics)
        all_results[modality] = results
    return all_results

def collect_all_results_for_all_checkpoints(model, checkpoints_name=CHECKPOINTS_NAME, modalities=MODALITIES, metrics=METRICS):
    all_results = {}
    for checkpoint_name in checkpoints_name:
        results = collect_all_results(model, checkpoint_name, modalities, metrics)
        all_results[checkpoint_name] = results
    return all_results

def collect_all_results_for_all_models(models=models, checkpoints_name=CHECKPOINTS_NAME, modalities=MODALITIES, metrics=METRICS):
    all_results = {}
    for model in models:
        results = collect_all_results_for_all_checkpoints(model, checkpoints_name, modalities, metrics)
        all_results[model] = results
    return all_results

results = collect_all_results_for_all_models(models)
print(json.dumps(results, indent=4))

In [None]:
print(results.keys())
print(results["Qwen3-VL-2B-GRPO-MRI-600-think"].keys())
print(results["Qwen3-VL-2B-GRPO-MRI-600-think"]["checkpoint-10"].keys())
print(results["Qwen3-VL-2B-GRPO-MRI-600-think"]["checkpoint-10"]["MRI"])

In [None]:
import matplotlib.pyplot as plt

def plot_metric_across_checkpoints(results, modality, metric, models=models, checkpoints_name=CHECKPOINTS_NAME):
    """
    Plots the value of a given metric for a specific modality across all checkpoints, for all models.
    
    Args:
        results (dict): Results dict with format explained above.
        modality (str): The modality to plot, e.g. "Fundus".
        metric (str): The metric to plot, e.g. "accuracy".
        models (list): List of model names.
        checkpoints_name (list): List of checkpoint names (str or int).
    """
    plt.figure(figsize=(10, 6))
    
    for model in models:
        x = []
        y = []
        for checkpoint in checkpoints_name:
            x.append(str(checkpoint))
            metric_value = None
            try:
                metric_value = results[model][checkpoint][modality][metric]
            except Exception:
                metric_value = None
            # Convert metric_value to float if possible, else np.nan
            try:
                metric_value = float(metric_value)
            except:
                metric_value = float('nan')
            y.append(metric_value)
        plt.plot(CHECKPOINTS_STEP, y, marker='o', label=model)
    
    plt.title(f"{metric} on {modality} across checkpoints")
    plt.xlabel("Checkpoint")
    plt.ylabel(metric)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Example usage:
plot_metric_across_checkpoints(results, modality='MRI', metric='accuracy')