# Set up and global variables

In [None]:
import os

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import random
import json

from matplotlib.colors import ListedColormap
from tqdm import tqdm
from scipy.stats import spearmanr

from src.prioritization import *
from src.utils import gini

RETRAIN_MODELS = True

MIN_TASK_DEFECT_SUBMISSIONS = 5  # task-defect pairs below this threshold are left out for some experiments
CLOSE_TIE_THRESHOLD = 0.05  # threshold for the <close ties> metric as a percentage of the max score
BINARY_CMAP = ListedColormap(['red', 'green'])

In [None]:
os.environ["CONFIG_ENV"] = "debug"
if False:
    os.environ["CONFIG_ENV"] = "production"

from config import load_config
config = load_config()

DEBUG = config["DEBUG"]
RESOLUTION = config['DEFAULTS']['resolution']

# input data
TRAINING_DATA_PATH = config['PATHS']['development_set']
VALIDATION_DATA_PATH = config['PATHS']['evaluation_set']
STORAGE_PATH = config['PATHS']['storage']

# output data
MODEL_OUTPUT_PATH = config['PATHS']['evaluation_trained_heuristics']
EVALUATION_PRIORITIZATIONS_CACHE = config['PATHS']['evaluation_prioritizations']
MODEL_METRICS_PATH = config['PATHS']['model_metrics']
IMAGE_DIR = config['PATHS']['images'] / 'heuristics'

os.makedirs(MODEL_OUTPUT_PATH, exist_ok=True)
os.makedirs(EVALUATION_PRIORITIZATIONS_CACHE, exist_ok=True)
os.makedirs(MODEL_METRICS_PATH, exist_ok=True)
os.makedirs(IMAGE_DIR, exist_ok=True)

## Utils

In [None]:
def task_and_defect_description(task, defect, items, defects, log, defect_log):
    """Generate an HTML display for a specific task and defect."""
    task_row = items.loc[task]
    defect_row = defects.loc[defect]
    submissions = log[(log["item"] == task) & (defect_log[defect])]
    
    return f"""
    <div style="display: flex; justify-content: space-between; gap: 20px;">
        
        <div style="width: 48%; border: 1px solid #ccc; padding: 10px; border-radius: 5px;">
            <h3>{task_row["name"]}</h3>
            <div><strong>Instructions:</strong><br>{task_row["instructions"]}</div>
            <div><strong>Solution:</strong><br>
                <pre style="background-color: #2e2e2e; color: #f5f5f5; padding: 10px; border-radius: 5px; font-family: monospace;">{task_row["solution"]}</pre>
            </div>
        </div>
        
        
        <div style="width: 48%; border: 1px solid #ccc; padding: 10px; border-radius: 5px;">
            <h3>{defect_row["defect name"]}</h3>
            <div><strong>Defect Type:</strong> {defect_row["defect type"]}</div>
            <div><strong>Severity:</strong> {defect_row["severity"]}</div>
            <div><strong>Description:</strong><br>{defect_row["description"]}</div>
            
            <div style="display: flex; justify-content: space-between; margin-top: 20px;">
                <div style="width: 48%; padding: 10px;">
                    <strong>Code Example:</strong><br>
                    <pre style="background-color: #2e2e2e; color: #f5f5f5; padding: 10px; border-radius: 5px; font-family: monospace;">{defect_row["code example"]}</pre>
                </div>
                <div style="width: 48%; padding: 10px;">
                    <strong>Code Fix Example:</strong><br>
                    <pre style="background-color: #2e2e2e; color: #f5f5f5; padding: 10px; border-radius: 5px; font-family: monospace;">{defect_row["code fix example"]}</pre>
                </div>
            </div>
        </div>
    </div>
    
    
    <div style="border: 1px solid #ccc; padding: 10px; margin-top: 20px; border-radius: 5px;">
        <strong>Example Submission:</strong><br>
        <pre style="background-color: #2e2e2e; color: #f5f5f5; padding: 10px; border-radius: 5px; font-family: monospace;">{submissions["answer"].iloc[random.randint(0, len(submissions) - 1)] if len(submissions) else 'No submissions found'}</pre>
    </div>
    """

