In [None]:
import matplotlib.pyplot as plt
import pandas as pd

def plot_metrics(qswin_csv, qvit_csv, swin_csv, vit_csv, seed:int, output_accuracy_filename="accuracy_comparison.png", output_auc_filename="auc_comparison.png"):
    """
    Plots validation accuracy and validation AUC of four different models over epochs into separate files.

    Args:
        qswin_csv: Path to the CSV file for Quantum Swin Transformer.
        qvit_csv: Path to the CSV file for Quantum Vision Transformer.
        swin_csv: Path to the CSV file for Swin Transformer.
        vit_csv: Path to the CSV file for Vision Transformer.
        output_accuracy_filename: The name of the file to save the accuracy plot to.
        output_auc_filename: The name of the file to save the auc plot to.
    """

    # Load data from CSV files
    df_qswin = pd.read_csv(qswin_csv)
    df_qvit = pd.read_csv(qvit_csv)
    df_swin = pd.read_csv(swin_csv)
    df_vit = pd.read_csv(vit_csv)

    # Filter for validation accuracy and auc and rename columns for easier plotting
    df_qswin_acc = df_qswin[df_qswin['tag'] == 'val_accuracy'].rename(columns={'value': 'qswin_accuracy'})
    df_qvit_acc = df_qvit[df_qvit['tag'] == 'val_accuracy'].rename(columns={'value': 'qvit_accuracy'})
    df_swin_acc = df_swin[df_swin['tag'] == 'val_accuracy'].rename(columns={'value': 'swin_accuracy'})
    df_vit_acc = df_vit[df_vit['tag'] == 'val_accuracy'].rename(columns={'value': 'vit_accuracy'})

    df_qswin_auc = df_qswin[df_qswin['tag'] == 'val_auc'].rename(columns={'value': 'qswin_auc'})
    df_qvit_auc = df_qvit[df_qvit['tag'] == 'val_auc'].rename(columns={'value': 'qvit_auc'})
    df_swin_auc = df_swin[df_swin['tag'] == 'val_auc'].rename(columns={'value': 'swin_auc'})
    df_vit_auc = df_vit[df_vit['tag'] == 'val_auc'].rename(columns={'value': 'vit_auc'})


    # --- Plot Accuracy ---
    plt.figure(figsize=(7, 6))
    plt.plot(df_qswin_acc['step'] + 1, df_qswin_acc['qswin_accuracy'], label='Quantum Swin Transformer', color='purple', linewidth=3)
    plt.plot(df_qvit_acc['step'] + 1, df_qvit_acc['qvit_accuracy'], label='Quantum Vision Transformer', color='magenta', linewidth=3)
    plt.plot(df_swin_acc['step'] + 1, df_swin_acc['swin_accuracy'], label='Swin Transformer', color='skyblue', linewidth=3)
    plt.plot(df_vit_acc['step'] + 1, df_vit_acc['vit_accuracy'], label='Vision Transformer', color='orange', linewidth=3)

    # Add dashed lines for visual comparison (can be adjusted)
    plt.plot(df_qswin_acc['step'] + 1, df_qswin_acc['qswin_accuracy'], linestyle='--', color='purple', linewidth=1.5)
    plt.plot(df_qvit_acc['step'] + 1, df_qvit_acc['qvit_accuracy'], linestyle='--', color='magenta', linewidth=1.5)
    plt.plot(df_swin_acc['step'] + 1, df_swin_acc['swin_accuracy'], linestyle='--', color='skyblue', linewidth=1.5)
    plt.plot(df_vit_acc['step'] + 1, df_vit_acc['vit_accuracy'], linestyle='--', color='orange', linewidth=1.5)

    # Customize Accuracy Plot
    plt.title(str(seed))
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.xlim(1, 10)
    plt.ylim(0.5, 1.0)
    plt.grid(True)
    plt.legend(fontsize=11, loc='lower right')
    plt.tight_layout()
    plt.savefig(output_accuracy_filename)
    plt.close()  # Close to not have them overlap


    # --- Plot AUC ---
    plt.figure(figsize=(7, 6))
    plt.plot(df_qswin_auc['step'] + 1, df_qswin_auc['qswin_auc'], label='Quantum Swin Transformer', color='purple', linewidth=3)
    plt.plot(df_qvit_auc['step'] + 1, df_qvit_auc['qvit_auc'], label='Quantum Vision Transformer', color='magenta', linewidth=3)
    plt.plot(df_swin_auc['step'] + 1, df_swin_auc['swin_auc'], label='Swin Transformer', color='skyblue', linewidth=3)
    plt.plot(df_vit_auc['step'] + 1, df_vit_auc['vit_auc'], label='Vision Transformer', color='orange', linewidth=3)

    # Add dashed lines for visual comparison (can be adjusted)
    plt.plot(df_qswin_auc['step'] + 1, df_qswin_auc['qswin_auc'], linestyle='--', color='purple', linewidth=1.5)
    plt.plot(df_qvit_auc['step'] + 1, df_qvit_auc['qvit_auc'], linestyle='--', color='magenta', linewidth=1.5)
    plt.plot(df_swin_auc['step'] + 1, df_swin_auc['swin_auc'], linestyle='--', color='skyblue', linewidth=1.5)
    plt.plot(df_vit_auc['step'] + 1, df_vit_auc['vit_auc'], linestyle='--', color='orange', linewidth=1.5)

    # Customize AUC Plot
    plt.title(str(seed))
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('AUC', fontsize=12)
    plt.xlim(1, 10)
    plt.ylim(0.87, 1.0)
    plt.grid(True)
    plt.legend(fontsize=11, loc='lower right')
    plt.tight_layout()
    plt.savefig(output_auc_filename)
    plt.close()

# Example usage (assuming you have the CSV files)
qswin_csv = "qswin.csv"
qvit_csv = "qvit.csv"
swin_csv = "swin.csv"
vit_csv = "vit.csv"

plot_metrics(qswin_csv, qvit_csv, swin_csv, vit_csv, 3407)