In [1]:
import pandas as pd
import numpy as np
import wandb
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display, HTML

# Hardcoded parameters for the specific sweep
entity_name = "dive-ci"
project_name = "CLadder"

# Initialize wandb API
api = wandb.Api()

def collect_data(sweep_id, dataset_names):
    # Get the sweep runs
    sweep = api.sweep(f"{entity_name}/{project_name}/{sweep_id}")
    # Collect run data
    runs_data = []
    for run in sweep.runs:
        # Get run summary and config
        run_data = {
            "id": run.id,
            "name": run.name,
            "state": run.state,
            "created_at": run.created_at,
            "summary": run.summary._json_dict if hasattr(run.summary, "_json_dict") else {},
            "config": run.config
        }
        
        # Extract relevant metrics and parameters
        summary = run_data["summary"]
        config = run_data["config"]
        # print(summary)
        
        run_info = {
            "run_id": run.id,
            "run_name": run.name,
            "reasoning": config.get("experiment.reasoning", None),
            "anonymize": config.get("dataset.anonymize", None),
            "learning_rate": config.get("training.learning_rate", None),
            "percent_train": config.get("dataset.percent_of_train_dataset", None),
            # "num_epochs": config.get("training.num_train_epochs", None),
            "max_steps": config.get("training.max_steps", None),
        }
        
        # Extract scores for different test datasets
        for data_name, score_key in dataset_names.items():
            try: 
                run_info[score_key] = summary[data_name]
            except KeyError:
                print(f"KeyError: {data_name} not found in summary for run {summary}")
        
        runs_data.append(run_info)
    df = pd.DataFrame(runs_data)
    display(df)
    return df

# sweep_id = "ngq4k8jt"  # CLadder
# sweep_id = "fho761r3"  # ProntoQA

cladder_dataset2score_name = {'cladder-v1-q-commonsense': 'commonsense', 'cladder-v1-q-anticommonsense': 'anticommonsense', 'cladder-v1-q-noncommonsense': 'noncommonsense'}
prontoqa_dataset2score_name = {'prontoqa': 'prontoqa', 'prontoqa-anticommonsense': 'anticommonsense', 'prontoqa-noncommonsense': 'noncommonsense'}

cladder_df = collect_data('hovxsc44', cladder_dataset2score_name)
prontoqa_df = collect_data('6qew4692', prontoqa_dataset2score_name)

Unnamed: 0,run_id,run_name,reasoning,anonymize,learning_rate,percent_train,max_steps,commonsense,anticommonsense,noncommonsense
0,jsmejv8t,SFT/QWen-cladder-RFalse-DP0.899-s1600-Absorigi...,False,original,0.0003,0.899,1600,80.00,78.46,65.43
1,5rxek5e6,SFT/QWen-cladder-RFalse-DP0.899-s800-Absorigin...,False,original,0.0003,0.899,800,67.88,49.13,44.14
2,bhoobm5a,SFT/QWen-cladder-RFalse-DP0.899-s200-Absorigin...,False,original,0.0003,0.899,200,50.00,50.48,52.15
3,7x71ri47,SFT/QWen-cladder-RFalse-DP0.899-s50-Absorigina...,False,original,0.0003,0.899,50,51.54,22.50,21.29
4,0r5anv0l,SFT/QWen-cladder-RTrue-DP0.899-s3200-Absorigin...,True,original,0.0003,0.899,3200,93.65,94.13,86.33
...,...,...,...,...,...,...,...,...,...,...
331,yxlbcnef,SFT/QWen-cladder-RTrue-DP0.005-s800-AbsNone-lr...,True,,0.0003,0.005,800,64.42,58.56,53.42
332,7tr2sdzt,SFT/QWen-cladder-RTrue-DP0.005-s100-AbsNone-lr...,True,,0.0003,0.005,100,53.65,55.87,53.22
333,hnzoh43h,SFT/QWen-cladder-RTrue-DP0.005-s400-AbsNone-lr...,True,,0.0003,0.005,400,58.75,57.40,58.01
334,i3nzc1mg,SFT/QWen-cladder-RTrue-DP0.005-s200-AbsNone-lr...,True,,0.0003,0.005,200,56.35,57.60,54.79