***

# Loading data

In [None]:
items = pd.read_csv(STORAGE_PATH / 'items.csv', index_col=0)
defects = pd.read_csv(STORAGE_PATH / f'defects.csv', index_col=0)

train_log = pd.read_csv(TRAINING_DATA_PATH / 'log.csv', index_col=0, parse_dates=['time'])
train_defect_log = pd.read_csv(TRAINING_DATA_PATH / 'defect_log.csv', index_col=0)
train_defect_log.columns = train_defect_log.columns.astype(int)

test_log = pd.read_csv(VALIDATION_DATA_PATH / 'log.csv', index_col=0, parse_dates=['time'])
test_defect_log = pd.read_csv(VALIDATION_DATA_PATH / 'defect_log.csv', index_col=0)
test_defect_log.columns = test_defect_log.columns.astype(int)

***

# Task filtering

## Task-defect pairs without minimal support

In [None]:
insufficient_support = (train_defect_log > 0).groupby(train_log["item"]).sum() < MIN_TASK_DEFECT_SUBMISSIONS
insufficient_support = insufficient_support.reindex(items.index).reindex(defects.index, axis=1).astype(bool)

***

# Heuristics

In [None]:
data = items, defects

models = [
    TaskCommonModel(*data),
    TaskCharacteristicModel(*data),
    StudentCommonModel(*data),
    StudentCharacteristicModel(*data),
    StudentEncounteredBeforeModel(*data),
    DefectMultiplicityModel(*data),
    SeverityModel(*data),
]

models = {model.get_model_name(): model for model in models}

In [None]:
for name, model in (pbar :=tqdm(models.items(), desc="Training Models")):
    pbar.set_description(f"Training {name}")
    if RETRAIN_MODELS:
        model.update(train_log, train_defect_log)
    else:
        models[name] = model.load(MODEL_OUTPUT_PATH / f"{name}.pkl")

***

# Exploratory analysis

In [None]:
if RETRAIN_MODELS:
    test_log = test_log.sort_values(by=['time'])

    model_prioritizations = []

    for idx, submission in tqdm(test_log.iterrows(), total=test_log.shape[0], desc="Calculating prioritizations"):
        defect_counts = test_defect_log.loc[idx]

        model_scores = {name: model._calculate_scores(submission, defect_counts) for name, model in models.items()}
        
        for defect in defect_counts[defect_counts > 0].index:
            row = {"submission id": idx, "defect id": defect}
            for name, scores in model_scores.items():
                row[name] = scores[defect]
            model_prioritizations.append(row)

        for model in models.values():
            model.update(submission, defect_counts)

    model_prioritizations = pd.DataFrame(model_prioritizations)

    model_prioritizations.to_csv(EVALUATION_PRIORITIZATIONS_CACHE / 'model_scores.csv', index=False)
else:
    model_prioritizations = pd.read_csv(EVALUATION_PRIORITIZATIONS_CACHE / 'model_scores.csv')

##  Weight histograms

In [None]:
def plot_model_weight_histogram(model_name, model, image_dir, bins=60):
    """Plot a histogram of model weights."""
    values = model.get_model_weights()

    if values is None:
        print(f"Model {model_name} has no weights. Skipping...")
        return

    values = values.values.flatten()
    thresholds = model.get_model_thresholds()

    fig, ax = plt.subplots(figsize=(12, 8), layout="constrained")

    n_unique = len(np.unique(values))
    discrete = n_unique <= 10

    sns.histplot(data=values, bins=bins if not discrete else n_unique, ax=ax, discrete=discrete, shrink=0.8)

    for threshold in thresholds:
        ax.axvline(threshold, color='r', linestyle='--', linewidth=2)
    
    ax.set_yscale('log')
    ax.set_xlabel(model.get_measure_name())
    ax.set_ylabel('Frequency (Log Scale)')
    ax.set_title(f'Distribution of {model.get_measure_name()}')
    
    plt.savefig(image_dir / model_name.lower().replace(" ", "_"), dpi=RESOLUTION)
    plt.close()

