# Baseline Clean Evaluation Notebook

This notebook is used as the master notebook baseline for all evaluations done for the final thesis results. 

## Setup and Imports

In [2]:
import re
from typing import Dict, List, Tuple, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import wandb
from IPython.display import HTML, display
from scipy import stats
from tqdm import tqdm
import scienceplots


# Initialize wandb API
api = wandb.Api()

## General Configuration
(Should generally need no modifications)

### Matplotlib Configuration and Styling

In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

plt.style.use(['science', 'grid'])
plt.rcParams['figure.figsize'] = (8, 5)
plt.rcParams['font.size'] = 11

colors = {
    # Main colors with pastel variants
    "primary": "#2f6bb6",         # Primary Blue, inspired by the "opinion" color in the original figures
    "primary_dark": "#063D79",    # Darker variant of primary blue (Straight from the Uni Würzburg design manual)
    "primary_pastel": "#8badd6",  # Pastel variant of primary blue
    
    "accent_red": "#e94e4e",      # Accent Red, inspired by the "aspect" color in the original figures
    "accent_red_dark": "#a83434", # Darker variant of accent red
    "accent_red_pastel": "#f19999", # Pastel variant of accent red
    
    "tertiary": "#f5a142",        # Tertiary Orange used for other span visualizations
    "tertiary_dark": "#b5732d",   # Darker variant of tertiary orange
    "tertiary_pastel": "#fac589", # Pastel variant of tertiary orange
    
    "classification": "#4ca64c",   # Classification Green, inspired by the "sentiment" color in the original figures
    "classification_dark": "#2f6c2f", # Darker variant of classification green
    "classification_pastel": "#9fd19f", # Pastel variant of classification green
    
    # Gray scale colors
    "gray_dark": "#4a4a4a",       # Dark gray for text and important elements
    "gray_medium": "#808080",     # Medium gray for secondary elements
    "gray_light": "#d3d3d3",      # Light gray for backgrounds and dividers
    "black": "#000000",           # Black for important elements
    
    # Background and utility colors
    "background": "#ffffff",      # White background
    "text": "#2b2b2b",            # Dark gray text
    "divider": "#ececec"          # Neutral gray dividers
}
# Getting the colors into the matplotlib cycle
plt.rcParams.update({
    "axes.prop_cycle": plt.cycler(color=[ #type: ignore
        colors["primary"],
        colors["accent_red"],
        colors["tertiary"],
        colors["classification"],
        colors["primary_dark"],
        colors["accent_red_dark"],
        colors["tertiary_dark"],
        colors["classification_dark"],
        colors["primary_pastel"],
        colors["accent_red_pastel"],
        colors["tertiary_pastel"],
        colors["classification_pastel"],
        colors["gray_dark"],
        colors["gray_medium"],
        colors["gray_light"],
    ]),
    "figure.facecolor": colors["background"],
    "axes.facecolor": colors["background"],
    "axes.edgecolor": colors["black"],
    "axes.labelcolor": colors["text"],
    "xtick.color": colors["text"],
    "ytick.color": colors["text"],
    "grid.color": colors["black"],
    "grid.alpha": 0.7,
    "text.color": colors["text"]
})

### Baselines staying the same through all evaluations

In [3]:
paper_baseline = {
    '14res': {'P': 65.52, 'R': 64.99, 'F1': 65.25},
    '14lap': {'P': 61.41, 'R': 56.19, 'F1': 58.69},
    '15res': {'P': 59.14, 'R': 59.38, 'F1': 59.26},
    '16res': {'P': 66.60, 'R': 68.68, 'F1': 67.62}
}
sota_baseline = {
    '14res': {'P': 76.1, 'R': 75.08, 'F1': 75.59},
    '14lap': {'P': 66.82, 'R': 60.68, 'F1': 63.61},
    '15res': {'P': 66.50, 'R': 63.86, 'F1': 65.15},
    '16res': {'P': 75.52, 'R': 74.14, 'F1': 74.83}
}
llm_baselines = {
    '14res': {'P': 74.65, 'R': 79.38, 'F1': 76.94},
    '14lap': {'P': 61.67, 'R': 68.39, 'F1': 64.86},
    '15res': {'P': 61.91, 'R': 73.40, 'F1': 67.17},
    '16res': {'P': 71.38, 'R': 78.60, 'F1': 74.81}
}

In [4]:
DATA_METRICS = ['test/tuple_prec_epoch', 'test/tuple_rec_epoch', 'test/tuple_f1_epoch']
EVAL_METRICS = ['P', 'R', 'F1']
METRIC_CONVERSION = {
    'test/tuple_prec_epoch': 'P',
    'test/tuple_rec_epoch': 'R',
    'test/tuple_f1_epoch': 'F1'
}

DATASETS = ['14res', '14lap', '15res', '16res']

## Core Functions

In [5]:
def fetch_run_data(api: wandb.Api, project_name: str, run_name: str, fraction: Optional[float] = None) -> pd.DataFrame:
    original_run_name = run_name
    if fraction is not None:
        run_name = f"{run_name}_Fraction_{fraction}"
    
    runs = api.runs(project_name, {"display_name": run_name, "state": "finished"})
    
    # Failsafe: If fraction is 1.0 and no runs found, try without the fraction suffix
    if fraction == 1.0 and len(list(runs)) == 0:
        runs = api.runs(project_name, {"display_name": original_run_name, "state": "finished"})
        # Keep the fraction in the run_name for consistency in the output
        display_run_name = run_name
    else:
        display_run_name = run_name
    
    data = []
    for run in tqdm(runs, desc=f"Fetching data for {display_run_name}"):
        config = run.config
        summary = run.summary
        data.append({
            'run_name': display_run_name,
            'dataset': config.get('dataset', {}).get('name', 'unknown'),
            **{m: summary.get(m, np.nan) for m in DATA_METRICS}
        })
    return pd.DataFrame(data)

In [6]:
def fetch_fraction_data(api: wandb.Api, project_name: str, run_configs: List[Dict[str, str|int]], paper_baselines: List[Tuple[str, Dict[str, Dict[str, float]]]], baseline_name: str, fractions: List[float], raw_values: bool = False, baseline_paper_compare_index: Optional[int] = None) -> Dict[float, pd.DataFrame]:
    """
    Fetch comparison tables for multiple training data fractions.
    
    Args:
        api: wandb API object
        project_name: Name of the wandb project
        run_configs: List of run configurations
        paper_baselines: List of paper baseline results
        baseline_name: Name of baseline run to compare against
        fractions: List of training data fractions to analyze
        raw_values: Whether to return raw run values or aggregated statistics
        baseline_paper_compare_index: Optional index for paper baseline comparison
        
    Returns:
        Dictionary mapping fractions to their corresponding comparison table DataFrames
    """
    return {
        fraction: create_comparison_table(
            api=api,
            project_name=project_name,
            run_configs=run_configs,
            paper_baselines=paper_baselines,
            baseline_name=baseline_name,
            fraction=fraction,
            raw_values=raw_values,
            baseline_paper_compare_index=baseline_paper_compare_index
        )
        for fraction in tqdm(fractions, desc="Fetching fraction data")
    }


def create_comparison_table(api: wandb.Api, project_name: str, run_configs: List[Dict[str, str|int]], paper_baselines: List[Tuple[str, Dict[str, Dict[str, float]]]], baseline_name: str, fraction: Optional[float] = None, raw_values: bool = False, baseline_paper_compare_index: Optional[int] = None) -> pd.DataFrame:
    # Verify baseline_name exists in run_configs
    if not any(config['run_name'] == baseline_name for config in run_configs):
        raise ValueError(f"baseline_name '{baseline_name}' must be included in run_configs")
        
    all_data = pd.concat([fetch_run_data(api, project_name, config['run_name'], fraction) for config in tqdm(run_configs, desc="Fetching data")]) #type: ignore
    datasets = DATASETS
    metrics = EVAL_METRICS
    
    if raw_values:
        results = []
        for dataset in datasets:
            # Paper baselines (single row since no raw values)
            for baseline_name, paper_baseline in paper_baselines:
                paper_row = {'Model': f'{baseline_name}', 'Dataset': dataset}
                paper_row.update({m: paper_baseline[dataset][m] for m in metrics}) #type: ignore
                results.append(paper_row)
            
            # Get raw values for each run
            for config in run_configs:
                run_name = f"{config['run_name']}_Fraction_{fraction}" if fraction is not None else config['run_name']
                run_data = all_data[(all_data['run_name'] == run_name) & (all_data['dataset'] == dataset)]
                
                # Add each individual run result as a row
                for _, run in run_data.iterrows():
                    row = {
                        'Model': config['display_name'] if 'display_name' in config else config['run_name'],
                        'Dataset': dataset,
                        **{METRIC_CONVERSION[m]: run[m] for m in DATA_METRICS}
                    }
                    results.append(row)
        
        return pd.DataFrame(results)
    
    else:
        results = []
        for dataset in datasets:
            # Paper baselines
            for paper_baseline_name, paper_baseline in paper_baselines:
                paper_row = {'Model': f'{paper_baseline_name}', 'Dataset': dataset}
                paper_row.update({m: paper_baseline[dataset][m] for m in metrics}) #type: ignore
                paper_row.update({f'{m}_std': 0 for m in metrics}) #type: ignore
                paper_row.update({f'{m}_diff': 0 for m in metrics}) #type: ignore
                results.append(paper_row)

            # Get our baseline metrics for comparison
            baseline_run_name = f"{baseline_name}_Fraction_{fraction}" if fraction is not None else baseline_name
            baseline_data = all_data[(all_data['run_name'] == baseline_run_name) & (all_data['dataset'] == dataset)]
            baseline_metrics = {}
            for i, m in enumerate(DATA_METRICS):
                values = baseline_data[m].dropna()
                baseline_metrics[metrics[i]] = values.mean()

            # All runs including baseline
            for config in run_configs:
                run_name = f"{config['run_name']}_Fraction_{fraction}" if fraction is not None else config['run_name']
                run_data = all_data[(all_data['run_name'] == run_name) & (all_data['dataset'] == dataset)]
                row = {'Model': config['display_name'], 'Dataset': dataset}
                
                for i, m in enumerate(['test/tuple_prec_epoch', 'test/tuple_rec_epoch', 'test/tuple_f1_epoch']):
                    values = run_data[m].dropna()
                    row[metrics[i]] = values.mean() #type: ignore
                    row[f'{metrics[i]}_std'] = values.std() #type: ignore
                    
                    # For baseline run, compare against paper baseline if index provided
                    if config['run_name'] == baseline_name and baseline_paper_compare_index is not None:
                        _, paper_baseline = paper_baselines[baseline_paper_compare_index]
                        row[f'{metrics[i]}_diff'] = row[metrics[i]] - paper_baseline[dataset][metrics[i]] #type: ignore
                    else:
                        # For all other runs, compare against our baseline
                        row[f'{metrics[i]}_diff'] = row[metrics[i]] - baseline_metrics[metrics[i]] #type: ignore
                
                results.append(row)

        return pd.DataFrame(results)

def format_value(value, std, diff, show_std=True, show_diff=True, is_paper_baseline=False):
    formatted = f"{value:.2f}"
    if not is_paper_baseline and (show_std or show_diff):
        formatted += " ("
        if show_std:
            formatted += f"σ {std:.2f}"
        if show_std and show_diff:
            formatted += ","
        if show_diff:
            formatted += f" {'+' if diff >= 0 else '-'}{abs(diff):.2f}"
        formatted += ")"
    return formatted

In [7]:
def create_latex_table(df: pd.DataFrame, caption: str, label: str, 
                       baselines: List[str] = [], bold_best: bool = False, 
                       groups: Optional[Dict[str, List[str]]] = None,
                       mode: str = 'full') -> str:
    """
    Create a LaTeX table for ACL formatted paper.
    
    Args:
        df: DataFrame with the data
        caption: Table caption
        label: Table label for referencing
        baselines: List of baseline model names
        bold_best: Whether to bold best values
        groups: Optional dictionary with model groups
        mode: 'single' for single column with avg only, 'full' for full-width with all datasets
    
    Returns:
        LaTeX code for the table
    """
    # Set datasets based on mode
    if mode == 'single':
        datasets = ['Avg']
        show_avg = True
    else:  # full mode
        datasets = ['14res', '14lap', '15res', '16res', 'Avg']
        show_avg = True
    
    metrics = ['P', 'R', 'F1']
    
    # If bold_best is True, find best values per dataset and metric
    best_values = {}
    group_best_values = {}
    overall_best_values = {}
    if bold_best:
        for dataset in datasets:
            dataset_data = df[df['Dataset'] == dataset] if dataset != 'Avg' else df
            for metric in metrics:
                if dataset == 'Avg':
                    model_avgs = df.groupby('Model')[metric].mean()
                    best_values[(dataset, metric)] = model_avgs.max()
                    overall_best_values[(dataset, metric)] = model_avgs.max()
                    if groups:
                        group_best_values[(dataset, metric)] = {}
                        for group_name, group_models in groups.items():
                            group_avgs = model_avgs[model_avgs.index.isin(group_models)]
                            group_best_values[(dataset, metric)][group_name] = group_avgs.max()
                else:
                    best_values[(dataset, metric)] = dataset_data[metric].max()
                    overall_best_values[(dataset, metric)] = dataset_data[metric].max()
                    if groups:
                        group_best_values[(dataset, metric)] = {}
                        for group_name, group_models in groups.items():
                            group_data = dataset_data[dataset_data['Model'].isin(group_models)]
                            group_best_values[(dataset, metric)][group_name] = group_data[metric].max()
    
    latex = []
    # Start table with appropriate environment based on mode
    if mode == 'full':
        latex.append("\\begin{table*}[t]")  # Full-width table
    else:
        latex.append("\\begin{table}[t]")   # Single-column table
        
    latex.append("    \\centering")
    latex.append("    \\setlength{\\tabcolsep}{2pt}")
    latex.append("    \\small")
    latex.append(f"    \\caption{{{caption}}}")
    
    # Only use adjustbox in full mode
    if mode == 'full':
        latex.append("    \\begin{adjustbox}{width=\\textwidth}")
    
    # Begin tabular
    latex.append("        \\begin{tabular}{l*{" + str(len(datasets) * 3) + "}{r}}")
    latex.append("            \\toprule")
    
    # Header rows
    header1 = ["\\multirow{2}{*}{Model}"]
    for dataset in datasets:
        header1.append(f"\\multicolumn{{3}}{{c}}{{{dataset}}}")
    latex.append("            " + " & ".join(header1) + " \\\\ &")
    
    header2 = []
    for dataset in datasets:
        header2.extend([f"P & R & F1 \\vrule" if dataset != datasets[-1] else "P & R & F1"])
    latex.append("            " + " & ".join(header2) + " \\\\")
    latex.append("            \\midrule")
    
    # Data rows
    if groups:
        # Iterate through groups
        for group_name, group_models in groups.items():
            # Add group header
            latex.append(f"            \\multicolumn{{{3*len(datasets)+1}}}{{l}}{{\\textbf{{{group_name}}}}} \\\\")
            latex.append("            \\midrule")
            
            # Add models in this group
            for model in df['Model'].unique():
                if model in group_models:
                    model_data = df[df['Model'] == model]
                    is_baseline = model in baselines
                    
                    # Split model name if it contains line break marker
                    model_parts = model.split("\\\\")
                    
                    # Main metrics row
                    row = [model_parts[0]]  # Use first part of model name
                    for dataset in datasets:
                        if dataset == 'Avg':
                            for metric in metrics:
                                avg_value = model_data[metric].mean()
                                formatted_value = f"{avg_value:.2f}"
                                if bold_best:
                                    # Bold if best in group
                                    if abs(avg_value - group_best_values[(dataset, metric)][group_name]) < 1e-10:
                                        formatted_value = f"\\textbf{{{formatted_value}}}"
                                    # Underline if best overall
                                    if abs(avg_value - overall_best_values[(dataset, metric)]) < 1e-10:
                                        formatted_value = f"\\underline{{{formatted_value}}}"
                                row.append(formatted_value)
                        else:
                            dataset_data = model_data[model_data['Dataset'] == dataset]
                            if not dataset_data.empty:
                                for metric in metrics:
                                    value = dataset_data[metric].values[0]
                                    formatted_value = f"{value:.2f}"
                                    if bold_best:
                                        # Bold if best in group
                                        if abs(value - group_best_values[(dataset, metric)][group_name]) < 1e-10:
                                            formatted_value = f"\\textbf{{{formatted_value}}}"
                                        # Underline if best overall
                                        if abs(value - overall_best_values[(dataset, metric)]) < 1e-10:
                                            formatted_value = f"\\underline{{{formatted_value}}}"
                                    row.append(formatted_value)
                                    if dataset != datasets[-1] and metric == metrics[-1]:
                                        row[-1] += " \\vrule"
                            else:
                                row.extend(["-", "-", "-"])
                                if dataset != datasets[-1]:
                                    row[-1] += " \\vrule"
                    latex.append("            " + " & ".join(row) + " \\\\")
                    
                    # Standard deviation row
                    std_row = [model_parts[1] if len(model_parts) > 1 else ""]
                    for dataset in datasets:
                        if dataset == 'Avg':
                            if not is_baseline:
                                for metric in metrics:
                                    avg_std = model_data[f"{metric}_std"].mean()
                                    std_row.append(f"{{\\scriptsize \\textit{{($\\sigma$ {avg_std:.2f})}}}}")
                            else:
                                std_row.extend(["", "", ""])
                        else:
                            dataset_data = model_data[model_data['Dataset'] == dataset]
                            if not dataset_data.empty and not is_baseline:
                                for metric in metrics:
                                    std = dataset_data[f"{metric}_std"].values[0]
                                    std_row.append(f"{{\\scriptsize \\textit{{($\\sigma$ {std:.2f})}}}}")
                                    if dataset != datasets[-1] and metric == metrics[-1]:
                                        std_row[-1] += " \\vrule"
                            else:
                                std_row.extend(["", "", ""])
                                if dataset != datasets[-1]:
                                    std_row[-1] += " \\vrule"
                    if not is_baseline or len(model_parts) > 1:
                        latex.append("            " + " & ".join(std_row) + " \\\\")
                    
                    latex.append("            \\midrule")
            # Add thicker rule between groups
            latex[-1] = "            \\midrule[1pt]"
    else:
        # Original logic for ungrouped table
        for model in df['Model'].unique():
            model_data = df[df['Model'] == model]
            is_baseline = model in baselines
            
            # Split model name if it contains line break marker
            model_parts = model.split("\\\\")
            
            # Main metrics row
            row = [model_parts[0]]  # Use first part of model name
            for dataset in datasets:
                if dataset == 'Avg':
                    for metric in metrics:
                        avg_value = model_data[metric].mean()
                        formatted_value = f"{avg_value:.2f}"
                        if bold_best and abs(avg_value - best_values[(dataset, metric)]) < 1e-10:
                            formatted_value = f"\\textbf{{{formatted_value}}}"
                        row.append(formatted_value)
                else:
                    dataset_data = model_data[model_data['Dataset'] == dataset]
                    if not dataset_data.empty:
                        for metric in metrics:
                            value = dataset_data[metric].values[0]
                            formatted_value = f"{value:.2f}"
                            if bold_best and abs(value - best_values[(dataset, metric)]) < 1e-10:
                                formatted_value = f"\\textbf{{{formatted_value}}}"
                            row.append(formatted_value)
                            if dataset != datasets[-1] and metric == metrics[-1]:
                                row[-1] += " \\vrule"
                    else:
                        row.extend(["-", "-", "-"])
                        if dataset != datasets[-1]:
                            row[-1] += " \\vrule"
            latex.append("            " + " & ".join(row) + " \\\\")
            
            # Standard deviation row
            std_row = [model_parts[1] if len(model_parts) > 1 else ""]
            for dataset in datasets:
                if dataset == 'Avg':
                    if not is_baseline:
                        for metric in metrics:
                            avg_std = model_data[f"{metric}_std"].mean()
                            std_row.append(f"{{\\scriptsize \\textit{{($\\sigma$ {avg_std:.2f})}}}}")
                    else:
                        std_row.extend(["", "", ""])
                else:
                    dataset_data = model_data[model_data['Dataset'] == dataset]
                    if not dataset_data.empty and not is_baseline:
                        for metric in metrics:
                            std = dataset_data[f"{metric}_std"].values[0]
                            std_row.append(f"{{\\scriptsize \\textit{{($\\sigma$ {std:.2f})}}}}")
                            if dataset != datasets[-1] and metric == metrics[-1]:
                                std_row[-1] += " \\vrule"
                    else:
                        std_row.extend(["", "", ""])
                        if dataset != datasets[-1]:
                            std_row[-1] += " \\vrule"
            if not is_baseline or len(model_parts) > 1:
                latex.append("            " + " & ".join(std_row) + " \\\\")
            
            latex.append("            \\midrule")
    
    # End table
    latex[-1] = "            \\bottomrule"  # Replace last midrule with bottomrule
    latex.append("        \\end{tabular}")
    
    # Close adjustbox if in full mode
    if mode == 'full':
        latex.append("    \\end{adjustbox}")
        
    latex.append(f"    \\label{{{label}}}")
    
    # Close table environment with appropriate end tag
    if mode == 'full':
        latex.append("\\end{table*}")
    else:
        latex.append("\\end{table}")
    
    return "\n".join(latex)