Unnamed: 0,run_id,run_name,reasoning,anonymize,learning_rate,percent_train,max_steps,prontoqa,anticommonsense,noncommonsense
0,8dere276,SFT/QWen-prontoqa-RFalse-DP0.899-s3200-Absrand...,False,random,0.00030,0.899,3200,99.50,93.00,95.75
1,ln5ldm4g,SFT/QWen-prontoqa-RFalse-DP0.899-s1600-Absrand...,False,random,0.00030,0.899,1600,99.00,92.25,95.25
2,22894wgs,SFT/QWen-prontoqa-RFalse-DP0.899-s800-Absrando...,False,random,0.00030,0.899,800,98.75,91.50,94.25
3,58wvgxuc,SFT/QWen-prontoqa-RFalse-DP0.899-s400-Absrando...,False,random,0.00030,0.899,400,93.75,86.00,90.50
4,27mwulp8,SFT/QWen-prontoqa-RFalse-DP0.899-s200-Absrando...,False,random,0.00030,0.899,200,92.00,85.75,90.25
...,...,...,...,...,...,...,...,...,...,...
331,lyrienuz,SFT/QWen-prontoqa-RTrue-DP0.025-s800-AbsNone-l...,True,,0.00015,0.025,800,96.75,71.00,78.25
332,pilksu49,SFT/QWen-prontoqa-RTrue-DP0.025-s200-AbsNone-l...,True,,0.00030,0.025,200,98.25,73.75,83.75
333,tw1ad27v,SFT/QWen-prontoqa-RTrue-DP0.025-s100-AbsNone-l...,True,,0.00030,0.025,100,99.00,72.50,83.75
334,u2dmtmp4,SFT/QWen-prontoqa-RTrue-DP0.025-s1600-AbsNone-...,True,,0.00015,0.025,1600,99.00,73.00,79.75


In [23]:
testset_name_mapping = {
    'commonsense': 'Commonsense',
    'anticommonsense': 'Anticommonsense',
    'noncommonsense': 'Noncommonsense',
    'prontoqa': 'Commonsense',
}
# Create comparison plots
def create_anonymize_comparison_plots(df, dataset2score_name, available_pct, datasetname):
    
    # Identify test dataset columns
    test_dataset_columns = [col for col in df.columns if col in dataset2score_name.values()]
    training_percentages = sorted(df['percent_train'].unique())
    learning_rates = sorted(df['learning_rate'].unique())
    anonymize_list = sorted(df['anonymize'].unique())
    
    print(learning_rates)
    print(f"Found {len(df)} runs with {len(test_dataset_columns)} test datasets")
    print(f"Test datasets: {[col.replace('score_', '') for col in test_dataset_columns]}")
    print(f"Training percentages: {training_percentages}")
    
    
    # Create a subplot for each test dataset and training percentage combination
    num_test_sets = len(test_dataset_columns)
    num_percentages = len(available_pct)
    
    for lr in learning_rates:
        print(f"Creating comparison plots for learning rate: {lr}")
    
        fig = make_subplots(
            rows=num_percentages, 
            cols=num_test_sets,
            subplot_titles=[f"{testset_name_mapping[col]} - {round(pct*100)}% Training" 
                            for pct in available_pct for col in test_dataset_columns],
            vertical_spacing=0.1,
            horizontal_spacing=0.05
        )
        
        # Color scheme for True/False anonymization
        colors = {"null": "rgb(31, 119, 180)", "random": "rgb(255, 127, 14)", "order": "rgb(44, 160, 44)", "original": "rgb(214, 39, 40)"}
        
        for i, pct in enumerate(available_pct):
            for j, test_col in enumerate(test_dataset_columns):
                dataset_name = test_col
                
                # Filter data for this training percentage
                df_pct = df[(df['percent_train'] == pct) & (df['learning_rate'] == lr) & (df['reasoning'] == True)]
                
                # Group by anonymization and epochs, calculate mean scores
                for anon in anonymize_list:
                    if anon == "original":
                        continue
                    df_group = df_pct[df_pct['anonymize'] == anon]
                    
                    # Skip if no data for this combination
                    if len(df_group) == 0:
                        continue
                    
                    # Get average scores for each epoch
                    # print(df_group)
                    epoch_scores = df_group.groupby('max_steps')[test_col].mean().reset_index()
                    # print(epoch_scores)
                    # print(epoch_scores)
                    
                    # Add line to plot
                    fig.add_trace(
                        go.Scatter(
                            x=epoch_scores['max_steps'],
                            y=epoch_scores[test_col],
                            mode='lines+markers',
                            name=f"CAPT={anon}",
                            line=dict(color=colors[str(anon)]),
                            legendgroup=f"CAPT={anon}",
                            showlegend=(i==0 and j==0) # Only show legend once
                        ),
                        row=i+1, col=j+1
                    )
                
                # Update axis labels
                if i == num_percentages-1:
                    fig.update_xaxes(title_text="Number of Epochs", row=i+1, col=j+1)
                if j == 0:
                    fig.update_yaxes(title_text="Score", row=i+1, col=j+1)
        
        # Update layout
        fig.update_layout(
            height=250*num_percentages,
            width=400*num_test_sets,
            title_text=f"{datasetname} Ablation Study",
            legend_title="CAPT Setting",
            margin=dict(t=50, b=20, l=20, r=20),
        )
        fig.show()

create_anonymize_comparison_plots(cladder_df, cladder_dataset2score_name, [0.01, 0.02, 0.899], datasetname="CLadder")
create_anonymize_comparison_plots(prontoqa_df, prontoqa_dataset2score_name, [0.025, 0.05, 0.899], datasetname="PrOntoQA")

[0.0003]
Found 336 runs with 3 test datasets
Test datasets: ['commonsense', 'anticommonsense', 'noncommonsense']
Training percentages: [0.005, 0.01, 0.02, 0.05, 0.1, 0.899]
Creating comparison plots for learning rate: 0.0003