In [None]:
model_weight_histograms_dir = IMAGE_DIR / 'model_weight_histograms'
os.makedirs(model_weight_histograms_dir, exist_ok=True)

for model_name, model in tqdm(models.items(), desc="Plotting Model Weight Histograms"):
    if model_name == 'Defect Multiplicity':
        # The histogram of the model weights is not well-defined for the Defect Multiplicity model. Instead, we use a simple approximation.
        defect_counts = [1, 2, 3]
        values = pd.DataFrame(
            [model._calculate_scores(None, pd.Series(i, index=model.defects.index)) for i in defect_counts],
            index=defect_counts,
        )

        values = values.values.flatten()
        thresholds = model.get_model_thresholds()

        fig, ax = plt.subplots(figsize=(12, 8), layout="constrained")

        n_unique = len(np.unique(values))
        discrete = n_unique <= 10

        sns.histplot(data=values, bins=60, ax=ax, shrink=0.8)

        for threshold in thresholds:
            ax.axvline(threshold, color='r', linestyle='--', linewidth=2)
        
        ax.set_yscale('log')
        ax.set_xlabel(model.get_measure_name())
        ax.set_ylabel('Frequency (Log Scale)')
        ax.set_title(f'Approximate Distribution of {model.get_measure_name()}')
        
        plt.savefig(model_weight_histograms_dir / model_name.lower().replace(" ", "_"), dpi=RESOLUTION)
        plt.close()
    else:
        plot_model_weight_histogram(model_name, model, model_weight_histograms_dir)

## Prediction histograms

In [None]:
def plot_model_score_histogram(
    model_name, 
    values,
    thresholds,
    image_dir,
    bins=60,
):
    """Plot a histogram of model scores (from cached prioritizations)."""
    values = np.asarray(values)

    if values.size == 0:
        print(f"Model {model_name}: no values available — skipping.")
        return

    # Decide whether discrete or continuous histogram
    n_unique = len(np.unique(values))
    discrete = n_unique <= 10  # same logic as your original

    fig, ax = plt.subplots(figsize=(12, 8), layout="constrained")

    sns.histplot(
        data=values,
        bins=n_unique if discrete else bins,
        ax=ax,
        discrete=discrete,
        shrink=0.8
    )

    # Add thresholds
    for treshold in thresholds:
        ax.axvline(treshold, color="r", linestyle="--", linewidth=2)

    ax.set_yscale('log')
    ax.set_xlabel(model_name)
    ax.set_ylabel("Frequency (log scale)")
    ax.set_title(f"Distribution of Scores for {model_name}")

    save_path = image_dir / f"{model_name.lower().replace(' ', '_')}.png"
    plt.savefig(save_path, dpi=RESOLUTION)
    plt.close()

In [None]:
model_score_histograms_dir = IMAGE_DIR / 'model_score_histograms'
os.makedirs(model_score_histograms_dir, exist_ok=True)

for model_name in models.keys():
    values = model_prioritizations[model_name].values

    thresholds = models[model_name].get_model_thresholds()

    plot_model_score_histogram(
        model_name=model_name,
        values=values,
        thresholds=thresholds,
        image_dir=model_score_histograms_dir,
    )

## Task-defect weight maps

In [None]:
def plot_task_weight_heatmap(model_name, model, image_dir, defects, items, mask=None, normalize_cmap=False):
    """Plot a heatmap of task-defect model weights."""
    model_weights = model.get_model_weights().copy()

    if model_weights is None:
        print(f"Model {model_name} has no weights. Skipping...")
        return

    defect_names = defects['display name'].loc[model_weights.columns]
    task_names = items['display name'].loc[model_weights.index]

    fig, ax = plt.subplots(figsize=(12, 17), layout="constrained")
    
    if normalize_cmap:
        lim = max(np.abs(model_weights.values).max(), 1)
        kwargs = {'cmap': 'coolwarm', 'vmin':-lim, 'vmax':lim}
    else:
        kwargs = {'cmap': 'Reds'}

    model_weights[mask] = np.nan
        
    sns.heatmap(
        model_weights, xticklabels=defect_names, yticklabels=task_names, cbar=True,
        cbar_kws={'label': model.get_measure_name()}, **kwargs
    )

    ax.tick_params(axis='x', labelsize=7)
    ax.tick_params(axis='y', labelsize=8)
    plt.title(model.get_measure_name())
    plt.xlabel("Defects")
    plt.ylabel("Tasks")
    title = model_name.lower().replace(" ", "_")
    plt.savefig(image_dir / (title if mask is None else title + "_masked"), dpi=RESOLUTION)
    plt.close()