In [8]:
def create_averaged_latex_table(
    df: pd.DataFrame, 
    caption: str, 
    label: str,
    groups: Dict[str, List[str]], # Dict mapping group names to list of models in that group
    paper_baseline: str = "Paper Baseline",
    reimpl_baseline: str = "Baseline",
    show_std: bool = True,
    show_diff: bool = True,
    bold_style: str = "none" # "none", "best", or "per_group"
) -> str:
    metrics = ['P', 'R', 'F1']
    
    # Calculate averages and find best values
    model_avgs = {}
    model_stds = {}
    best_values = {m: -float('inf') for m in metrics}
    group_best_values = {group: {m: -float('inf') for m in metrics} for group in groups.keys()}
    
    for model in df['Model'].unique():
        model_data = df[df['Model'] == model]
        avgs = {m: model_data[m].mean() for m in metrics}
        stds = {m: model_data[f"{m}_std"].mean() if f"{m}_std" in model_data else 0 for m in metrics}
        model_avgs[model] = avgs
        model_stds[model] = stds
        
        # Update best values (excluding paper baseline)
        if model != paper_baseline:
            for m in metrics:
                best_values[m] = max(best_values[m], avgs[m])
                # Update group bests
                for group, models in groups.items():
                    if model in models:
                        group_best_values[group][m] = max(group_best_values[group][m], avgs[m])
    
    # Calculate diffs against reimpl baseline
    baseline_avgs = model_avgs[reimpl_baseline]
    model_diffs = {
        model: {m: avgs[m] - baseline_avgs[m] for m in metrics}
        for model, avgs in model_avgs.items()
    }
    
    latex = []
    # Start table
    latex.append("\\begin{table}[!htbp]")
    latex.append("    \\centering")
    latex.append("    \\setlength{\\tabcolsep}{4pt}")
    latex.append(f"    \\caption{{{caption}}}")
    latex.append("    \\begin{tabular}{lccc}")
    latex.append("        \\toprule")
    
    # Header
    latex.append("        \\multirow{2}{*}{Model} & \\multicolumn{3}{c}{Average Performance} \\\\")
    latex.append("        \\cmidrule{2-4}")
    latex.append("        & P & R & F1 \\\\")
    latex.append("        \\midrule")
    
    # Data rows
    for group_name, group_models in groups.items():
        if group_name:  # Add group header if group has a name
            latex.append(f"        \\multicolumn{{4}}{{l}}{{\\textit{{{group_name}:}}}} \\\\")
            
        for model in group_models:
            avgs = model_avgs[model]
            stds = model_stds[model]
            diffs = model_diffs[model]
            
            row = [model if "\\textit" not in model else f"\\multicolumn{{1}}{{l}}{{{model}}}"]
            
            for metric in metrics:
                value = avgs[metric]
                formatted = f"{value:.2f}"
                
                # Handle different bolding styles
                if bold_style != "none" and model != paper_baseline:
                    if bold_style == "best" and abs(value - best_values[metric]) < 1e-10:
                        formatted = f"\\textbf{{{formatted}}}"
                    elif bold_style == "per_group":
                        # Bold if best in group
                        if abs(value - group_best_values[group_name][metric]) < 1e-10:
                            formatted = f"\\textbf{{{formatted}}}"
                            # Additionally underline if absolute best
                            if abs(value - best_values[metric]) < 1e-10:
                                formatted = f"\\underline{{{formatted}}}"
                
                # Add std/diff if needed
                if show_std and model not in [paper_baseline]:
                    std_str = f"$\\sigma$ {stds[metric]:.2f}"
                    if show_diff and model not in [reimpl_baseline]:
                        diff_str = f"{diffs[metric]:+.2f}"
                        formatted += f" {{\\scriptsize \\textit{{({std_str}, {diff_str})}}}} "
                    else:
                        formatted += f" {{\\scriptsize \\textit{{({std_str})}}}} "
                        
                row.append(formatted)
                
            latex.append("        " + " & ".join(row) + " \\\\")
            
        latex.append("        \\midrule")
    
    # End table
    latex[-1] = "        \\bottomrule"
    latex.append("    \\end{tabular}")
    latex.append(f"    \\label{{{label}}}")
    latex.append("\\end{table}")
    
    return "\n".join(latex)


## Visualization Functions

In [9]:
def style_dataframe(df: pd.DataFrame, show_std: bool = True, show_diff: bool = True, caption: str = "Comparison Results for Triplet on the D20b dataset", baselines: List[str] = []) -> pd.DataFrame:
    datasets = ['14res', '14lap', '15res', '16res', 'Average']
    metrics = ['P', 'R', 'F1']

    styled_data = []
    numeric_data = []

    for model in df['Model'].unique():
        row_data = {'Model': model}
        numeric_row = {'Model': model}
        model_data = df[df['Model'] == model]
        
        for dataset in datasets:
            if dataset == 'Average':
                for metric in metrics:
                    avg_value = model_data[metric].mean()
                    avg_std = model_data[f"{metric}_std"].mean()
                    avg_diff = model_data[f"{metric}_diff"].mean()
                    is_paper_baseline = (model in baselines)
                    row_data[f"{dataset}_{metric}"] = format_value(avg_value, avg_std, avg_diff, show_std, show_diff, is_paper_baseline)
                    numeric_row[f"{dataset}_{metric}"] = avg_value
            else:
                dataset_data = model_data[model_data['Dataset'] == dataset]
                if not dataset_data.empty:
                    for metric in metrics:
                        value = dataset_data[metric].values[0]
                        std = dataset_data[f"{metric}_std"].values[0]
                        diff = dataset_data[f"{metric}_diff"].values[0]
                        is_paper_baseline = (model in baselines)
                        row_data[f"{dataset}_{metric}"] = format_value(value, std, diff, show_std, show_diff, is_paper_baseline)
                        numeric_row[f"{dataset}_{metric}"] = value
                else:
                    for metric in metrics:
                        row_data[f"{dataset}_{metric}"] = "-"
                        numeric_row[f"{dataset}_{metric}"] = np.nan
        styled_data.append(row_data)
        numeric_data.append(numeric_row)

    styled_df = pd.DataFrame(styled_data)
    numeric_df = pd.DataFrame(numeric_data)

    column_tuples = [('Model', '')] + [(dataset, metric) for dataset in datasets for metric in metrics]
    styled_df.columns = pd.MultiIndex.from_tuples(column_tuples)
    numeric_df.columns = pd.MultiIndex.from_tuples(column_tuples)

    def highlight_max_and_close(s, props='font-weight:bold;'):
        max_value = s.max()
        is_max = s == max_value
        MAX_ALLOWED_DIFF = 0.5
        is_close = (max_value - s) <= MAX_ALLOWED_DIFF
        return ['font-weight:bold; background-color: #d3d3d3; color: black;' if v else 'background-color: #d3d3d3; color: black;' if c else 'background-color: white; color: black;' for v, c in zip(is_max, is_close)]

    return styled_df.style.set_properties(**{
        'text-align': 'center',
        'padding': '8px',
        'border': '1px solid #666',
        'background-color': 'white',
        'color': 'black'
    }).set_table_styles([ #type: ignore
        {'selector': 'th', 'props': [('background-color', '#e6e6e6'), ('font-weight', 'bold'), ('text-align', 'center'), ('color', 'black')]},
        {'selector': 'th.col_heading.level0', 'props': [('font-size', '1.1em'), ('border-right', '2px solid #666'), ('color', 'black')]},
        {'selector': 'th.col_heading.level1', 'props': [('font-size', '0.9em'), ('color', 'black')]},
        {'selector': 'td', 'props': [('border-right', '1px solid #666')]},
        {'selector': 'td:nth-child(3n+2)', 'props': [('border-right', '2px solid #666')]},
        {'selector': 'caption', 'props': [('caption-side', 'top'), ('font-size', '1.2em'), ('font-weight', 'bold'), ('color', 'black')]}
    ]).apply(lambda x: highlight_max_and_close(numeric_df.loc[:, x.name]), axis=0, subset=pd.IndexSlice[:, datasets]).set_caption(caption) #type: ignore

In [10]:
def plot_fraction_comparison(
    api: wandb.Api,
    project_name: str, 
    run_configs: list[dict[str, str | int]],
    fractions: list[float],
    *,
    show_error_bars: bool = True,
    show_all_metrics: bool = False,
    show_baseline_diff: bool = False,
    baseline_name: str = "Baseline",
    title: Optional[str] = None,
    figsize_single: Tuple[int, int] = (10, 4),
    figsize_all: Tuple[int, int] = (10, 3),
    legend_columns: int = 3,
    legend_fontsize: int = 11,
    legend_position: Tuple[float, float] = (0.5, 0.05),
    y_margin_factor: float = 0.1,
    x_limits: Tuple[float, float] = (0.1, 1.1),
    grid_alpha: float = 0.7,
    annotation_fontsize: int = 9,
    line_width: float = 2.5,
    print_raw_values: bool = False,
    positive_y_offset: int = 5,
    negative_y_offset: int = -12
) -> plt.Figure: #type: ignore
    """
    Plot line graphs showing how performance metrics change across different dataset fractions.
    Can show either just F1 score or all three metrics (P, R, F1) side by side.
    
    Args:
        api: Weights & Biases API instance
        project_name: Name of the W&B project to fetch data from
        run_configs: List of config dicts with keys 'run_name', 'display_name', and optionally 'color'
        fractions: List of dataset fractions to plot (between 0 and 1)
        show_error_bars: Whether to show error bars on the plot points
        show_all_metrics: Whether to show all metrics (P,R,F1) or just F1
        show_baseline_diff: If True, shows the difference from baseline for each point
        baseline_name: Name of the baseline model to compare against when show_baseline_diff is True
        title: Optional title for the plot. If None and show_all_metrics is False, uses default title
        figsize_single: Figure size when showing single metric
        figsize_all: Figure size when showing all metrics
        legend_columns: Number of columns in the legend
        legend_fontsize: Font size for legend text
        legend_position: (x,y) coordinates for legend placement
        y_margin_factor: Factor to add margin to y-axis limits
        x_limits: (min, max) limits for x-axis
        grid_alpha: Transparency of grid lines
        annotation_fontsize: Font size for difference annotations
        line_width: Width of plotted lines
        
    Returns:
        matplotlib.pyplot.Figure: The generated plot figure
    """
    # Get configs for each fraction
    # Fetch data for all configurations
    all_data = pd.concat([
        fetch_run_data(api, project_name, config['run_name'], fraction=fraction)
        for fraction in tqdm(fractions, desc="Fetching fraction data") 
        for config in run_configs
    ]) #type: ignore

    metrics = ['test/tuple_f1_epoch'] if not show_all_metrics else [
        'test/tuple_prec_epoch', 'test/tuple_rec_epoch', 'test/tuple_f1_epoch'
    ]
    metric_names = ['F1'] if not show_all_metrics else ['Precision', 'Recall', 'F1']
    
    # Calculate metrics for each model and fraction
    results = []
    for fraction in fractions:
        for config in run_configs:
            run_name = config['run_name']
            fraction_data = all_data[
                (all_data['run_name'].str.contains(f"Fraction_{fraction}", regex=False)) & 
                (all_data['run_name'].str.contains(run_name, regex=False))
            ]
            
            # Skip if no data found for this configuration
            if len(fraction_data) == 0:
                print(f"No data found for {run_name} at fraction {fraction}")
                continue
            # Calculate metrics for each dataset's runs
            dataset_groups = fraction_data.groupby('dataset')
            
            # Print dataset-specific results in a nicely formatted table
            if print_raw_values:
                print(f"\n=== Results for {config['display_name']} at fraction {fraction} ===")
                print(f"{'Dataset':<15} {'Precision':<15} {'Recall':<15} {'F1':<15}")
                print("-" * 60)
                for dataset_name, group in dataset_groups:
                    prec_mean = group['test/tuple_prec_epoch'].mean()
                    prec_std = group['test/tuple_prec_epoch'].std()
                    rec_mean = group['test/tuple_rec_epoch'].mean()
                    rec_std = group['test/tuple_rec_epoch'].std()
                    f1_mean = group['test/tuple_f1_epoch'].mean()
                    f1_std = group['test/tuple_f1_epoch'].std()
                    print(f"{dataset_name:<15} {prec_mean:>6.3f} ± {prec_std:>4.3f}  {rec_mean:>6.3f} ± {rec_std:>4.3f}  {f1_mean:>6.3f} ± {f1_std:>4.3f}")
                print("-" * 60)
            
            # Calculate metrics for results DataFrame
            for metric, metric_name in zip(metrics, metric_names):
                dataset_stds = [group[metric].std() for _, group in dataset_groups]
                avg_metric = fraction_data[metric].mean()
                std_metric = np.mean(dataset_stds)
                
                results.append({
                    'Fraction': fraction,
                    'Model': config['display_name'], 
                    'Metric': metric_name,
                    'Average': avg_metric,
                    'Std': std_metric
                })
    
    results_df = pd.DataFrame(results)
    # Create the plot
    if show_all_metrics:
        fig, axes = plt.subplots(1, 3, figsize=figsize_all)
    else:
        fig, axes = plt.subplots(1, 1, figsize=figsize_single)
        axes = [axes]  # Make it iterable for consistent code
        
    # Find global min and max for y-axis scaling
    y_min = results_df['Average'].min() - results_df['Std'].max()
    y_max = results_df['Average'].max() + results_df['Std'].max()
    y_margin = (y_max - y_min) * y_margin_factor
    y_min -= y_margin
    y_max += y_margin
        
    for ax_idx, (metric_name, ax) in enumerate(zip(metric_names, axes)):
        metric_data = results_df[results_df['Metric'] == metric_name]
        
        # Get baseline data if showing differences
        if show_baseline_diff:
            baseline_data = metric_data[metric_data['Model'] == baseline_name].set_index('Fraction')['Average']
            non_baseline_models = metric_data[metric_data['Model'] != baseline_name]['Model'].unique()
            
        # First pass to get min/max values per fraction for anchoring
        fraction_anchors = {}
        for fraction in metric_data['Fraction'].unique():
            fraction_data = metric_data[metric_data['Fraction'] == fraction]
            fraction_anchors[fraction] = {
                'min': fraction_data['Average'].min(),
                'max': fraction_data['Average'].max()
            }
            
        # First collect all differences for each fraction
        fraction_diffs = {}
        if show_baseline_diff:
            for fraction in metric_data['Fraction'].unique():
                fraction_diffs[fraction] = {'pos': [], 'neg': []}
                fraction_models = metric_data[metric_data['Fraction'] == fraction]
                for _, row in fraction_models.iterrows():
                    if row['Model'] != baseline_name:
                        config = next((c for c in run_configs if c['display_name'] == row['Model']), None)
                        color_idx = config.get('color', None) if config else None
                        if color_idx is not None:
                            plot_color = colors[str(color_idx)]
                        else:
                            model_idx = list(metric_data['Model'].unique()).index(row['Model'])
                            color_keys = list(colors.keys())
                            plot_color = colors[color_keys[model_idx*3 % len(color_keys)]]
                            
                        diff = float(row['Average']) - float(baseline_data[fraction]) #type: ignore
                        if diff >= 0:
                            fraction_diffs[fraction]['pos'].append((diff, row['Model'], plot_color))
                        else:
                            fraction_diffs[fraction]['neg'].append((diff, row['Model'], plot_color))
                            
                # Sort differences by absolute value
                fraction_diffs[fraction]['pos'].sort(key=lambda x: abs(x[0]), reverse=False)
                fraction_diffs[fraction]['neg'].sort(key=lambda x: abs(x[0]), reverse=False)
        
        for model, model_data in metric_data.groupby('Model'):
            config = next((c for c in run_configs if c['display_name'] == model), None)
            color_idx = config.get('color', None) if config else None

            if color_idx is not None:
                plot_color = colors[str(color_idx)]
            else:
                model_idx = list(metric_data['Model'].unique()).index(model)
                color_keys = list(colors.keys())
                plot_color = colors[color_keys[model_idx*3 % len(color_keys)]]

            if len(model_data) >= 2:
                ax.plot(model_data['Fraction'], model_data['Average'],
                       marker='o', label=model if ax_idx == len(metric_names)-1 else None,
                       color=plot_color, linewidth=line_width,
                       zorder=10 if model == "Baseline" else 5)
                
                if show_error_bars:
                    ax.errorbar(model_data["Fraction"], model_data["Average"],
                              yerr=model_data["Std"],
                              fmt='none', ecolor=plot_color, capsize=5,
                              zorder=5 if model != "Baseline" else 10)
                              
        # Add annotations after all lines are plotted
        if show_baseline_diff:
            for fraction in fraction_diffs:
                # Position positive differences above
                for i, (diff, _, color) in enumerate(fraction_diffs[fraction]['pos']):
                    diff_text = f"{diff:+.1f}" if abs(diff) >= 0.1 else f"{diff:+.2f}"
                    y_pos = fraction_anchors[fraction]['max']
                    y_offset = positive_y_offset + (i * 10)
                    
                    ax.annotate(r"\textbf{" + diff_text + "}", 
                              xy=(fraction, y_pos),
                              xytext=(-13, y_offset), textcoords='offset points',
                              fontsize=annotation_fontsize, fontweight='bold', color=color)
                
                # Position negative differences below
                for i, (diff, _, color) in enumerate(fraction_diffs[fraction]['neg']):
                    diff_text = f"{diff:+.1f}" if abs(diff) >= 0.1 else f"{diff:+.2f}"
                    y_pos = fraction_anchors[fraction]['min']
                    y_offset = negative_y_offset - (i * 10)
                    
                    ax.annotate(r"\textbf{" + diff_text + "}", 
                              xy=(fraction, y_pos),
                              xytext=(-13, y_offset), textcoords='offset points',
                              fontsize=annotation_fontsize, fontweight='bold', color=color)

        ax.set_xlabel('Dataset Fraction')
        ax.set_ylabel(f'Average \\textbf{{{metric_name}}} Score')
        ax.grid(True, linestyle='--', alpha=grid_alpha, zorder=0)
        ax.set_ylim(y_min, y_max)
        ax.set_xlim(*x_limits)
    if not show_all_metrics and title is not None:
        axes[0].set_title(title)
    elif not show_all_metrics:
        axes[0].set_title('Enhanced Model Performance Across Different Dataset Fractions')
        
    # Place legend with handles and labels in config order
    handles, labels = axes[-1].get_legend_handles_labels()
    config_order = [c['display_name'] for c in run_configs]
    ordered_pairs = [(h, l) for l in config_order for h, lab in zip(handles, labels) if lab == l]
    ordered_handles, ordered_labels = zip(*ordered_pairs) if ordered_pairs else ([], [])
    
    fig.legend(ordered_handles, ordered_labels, 
              bbox_to_anchor=legend_position, loc='upper center',
              ncol=legend_columns, frameon=True, fontsize=legend_fontsize,
              edgecolor='black', fancybox=False)
    plt.tight_layout()
    
    return plt.gcf()