[0.00015, 0.0003]
Found 336 runs with 3 test datasets
Test datasets: ['prontoqa', 'anticommonsense', 'noncommonsense']
Training percentages: [0.025, 0.05, 0.1, 0.899]
Creating comparison plots for learning rate: 0.00015


Creating comparison plots for learning rate: 0.0003


In [5]:
display(cladder_df[(cladder_df["max_steps"] == 3200) & ((cladder_df["anonymize"] == "null") | (cladder_df["anonymize"] == "random")) & (cladder_df["percent_train"] == 0.899)])
display(prontoqa_df[(prontoqa_df["max_steps"] == 3200) & ((prontoqa_df["anonymize"] == "null") | (prontoqa_df["anonymize"] == "random")) & (prontoqa_df["percent_train"] == 0.899) & (prontoqa_df["learning_rate"] == 0.00015)])

Unnamed: 0,run_id,run_name,reasoning,anonymize,learning_rate,percent_train,max_steps,commonsense,anticommonsense,noncommonsense
127,tcmmbrxk,SFT/QWen-cladder-RFalse-DP0.899-s3200-Absrando...,False,random,0.0003,0.899,3200,87.6,89.42,80.76
132,xorwfc7l,SFT/QWen-cladder-RTrue-DP0.899-s3200-Absrandom...,True,random,0.0003,0.899,3200,93.08,94.62,88.57
251,1itnnntm,SFT/QWen-cladder-RFalse-DP0.899-s3200-AbsNone-...,False,,0.0003,0.899,3200,90.29,86.54,80.08
261,cud211nj,SFT/QWen-cladder-RTrue-DP0.899-s3200-AbsNone-l...,True,,0.0003,0.899,3200,95.38,95.67,88.09


Unnamed: 0,run_id,run_name,reasoning,anonymize,learning_rate,percent_train,max_steps,prontoqa,anticommonsense,noncommonsense
7,d8zzlopc,SFT/QWen-prontoqa-RFalse-DP0.899-s3200-Absrand...,False,random,0.00015,0.899,3200,99.25,91.5,95.5
21,uufgeeg3,SFT/QWen-prontoqa-RTrue-DP0.899-s3200-Absrando...,True,random,0.00015,0.899,3200,99.5,93.0,96.75
231,flhnk7lq,SFT/QWen-prontoqa-RFalse-DP0.899-s3200-AbsNone...,False,,0.00015,0.899,3200,100.0,58.0,77.25
245,zo6kn9yi,SFT/QWen-prontoqa-RTrue-DP0.899-s3200-AbsNone-...,True,,0.00015,0.899,3200,100.0,74.0,83.75


In [5]:
display(cladder_df[(cladder_df["max_steps"] == 3200) & ((cladder_df["anonymize"] == "null") | (cladder_df["anonymize"] == "random")) & (cladder_df["percent_train"] == 0.02)])
display(prontoqa_df[(prontoqa_df["max_steps"] == 3200) & ((prontoqa_df["anonymize"] == "null") | (prontoqa_df["anonymize"] == "random")) & (prontoqa_df["percent_train"] == 0.025) & (prontoqa_df["learning_rate"] == 0.00015)])


Unnamed: 0,run_id,run_name,reasoning,anonymize,learning_rate,percent_train,max_steps,commonsense,anticommonsense,noncommonsense
161,kqg54cjt,SFT/QWen-cladder-RFalse-DP0.02-s3200-Absrandom...,False,random,0.0003,0.02,3200,71.35,63.75,62.5
168,vzidorm3,SFT/QWen-cladder-RTrue-DP0.02-s3200-Absrandom-...,True,random,0.0003,0.02,3200,79.04,79.9,73.73
294,406dphom,SFT/QWen-cladder-RFalse-DP0.02-s3200-AbsNone-l...,False,,0.0003,0.02,3200,70.48,68.85,66.11
301,unt5tlj3,SFT/QWen-cladder-RTrue-DP0.02-s3200-AbsNone-lr...,True,,0.0003,0.02,3200,68.65,68.85,66.11


Unnamed: 0,run_id,run_name,reasoning,anonymize,learning_rate,percent_train,max_steps,prontoqa,anticommonsense,noncommonsense
91,w23hmhlf,SFT/QWen-prontoqa-RFalse-DP0.025-s3200-Absrand...,False,random,0.00015,0.025,3200,89.0,83.75,88.75
105,66lh4dxi,SFT/QWen-prontoqa-RTrue-DP0.025-s3200-Absrando...,True,random,0.00015,0.025,3200,87.5,82.5,81.0
315,nnvj293b,SFT/QWen-prontoqa-RFalse-DP0.025-s3200-AbsNone...,False,,0.00015,0.025,3200,99.75,61.5,71.25
329,64kfsf57,SFT/QWen-prontoqa-RTrue-DP0.025-s3200-AbsNone-...,True,,0.00015,0.025,3200,99.5,70.75,79.0