In [None]:
task_defect_maps_dir = IMAGE_DIR / 'task_defect_weight_maps'
os.makedirs(task_defect_maps_dir, exist_ok=True)

for model_name, model in models.items():
    if model.get_context_type() == "task":
        normalize_cmap = model_name == "Task Characteristic"
        plot_task_weight_heatmap(model_name, model, task_defect_maps_dir, defects, items, normalize_cmap=normalize_cmap)
    else:
        print(f"Model {model_name} is not a task model. Skipping...")

In [None]:
for model_name, model in models.items():
    if model.get_context_type() == "task":
        normalize_cmap = model_name == "Task Characteristic"
        plot_task_weight_heatmap(model_name, model, task_defect_maps_dir, defects, items, normalize_cmap=normalize_cmap, mask=insufficient_support)
    else:
        print(f"Model {model_name} is not a task model. Skipping...")

In [None]:
model_metrics_dir = IMAGE_DIR / 'model_metrics'
os.makedirs(model_metrics_dir, exist_ok=True)

***

# Quantitative analysis

In [None]:
if RETRAIN_MODELS:
    # Prepare metric containers
    metrics = {
        'exact_ties':  {name: [] for name in models.keys()},
        'close_ties':  {name: [] for name in models.keys()},
        'gini':        {name: [] for name in models.keys()},
    }

    model_correlations = {
        name1: {name2: [] for name2 in models.keys()}
        for name1 in models.keys()
    }

    # Group by submission id to reconstruct each submission's defect set
    submission_groups = model_prioritizations.groupby("submission id")

    for submission_id, group in tqdm(submission_groups, desc="Computing metrics from cached prioritizations"):
        if group.shape[0] < 2:
            # Too few defects to compute prioritization metrics. Skip or append zeros—keeping skip to match your original logic.
            continue

        # Reconstruct score vectors
        model_scores = {
            name: group[name].values.astype(float)
            for name in models.keys()
        }

        # Per-model metrics
        for name, scores in model_scores.items():
            max_score = scores.max()

            # Exact ties
            metrics['exact_ties'][name].append(int((scores == max_score).sum() > 1))

            # Close ties
            close_threshold_value = max_score * CLOSE_TIE_THRESHOLD
            metrics['close_ties'][name].append(
                int((scores >= (max_score - close_threshold_value)).sum() > 1)
            )

            # Gini coefficient
            metrics['gini'][name].append(gini(scores))

        # Inter-model Spearman correlations
        for i, name_i in enumerate(models.keys()):
            for j, name_j in enumerate(models.keys()):
                if i >= j:
                    continue

                arr_i = model_scores[name_i]
                arr_j = model_scores[name_j]

                # Handle constant arrays (undefined Spearman)
                if np.all(arr_i == arr_i[0]) or np.all(arr_j == arr_j[0]):
                    rho = np.nan
                else:
                    rho, _ = spearmanr(arr_i, arr_j)

                model_correlations[name_i][name_j].append(rho)
                model_correlations[name_j][name_i].append(rho)

    # Save results
    with open(MODEL_METRICS_PATH / "metrics.json", "w") as f:
        json.dump(metrics, f)

    with open(MODEL_METRICS_PATH / "model_correlations.json", "w") as f:
        json.dump(model_correlations, f)
else:
    with open(MODEL_METRICS_PATH / "metrics.json", "r") as f:
        metrics = json.load(f)
    with open(MODEL_METRICS_PATH / "model_correlations.json", "r") as f:
        model_correlations = json.load(f)