In [11]:
def plot_config_comparison(
    comparison_table: pd.DataFrame,
    groupings: Dict[str, List[str]],
    title: str = "Configuration Comparison",
    show_baseline: bool = True,
    baseline_name: str = "Reimpl. Baseline",
    figsize: Tuple[int, int] = (8, 5),
    rotation: int = 45,
    show_values: bool = True
) -> plt.Figure: #type: ignore
    """
    Plot a comparison of different model configurations showing P, R, and F1 scores.
    
    Args:
        comparison_table: DataFrame containing the results
        groupings: Dict mapping group names to lists of run names
        title: Plot title
        show_baseline: Whether to show baseline as reference lines
        baseline_name: Name of the baseline in the data
        figsize: Figure size
        rotation: X-tick label rotation
        show_values: Whether to show numeric values on bars
    """
    plt.figure(figsize=figsize)
    
    # Prepare data
    results = []
    metrics = ['P', 'R', 'F1']
    
    for group_name, config_names in groupings.items():
        group_data = comparison_table[comparison_table["Model"].isin(config_names)]
        
        if group_data.empty:
            print(f"Warning: No data found for group {group_name}")
            continue
            
        results.append({
            "Group": group_name,
            **{f"Mean_{m}": group_data[m].mean() for m in metrics},
            **{f"Std_{m}": group_data[m].std() for m in metrics}
        })
    
    results_df = pd.DataFrame(results)
    
    if results_df.empty:
        raise ValueError("No data found for any group")
    
    # Plot setup
    x = np.arange(len(results_df))
    width = 0.25  # Narrower bars to fit all three metrics
    
    # Calculate y-axis bounds with padding
    all_values = []
    for m in metrics:
        values = results_df[f"Mean_{m}"]
        errors = results_df[f"Std_{m}"]
        all_values.extend(values + errors)
        all_values.extend(values - errors)
    
    y_min = max(0, min(all_values) - 2)
    y_max = min(100, max(all_values) + 2)
    
    # Create grouped bars first so they appear in legend before baseline
    for i, metric in enumerate(metrics):
        bars = plt.bar(
            x + (i - 1) * width,  # Center the groups
            results_df[f"Mean_{metric}"],
            width,
            yerr=results_df[f"Std_{metric}"],
            ecolor=colors["gray_medium"],
            capsize=5,
            color=f'C{i}',
            label=metric
        )
        
        # Add value labels if requested
        if show_values:
            for bar in bars:
                height = bar.get_height()
                plt.text(
                    bar.get_x() + bar.get_width()/2.,
                    height - 0.5,  # Offset text position down slightly
                    f'\\textbf{{{height:.1f}}}',
                    ha='center', 
                    va='top',  # Changed to top alignment
                    fontsize=10
                )

    # Add baseline if requested
    if show_baseline:
        baseline_data = comparison_table[comparison_table["Model"] == baseline_name]
        if not baseline_data.empty:
            for i, metric in enumerate(metrics):
                baseline_value = baseline_data[metric].mean()
                plt.axhline(
                    y=baseline_value,
                    color=f'C{i}',
                    linestyle='--',
                    alpha=1.0,
                    linewidth=1.5,
                    label=f'{metric} Baseline ({baseline_value:.2f})'
                )
    
    # Customize plot
    plt.xlabel("Configuration")
    plt.ylabel("Score")
    if title is not None and title != "":
        plt.title(title)
    plt.xticks(x, results_df["Group"], rotation=rotation, ha="center")
    # Reorder legend entries to group metrics and baselines together
    handles, labels = plt.gca().get_legend_handles_labels()
    order = []
    # For each metric, add baseline then regular metric
    for i in range(len(metrics)):
        if show_baseline:
            order.append(i + len(metrics))  # Baseline version
        order.append(i)  # Regular metric
    plt.legend([handles[i] for i in order], [labels[i] for i in order],
              bbox_to_anchor=(0.5, -0.1), loc='upper center', ncol=3)
    plt.ylim(y_min, y_max)
    plt.grid(True, linestyle="--", alpha=0.4)
    plt.tight_layout()
    
    return plt

### Additional Analysis

In [12]:
def run_significance_tests(raw_values: pd.DataFrame, 
                         baseline_name: Optional[str] = None, 
                         comparison_models: Optional[List[str]] = None,
                         show_per_dataset: bool = False,
                         alpha: float = 0.05) -> None:
    """Run significance tests for each metric between model pairs.
    
    Args:
        raw_values: DataFrame containing the results
        baseline_name: Optional name of baseline model to compare against
        comparison_models: Optional list of models to compare against baseline/each other
        show_per_dataset: If True, shows p-values for individual datasets
        alpha: Significance level for hypothesis testing (default: 0.05)
    """
    all_models = raw_values['Model'].unique()
    assert isinstance(all_models, np.ndarray) and all_models.ndim == 1, "all_models must be a 1D array"
    
    # Determine which models to compare
    if baseline_name is not None:
        baseline_models = [baseline_name]
        comparison_models = comparison_models or [m for m in all_models if m != baseline_name]
    else:
        baseline_models = comparison_models or all_models
        comparison_models = comparison_models or all_models #type: ignore
            
    for metric in EVAL_METRICS:
        print(f"\n{metric} Significance Tests")
        print("-" * 150)
        print(f"{'Model 1':<55} {'Model 2':<55} {'p-value':>10} {'Sig':>5}")
        print("-" * 150)
            
        for model1 in baseline_models:
            for model2 in comparison_models: #type: ignore
                if model1 == model2:
                    continue
                    
                p_values = []
                for dataset in DATASETS:
                    values1 = raw_values[
                        (raw_values['Model'] == model1) & 
                        (raw_values['Dataset'] == dataset)
                    ][metric].values
                    values2 = raw_values[
                        (raw_values['Model'] == model2) & 
                        (raw_values['Dataset'] == dataset)
                    ][metric].values

                    min_len = min(len(values1), len(values2))
                    if min_len > 0:
                        values1 = values1[:min_len]
                        values2 = values2[:min_len]
                        _, p_value = stats.ttest_rel(values1, values2)
                        p_values.append(p_value)
                        
                        if show_per_dataset:
                            significance = "✓" if p_value < alpha else "✗" #type: ignore
                            print(f"{dataset:>8}: {p_value:>8.4f} {significance}")
                
                if p_values:
                    statistic, combined_p = stats.combine_pvalues(p_values, method='fisher')
                    combined_p = combined_p.item() if isinstance(combined_p, np.ndarray) else combined_p
                    significance = "✓" if combined_p < alpha else "✗" #type: ignore
                    print(f"{model1:<55} {model2:<55} {combined_p:>10.4f} {significance:>5}")
                else:
                    print(f"{model1:<55} {model2:<65} {'N/A':>10} {'N/A':>5}")

In [13]:
def create_performance_heatmap(
    data: pd.DataFrame | Dict[float, pd.DataFrame],
    metric: str = 'F1',
    figsize: Tuple[int, int] = (12, 8),  # Keep original size since it's page-constrained
    cmap: str = 'RdYlBu',
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    annotate: bool = True,
    fmt: str = '.2f',
    title: Optional[str] = None,
    rotate_xlabels: int = 45,
    normalize_by_dataset: bool = False,
    show_normalized_values: bool = False,
    models: Optional[List[str]] = None,
    split_by: str = 'dataset',
    baseline_model: Optional[str] = None,
    ax: Optional[plt.Axes] = None,
    show_ylabels: bool = True,
    show_std: bool = False,
) -> plt.Figure: #type: ignore
    """
    Create a heatmap visualization of model performance across datasets or fractions.
    
    Args:
        data: DataFrame containing results for one fraction, or list of DataFrames for multiple fractions
        metric: Metric to visualize ('P', 'R', or 'F1')
        figsize: Figure size as (width, height)
        cmap: Colormap to use
        vmin: Minimum value for colormap scaling
        vmax: Maximum value for colormap scaling
        annotate: Whether to show values in cells
        fmt: Format string for cell annotations
        title: Custom title (defaults to metric name if None)
        rotate_xlabels: Rotation angle for x-axis labels
        normalize_by_dataset: Whether to normalize scores within each dataset/fraction for coloring
        show_normalized_values: Whether to display normalized values instead of raw scores
        models: Optional list of models to include in comparison
        split_by: Whether to split by 'dataset' or 'fraction'
        baseline_model: When normalizing, center colors around this model's performance
        ax: Optional axes to plot on
        show_ylabels: Whether to show y-axis labels (model names)
        show_std: Whether to show standard deviation in annotations
    
    Returns:
        matplotlib Figure object
    """
    # Validate input type matches split_by parameter
    split_by = split_by.lower()
    if split_by not in ['dataset', 'fraction']:
        raise ValueError("split_by must be either 'dataset' or 'fraction'")
        
    if split_by == 'fraction':
        if not isinstance(data, dict):
            raise ValueError("For split_by='fraction', data must be a dictionary mapping fractions to DataFrames")
        # Convert dictionary to list of dataframes with fraction info
        fraction_dfs = []
        for fraction, df in data.items():
            df = df.copy()
            df['Fraction'] = fraction
            fraction_dfs.append(df)
        data = pd.concat(fraction_dfs, axis=0)
    else:  # dataset
        if not isinstance(data, pd.DataFrame):
            raise ValueError("For split_by='dataset', data must be a single DataFrame")

    # Filter models if specified
    if models is not None:
        data = data[data['Model'].isin(models)]
        
    # Determine grouping column based on split_by parameter
    if split_by == 'fraction':
        grouping_col = 'Fraction'
        # For each model and fraction, first calculate mean and std per dataset
        dataset_stats = data.groupby(['Model', 'Fraction', 'Dataset'])[metric].agg(['mean', 'std']).reset_index()
        # Then average across datasets
        pivot_data = dataset_stats.groupby(['Model', 'Fraction']).agg({
            'mean': 'mean',  # Average of means across datasets
            'std': 'mean'    # Average of standard deviations across datasets
        }).reset_index()
        # Sort fractions
        column_order = sorted(pivot_data[grouping_col].unique())
    else:  # dataset
        grouping_col = 'Dataset'
        column_order = DATASETS
        pivot_data = data.groupby(['Model', 'Dataset'])[metric].agg(['mean', 'std']).reset_index()
    
    # Pivot the mean values
    pivot_means = pivot_data.pivot_table(
        values='mean',
        index='Model',
        columns=grouping_col,
    )
    
    # Pivot the std values if needed
    if show_std:
        pivot_stds = pivot_data.pivot_table(
            values='std',
            index='Model',
            columns=grouping_col,
        )
    
    # Sort columns in specified order
    pivot_means = pivot_means[column_order]
    if show_std:
        pivot_stds = pivot_stds[column_order]
    
    # Reverse the order of models
    pivot_means = pivot_means.iloc[::-1]
    if show_std:
        pivot_stds = pivot_stds.iloc[::-1]
    
    # Store raw data for annotations if needed
    raw_data = pivot_means.copy()
    
    # Normalize if requested
    if normalize_by_dataset:
        if baseline_model is not None and baseline_model in pivot_means.index:
            # Center around baseline model's performance
            baseline_values = pivot_means.loc[baseline_model]
            pivot_means = pivot_means.subtract(baseline_values)
        else:
            # Standard normalization if no baseline specified
            pivot_means = (pivot_means - pivot_means.mean()) / pivot_means.std()
    
    # Create figure if no axes provided
    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = plt.gca()
    
    full_metric_names = {
        'P': 'Precision',
        'R': 'Recall',
        'F1': 'F1 Score'
    }

    # Prepare annotation text if showing std dev
    if annotate:
        if show_std:
            # Create annotation text with better centering by using \n instead of \\
            # Use larger font for the mean value compared to std dev
            annotations = np.array([[f"\\Large {m:.2f}\n{{\\small \\textit{{($\\sigma$ {s:.2f})}}}}" for m, s in zip(row_means, row_stds)] 
                                  for row_means, row_stds in zip(raw_data.values, pivot_stds.values)])
        else:
            annotations = raw_data if not show_normalized_values else pivot_means
    else:
        annotations = False

    # Set up the heatmap with optimized font sizes for readability
    sns.heatmap(
        pivot_means,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        annot=annotations,
        fmt='' if show_std else fmt,
        annot_kws={'ha': 'center', 'va': 'center', 'ma': 'center', 'size': 11},  # Optimized font size
        cbar_kws={'label': f'{"Normalized " if normalize_by_dataset else ""}\\textbf{{{full_metric_names[metric]}}} Score'},
        center=0 if normalize_by_dataset else ((vmax + vmin)/2 if (vmin is not None and vmax is not None) else None),
        ax=ax
    )
    
    # Customize the plot with optimized font sizes
    split_type = "Datasets" if split_by == 'dataset' else "Fractions"
    title_suffix = f" (Normalized to {baseline_model})" if (normalize_by_dataset and baseline_model) else (" (Normalized Colors)" if normalize_by_dataset else "")
    ax.set_title(title if title is not None else f'{metric} Performance Across {split_type}{title_suffix}', fontsize=12, pad=20)
    ax.set_xlabel(split_type, fontsize=11, labelpad=10)
    ax.set_ylabel('Model' if show_ylabels else '', fontsize=11, labelpad=10)
    
    # Rotate x-axis labels with optimized font size
    plt.setp(ax.get_xticklabels(), rotation=rotate_xlabels, ha='right', fontsize=10)
    plt.setp(ax.get_yticklabels(), rotation=0, fontsize=10)
    
    # Hide y-axis labels if requested
    if not show_ylabels:
        ax.set_yticklabels([])
    
    return plt.gcf()

def plot_all_metrics_heatmaps(
    data: pd.DataFrame | Dict[float, pd.DataFrame],
    metrics: List[str] = ['P', 'R', 'F1'],
    figsize: Tuple[int, int] = (12, 4),
    baseline_model: Optional[str] = None,
    full_title_input: Optional[str] = None,
    **kwargs
) -> plt.Figure: #type: ignore
    """
    Create heatmaps for multiple metrics side by side.
    
    Args:
        data: DataFrame or dictionary mapping fractions to DataFrames containing the results
        metrics: List of metrics to visualize
        figsize: Size of the overall figure
        **kwargs: Additional arguments passed to create_performance_heatmap
    """
    fig, axes = plt.subplots(1, len(metrics), figsize=figsize)
    
    for i, (ax, metric) in enumerate(zip(axes, metrics)):
        # Only show y-labels for leftmost plot
        show_ylabels = (i == 0)
        if full_title_input is None:
            full_title = f"Performance Across Datasets (Normalized to {baseline_model})" if i == 1 and baseline_model else "Performance Across Datasets (Normalized)" if i == 1 else ""
        else:
            full_title = full_title_input if i == 1 else ""
        create_performance_heatmap(data, metric=metric, ax=ax, show_ylabels=show_ylabels, title=full_title, baseline_model=baseline_model, **kwargs)
    
    plt.tight_layout()
    plt.show()
    return fig


## Main Analysis

### Analysis specific configurations

In [14]:
project_name = "bartabsa-reproduce"
BASELINE_NAME = "Architecture_Cleanup_Baseline"

### Ablations

In [15]:
BASELINE_NAME = "Architecture_Cleanup_Baseline"
Reimpl_Baseline_Display = "Reimpl. Baseline"
prefix = "Architecture_Cleanup_"
decoder_configs: list[dict[str, str|int]] = [
    {"run_name": BASELINE_NAME, "display_name": Reimpl_Baseline_Display, "color": "tertiary"},
    {"run_name": f"{prefix}Normalization_Only", "display_name": "Normalization Only"},
    {"run_name": f"{prefix}Full_Enhanced", "display_name": "Full Enhanced"},
    {"run_name": f"{prefix}Full_Enhanced_Fixed", "display_name": "Full Enhanced Fixed"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention", "display_name": "Full Enhanced Bart Attention"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_Fixed", "display_name": "Full Enhanced Bart Attention Fixed"},
    {"run_name": f"{prefix}Abl_Full_Enhanced_Value_and_DimNorm", "display_name": "Full Enhanced + Value & DimNorm"},
    {"run_name": f"{prefix}Abl_Full_Enhanced_Value_and_DimNorm_Fixed", "display_name": "Full Enhanced + Value & DimNorm Fixed"},
    {"run_name": f"{prefix}Abl_Full_Enhanced_RMS_Everywhere", "display_name": "Full Enhanced + RMS Everywhere"},
    {"run_name": f"{prefix}Abl_Full_Enhanced_RMS_Everywhere_Fixed", "display_name": "Full Enhanced + RMS Everywhere Fixed"},
    {"run_name": f"{prefix}Abl_Full_Enhanced_RMS_Nowhere", "display_name": "Full Enhanced + No RMS"},
    {"run_name": f"{prefix}Abl_Full_Enhanced_RMS_Nowhere_Fixed", "display_name": "Full Enhanced + No RMS Fixed"},
    {"run_name": f"{prefix}Full_Enhanced_NoGating_NoFinalNorm", "display_name": "Full Enhanced No Gating No Final Norm"},
    {"run_name": f"{prefix}Full_Enhanced_NoGating_FinalNorm", "display_name": "Full Enhanced No Gating Final Norm"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_NoGating", "display_name": "Full Enhanced Bart Attention No Gating"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_NoGating_NoFinalNorm", "display_name": "Full Enhanced Bart Attention No Gating No Final Norm"},
    {"run_name": f"{prefix}Nearly_Antons_Dream_Architecture", "display_name": "Nearly Anton's Dream Architecture"},
    {"run_name": f"{prefix}Antons_Dream_Architecture", "display_name": "Anton's Dream Architecture"},
]

fractions = [1.0]

In [16]:
subtasks_sweep_table = create_comparison_table(api, project_name, decoder_configs, [("Paper Baseline", paper_baseline)], BASELINE_NAME, fraction=1.0, baseline_paper_compare_index=0)
styled_table = style_dataframe(subtasks_sweep_table, show_std=True, show_diff=True, baselines=["Paper Baseline"])
display(styled_table)

Fetching data for Architecture_Cleanup_Baseline_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 9371.70it/s]
Fetching data for Architecture_Cleanup_Normalization_Only_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 32564.47it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 51941.85it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Fixed_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 35469.80it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 85423.71it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_Fixed_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 45964.98it/s]
Fetching data for Architecture_Cleanup_Abl_Full_Enhanced_Value_and_DimNorm_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 40563.87it/s]
Fetching data for Architecture_Cleanup_Abl_Full_Enhanced_Value_and_DimNorm_Fixed_Fraction_1.0: 100%|██████████|

Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Reimpl. Baseline,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"
2,Normalization Only,"76.10 (σ 1.65, -0.33)","73.33 (σ 2.01, -0.23)","74.66 (σ 1.67, -0.27)","66.13 (σ 2.59, -0.11)","59.21 (σ 4.70, -0.11)","62.44 (σ 3.69, -0.09)","63.77 (σ 3.16, -0.69)","62.02 (σ 1.55, +1.69)","62.87 (σ 2.28, +0.57)","69.45 (σ 1.45, +0.93)","68.55 (σ 2.53, +0.49)","68.96 (σ 1.44, +0.70)","68.86 (σ 2.21, -0.05)","65.78 (σ 2.70, +0.46)","67.23 (σ 2.27, +0.23)"
3,Full Enhanced,"76.86 (σ 2.04, +0.43)","75.03 (σ 1.46, +1.47)","75.91 (σ 1.70, +0.98)","68.91 (σ 1.96, +2.67)","62.27 (σ 1.38, +2.95)","65.40 (σ 1.47, +2.87)","65.81 (σ 2.15, +1.34)","62.76 (σ 1.95, +2.43)","64.21 (σ 1.27, +1.91)","72.39 (σ 1.00, +3.87)","70.37 (σ 1.20, +2.32)","71.35 (σ 0.72, +3.09)","70.99 (σ 1.79, +2.08)","67.61 (σ 1.50, +2.29)","69.22 (σ 1.29, +2.21)"
4,Full Enhanced Fixed,"77.76 (σ 1.04, +1.33)","75.81 (σ 1.75, +2.25)","76.74 (σ 1.09, +1.81)","68.27 (σ 1.29, +2.03)","61.98 (σ 1.45, +2.66)","64.96 (σ 1.08, +2.43)","65.33 (σ 2.08, +0.86)","63.27 (σ 1.21, +2.94)","64.27 (σ 1.62, +1.98)","69.82 (σ 1.27, +1.29)","69.24 (σ 1.60, +1.19)","69.51 (σ 1.26, +1.25)","70.29 (σ 1.42, +1.38)","67.58 (σ 1.50, +2.26)","68.87 (σ 1.26, +1.87)"
5,Full Enhanced Bart Attention,"77.06 (σ 1.82, +0.63)","74.77 (σ 1.50, +1.21)","75.88 (σ 1.59, +0.94)","67.67 (σ 2.50, +1.43)","59.64 (σ 2.73, +0.32)","63.38 (σ 2.51, +0.85)","64.47 (σ 0.77, +0.00)","61.70 (σ 2.00, +1.37)","63.04 (σ 1.41, +0.74)","72.87 (σ 1.72, +4.35)","70.92 (σ 1.11, +2.87)","71.88 (σ 1.39, +3.61)","70.52 (σ 1.70, +1.61)","66.76 (σ 1.84, +1.44)","68.54 (σ 1.72, +1.54)"
6,Full Enhanced Bart Attention Fixed,"77.70 (σ 0.99, +1.28)","75.88 (σ 1.38, +2.32)","76.76 (σ 1.18, +1.82)","68.56 (σ 1.01, +2.32)","62.13 (σ 2.20, +2.81)","65.16 (σ 1.50, +2.63)","64.67 (σ 3.39, +0.20)","62.91 (σ 2.32, +2.58)","63.74 (σ 2.42, +1.44)","70.02 (σ 1.95, +1.50)","69.54 (σ 1.77, +1.49)","69.77 (σ 1.78, +1.51)","70.24 (σ 1.83, +1.33)","67.61 (σ 1.92, +2.30)","68.86 (σ 1.72, +1.85)"
7,Full Enhanced + Value & DimNorm,"77.33 (σ 1.87, +0.90)","75.43 (σ 1.44, +1.87)","76.34 (σ 1.50, +1.40)","67.66 (σ 1.66, +1.43)","61.65 (σ 1.85, +2.34)","64.49 (σ 1.32, +1.96)","65.18 (σ 0.70, +0.71)","61.45 (σ 1.52, +1.12)","63.24 (σ 0.96, +0.94)","71.36 (σ 1.73, +2.84)","69.38 (σ 1.32, +1.32)","70.34 (σ 1.39, +2.08)","70.38 (σ 1.49, +1.47)","66.98 (σ 1.53, +1.66)","68.60 (σ 1.29, +1.60)"
8,Full Enhanced + Value & DimNorm Fixed,"77.07 (σ 1.63, +0.64)","75.46 (σ 1.23, +1.90)","76.22 (σ 0.98, +1.29)","67.96 (σ 2.18, +1.72)","62.17 (σ 3.06, +2.85)","64.90 (σ 2.17, +2.37)","65.11 (σ 1.30, +0.65)","62.12 (σ 2.09, +1.79)","63.56 (σ 1.36, +1.27)","69.98 (σ 2.85, +1.45)","68.31 (σ 2.65, +0.25)","69.11 (σ 2.58, +0.85)","70.03 (σ 1.99, +1.12)","67.01 (σ 2.26, +1.70)","68.45 (σ 1.77, +1.44)"
9,Full Enhanced + RMS Everywhere,"73.03 (σ 2.47, -3.39)","71.29 (σ 2.07, -2.27)","72.12 (σ 2.17, -2.82)","62.22 (σ 2.42, -4.02)","58.28 (σ 3.18, -1.04)","60.10 (σ 1.71, -2.43)","62.45 (σ 0.88, -2.02)","58.66 (σ 2.09, -1.68)","60.47 (σ 1.23, -1.83)","69.78 (σ 1.05, +1.25)","65.47 (σ 3.49, -2.58)","67.52 (σ 2.28, -0.74)","66.87 (σ 1.71, -2.04)","63.42 (σ 2.71, -1.89)","65.05 (σ 1.85, -1.95)"


In [17]:
BASELINE_NAME = "Architecture_Cleanup_Baseline"
Reimpl_Baseline_Display = "Reimpl. Baseline"
prefix = "Architecture_Cleanup_"
bart_decoder_configs: list[dict[str, str|int]] = [
    {"run_name": BASELINE_NAME, "display_name": Reimpl_Baseline_Display, "color": "tertiary"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention", "display_name": "Full Enhanced Bart Attention"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_Fixed", "display_name": "Full Enhanced Bart Attention Fixed"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_NoGating", "display_name": "Full Enhanced Bart Attention No Gating"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_NoGating_NoFinalNorm", "display_name": "Full Enhanced Bart Attention No Gating No Final Norm"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_No_Enc_Gating", "display_name": "Full Enhanced Bart Attention No Enc Gating"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_RMS_Only", "display_name": "Full Enhanced Bart Attention RMS Only"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_No_RMS", "display_name": "Full Enhanced Bart Attention No RMS"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_No_Enc_Gating_No_Final_Norm", "display_name": "Full Enhanced Bart Attention No Enc Gating No Final Norm"},
    {"run_name": f"{prefix}Full_Enhanced_Bart_Attention_No_Enc_Gating_No_Final_Norm_RMS_Enc_Norm", "display_name": "Full Enhanced Bart Attention No Enc Gating No Final Norm RMS Enc Norm"},
]

fractions = [1.0]

subtasks_sweep_table = create_comparison_table(api, project_name, bart_decoder_configs, [("Paper Baseline", paper_baseline)], BASELINE_NAME, fraction=1.0, baseline_paper_compare_index=0)
styled_table = style_dataframe(subtasks_sweep_table, show_std=True, show_diff=True, baselines=["Paper Baseline"])
display(styled_table)

Fetching data for Architecture_Cleanup_Baseline_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 129653.91it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 102550.22it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_Fixed_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 46733.19it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_NoGating_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 15972.22it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_NoGating_NoFinalNorm_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 22932.23it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_No_Enc_Gating_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 78179.01it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_RMS_Only_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 32627.80it/s]
Fetching data for Architectu

Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Reimpl. Baseline,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"
2,Full Enhanced Bart Attention,"77.06 (σ 1.82, +0.63)","74.77 (σ 1.50, +1.21)","75.88 (σ 1.59, +0.94)","67.67 (σ 2.50, +1.43)","59.64 (σ 2.73, +0.32)","63.38 (σ 2.51, +0.85)","64.47 (σ 0.77, +0.00)","61.70 (σ 2.00, +1.37)","63.04 (σ 1.41, +0.74)","72.87 (σ 1.72, +4.35)","70.92 (σ 1.11, +2.87)","71.88 (σ 1.39, +3.61)","70.52 (σ 1.70, +1.61)","66.76 (σ 1.84, +1.44)","68.54 (σ 1.72, +1.54)"
3,Full Enhanced Bart Attention Fixed,"77.70 (σ 0.99, +1.28)","75.88 (σ 1.38, +2.32)","76.76 (σ 1.18, +1.82)","68.56 (σ 1.01, +2.32)","62.13 (σ 2.20, +2.81)","65.16 (σ 1.50, +2.63)","64.67 (σ 3.39, +0.20)","62.91 (σ 2.32, +2.58)","63.74 (σ 2.42, +1.44)","70.02 (σ 1.95, +1.50)","69.54 (σ 1.77, +1.49)","69.77 (σ 1.78, +1.51)","70.24 (σ 1.83, +1.33)","67.61 (σ 1.92, +2.30)","68.86 (σ 1.72, +1.85)"
4,Full Enhanced Bart Attention No Gating,"75.46 (σ 2.05, -0.97)","73.76 (σ 0.86, +0.20)","74.57 (σ 1.44, -0.36)","66.37 (σ 2.58, +0.13)","59.67 (σ 2.00, +0.35)","62.80 (σ 1.76, +0.27)","62.83 (σ 1.85, -1.64)","58.09 (σ 1.39, -2.24)","60.36 (σ 1.53, -1.94)","71.30 (σ 2.07, +2.78)","68.72 (σ 0.83, +0.67)","69.97 (σ 1.33, +1.71)","68.99 (σ 2.14, +0.08)","65.06 (σ 1.27, -0.26)","66.93 (σ 1.51, -0.08)"
5,Full Enhanced Bart Attention No Gating No Final Norm,"73.95 (σ 2.75, -2.47)","71.47 (σ 3.81, -2.09)","72.62 (σ 2.73, -2.31)","63.47 (σ 1.73, -2.77)","57.35 (σ 2.27, -1.97)","60.22 (σ 1.81, -2.31)","61.46 (σ 2.37, -3.01)","58.54 (σ 2.32, -1.79)","59.95 (σ 2.27, -2.35)","69.26 (σ 1.08, +0.74)","69.16 (σ 1.05, +1.11)","69.20 (σ 0.95, +0.94)","67.04 (σ 1.98, -1.88)","64.13 (σ 2.36, -1.19)","65.50 (σ 1.94, -1.51)"
6,Full Enhanced Bart Attention No Enc Gating,"77.52 (σ 1.20, +1.09)","75.74 (σ 1.20, +2.18)","76.59 (σ 1.09, +1.66)","68.90 (σ 2.39, +2.66)","61.11 (σ 3.48, +1.79)","64.72 (σ 2.58, +2.19)","63.69 (σ 0.76, -0.78)","62.78 (σ 1.15, +2.45)","63.22 (σ 0.75, +0.92)","71.81 (σ 2.48, +3.29)","70.71 (σ 1.91, +2.66)","71.24 (σ 2.01, +2.98)","70.48 (σ 1.71, +1.56)","67.59 (σ 1.93, +2.27)","68.94 (σ 1.61, +1.94)"
7,Full Enhanced Bart Attention RMS Only,"73.24 (σ 2.31, -3.19)","69.79 (σ 4.25, -3.77)","71.42 (σ 3.25, -3.51)","62.83 (σ 2.55, -3.41)","57.74 (σ 4.82, -1.57)","60.10 (σ 3.41, -2.43)","59.05 (σ 2.35, -5.41)","54.26 (σ 3.12, -6.07)","56.53 (σ 2.55, -5.77)","67.58 (σ 1.84, -0.94)","65.22 (σ 2.26, -2.83)","66.37 (σ 2.01, -1.89)","65.68 (σ 2.26, -3.24)","61.76 (σ 3.61, -3.56)","63.61 (σ 2.80, -3.40)"
8,Full Enhanced Bart Attention No RMS,"65.38 (σ 4.37, -11.04)","56.79 (σ 5.46, -16.77)","60.75 (σ 4.99, -14.18)","55.98 (σ 4.04, -10.26)","45.99 (σ 3.68, -13.33)","50.47 (σ 3.77, -12.06)","33.32 (σ 2.12, -31.14)","28.02 (σ 3.45, -32.31)","30.37 (σ 2.66, -31.93)","54.52 (σ 2.22, -14.00)","50.70 (σ 1.69, -17.36)","52.52 (σ 1.92, -15.74)","52.30 (σ 3.19, -16.61)","45.37 (σ 3.57, -19.94)","48.53 (σ 3.33, -18.48)"
9,Full Enhanced Bart Attention No Enc Gating No Final Norm,"76.54 (σ 1.06, +0.11)","74.48 (σ 2.92, +0.92)","75.43 (σ 1.22, +0.50)","67.91 (σ 3.42, +1.67)","61.08 (σ 1.63, +1.76)","64.27 (σ 1.85, +1.74)","62.97 (σ 2.35, -1.49)","61.42 (σ 1.56, +1.09)","62.17 (σ 1.89, -0.12)","70.73 (σ 2.48, +2.21)","66.89 (σ 2.54, -1.16)","68.75 (σ 2.49, +0.49)","69.54 (σ 2.33, +0.62)","65.97 (σ 2.16, +0.65)","67.66 (σ 1.86, +0.65)"