In [None]:
# --- Final Aggregation and Formatting ---
results = {}
for metric_name, data in metrics.items():
    avg_values = {name: np.mean(values) for name, values in data.items()}
    results[f'avg_{metric_name}'] = pd.Series(avg_values)

print(f"Average Exact Ties: \n{results['avg_exact_ties']}")
print("-" * 20)
print(f"Average Close Ties: \n{results['avg_close_ties']}")
print("-" * 20)
print(f"Average Gini Coefficients: \n{results['avg_gini']}")

## Decisivness - Ties

In [None]:
plot_data = pd.DataFrame({
    'Exact Ties': results['avg_exact_ties'],
    'Close Ties': results['avg_close_ties'] - results['avg_exact_ties']
}).T.reset_index()

plot_data = plot_data.rename(columns={'index': 'Tie Type'})

plot_data_melted = plot_data.melt(id_vars='Tie Type', var_name='Model', value_name='Average Count')

ax = plot_data.set_index('Tie Type').T.plot(
    kind='bar', 
    stacked=True, 
    figsize=(12, 8),
    colormap='tab10'
)

plt.title("Average Number of Exact and Close Ties per Model", fontsize=16)
plt.ylabel("Average Count", fontsize=12)
plt.xlabel("Prioritization Model", fontsize=12)
plt.xticks(rotation=45, ha="right")

plt.legend(title='Tie Type')
plt.tight_layout()
plt.savefig(model_metrics_dir / 'ties_bar_plot.png', dpi=RESOLUTION)
plt.close()

## Decisivness - Gini

In [None]:
plot_data = []
for model_name, gini_list in metrics['gini'].items():
    for gini_value in gini_list:
        plot_data.append({'Model': model_name, 'Gini Coefficient': gini_value})

plot_data = pd.DataFrame(plot_data)

plt.figure(figsize=(12, 8))

sns.boxplot(
    x='Model',
    y='Gini Coefficient',
    data=plot_data,
    notch=True,
    palette='viridis',
    hue='Model',
    legend=False
)

plt.title("Distribution of Gini Coefficients by Prioritization Model", fontsize=16)
plt.xlabel("Prioritization Model", fontsize=12)
plt.ylabel("Gini Coefficient", fontsize=12)
plt.xticks(rotation=45, ha="right")
plt.ylim(0, 1)
plt.tight_layout()

plt.savefig(model_metrics_dir / 'gini_box_plot.png', dpi=RESOLUTION)
plt.close()

## Inter-Model Agreement

In [None]:
correlation_matrix = np.empty((len(models.keys()), len(models.keys())))
correlation_matrix[:] = np.nan

for i, first in enumerate(models.keys()):
    for j, second in enumerate(models.keys()):
        if i >= j:
            continue
        correlation_matrix[i, j] = np.nanmean(model_correlations[first][second])
        correlation_matrix[j, i] = correlation_matrix[i, j]

correlation_matrix = pd.DataFrame(correlation_matrix, index=models.keys(), columns=models.keys())

plt.figure(figsize=(10, 8))

sns.heatmap(
    correlation_matrix,
    annot=True,
    cmap='coolwarm',
    fmt=".2f",
    linewidths=.5,
    cbar_kws={'label': "Average Spearman's Correlation"},
    vmin=-1,
    vmax=1
)

plt.title("Inter-Model Prioritization Agreement (Spearman's Rho)", fontsize=16)
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()

plt.savefig(model_metrics_dir / 'model_correlation_heatmap.png', dpi=RESOLUTION)
plt.close()

## Agreement with baseline

In [None]:
plot_data = []
for model, corr_list in model_correlations['Naive Severity'].items():
    for corr in corr_list:
        plot_data.append({'Model': model, 'Spearman_Rho': corr})

plot_data = pd.DataFrame(plot_data)

plt.figure(figsize=(10, 8))

sns.boxplot(
    x='Model',
    y='Spearman_Rho',
    data=plot_data,
    palette='rocket',
    hue='Model',
    legend=False,
    notch=True
)

plt.title("Distribution of Spearman's Rho by Prioritization Model", fontsize=16)
plt.xlabel("Prioritization Model", fontsize=12)
plt.ylabel("Spearman's Rho", fontsize=12)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()

plt.savefig(model_metrics_dir / 'agreement_with_baseline_box_plot.png', dpi=RESOLUTION)
plt.close()

## Context sensitivity

In [None]:
matrices = {}
for name, model in models.items():
    print(f"Calculating correlation matrix for {name}")
    
    weights = model.get_model_weights()

    # Handle single-submission and context-less models
    if weights is None or weights.shape[0] == 1:
        matrices[name] = pd.DataFrame(1.0, index=items.index, columns=items.index)
        continue

    # Filter out constant rows (for spearman correlation)
    is_constant_row = weights.apply(lambda row: np.all(row == row.iloc[0]), axis=1)
    filtered_weights = weights[~is_constant_row]

    correlation_matrix = filtered_weights.T.corr(method='spearman')

    # Fill NaNs (which correspond to constant rows or missing tasks) with 1.0 (perfect agreement)
    correlation_matrix = correlation_matrix.reindex(index=weights.index, columns=weights.index)
    np.fill_diagonal(correlation_matrix.values, 1.0)
    
    matrices[name] = correlation_matrix

In [None]:
task_matrices = {
    name: matrix for name, matrix in matrices.items() if models[name].get_context_type() == 'task'
}

ncols = 2
nrows = (len(task_matrices) + ncols - 1) // ncols

fig, axes = plt.subplots(
    nrows=nrows, ncols=ncols, figsize=(10 * ncols, 9 * nrows),
    sharex=True, sharey=True,
    gridspec_kw={'wspace': 0.05, 'hspace': 0.05} # Fine-tune the spacing
)
axes = axes.flatten()

for i, (name, correlation_matrix) in enumerate(task_matrices.items()):
    ax = axes[i]
    
    sns.heatmap(
        correlation_matrix.astype(float),
        ax=ax,
        cmap='coolwarm',
        vmin=-1,
        vmax=1,
        linewidths=.5,
        cbar=False,
    )

    ax.set_title(f"Task-to-Task Agreement for {name}", fontsize=18)
    if i in [2, 3]:
        ax.set_xlabel('Task ID', fontsize=12)
    else:
        ax.set_xlabel('')
        ax.set_xticklabels([])
    if i in [0, 2]:
        ax.set_ylabel('Task ID', fontsize=12)
    else:
        ax.set_ylabel('')
        ax.set_yticklabels([])

# Create a single color bar for the entire figure
fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])
last_heatmap = axes[-1].collections[0] if len(task_matrices) % 2 == 0 else axes[-2].collections[0]
cbar = fig.colorbar(last_heatmap, cax=cbar_ax)
cbar.set_label("Spearman's Correlation", fontsize=16)

plt.savefig(model_metrics_dir / 'task_model_sensitivity_heatmaps.png', dpi=RESOLUTION)
plt.close()

In [None]:
std_devs = {}
for name, correlation_matrix in matrices.items():
    upper_triangle_mask = np.triu(np.ones(correlation_matrix.shape), k=1).astype(bool)
    off_diagonal_values = correlation_matrix.where(upper_triangle_mask).stack().values
    off_diagonal_std = off_diagonal_values.std()
    
    std_devs[name] = off_diagonal_std

std_df = pd.DataFrame(std_devs.items(), columns=['Model', 'Standard Deviation'])

plt.figure(figsize=(10, 6))
sns.barplot(x='Model', y='Standard Deviation', data=std_df, palette='viridis', hue='Model', legend=False)

plt.title("Standard Deviation of Inter-Task Correlation", fontsize=16)
plt.ylabel("Standard Deviation of Spearman's Rho", fontsize=12)
plt.xlabel("Prioritization Model", fontsize=12)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()

plt.savefig(model_metrics_dir / 'sensitivity_bar_plot.png', dpi=RESOLUTION)
plt.close()

***

# Export models

In [None]:
if RETRAIN_MODELS:
    for name, model in (pbar :=tqdm(models.items(), desc="Training Models")):
        pbar.set_description(f"Training {name}")
        model.save(MODEL_OUTPUT_PATH / f"{name}.pkl")