In [18]:
BASELINE_NAME = "Architecture_Cleanup_Baseline_Large"
Reimpl_Baseline_Display = "Baseline Large"
prefix = "Architecture_Cleanup_"
bart_decoder_configs: list[dict[str, str|int]] = [
    {"run_name": BASELINE_NAME, "display_name": Reimpl_Baseline_Display, "color": "tertiary"},
    {"run_name": f"{prefix}Enc_Norm_Large", "display_name": "Enc Norm Large"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large", "display_name": "Full Enhanced Model Large"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large_No_Enc_Gating", "display_name": "Full Enhanced Model Large No Enc Gating"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large_No_Dec_Gating", "display_name": "Full Enhanced Model Large No Dec Gating"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large_No_Final_Norm", "display_name": "Full Enhanced Model Large No Final Norm"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large_NormalLayerNorm", "display_name": "Full Enhanced Model Large Normal Layer Norm"},
]

fractions = [1.0]

subtasks_sweep_table = create_comparison_table(api, project_name, bart_decoder_configs, [("Paper Baseline", paper_baseline)], BASELINE_NAME, fraction=1.0, baseline_paper_compare_index=0)
styled_table = style_dataframe(subtasks_sweep_table, show_std=True, show_diff=True, baselines=["Paper Baseline"])
display(styled_table)

Fetching data for Architecture_Cleanup_Baseline_Large_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 54471.48it/s]
Fetching data for Architecture_Cleanup_Enc_Norm_Large_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 74698.20it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 72253.30it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large_No_Enc_Gating_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 60831.09it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large_No_Dec_Gating_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 81601.25it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large_No_Final_Norm_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 31595.51it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large_NormalLayerNorm_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 45148.59it/s]
Fetching data: 100%|██████████| 7/7 [00:02<00:00,  2.3

Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Baseline Large,"79.13 (σ 1.92, +13.61)","75.79 (σ 4.42, +10.80)","77.32 (σ 2.24, +12.07)","56.33 (σ 31.54, -5.08)","52.09 (σ 29.18, -4.10)","54.11 (σ 30.30, -4.58)","64.63 (σ 7.51, +5.49)","64.18 (σ 8.56, +4.80)","64.35 (σ 7.79, +5.09)","74.99 (σ 3.34, +8.39)","75.13 (σ 4.46, +6.45)","75.05 (σ 3.78, +7.43)","68.77 (σ 11.08, +5.60)","66.80 (σ 11.65, +4.49)","67.71 (σ 11.03, +5.00)"
2,Enc Norm Large,"76.51 (σ 6.09, -2.63)","75.76 (σ 5.51, -0.03)","76.10 (σ 5.72, -1.23)","67.95 (σ 10.81, +11.62)","62.57 (σ 9.05, +10.49)","65.12 (σ 9.79, +11.01)","66.78 (σ 0.81, +2.15)","68.47 (σ 1.22, +4.28)","67.60 (σ 0.82, +3.25)","74.21 (σ 0.73, -0.79)","76.34 (σ 1.60, +1.21)","75.25 (σ 1.10, +0.20)","71.36 (σ 4.61, +2.59)","70.78 (σ 4.34, +3.98)","71.02 (σ 4.36, +3.31)"
3,Full Enhanced Model Large,"80.26 (σ 1.21, +1.13)","81.43 (σ 1.48, +5.64)","80.82 (σ 1.02, +3.50)","70.85 (σ 0.73, +14.51)","64.94 (σ 1.42, +12.86)","67.75 (σ 0.64, +13.63)","67.44 (σ 1.39, +2.82)","67.63 (σ 1.76, +3.45)","67.52 (σ 1.14, +3.17)","76.08 (σ 2.32, +1.09)","77.10 (σ 1.70, +1.96)","76.58 (σ 1.94, +1.53)","73.66 (σ 1.41, +4.89)","72.77 (σ 1.59, +5.98)","73.16 (σ 1.18, +5.46)"
4,Full Enhanced Model Large No Enc Gating,"80.02 (σ 2.69, +0.89)","80.12 (σ 2.24, +4.33)","80.05 (σ 2.35, +2.72)","71.93 (σ 1.40, +15.59)","67.60 (σ 2.13, +15.52)","69.66 (σ 0.61, +15.54)","68.22 (σ 1.22, +3.60)","68.02 (σ 2.53, +3.84)","68.09 (σ 1.39, +3.74)","74.54 (σ 1.41, -0.46)","75.53 (σ 3.93, +0.39)","74.99 (σ 2.58, -0.05)","73.68 (σ 1.68, +4.90)","72.82 (σ 2.71, +6.02)","73.20 (σ 1.73, +5.49)"
5,Full Enhanced Model Large No Dec Gating,"77.36 (σ 1.59, -1.77)","77.91 (σ 1.52, +2.12)","77.59 (σ 0.78, +0.27)","51.23 (σ 31.00, -5.10)","46.39 (σ 27.37, -5.70)","48.66 (σ 29.05, -5.46)","52.00 (σ 29.14, -12.63)","49.97 (σ 28.39, -14.21)","50.93 (σ 28.71, -13.42)","68.22 (σ 7.62, -6.78)","68.13 (σ 6.62, -7.00)","68.16 (σ 7.07, -6.89)","62.20 (σ 17.34, -6.57)","60.60 (σ 15.98, -6.20)","61.33 (σ 16.40, -6.38)"
6,Full Enhanced Model Large No Final Norm,"80.72 (σ 0.66, +1.59)","81.05 (σ 1.27, +5.26)","80.87 (σ 0.89, +3.55)","70.34 (σ 2.80, +14.01)","64.73 (σ 2.00, +12.65)","67.41 (σ 2.33, +13.29)","66.39 (σ 2.60, +1.76)","66.47 (σ 4.48, +2.28)","66.38 (σ 3.16, +2.03)","75.66 (σ 2.23, +0.67)","78.32 (σ 1.34, +3.19)","76.95 (σ 1.67, +1.91)","73.28 (σ 2.07, +4.51)","72.64 (σ 2.27, +5.85)","72.90 (σ 2.01, +5.19)"
7,Full Enhanced Model Large Normal Layer Norm,"78.34 (σ 1.81, -0.80)","77.28 (σ 2.68, +1.49)","77.78 (σ 2.21, +0.45)","66.36 (σ 9.36, +10.02)","62.46 (σ 9.45, +10.37)","64.33 (σ 9.39, +10.22)","67.72 (σ 1.98, +3.10)","68.27 (σ 1.84, +4.08)","67.98 (σ 1.62, +3.63)","74.05 (σ 2.98, -0.94)","74.08 (σ 5.85, -1.05)","74.03 (σ 4.37, -1.02)","71.62 (σ 4.03, +2.84)","70.52 (σ 4.95, +3.72)","71.03 (σ 4.40, +3.32)"


- Die use_final_layer_norm scheints konsequent eher schlechter zu machen 


In [19]:
BASELINE_NAME = "Seq2Seq_Tests_BART_Baseline"
Reimpl_Baseline_Display = "BART Baseline"
prefix = "Seq2Seq_Tests_"
bart_decoder_configs: list[dict[str, str|int]] = [
    {"run_name": BASELINE_NAME, "display_name": Reimpl_Baseline_Display, "color": "tertiary"},
    {"run_name": f"{prefix}BART_Baseline_Enhanced", "display_name": "BART Baseline Enhanced"},
    {"run_name": f"{prefix}Roberta2Roberta", "display_name": "Roberta2Roberta"},
    {"run_name": f"{prefix}Roberta2Roberta_Enhanced", "display_name": "Roberta2Roberta Enhanced"},
    {"run_name": f"{prefix}Roberta2GPT2", "display_name": "Roberta2GPT2"},
    {"run_name": f"{prefix}Roberta2GPT2_Enhanced", "display_name": "Roberta2GPT2 Enhanced"},
    {"run_name": f"{prefix}BartLarge", "display_name": "BartLarge"},
    {"run_name": f"{prefix}BartLarge_Enhanced", "display_name": "BartLarge Enhanced"},
    {"run_name": f"{prefix}Bert2Bert", "display_name": "Bert2Bert"},
    {"run_name": f"{prefix}Bert2Bert_Enhanced", "display_name": "Bert2Bert Enhanced"},
    {"run_name": f"{prefix}Bert2GPT2", "display_name": "Bert2GPT2"},
    {"run_name": f"{prefix}Bert2GPT2_Enhanced", "display_name": "Bert2GPT2 Enhanced"},
    {"run_name": f"{prefix}RobertaLarge2RobertaLarge", "display_name": "RobertaLarge2RobertaLarge"},
    {"run_name": f"{prefix}RobertaLarge2RobertaLarge_Enhanced", "display_name": "RobertaLarge2RobertaLarge Enhanced"},
    {"run_name": f"{prefix}RobertaLarge2GPT2Medium", "display_name": "RobertaLarge2GPT2Medium"},
    {"run_name": f"{prefix}RobertaLarge2GPT2Medium_Enhanced", "display_name": "RobertaLarge2GPT2Medium Enhanced"},
    {"run_name": f"{prefix}GPT2GPT-Base", "display_name": "GPT2GPT-Base"},
    {"run_name": f"{prefix}GPT2GPT-Base_Enhanced", "display_name": "GPT2GPT-Base Enhanced"},
    {"run_name": f"{prefix}GPT2GPT-Medium", "display_name": "GPT2GPT-Medium"},
    {"run_name": f"{prefix}GPT2GPT-Medium_Enhanced", "display_name": "GPT2GPT-Medium Enhanced"},
    {"run_name": f"{prefix}GPT2GPT-Large", "display_name": "GPT2GPT-Large"},
    {"run_name": f"{prefix}GPT2GPT-Large_Enhanced", "display_name": "GPT2GPT-Large Enhanced"},
    {"run_name": f"{prefix}GPT2GPT-XL", "display_name": "GPT2GPT-XL"},
    {"run_name": f"{prefix}GPT2GPT-XL_Enhanced", "display_name": "GPT2GPT-XL Enhanced"},
]

fractions = [1.0]

subtasks_sweep_table = create_comparison_table(api, project_name, bart_decoder_configs, [("Paper Baseline", paper_baseline)], BASELINE_NAME, fraction=1.0, baseline_paper_compare_index=0)
styled_table = style_dataframe(subtasks_sweep_table, show_std=True, show_diff=True, baselines=["Paper Baseline"])
display(styled_table)

Fetching data for Seq2Seq_Tests_BART_Baseline_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 76678.32it/s]
Fetching data for Seq2Seq_Tests_BART_Baseline_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 62091.84it/s]
Fetching data for Seq2Seq_Tests_Roberta2Roberta_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 53601.33it/s]
Fetching data for Seq2Seq_Tests_Roberta2Roberta_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 48573.29it/s]
Fetching data for Seq2Seq_Tests_Roberta2GPT2_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 75709.46it/s]
Fetching data for Seq2Seq_Tests_Roberta2GPT2_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 62322.50it/s]
Fetching data for Seq2Seq_Tests_BartLarge_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 56987.83it/s]
Fetching data for Seq2Seq_Tests_BartLarge_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 42625.04it/s]
Fetching data for Seq2Seq_Tests_Bert2Bert_Fraction_1.0: 100%|██████████| 20/20 [00

Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,BART Baseline,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"
2,BART Baseline Enhanced,"77.76 (σ 1.04, +1.33)","75.81 (σ 1.75, +2.25)","76.74 (σ 1.09, +1.81)","68.27 (σ 1.29, +2.03)","61.98 (σ 1.45, +2.66)","64.96 (σ 1.08, +2.43)","65.33 (σ 2.08, +0.86)","63.27 (σ 1.21, +2.94)","64.27 (σ 1.62, +1.98)","69.82 (σ 1.27, +1.29)","69.24 (σ 1.60, +1.19)","69.51 (σ 1.26, +1.25)","70.29 (σ 1.42, +1.38)","67.58 (σ 1.50, +2.26)","68.87 (σ 1.26, +1.87)"
3,Roberta2Roberta,"69.87 (σ 15.63, -6.56)","63.52 (σ 21.86, -10.04)","66.17 (σ 19.72, -8.76)","40.84 (σ 37.34, -25.40)","35.39 (σ 32.52, -23.93)","37.89 (σ 34.72, -24.64)","62.01 (σ 9.87, -2.45)","60.65 (σ 9.37, +0.32)","61.29 (σ 9.49, -1.01)","58.09 (σ 32.52, -10.43)","56.61 (σ 31.73, -11.44)","57.30 (σ 32.06, -10.96)","57.70 (σ 23.84, -11.21)","54.04 (σ 23.87, -11.27)","55.66 (σ 24.00, -11.34)"
4,Roberta2Roberta Enhanced,"76.81 (σ 2.19, +0.38)","73.02 (σ 1.82, -0.54)","74.82 (σ 0.91, -0.12)","67.18 (σ 4.50, +0.94)","57.35 (σ 8.88, -1.96)","61.73 (σ 7.17, -0.80)","62.00 (σ 9.59, -2.46)","61.01 (σ 9.41, +0.68)","61.49 (σ 9.49, -0.80)","69.68 (σ 8.42, +1.16)","66.73 (σ 10.87, -1.32)","68.12 (σ 9.75, -0.14)","68.92 (σ 6.17, +0.01)","64.53 (σ 7.75, -0.79)","66.54 (σ 6.83, -0.47)"
5,Roberta2GPT2,"74.21 (σ 2.95, -2.21)","69.00 (σ 3.89, -4.56)","71.47 (σ 3.28, -3.46)","65.45 (σ 3.05, -0.79)","53.85 (σ 2.77, -5.47)","59.06 (σ 2.77, -3.47)","60.16 (σ 1.19, -4.30)","56.82 (σ 1.31, -3.51)","58.43 (σ 1.11, -3.87)","70.72 (σ 0.66, +2.20)","67.45 (σ 1.46, -0.60)","69.04 (σ 1.07, +0.77)","67.64 (σ 1.96, -1.28)","61.78 (σ 2.36, -3.53)","64.50 (σ 2.05, -2.51)"
6,Roberta2GPT2 Enhanced,"75.54 (σ 1.41, -0.89)","71.13 (σ 1.48, -2.43)","73.23 (σ 1.31, -1.70)","66.50 (σ 1.15, +0.26)","59.27 (σ 1.60, -0.05)","62.64 (σ 0.75, +0.11)","63.96 (σ 3.63, -0.51)","60.83 (σ 3.19, +0.50)","62.34 (σ 3.37, +0.04)","72.21 (σ 1.72, +3.69)","70.91 (σ 1.54, +2.86)","71.55 (σ 1.54, +3.29)","69.55 (σ 1.98, +0.64)","65.54 (σ 1.95, +0.22)","67.44 (σ 1.74, +0.43)"
7,BartLarge,"79.13 (σ 1.92, +2.71)","75.79 (σ 4.42, +2.23)","77.32 (σ 2.24, +2.39)","56.33 (σ 31.54, -9.90)","52.09 (σ 29.18, -7.23)","54.11 (σ 30.30, -8.42)","64.63 (σ 7.51, +0.16)","64.18 (σ 8.56, +3.85)","64.35 (σ 7.79, +2.05)","74.99 (σ 3.34, +6.47)","75.13 (σ 4.46, +7.08)","75.05 (σ 3.78, +6.79)","68.77 (σ 11.08, -0.14)","66.80 (σ 11.65, +1.48)","67.71 (σ 11.03, +0.70)"
8,BartLarge Enhanced,"78.85 (σ 1.51, +2.42)","79.47 (σ 0.52, +5.91)","79.13 (σ 0.84, +4.20)","68.90 (σ 2.83, +2.66)","66.09 (σ 1.47, +6.77)","67.43 (σ 1.53, +4.90)","67.00 (σ 1.32, +2.53)","67.58 (σ 1.68, +7.25)","67.27 (σ 1.20, +4.98)","73.56 (σ 5.33, +5.04)","74.63 (σ 6.09, +6.57)","74.07 (σ 5.63, +5.81)","72.08 (σ 2.75, +3.16)","71.94 (σ 2.44, +6.63)","71.98 (σ 2.30, +4.97)"
9,Bert2Bert,"71.13 (σ 2.58, -5.29)","67.36 (σ 1.77, -6.20)","69.17 (σ 2.05, -5.76)","61.60 (σ 2.36, -4.64)","53.13 (σ 4.32, -6.19)","56.93 (σ 2.55, -5.60)","62.33 (σ 2.71, -2.14)","57.06 (σ 2.62, -3.27)","59.56 (σ 2.55, -2.74)","70.05 (σ 1.81, +1.53)","66.98 (σ 1.35, -1.07)","68.46 (σ 1.30, +0.20)","66.28 (σ 2.36, -2.63)","61.13 (σ 2.51, -4.18)","63.53 (σ 2.11, -3.48)"


In [20]:
cloned_table = subtasks_sweep_table.copy()

print(create_latex_table(cloned_table, caption="WIP (Auto Export): Different Encoder-Decoder Configurations", label="tab:encoder-decoder-configs", baselines=["Paper Baseline"], bold_best=True, mode="single"))

\begin{table}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{WIP (Auto Export): Different Encoder-Decoder Configurations}
        \begin{tabular}{l*{3}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \\
            \midrule
            Paper Baseline & 63.17 & 62.31 & 62.70 \\
            \midrule
            BART Baseline & 68.91 & 65.31 & 67.01 \\
             & {\scriptsize \textit{($\sigma$ 2.13)}} & {\scriptsize \textit{($\sigma$ 3.17)}} & {\scriptsize \textit{($\sigma$ 2.44)}} \\
            \midrule
            BART Baseline Enhanced & 70.29 & 67.58 & 68.87 \\
             & {\scriptsize \textit{($\sigma$ 1.42)}} & {\scriptsize \textit{($\sigma$ 1.50)}} & {\scriptsize \textit{($\sigma$ 1.26)}} \\
            \midrule
            Roberta2Roberta & 57.70 & 54.04 & 55.66 \\
             & {\scriptsize \textit{($\sigma$ 23.84)}} & {\scriptsize \textit{($\sigma$ 23.87)}} & {\scriptsize \texti

In [21]:
# Weight Randomization Experiments
prefix = "Weight_Randomization_"
weight_randomization_configs: list[dict[str, str|int]] = [
    {"run_name": f"{prefix}Baseline", "display_name": "Baseline", "color": "tertiary"},
    {"run_name": f"{prefix}Randomize_Encoder", "display_name": "Randomize Encoder"},
    {"run_name": f"{prefix}Randomize_Decoder", "display_name": "Randomize Decoder"},
    {"run_name": f"{prefix}Randomize_Both", "display_name": "Randomize Both"},
]

fractions = [1.0]

weight_randomization_table = create_comparison_table(api, project_name, weight_randomization_configs, [("Paper Baseline", paper_baseline)], f"{prefix}Baseline", fraction=1.0, baseline_paper_compare_index=0)
styled_table = style_dataframe(weight_randomization_table, show_std=True, show_diff=True, baselines=["Paper Baseline"])
display(styled_table)

Fetching data:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching data for Weight_Randomization_Baseline_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 56833.39it/s]
Fetching data for Weight_Randomization_Randomize_Encoder_Fraction_1.0: 100%|██████████| 40/40 [00:00<00:00, 65896.37it/s]
Fetching data for Weight_Randomization_Randomize_Decoder_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 70611.18it/s]
Fetching data for Weight_Randomization_Randomize_Both_Fraction_1.0: 100%|██████████| 40/40 [00:00<00:00, 87701.08it/s]
Fetching data: 100%|██████████| 4/4 [00:01<00:00,  2.97it/s]


Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Baseline,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"
2,Randomize Encoder,"28.34 (σ 8.32, -48.09)","22.59 (σ 7.01, -50.97)","25.11 (σ 7.57, -49.83)","16.84 (σ 14.77, -49.40)","11.45 (σ 10.39, -47.87)","13.60 (σ 12.14, -48.93)","11.11 (σ 12.24, -53.36)","7.63 (σ 8.44, -52.70)","9.02 (σ 9.93, -53.27)","17.12 (σ 14.95, -51.40)","12.01 (σ 10.34, -56.04)","14.05 (σ 12.12, -54.21)","18.35 (σ 12.57, -50.56)","13.42 (σ 9.04, -51.90)","15.45 (σ 10.44, -51.56)"
3,Randomize Decoder,"73.86 (σ 1.09, -2.57)","68.98 (σ 1.05, -4.58)","71.32 (σ 1.07, -3.62)","62.37 (σ 1.87, -3.87)","54.24 (σ 2.85, -5.08)","58.00 (σ 2.37, -4.53)","58.62 (σ 1.38, -5.84)","56.39 (σ 3.00, -3.94)","57.46 (σ 2.19, -4.84)","67.43 (σ 1.72, -1.09)","66.67 (σ 2.21, -1.38)","67.04 (σ 1.90, -1.22)","65.57 (σ 1.51, -3.34)","61.57 (σ 2.28, -3.74)","63.46 (σ 1.88, -3.55)"
4,Randomize Both,"42.21 (σ 2.64, -34.21)","31.22 (σ 1.98, -42.34)","35.85 (σ 1.92, -39.09)","29.84 (σ 0.82, -36.40)","22.03 (σ 1.83, -37.29)","25.30 (σ 1.38, -37.23)","32.53 (σ 3.08, -31.93)","24.69 (σ 1.45, -35.64)","28.05 (σ 2.00, -34.25)","41.52 (σ 3.77, -27.00)","32.32 (σ 1.20, -35.73)","36.30 (σ 2.02, -31.96)","36.53 (σ 2.57, -32.39)","27.57 (σ 1.62, -37.75)","31.38 (σ 1.83, -35.63)"


In [22]:
# Generate LaTeX table
cloned_table = weight_randomization_table.copy()
print(create_latex_table(cloned_table, caption="Weight Randomization Experiments", label="tab:weight-randomization", baselines=["Paper Baseline"], bold_best=True, mode="single"))

\begin{table}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{Weight Randomization Experiments}
        \begin{tabular}{l*{3}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \\
            \midrule
            Paper Baseline & 63.17 & 62.31 & 62.70 \\
            \midrule
            Baseline & \textbf{68.91} & \textbf{65.31} & \textbf{67.01} \\
             & {\scriptsize \textit{($\sigma$ 2.13)}} & {\scriptsize \textit{($\sigma$ 3.17)}} & {\scriptsize \textit{($\sigma$ 2.44)}} \\
            \midrule
            Randomize Encoder & 18.35 & 13.42 & 15.45 \\
             & {\scriptsize \textit{($\sigma$ 12.57)}} & {\scriptsize \textit{($\sigma$ 9.04)}} & {\scriptsize \textit{($\sigma$ 10.44)}} \\
            \midrule
            Randomize Decoder & 65.57 & 61.57 & 63.46 \\
             & {\scriptsize \textit{($\sigma$ 1.51)}} & {\scriptsize \textit{($\sigma$ 2.28)}} & {\scriptsize \textit{($\sig

In [33]:
BASELINE_NAME = "Architecture_Cleanup_Full_Enhanced"
Reimpl_Baseline_Display = "Reimpl. Baseline"
prefix = "Architecture_Cleanup_"
mlp_whatup_configs: list[dict[str, str|int]] = [
    {"run_name": BASELINE_NAME, "display_name": Reimpl_Baseline_Display, "color": "tertiary"},
    {"run_name": f"Architecture_Cleanup_Baseline_NoMLP", "display_name": "No MLP"},
]

fractions = [1.0]

mlp_table = create_comparison_table(api, project_name, mlp_whatup_configs, [("Paper Baseline", paper_baseline)], BASELINE_NAME, fraction=1.0, baseline_paper_compare_index=0)
styled_table = style_dataframe(mlp_table, show_std=True, show_diff=True, baselines=["Paper Baseline"])
display(styled_table)

Fetching data for Architecture_Cleanup_Full_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 67108.86it/s]
Fetching data for Architecture_Cleanup_Baseline_NoMLP_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 27395.85it/s]
Fetching data: 100%|██████████| 2/2 [00:00<00:00,  4.02it/s]


Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Reimpl. Baseline,"76.86 (σ 2.04, +11.34)","75.03 (σ 1.46, +10.04)","75.91 (σ 1.70, +10.66)","68.91 (σ 1.96, +7.50)","62.27 (σ 1.38, +6.08)","65.40 (σ 1.47, +6.71)","65.81 (σ 2.15, +6.67)","62.76 (σ 1.95, +3.38)","64.21 (σ 1.27, +4.95)","72.39 (σ 1.00, +5.79)","70.37 (σ 1.20, +1.69)","71.35 (σ 0.72, +3.73)","70.99 (σ 1.79, +7.82)","67.61 (σ 1.50, +5.30)","69.22 (σ 1.29, +6.51)"
2,No MLP,"74.54 (σ 1.16, -2.32)","70.13 (σ 2.78, -4.90)","72.23 (σ 1.93, -3.68)","64.36 (σ 2.99, -4.55)","55.10 (σ 5.61, -7.16)","59.31 (σ 4.40, -6.09)","63.12 (σ 1.16, -2.69)","59.04 (σ 0.94, -3.72)","61.00 (σ 0.99, -3.20)","69.49 (σ 1.49, -2.90)","65.28 (σ 2.29, -5.08)","67.31 (σ 1.85, -4.04)","67.88 (σ 1.70, -3.11)","62.39 (σ 2.90, -5.22)","64.96 (σ 2.29, -4.25)"


#### For the Paper

In [24]:
# Define common parameters
project_name = "bartabsa-reproduce"
BASELINE_NAME = "Architecture_Cleanup_Baseline"

In [25]:
# Configuration for reimplementation comparison
reimplementation_configs = [
    {"run_name": BASELINE_NAME, "display_name": Reimpl_Baseline_Display},
]

# Generate the table
reimplementation_df = create_comparison_table(
    api, 
    project_name, 
    reimplementation_configs, 
    [("Paper Baseline", paper_baseline)], 
    BASELINE_NAME, 
    baseline_paper_compare_index=0
)

# For display in notebook
display(style_dataframe(reimplementation_df, show_std=True, show_diff=True, baselines=["Paper Baseline"]))



Fetching data for Architecture_Cleanup_Baseline: 100%|██████████| 20/20 [00:00<00:00, 183157.38it/s]
Fetching data: 100%|██████████| 1/1 [00:00<00:00, 225.46it/s]


Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Reimpl. Baseline,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"


In [26]:
# Create the LaTeX table
print(create_latex_table(
    reimplementation_df, 
    caption="\\textit{work in progress:} Reimplementation Results Compared to Original BARTABSA. \\note{reimplementation significantly outperforms reported values; consistent improvements across all datasets; very strong baseline for our enhanced version}", 
    label="tab:reimplementation_results",
    baselines=["Paper Baseline"],
    bold_best=True,
    # mode="single"
))

\begin{table*}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{\textit{work in progress:} Reimplementation Results Compared to Original BARTABSA. \note{reimplementation significantly outperforms reported values; consistent improvements across all datasets; very strong baseline for our enhanced version}}
    \begin{adjustbox}{width=\textwidth}
        \begin{tabular}{l*{15}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{14res} & \multicolumn{3}{c}{14lap} & \multicolumn{3}{c}{15res} & \multicolumn{3}{c}{16res} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \\
            \midrule
            Paper Baseline & 65.52 & 64.99 & 65.25 \vrule & 61.41 & 56.19 & 58.69 \vrule & 59.14 & 59.38 & 59.26 \vrule & 66.60 & \textbf{68.68} & 67.62 \vrule & 63.17 & 62.31 & 62.70 \\
            \midrule
            Reimpl. Baseline & \textbf{76.43} & \textbf{73.5

In [27]:
# Architecture Improvements
# Define the baseline name
BASELINE_NAME = "Architecture_Cleanup_Baseline"

# Configuration for architecture improvements
architecture_configs = [
    {"run_name": BASELINE_NAME, "display_name": "Baseline"},
    {"run_name": "Architecture_Cleanup_Normalization_Only", "display_name": "+ Normalization"},
    {"run_name": "Architecture_Cleanup_Full_Enhanced", "display_name": "+ Attention"},
    {"run_name": "Architecture_Cleanup_Full_Enhanced_Fixed", "display_name": "+ Final Norm (spotty)"},
    {"run_name": "Architecture_Cleanup_Full_Enhanced_Bart_Attention_Fixed", "display_name": "w/ BART Attention instead"},
]

# Generate the table
architecture_df = create_comparison_table(
    api, 
    project_name, 
    architecture_configs,
    [("Paper Baseline", paper_baseline)], 
    BASELINE_NAME, 
    baseline_paper_compare_index=0
)

# For display in notebook
display(style_dataframe(architecture_df, show_std=True, show_diff=True, baselines=["Paper Baseline"]))


# Create the LaTeX table
print(create_latex_table(
    architecture_df, 
    caption="\\textit{work in progress:} Impact of Architecture Enhancements on Performance. \\note{normalization provides substantial improvements; full enhanced model with gating and attention mechanisms further boosts performance; BART attention variant provides best overall results}", 
    label="tab:architecture_improvements",
    baselines=["Paper Baseline"],
    bold_best=True,
    # mode="single"
))

Fetching data for Architecture_Cleanup_Baseline: 100%|██████████| 20/20 [00:00<00:00, 186828.69it/s]
Fetching data for Architecture_Cleanup_Normalization_Only: 100%|██████████| 20/20 [00:00<00:00, 135082.25it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced: 100%|██████████| 20/20 [00:00<00:00, 234318.66it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Fixed: 100%|██████████| 20/20 [00:00<00:00, 178861.58it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Bart_Attention_Fixed: 100%|██████████| 20/20 [00:00<00:00, 223101.28it/s]
Fetching data: 100%|██████████| 5/5 [00:00<00:00, 789.50it/s]


Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Baseline,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"
2,+ Normalization,"76.10 (σ 1.65, -0.33)","73.33 (σ 2.01, -0.23)","74.66 (σ 1.67, -0.27)","66.13 (σ 2.59, -0.11)","59.21 (σ 4.70, -0.11)","62.44 (σ 3.69, -0.09)","63.77 (σ 3.16, -0.69)","62.02 (σ 1.55, +1.69)","62.87 (σ 2.28, +0.57)","69.45 (σ 1.45, +0.93)","68.55 (σ 2.53, +0.49)","68.96 (σ 1.44, +0.70)","68.86 (σ 2.21, -0.05)","65.78 (σ 2.70, +0.46)","67.23 (σ 2.27, +0.23)"
3,+ Attention,"76.86 (σ 2.04, +0.43)","75.03 (σ 1.46, +1.47)","75.91 (σ 1.70, +0.98)","68.91 (σ 1.96, +2.67)","62.27 (σ 1.38, +2.95)","65.40 (σ 1.47, +2.87)","65.81 (σ 2.15, +1.34)","62.76 (σ 1.95, +2.43)","64.21 (σ 1.27, +1.91)","72.39 (σ 1.00, +3.87)","70.37 (σ 1.20, +2.32)","71.35 (σ 0.72, +3.09)","70.99 (σ 1.79, +2.08)","67.61 (σ 1.50, +2.29)","69.22 (σ 1.29, +2.21)"
4,+ Final Norm (spotty),"77.76 (σ 1.04, +1.33)","75.81 (σ 1.75, +2.25)","76.74 (σ 1.09, +1.81)","68.27 (σ 1.29, +2.03)","61.98 (σ 1.45, +2.66)","64.96 (σ 1.08, +2.43)","65.33 (σ 2.08, +0.86)","63.27 (σ 1.21, +2.94)","64.27 (σ 1.62, +1.98)","69.82 (σ 1.27, +1.29)","69.24 (σ 1.60, +1.19)","69.51 (σ 1.26, +1.25)","70.29 (σ 1.42, +1.38)","67.58 (σ 1.50, +2.26)","68.87 (σ 1.26, +1.87)"
5,w/ BART Attention instead,"77.70 (σ 0.99, +1.28)","75.88 (σ 1.38, +2.32)","76.76 (σ 1.18, +1.82)","68.56 (σ 1.01, +2.32)","62.13 (σ 2.20, +2.81)","65.16 (σ 1.50, +2.63)","64.67 (σ 3.39, +0.20)","62.91 (σ 2.32, +2.58)","63.74 (σ 2.42, +1.44)","70.02 (σ 1.95, +1.50)","69.54 (σ 1.77, +1.49)","69.77 (σ 1.78, +1.51)","70.24 (σ 1.83, +1.33)","67.61 (σ 1.92, +2.30)","68.86 (σ 1.72, +1.85)"


\begin{table*}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{\textit{work in progress:} Impact of Architecture Enhancements on Performance. \note{normalization provides substantial improvements; full enhanced model with gating and attention mechanisms further boosts performance; BART attention variant provides best overall results}}
    \begin{adjustbox}{width=\textwidth}
        \begin{tabular}{l*{15}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{14res} & \multicolumn{3}{c}{14lap} & \multicolumn{3}{c}{15res} & \multicolumn{3}{c}{16res} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \\
            \midrule
            Paper Baseline & 65.52 & 64.99 & 65.25 \vrule & 61.41 & 56.19 & 58.69 \vrule & 59.14 & 59.38 & 59.26 \vrule & 66.60 & 68.68 & 67.62 \vrule & 63.17 & 62.31 & 62.70 \\
            \midrule
            Baseline & 76.43 & 73.56 

In [28]:
# Define the baseline for model scaling
BART_BASE_BASELINE = "Seq2Seq_Tests_BART_Baseline_Fraction_1.0"

# Configuration for BART Base vs BART Large comparison
scaling_configs = [
    {"run_name": BART_BASE_BASELINE, "display_name": "BART Base"},
    {"run_name": "Seq2Seq_Tests_BART_Baseline_Enhanced_Fraction_1.0", "display_name": "BART Base Enhanced"},
    {"run_name": "Seq2Seq_Tests_BartLarge_Fraction_1.0", "display_name": "BART Large"},
    {"run_name": "Architecture_Cleanup_Full_Enhanced_Model_Large", "display_name": "BART Large Enhanced"},
]

# Generate the table
scaling_df = create_comparison_table(
    api, 
    project_name, 
    scaling_configs,
    [("Paper Baseline", paper_baseline)],
    baseline_name=BART_BASE_BASELINE,
    baseline_paper_compare_index=0
)

# For display in notebook
display(style_dataframe(scaling_df, show_std=True, show_diff=True, baselines=["Paper Baseline"]))

# Create the LaTeX table
print(create_latex_table(
    scaling_df, 
    caption="\\textit{work in progress:} Performance Comparison Between BART Base and BART Large Models. \\note{enhanced architecture benefits both model sizes; BART Large with enhancements provides best overall performance; scaling from base to large yields consistent improvements}", 
    label="tab:bart_model_comparison",
    baselines=["Paper Baseline"],
    bold_best=True
))

Fetching data for Seq2Seq_Tests_BART_Baseline_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 77243.17it/s]
Fetching data for Seq2Seq_Tests_BART_Baseline_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 212369.82it/s]
Fetching data for Seq2Seq_Tests_BartLarge_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 232371.41it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large: 100%|██████████| 20/20 [00:00<00:00, 178861.58it/s]
Fetching data: 100%|██████████| 4/4 [00:00<00:00, 601.40it/s]


Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,BART Base,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"
2,BART Base Enhanced,"77.76 (σ 1.04, +1.33)","75.81 (σ 1.75, +2.25)","76.74 (σ 1.09, +1.81)","68.27 (σ 1.29, +2.03)","61.98 (σ 1.45, +2.66)","64.96 (σ 1.08, +2.43)","65.33 (σ 2.08, +0.86)","63.27 (σ 1.21, +2.94)","64.27 (σ 1.62, +1.98)","69.82 (σ 1.27, +1.29)","69.24 (σ 1.60, +1.19)","69.51 (σ 1.26, +1.25)","70.29 (σ 1.42, +1.38)","67.58 (σ 1.50, +2.26)","68.87 (σ 1.26, +1.87)"
3,BART Large,"79.13 (σ 1.92, +2.71)","75.79 (σ 4.42, +2.23)","77.32 (σ 2.24, +2.39)","56.33 (σ 31.54, -9.90)","52.09 (σ 29.18, -7.23)","54.11 (σ 30.30, -8.42)","64.63 (σ 7.51, +0.16)","64.18 (σ 8.56, +3.85)","64.35 (σ 7.79, +2.05)","74.99 (σ 3.34, +6.47)","75.13 (σ 4.46, +7.08)","75.05 (σ 3.78, +6.79)","68.77 (σ 11.08, -0.14)","66.80 (σ 11.65, +1.48)","67.71 (σ 11.03, +0.70)"
4,BART Large Enhanced,"80.26 (σ 1.21, +3.84)","81.43 (σ 1.48, +7.87)","80.82 (σ 1.02, +5.88)","70.85 (σ 0.73, +4.61)","64.94 (σ 1.42, +5.62)","67.75 (σ 0.64, +5.22)","67.44 (σ 1.39, +2.98)","67.63 (σ 1.76, +7.30)","67.52 (σ 1.14, +5.22)","76.08 (σ 2.32, +7.56)","77.10 (σ 1.70, +9.04)","76.58 (σ 1.94, +8.32)","73.66 (σ 1.41, +4.75)","72.77 (σ 1.59, +7.46)","73.16 (σ 1.18, +6.16)"


\begin{table*}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{\textit{work in progress:} Performance Comparison Between BART Base and BART Large Models. \note{enhanced architecture benefits both model sizes; BART Large with enhancements provides best overall performance; scaling from base to large yields consistent improvements}}
    \begin{adjustbox}{width=\textwidth}
        \begin{tabular}{l*{15}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{14res} & \multicolumn{3}{c}{14lap} & \multicolumn{3}{c}{15res} & \multicolumn{3}{c}{16res} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \\
            \midrule
            Paper Baseline & 65.52 & 64.99 & 65.25 \vrule & 61.41 & 56.19 & 58.69 \vrule & 59.14 & 59.38 & 59.26 \vrule & 66.60 & 68.68 & 67.62 \vrule & 63.17 & 62.31 & 62.70 \\
            \midrule
            BART Base & 76.43 & 73.56 & 7

In [29]:
# Define the baseline for randomization experiments
RAND_BASELINE = "Weight_Randomization_Baseline_Fraction_1.0"

# Configuration for weight randomization experiments
randomization_configs = [
    {"run_name": RAND_BASELINE, "display_name": "Full Pretrained"},
    {"run_name": "Weight_Randomization_Randomize_Encoder_Fraction_1.0", "display_name": "Random Encoder"},
    {"run_name": "Weight_Randomization_Randomize_Decoder_Fraction_1.0", "display_name": "Random Decoder"},
    {"run_name": "Weight_Randomization_Randomize_Both_Fraction_1.0", "display_name": "Random Both"},
]

# Generate the table
randomization_df = create_comparison_table(
    api, 
    project_name, 
    randomization_configs,
    [("Paper Baseline", paper_baseline)],
    baseline_name=RAND_BASELINE,
    baseline_paper_compare_index=0
)

# For display in notebook
display(style_dataframe(randomization_df, show_std=True, show_diff=True, baselines=["Paper Baseline"]))

# Create the LaTeX table
print(create_latex_table(
    randomization_df, 
    caption="\\textit{work in progress:} Impact of Weight Randomization on Model Performance. \\note{encoder randomization causes dramatic performance drop; decoder randomization has minimal impact; results demonstrate encoder quality is critical while decoder contributes minimally; explains why decoder-pretrained models like GPT-2 underperform}", 
    label="tab:weight_randomization",
    baselines=["Paper Baseline"],
    bold_best=True, 
    mode="single"
))

Fetching data for Weight_Randomization_Baseline_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 133364.20it/s]
Fetching data for Weight_Randomization_Randomize_Encoder_Fraction_1.0: 100%|██████████| 40/40 [00:00<00:00, 277768.48it/s]
Fetching data for Weight_Randomization_Randomize_Decoder_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 197844.53it/s]
Fetching data for Weight_Randomization_Randomize_Both_Fraction_1.0: 100%|██████████| 40/40 [00:00<00:00, 233991.85it/s]
Fetching data: 100%|██████████| 4/4 [00:00<00:00, 711.74it/s]


Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Full Pretrained,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"
2,Random Encoder,"28.34 (σ 8.32, -48.09)","22.59 (σ 7.01, -50.97)","25.11 (σ 7.57, -49.83)","16.84 (σ 14.77, -49.40)","11.45 (σ 10.39, -47.87)","13.60 (σ 12.14, -48.93)","11.11 (σ 12.24, -53.36)","7.63 (σ 8.44, -52.70)","9.02 (σ 9.93, -53.27)","17.12 (σ 14.95, -51.40)","12.01 (σ 10.34, -56.04)","14.05 (σ 12.12, -54.21)","18.35 (σ 12.57, -50.56)","13.42 (σ 9.04, -51.90)","15.45 (σ 10.44, -51.56)"
3,Random Decoder,"73.86 (σ 1.09, -2.57)","68.98 (σ 1.05, -4.58)","71.32 (σ 1.07, -3.62)","62.37 (σ 1.87, -3.87)","54.24 (σ 2.85, -5.08)","58.00 (σ 2.37, -4.53)","58.62 (σ 1.38, -5.84)","56.39 (σ 3.00, -3.94)","57.46 (σ 2.19, -4.84)","67.43 (σ 1.72, -1.09)","66.67 (σ 2.21, -1.38)","67.04 (σ 1.90, -1.22)","65.57 (σ 1.51, -3.34)","61.57 (σ 2.28, -3.74)","63.46 (σ 1.88, -3.55)"
4,Random Both,"42.21 (σ 2.64, -34.21)","31.22 (σ 1.98, -42.34)","35.85 (σ 1.92, -39.09)","29.84 (σ 0.82, -36.40)","22.03 (σ 1.83, -37.29)","25.30 (σ 1.38, -37.23)","32.53 (σ 3.08, -31.93)","24.69 (σ 1.45, -35.64)","28.05 (σ 2.00, -34.25)","41.52 (σ 3.77, -27.00)","32.32 (σ 1.20, -35.73)","36.30 (σ 2.02, -31.96)","36.53 (σ 2.57, -32.39)","27.57 (σ 1.62, -37.75)","31.38 (σ 1.83, -35.63)"


\begin{table}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{\textit{work in progress:} Impact of Weight Randomization on Model Performance. \note{encoder randomization causes dramatic performance drop; decoder randomization has minimal impact; results demonstrate encoder quality is critical while decoder contributes minimally; explains why decoder-pretrained models like GPT-2 underperform}}
        \begin{tabular}{l*{3}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \\
            \midrule
            Paper Baseline & 63.17 & 62.31 & 62.70 \\
            \midrule
            Full Pretrained & \textbf{68.91} & \textbf{65.31} & \textbf{67.01} \\
             & {\scriptsize \textit{($\sigma$ 2.13)}} & {\scriptsize \textit{($\sigma$ 3.17)}} & {\scriptsize \textit{($\sigma$ 2.44)}} \\
            \midrule
            Random Encoder & 18.35 & 13.42 & 15.45 \\
             & {\scriptsize \textit{($\

In [30]:
# Define the baseline for GPT-2 scaling experiments
BASELINE_NAME = "Architecture_Cleanup_Baseline"

# Configuration for GPT-2 scaling experiments
gpt2_scaling_configs = [
    {"run_name": BASELINE_NAME, "display_name": "Baseline"},
    {"run_name": "Seq2Seq_Tests_GPT2GPT-Base_Fraction_1.0", "display_name": "GPT-2 Base"},
    {"run_name": "Seq2Seq_Tests_GPT2GPT-Medium_Fraction_1.0", "display_name": "GPT-2 Medium"},
    {"run_name": "Seq2Seq_Tests_GPT2GPT-Large_Fraction_1.0", "display_name": "GPT-2 Large"},
    {"run_name": "Seq2Seq_Tests_GPT2GPT-XL_Fraction_1.0", "display_name": "GPT-2 XL"},
]

# Generate the table
gpt2_scaling_df = create_comparison_table(
    api, 
    project_name, 
    gpt2_scaling_configs,
    [("Paper Baseline", paper_baseline)],
    baseline_name=BASELINE_NAME,
    baseline_paper_compare_index=0
)

# For display in notebook
display(style_dataframe(gpt2_scaling_df, show_std=True, show_diff=True, baselines=["Paper Baseline"]))

# Create the LaTeX table
print(create_latex_table(
    gpt2_scaling_df, 
    caption="\\textit{work in progress:} Performance of GPT-2 Models with Scaling. \\note{minimal scaling benefits observed across GPT-2 sizes; even 1.5B parameter GPT-2 XL underperforms BART-base; enhanced architecture provides consistent improvements but cannot overcome fundamental limitations; further confirms encoder quality is more important than model size for this task}", 
    label="tab:gpt2_scaling",
    baselines=["Paper Baseline"],
    bold_best=True,
    mode="single"
))

Fetching data for Architecture_Cleanup_Baseline: 100%|██████████| 20/20 [00:00<00:00, 52891.60it/s]
Fetching data for Seq2Seq_Tests_GPT2GPT-Base_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 221920.85it/s]
Fetching data for Seq2Seq_Tests_GPT2GPT-Medium_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 31847.41it/s]
Fetching data for Seq2Seq_Tests_GPT2GPT-Large_Fraction_1.0: 100%|██████████| 28/28 [00:00<00:00, 254200.24it/s]
Fetching data for Seq2Seq_Tests_GPT2GPT-XL_Fraction_1.0: 100%|██████████| 24/24 [00:00<00:00, 247329.97it/s]
Fetching data: 100%|██████████| 5/5 [00:00<00:00, 494.32it/s]


Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Baseline,"76.43 (σ 1.70, +10.91)","73.56 (σ 3.41, +8.57)","74.94 (σ 2.50, +9.69)","66.24 (σ 2.62, +4.83)","59.32 (σ 4.18, +3.13)","62.53 (σ 3.13, +3.84)","64.46 (σ 1.56, +5.32)","60.33 (σ 3.26, +0.95)","62.30 (σ 2.30, +3.04)","68.52 (σ 2.64, +1.92)","68.05 (σ 1.83, -0.63)","68.26 (σ 1.84, +0.64)","68.91 (σ 2.13, +5.75)","65.31 (σ 3.17, +3.00)","67.01 (σ 2.44, +4.30)"
2,GPT-2 Base,"64.14 (σ 1.26, -12.28)","57.36 (σ 1.66, -16.20)","60.52 (σ 0.87, -14.41)","54.77 (σ 2.96, -11.47)","46.17 (σ 1.53, -13.14)","50.08 (σ 2.03, -12.45)","52.92 (σ 2.22, -11.55)","45.48 (σ 2.91, -14.85)","48.88 (σ 2.32, -13.41)","63.87 (σ 2.98, -4.65)","59.42 (σ 1.12, -8.63)","61.54 (σ 1.71, -6.73)","58.92 (σ 2.35, -9.99)","52.11 (σ 1.81, -13.20)","55.26 (σ 1.73, -11.75)"
3,GPT-2 Medium,"66.41 (σ 1.42, -10.01)","58.91 (σ 2.61, -14.65)","62.40 (σ 1.83, -12.53)","55.63 (σ 3.74, -10.61)","47.00 (σ 3.23, -12.32)","50.92 (σ 3.34, -11.61)","55.51 (σ 1.70, -8.95)","49.84 (σ 2.42, -10.49)","52.49 (σ 1.77, -9.81)","62.68 (σ 1.18, -5.84)","60.20 (σ 2.44, -7.86)","61.40 (σ 1.81, -6.86)","60.06 (σ 2.01, -8.85)","53.99 (σ 2.68, -11.33)","56.80 (σ 2.19, -10.20)"
4,GPT-2 Large,"64.07 (σ 1.73, -12.35)","58.62 (σ 1.87, -14.94)","61.18 (σ 1.05, -13.75)","56.54 (σ 2.66, -9.70)","48.89 (σ 1.53, -10.43)","52.40 (σ 1.56, -10.13)","53.77 (σ 3.33, -10.69)","45.83 (σ 3.02, -14.50)","49.47 (σ 3.10, -12.83)","64.75 (σ 1.39, -3.78)","62.01 (σ 1.40, -6.04)","63.34 (σ 1.32, -4.92)","59.78 (σ 2.28, -9.13)","53.84 (σ 1.96, -11.48)","56.60 (σ 1.76, -10.41)"
5,GPT-2 XL,"67.43 (σ 0.71, -9.00)","60.33 (σ 2.69, -13.23)","63.65 (σ 1.73, -11.29)","54.56 (σ 2.51, -11.68)","48.72 (σ 3.03, -10.60)","51.44 (σ 2.52, -11.09)","54.65 (σ 2.77, -9.81)","47.25 (σ 1.28, -13.08)","50.65 (σ 1.53, -11.64)","65.14 (σ 1.71, -3.39)","60.66 (σ 3.39, -7.39)","62.77 (σ 2.16, -5.50)","60.44 (σ 1.92, -8.47)","54.24 (σ 2.60, -11.07)","57.13 (σ 1.99, -9.88)"


\begin{table}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{\textit{work in progress:} Performance of GPT-2 Models with Scaling. \note{minimal scaling benefits observed across GPT-2 sizes; even 1.5B parameter GPT-2 XL underperforms BART-base; enhanced architecture provides consistent improvements but cannot overcome fundamental limitations; further confirms encoder quality is more important than model size for this task}}
        \begin{tabular}{l*{3}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \\
            \midrule
            Paper Baseline & 63.17 & 62.31 & 62.70 \\
            \midrule
            Baseline & \textbf{68.91} & \textbf{65.31} & \textbf{67.01} \\
             & {\scriptsize \textit{($\sigma$ 2.13)}} & {\scriptsize \textit{($\sigma$ 3.17)}} & {\scriptsize \textit{($\sigma$ 2.44)}} \\
            \midrule
            GPT-2 Base & 58.92 & 52.11 & 55.26 \\
             & {\s

In [31]:

# Configuration for model scaling and cross-architecture comparison
SCRAMBLED_BASELINE = "Seq2Seq_Tests_BART_Baseline_Enhanced_Fraction_1.0"
prefix = "Seq2Seq_Tests_"

scambled_models_configs = [
    {"run_name": SCRAMBLED_BASELINE, "display_name": "Enhanced BART Baseline"},
    {"run_name": f"{prefix}Roberta2Roberta_Fraction_1.0", "display_name": "Roberta2Roberta"},
    {"run_name": f"{prefix}Roberta2Roberta_Enhanced_Fraction_1.0", "display_name": "+ Enhancements"},
    {"run_name": f"{prefix}Roberta2GPT2_Fraction_1.0", "display_name": "Roberta2GPT2"},
    {"run_name": f"{prefix}Roberta2GPT2_Enhanced_Fraction_1.0", "display_name": "+ Enhancements."},
    {"run_name": f"{prefix}Bert2Bert_Enhanced_Fraction_1.0", "display_name": "Bert2Bert (Enhanced)"},
    {"run_name": f"{prefix}Bert2GPT2_Enhanced_Fraction_1.0", "display_name": "Bert2GPT2 (Enhanced)"},
    {"run_name": f"{prefix}RobertaLarge2RobertaLarge_Enhanced_Fraction_1.0", "display_name": "RobertaLarge2RobertaLarge (Enhanced)"},
    {"run_name": f"{prefix}RobertaLarge2GPT2Medium_Enhanced_Fraction_1.0", "display_name": "RobertaLarge2GPT2Medium (Enhanced)"},
]

# Generate the table
scrambled_models_df = create_comparison_table(
    api, 
    project_name, 
    scambled_models_configs,
    [("Paper Baseline", paper_baseline)],
    baseline_name=SCRAMBLED_BASELINE,
    baseline_paper_compare_index=0
)

display(style_dataframe(scrambled_models_df, show_std=True, show_diff=True, baselines=["Paper Baseline"]))

# Create the LaTeX table
print(create_latex_table(
    scrambled_models_df, 
    caption="\\textit{work in progress:} Performance Across Different Model Architectures and Sizes. \\note{enhanced architecture benefits all model types; BART Large with enhancements provides best overall performance; encoder-focused models (BART/RoBERTa) outperform decoder-focused models (GPT-2); scaling benefits encoder-focused models more than decoder-focused ones}", 
    label="tab:model_comparison",
    baselines=["Paper Baseline"],
    bold_best=True,
    mode="single"
))

Fetching data for Seq2Seq_Tests_BART_Baseline_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 200205.44it/s]
Fetching data for Seq2Seq_Tests_Roberta2Roberta_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 118987.35it/s]
Fetching data for Seq2Seq_Tests_Roberta2Roberta_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 230456.26it/s]
Fetching data for Seq2Seq_Tests_Roberta2GPT2_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 217321.45it/s]
Fetching data for Seq2Seq_Tests_Roberta2GPT2_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 229824.88it/s]
Fetching data for Seq2Seq_Tests_Bert2Bert_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 227333.55it/s]
Fetching data for Seq2Seq_Tests_Bert2GPT2_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 140043.54it/s]
Fetching data for Seq2Seq_Tests_RobertaLarge2RobertaLarge_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 200205.44it/s]
Fetching data for Seq2Seq_Tests_RobertaLarge

Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Enhanced BART Baseline,"77.76 (σ 1.04, +12.24)","75.81 (σ 1.75, +10.82)","76.74 (σ 1.09, +11.49)","68.27 (σ 1.29, +6.86)","61.98 (σ 1.45, +5.79)","64.96 (σ 1.08, +6.27)","65.33 (σ 2.08, +6.19)","63.27 (σ 1.21, +3.89)","64.27 (σ 1.62, +5.01)","69.82 (σ 1.27, +3.22)","69.24 (σ 1.60, +0.56)","69.51 (σ 1.26, +1.89)","70.29 (σ 1.42, +7.13)","67.58 (σ 1.50, +5.27)","68.87 (σ 1.26, +6.17)"
2,Roberta2Roberta,"69.87 (σ 15.63, -7.89)","63.52 (σ 21.86, -12.29)","66.17 (σ 19.72, -10.57)","40.84 (σ 37.34, -27.43)","35.39 (σ 32.52, -26.59)","37.89 (σ 34.72, -27.07)","62.01 (σ 9.87, -3.32)","60.65 (σ 9.37, -2.62)","61.29 (σ 9.49, -2.99)","58.09 (σ 32.52, -11.72)","56.61 (σ 31.73, -12.64)","57.30 (σ 32.06, -12.21)","57.70 (σ 23.84, -12.59)","54.04 (σ 23.87, -13.54)","55.66 (σ 24.00, -13.21)"
3,+ Enhancements,"76.81 (σ 2.19, -0.95)","73.02 (σ 1.82, -2.78)","74.82 (σ 0.91, -1.93)","67.18 (σ 4.50, -1.09)","57.35 (σ 8.88, -4.63)","61.73 (σ 7.17, -3.23)","62.00 (σ 9.59, -3.33)","61.01 (σ 9.41, -2.26)","61.49 (σ 9.49, -2.78)","69.68 (σ 8.42, -0.14)","66.73 (σ 10.87, -2.51)","68.12 (σ 9.75, -1.39)","68.92 (σ 6.17, -1.38)","64.53 (σ 7.75, -3.05)","66.54 (σ 6.83, -2.33)"
4,Roberta2GPT2,"74.21 (σ 2.95, -3.54)","69.00 (σ 3.89, -6.81)","71.47 (σ 3.28, -5.27)","65.45 (σ 3.05, -2.82)","53.85 (σ 2.77, -8.13)","59.06 (σ 2.77, -5.90)","60.16 (σ 1.19, -5.17)","56.82 (σ 1.31, -6.45)","58.43 (σ 1.11, -5.85)","70.72 (σ 0.66, +0.90)","67.45 (σ 1.46, -1.79)","69.04 (σ 1.07, -0.48)","67.64 (σ 1.96, -2.66)","61.78 (σ 2.36, -5.80)","64.50 (σ 2.05, -4.37)"
5,+ Enhancements.,"75.54 (σ 1.41, -2.22)","71.13 (σ 1.48, -4.68)","73.23 (σ 1.31, -3.51)","66.50 (σ 1.15, -1.78)","59.27 (σ 1.60, -2.71)","62.64 (σ 0.75, -2.32)","63.96 (σ 3.63, -1.37)","60.83 (σ 3.19, -2.44)","62.34 (σ 3.37, -1.94)","72.21 (σ 1.72, +2.40)","70.91 (σ 1.54, +1.67)","71.55 (σ 1.54, +2.03)","69.55 (σ 1.98, -0.74)","65.54 (σ 1.95, -2.04)","67.44 (σ 1.74, -1.43)"
6,Bert2Bert (Enhanced),"72.68 (σ 0.31, -5.07)","67.65 (σ 1.55, -8.16)","70.05 (σ 0.89, -6.70)","61.71 (σ 2.54, -6.56)","52.78 (σ 2.30, -9.20)","56.87 (σ 2.25, -8.09)","61.12 (σ 2.62, -4.21)","57.18 (σ 2.15, -6.09)","59.05 (σ 2.00, -5.22)","69.26 (σ 1.02, -0.55)","67.47 (σ 1.80, -1.77)","68.34 (σ 1.37, -1.17)","66.19 (σ 1.62, -4.10)","61.27 (σ 1.95, -6.31)","63.58 (σ 1.63, -5.29)"
7,Bert2GPT2 (Enhanced),"72.50 (σ 2.30, -5.25)","65.25 (σ 1.28, -10.56)","68.64 (σ 1.22, -8.10)","62.35 (σ 2.25, -5.93)","51.14 (σ 1.22, -10.84)","56.17 (σ 1.52, -8.79)","58.59 (σ 1.51, -6.74)","50.15 (σ 1.22, -13.12)","54.03 (σ 1.27, -10.25)","68.87 (σ 1.21, -0.95)","64.69 (σ 1.68, -4.55)","66.70 (σ 1.39, -2.81)","65.58 (σ 1.82, -4.72)","57.81 (σ 1.35, -9.77)","61.38 (σ 1.35, -7.49)"
8,RobertaLarge2RobertaLarge (Enhanced),"0.00 (σ 0.00, -77.76)","0.00 (σ 0.00, -75.81)","0.00 (σ 0.00, -76.74)","0.00 (σ 0.00, -68.27)","0.00 (σ 0.00, -61.98)","0.00 (σ 0.00, -64.96)","0.00 (σ 0.00, -65.33)","0.00 (σ 0.00, -63.27)","0.00 (σ 0.00, -64.27)","13.63 (σ 30.47, -56.19)","3.82 (σ 8.54, -65.43)","5.96 (σ 13.32, -63.56)","3.41 (σ 7.62, -66.89)","0.95 (σ 2.13, -66.62)","1.49 (σ 3.33, -67.38)"
9,RobertaLarge2GPT2Medium (Enhanced),"75.84 (σ 0.89, -1.91)","71.33 (σ 1.60, -4.48)","73.49 (σ 1.24, -3.25)","65.69 (σ 1.75, -2.58)","58.87 (σ 2.41, -3.11)","62.07 (σ 1.95, -2.89)","62.99 (σ 1.66, -2.34)","59.78 (σ 2.28, -3.49)","61.32 (σ 1.71, -2.96)","73.95 (σ 0.41, +4.13)","71.10 (σ 2.78, +1.85)","72.46 (σ 1.39, +2.95)","69.62 (σ 1.18, -0.68)","65.27 (σ 2.27, -2.31)","67.33 (σ 1.57, -1.54)"


\begin{table}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{\textit{work in progress:} Performance Across Different Model Architectures and Sizes. \note{enhanced architecture benefits all model types; BART Large with enhancements provides best overall performance; encoder-focused models (BART/RoBERTa) outperform decoder-focused models (GPT-2); scaling benefits encoder-focused models more than decoder-focused ones}}
        \begin{tabular}{l*{3}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \\
            \midrule
            Paper Baseline & 63.17 & 62.31 & 62.70 \\
            \midrule
            Enhanced BART Baseline & \textbf{70.29} & \textbf{67.58} & \textbf{68.87} \\
             & {\scriptsize \textit{($\sigma$ 1.42)}} & {\scriptsize \textit{($\sigma$ 1.50)}} & {\scriptsize \textit{($\sigma$ 1.26)}} \\
            \midrule
            Roberta2Roberta & 57.70 & 54.04 & 55.66 \\
      

In [32]:
LARGE_BASELINE_NAME = "Architecture_Cleanup_Baseline_Large"
Reimpl_Baseline_Display = "Baseline Large"
prefix = "Architecture_Cleanup_"
bart_large_configs: list[dict[str, str|int]] = [
    {"run_name": LARGE_BASELINE_NAME, "display_name": Reimpl_Baseline_Display, "color": "tertiary"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large_No_EncoderNorm", "display_name": "No Enc Norm"},
    {"run_name": f"Seq2Seq_Tests_BartLarge_Enhanced_Fraction_1.0", "display_name": "No Enc Gating"},
    {"run_name": f"{prefix}Enc_Norm_Large", "display_name": "No Attention"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large_No_Dec_Gating", "display_name": "No Dec Gating"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large_No_Final_Norm", "display_name": "No Final Norm"},
    {"run_name": f"{prefix}Full_Enhanced_Model_Large", "display_name": "Full Enhanced Model"},
]

# Generate the table
bart_large_df = create_comparison_table(
    api, 
    project_name, 
    bart_large_configs,
    [("Paper Baseline", paper_baseline)],
    baseline_name=LARGE_BASELINE_NAME,
    baseline_paper_compare_index=0
)

display(style_dataframe(bart_large_df, show_std=True, show_diff=True, baselines=["Paper Baseline"]))

# Create the LaTeX table
print(create_latex_table(
    bart_large_df, 
    caption="\\textit{work in progress:} Model performance and ablation study on BART-Large.", 
    label="tab:architecture_improvements",
    baselines=["Paper Baseline"],
    bold_best=True,
    mode="full"
))

Fetching data for Architecture_Cleanup_Baseline_Large: 100%|██████████| 20/20 [00:00<00:00, 206108.30it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large_No_EncoderNorm: 100%|██████████| 20/20 [00:00<00:00, 72440.48it/s]
Fetching data for Seq2Seq_Tests_BartLarge_Enhanced_Fraction_1.0: 100%|██████████| 20/20 [00:00<00:00, 235635.06it/s]
Fetching data for Architecture_Cleanup_Enc_Norm_Large: 100%|██████████| 20/20 [00:00<00:00, 202135.13it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large_No_Dec_Gating: 100%|██████████| 20/20 [00:00<00:00, 235635.06it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large_No_Final_Norm: 100%|██████████| 20/20 [00:00<00:00, 238312.73it/s]
Fetching data for Architecture_Cleanup_Full_Enhanced_Model_Large: 100%|██████████| 20/20 [00:00<00:00, 206108.30it/s]
Fetching data: 100%|██████████| 7/7 [00:00<00:00, 29.86it/s]


Unnamed: 0_level_0,Model,14res,14res,14res,14lap,14lap,14lap,15res,15res,15res,16res,16res,16res,Average,Average,Average
Unnamed: 0_level_1,Unnamed: 1_level_1,P,R,F1,P,R,F1,P,R,F1,P,R,F1,P,R,F1
0,Paper Baseline,65.52,64.99,65.25,61.41,56.19,58.69,59.14,59.38,59.26,66.60,68.68,67.62,63.17,62.31,62.70
1,Baseline Large,"79.13 (σ 1.92, +13.61)","75.79 (σ 4.42, +10.80)","77.32 (σ 2.24, +12.07)","56.33 (σ 31.54, -5.08)","52.09 (σ 29.18, -4.10)","54.11 (σ 30.30, -4.58)","64.63 (σ 7.51, +5.49)","64.18 (σ 8.56, +4.80)","64.35 (σ 7.79, +5.09)","74.99 (σ 3.34, +8.39)","75.13 (σ 4.46, +6.45)","75.05 (σ 3.78, +7.43)","68.77 (σ 11.08, +5.60)","66.80 (σ 11.65, +4.49)","67.71 (σ 11.03, +5.00)"
2,No Enc Norm,"77.00 (σ 2.47, -2.13)","77.20 (σ 4.65, +1.41)","77.07 (σ 3.55, -0.25)","69.73 (σ 2.37, +13.40)","64.23 (σ 1.43, +12.15)","66.84 (σ 1.54, +12.73)","51.26 (σ 29.92, -13.37)","49.74 (σ 29.12, -14.45)","50.48 (σ 29.50, -13.87)","70.12 (σ 6.59, -4.88)","70.40 (σ 7.51, -4.73)","70.25 (σ 7.05, -4.80)","67.03 (σ 10.34, -1.74)","65.39 (σ 10.68, -1.41)","66.16 (σ 10.41, -1.55)"
3,No Enc Gating,"78.85 (σ 1.51, -0.28)","79.47 (σ 0.52, +3.68)","79.13 (σ 0.84, +1.81)","68.90 (σ 2.83, +12.57)","66.09 (σ 1.47, +14.01)","67.43 (σ 1.53, +13.31)","67.00 (σ 1.32, +2.37)","67.58 (σ 1.68, +3.40)","67.27 (σ 1.20, +2.93)","73.56 (σ 5.33, -1.43)","74.63 (σ 6.09, -0.51)","74.07 (σ 5.63, -0.97)","72.08 (σ 2.75, +3.31)","71.94 (σ 2.44, +5.14)","71.98 (σ 2.30, +4.27)"
4,No Attention,"76.51 (σ 6.09, -2.63)","75.76 (σ 5.51, -0.03)","76.10 (σ 5.72, -1.23)","67.95 (σ 10.81, +11.62)","62.57 (σ 9.05, +10.49)","65.12 (σ 9.79, +11.01)","66.78 (σ 0.81, +2.15)","68.47 (σ 1.22, +4.28)","67.60 (σ 0.82, +3.25)","74.21 (σ 0.73, -0.79)","76.34 (σ 1.60, +1.21)","75.25 (σ 1.10, +0.20)","71.36 (σ 4.61, +2.59)","70.78 (σ 4.34, +3.98)","71.02 (σ 4.36, +3.31)"
5,No Dec Gating,"77.36 (σ 1.59, -1.77)","77.91 (σ 1.52, +2.12)","77.59 (σ 0.78, +0.27)","51.23 (σ 31.00, -5.10)","46.39 (σ 27.37, -5.70)","48.66 (σ 29.05, -5.46)","52.00 (σ 29.14, -12.63)","49.97 (σ 28.39, -14.21)","50.93 (σ 28.71, -13.42)","68.22 (σ 7.62, -6.78)","68.13 (σ 6.62, -7.00)","68.16 (σ 7.07, -6.89)","62.20 (σ 17.34, -6.57)","60.60 (σ 15.98, -6.20)","61.33 (σ 16.40, -6.38)"
6,No Final Norm,"80.72 (σ 0.66, +1.59)","81.05 (σ 1.27, +5.26)","80.87 (σ 0.89, +3.55)","70.34 (σ 2.80, +14.01)","64.73 (σ 2.00, +12.65)","67.41 (σ 2.33, +13.29)","66.39 (σ 2.60, +1.76)","66.47 (σ 4.48, +2.28)","66.38 (σ 3.16, +2.03)","75.66 (σ 2.23, +0.67)","78.32 (σ 1.34, +3.19)","76.95 (σ 1.67, +1.91)","73.28 (σ 2.07, +4.51)","72.64 (σ 2.27, +5.85)","72.90 (σ 2.01, +5.19)"
7,Full Enhanced Model,"80.26 (σ 1.21, +1.13)","81.43 (σ 1.48, +5.64)","80.82 (σ 1.02, +3.50)","70.85 (σ 0.73, +14.51)","64.94 (σ 1.42, +12.86)","67.75 (σ 0.64, +13.63)","67.44 (σ 1.39, +2.82)","67.63 (σ 1.76, +3.45)","67.52 (σ 1.14, +3.17)","76.08 (σ 2.32, +1.09)","77.10 (σ 1.70, +1.96)","76.58 (σ 1.94, +1.53)","73.66 (σ 1.41, +4.89)","72.77 (σ 1.59, +5.98)","73.16 (σ 1.18, +5.46)"


\begin{table*}[t]
    \centering
    \setlength{\tabcolsep}{2pt}
    \small
    \caption{\textit{work in progress:} Model performance and ablation study on BART-Large.}
    \begin{adjustbox}{width=\textwidth}
        \begin{tabular}{l*{15}{r}}
            \toprule
            \multirow{2}{*}{Model} & \multicolumn{3}{c}{14res} & \multicolumn{3}{c}{14lap} & \multicolumn{3}{c}{15res} & \multicolumn{3}{c}{16res} & \multicolumn{3}{c}{Avg} \\ &
            P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \vrule & P & R & F1 \\
            \midrule
            Paper Baseline & 65.52 & 64.99 & 65.25 \vrule & 61.41 & 56.19 & 58.69 \vrule & 59.14 & 59.38 & 59.26 \vrule & 66.60 & 68.68 & 67.62 \vrule & 63.17 & 62.31 & 62.70 \\
            \midrule
            Baseline Large & 79.13 & 75.79 & 77.32 \vrule & 56.33 & 52.09 & 54.11 \vrule & 64.63 & 64.18 & 64.35 \vrule & 74.99 & 75.13 & 75.05 \vrule & 68.77 & 66.80 & 67.71 \\
             & {\scriptsize \textit{($\sigma$ 1.92)}}