In [None]:
#Graphs for manuscripts.

import pandas as pd
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import re

In [None]:
#Import Table S3 data for figure 1
df = pd.read_csv("results/Table S3 Results.csv")

In [None]:
# Filter for non-shot or 10-shot results only
df_filtered = df[df['Number of shots'].isna() | (df['Number of shots'] == 10)].copy()

# Extract central values
def extract_central_value(s):
    try:
        return float(s.split(" ")[0])
    except:
        return np.nan

df_filtered["Precision"] = df_filtered["Precision (95% CI)"].apply(extract_central_value)
df_filtered["Recall"] = df_filtered["Recall (95% CI)"].apply(extract_central_value)
df_filtered["F1"] = df_filtered["F1 (95% CI)"].apply(extract_central_value)

# Extract CI bounds
def extract_ci_bounds(s):
    try:
        low_high = s.split("(")[1].strip(")").split("-")
        return float(low_high[0]), float(low_high[1])
    except:
        return np.nan, np.nan

df_filtered["Precision_low"], df_filtered["Precision_high"] = zip(*df_filtered["Precision (95% CI)"].map(extract_ci_bounds))
df_filtered["Recall_low"], df_filtered["Recall_high"] = zip(*df_filtered["Recall (95% CI)"].map(extract_ci_bounds))
df_filtered["F1_low"], df_filtered["F1_high"] = zip(*df_filtered["F1 (95% CI)"].map(extract_ci_bounds))

# Reshape for plotting
def reshape_for_plot(metric):
    return pd.DataFrame({
        "Model name": df_filtered["Model name"],
        "Metric": metric,
        "Score": df_filtered[metric],
        "Lower": df_filtered[f"{metric}_low"],
        "Upper": df_filtered[f"{metric}_high"]
    })

df_plot = pd.concat([
    reshape_for_plot("F1"),
    reshape_for_plot("Precision"),
    reshape_for_plot("Recall")
])

df_plot["yerr"] = df_plot["Upper"] - df_plot["Score"]

# Assign color: black for Inter-annotator (assumed first row), others by palette
unique_models = df_plot["Model name"].unique()
palette = sns.color_palette("Set2", len(unique_models))
model_colors = dict(zip(unique_models, palette))
model_colors[unique_models[0]] = "black"  # First model in black

# Plotting
sns.set(style="whitegrid")
g = sns.FacetGrid(df_plot, row="Metric", sharex=True, sharey=False, height=4, aspect=2)

def plot_with_ci(data, **kwargs):
    ax = plt.gca()
    for i, row in data.iterrows():
        color = model_colors[row["Model name"]]
        ax.bar(row["Model name"], row["Score"], color=color)
        ax.errorbar(row["Model name"], row["Score"],
                    yerr=[[row["Score"] - row["Lower"]], [row["Upper"] - row["Score"]]],
                    fmt='none', c='black', capsize=3)

g.map_dataframe(plot_with_ci)

g.set_titles(row_template="{row_name}")
# Loop through axes and customize
for i, ax in enumerate(g.axes.flatten()):
    ax.set_ylim(0, 1)
    ax.set_ylabel("Score")

    # Only set xtick labels on the last (bottom) row
    if i == len(g.axes.flatten()) - 1:
        labels = ax.get_xticklabels()
        ax.set_xticks(ax.get_xticks())  # force ticks to be set
        ax.set_xticklabels(labels, rotation=90, ha='center', fontsize=9)
    else:
        ax.set_xticklabels([])

g.figure.subplots_adjust(hspace=0.3, bottom=0.2)
#g.figure.set_size_inches(12, 10)

custom_xticks = df_plot["Model name"].unique()
for ax in g.axes.flatten():
    ax.set_xticks(range(len(custom_xticks)))  # numeric positions
    ax.set_xticklabels(custom_xticks, rotation=90, ha='center', fontsize=9)
plt.show()





In [None]:
# Export as SVG
g.figure.savefig("Revised Paper Fig 1.svg", format="svg")

# Export as PDF
g.figure.savefig("Revised Paper Fig 1.pdf", format="pdf")

In [None]:
####### FIGURE 2, import the strimmed down table S4
df = pd.read_csv("results/Table S4 Results.csv")

In [None]:
# Filter for non-shot or 10-shot results only
df_filtered = df[df['Number of shots'].isna() | (df['Number of shots'] == 10)].copy()

# Extract central values (Precision, Recall, F1)
def extract_central_value(s):
    try:
        return float(s.split(" ")[0])
    except:
        return np.nan

df_filtered["Precision"] = df_filtered["Precision (95% CI)"].apply(extract_central_value)
df_filtered["Recall"] = df_filtered["Recall (95% CI)"].apply(extract_central_value)
df_filtered["F1"] = df_filtered["F1 (95% CI)"].apply(extract_central_value)

# Extract CI bounds (Lower and Upper)
def extract_ci_bounds(s):
    try:
        low_high = s.split("(")[1].strip(")").split("-")
        return float(low_high[0]), float(low_high[1])
    except:
        return np.nan, np.nan

df_filtered["Precision_low"], df_filtered["Precision_high"] = zip(*df_filtered["Precision (95% CI)"].map(extract_ci_bounds))
df_filtered["Recall_low"], df_filtered["Recall_high"] = zip(*df_filtered["Recall (95% CI)"].map(extract_ci_bounds))
df_filtered["F1_low"], df_filtered["F1_high"] = zip(*df_filtered["F1 (95% CI)"].map(extract_ci_bounds))

# Reshape for plotting
def reshape_for_plot(metric):
    return pd.DataFrame({
        "Model name": df_filtered["Model name"],
        "Dataset": df_filtered["Dataset"],  # Add Dataset column
        "Metric": metric,
        "Score": df_filtered[metric],
        "Lower": df_filtered[f"{metric}_low"],
        "Upper": df_filtered[f"{metric}_high"]
    })

# Concatenate reshaped data for each metric
df_plot = pd.concat([
    reshape_for_plot("F1"),
    reshape_for_plot("Precision"),
    reshape_for_plot("Recall")
])

df_plot["yerr"] = df_plot["Upper"] - df_plot["Score"]

# Assign color: black for Inter-annotator (first model row), others by palette
unique_models = df_plot["Model name"].unique()
palette = sns.color_palette("Set2", len(unique_models))
model_colors = dict(zip(unique_models, palette))
model_colors[unique_models[0]] = "black"  # First model in black

# Plotting
sns.set(style="whitegrid")

# Create a FacetGrid with only rows for the 3 metrics (F1, Precision, Recall)
g = sns.FacetGrid(df_plot, row="Metric", sharex=True, sharey=False, height=4, aspect=2)

def plot_with_ci(data, **kwargs):
    ax = plt.gca()
    for i, row in data.iterrows():
        color = model_colors[row["Model name"]]
        ax.plot(row["Model name"], row["Score"], 'o', color=color)  # Plot points for model
        ax.errorbar(row["Model name"], row["Score"],
                    yerr=[[row["Score"] - row["Lower"]], [row["Upper"] - row["Score"]]],
                    fmt='none', c='black', capsize=3)  # Error bars

g.map_dataframe(plot_with_ci)

# Customize plot titles, labels, etc.
g.set_titles(row_template="{row_name}", col_template="")
for ax in g.axes.flatten():
    ax.set_ylim(0, 1)
    ax.set_ylabel("Score")
    ax.set_xlabel("Model Name")  # Add x-axis label for each plot
    ax.tick_params(axis='x', rotation=90, labelsize=9)

# Adjust spacing between plots
g.figure.subplots_adjust(hspace=0.3, bottom=0.2)

# Show plot
plt.show()


In [None]:
# Combine Model name and Number of shots for unique labels where shots aren't numeric
df['Plot Label'] = df.apply(lambda row: f"{row['Model name']} ({row['Number of shots']})"
                           if not str(row['Number of shots']).isdigit() and str(row['Number of shots']) != 'N/A'
                           else row['Model name'], axis=1)


# 2. Data Processing: Extract mean and error margins
def parse_metric(metric_str):
    """Parses 'mean ( lower-upper )' string into mean, lower, upper."""
    if pd.isna(metric_str) or not isinstance(metric_str, str):
        return np.nan, np.nan, np.nan
    match = re.match(r'(\d+\.?\d*)\s*\(\s*(\d+\.?\d*)\s*-\s*(\d+\.?\d*)\s*\)', metric_str)
    if match:
        mean, lower, upper = map(float, match.groups())
        return mean, lower, upper
    # Handle cases where there might only be a mean value (though not in this specific data)
    try:
        return float(metric_str), np.nan, np.nan
    except ValueError:
        return np.nan, np.nan, np.nan

metrics = ['Precision', 'Recall', 'F1']
for metric in metrics:
    col_name = f'{metric} (95% CI)'
    # Apply parsing
    parsed_data = df[col_name].apply(parse_metric).tolist()
    df[[f'{metric}_mean', f'{metric}_lower', f'{metric}_upper']] = pd.DataFrame(parsed_data, index=df.index)

    # Calculate asymmetric error margins needed for plt.errorbar
    # Error is [distance from mean to lower bound, distance from mean to upper bound]
    lower_error = df[f'{metric}_mean'] - df[f'{metric}_lower']
    upper_error = df[f'{metric}_upper'] - df[f'{metric}_mean']
    df[f'{metric}_error'] = list(zip(lower_error, upper_error)) # Store as tuples per row

# 3. Filtering Data
# Keep rows where 'Number of shots' is not numeric (like N/A, text descriptions)
# OR where 'Number of shots' is exactly '10'
df['Number of shots'] = df['Number of shots'].astype(str) # Ensure it's string for comparison
is_numeric = df['Number of shots'].str.match(r'^\d+$')
is_numeric = is_numeric.fillna(False) # Handle potential NaN if any

# Keep non-numeric OR the '10' shot numeric ones
df_filtered = df[~is_numeric | (df['Number of shots'] == '10')].copy()

# Ensure unique plot labels if the same model now appears twice (e.g. AnonCAT)
# If model name is duplicated after filtering, add shot info for clarity
duplicates = df_filtered['Plot Label'].duplicated(keep=False)
df_filtered.loc[duplicates, 'Plot Label'] = df_filtered.loc[duplicates].apply(
    lambda row: f"{row['Model name']}", axis=1
)


# 4. Plotting
n_models = len(df_filtered)
x = np.arange(n_models)  # the label locations
width = 0.8  # the width of the bars

# Use a colormap for distinct colors
colors = plt.cm.get_cmap('tab10', n_models)

fig, axes = plt.subplots(3, 1, figsize=(18, 18), sharex=True) # Increased height
#fig.suptitle('Model Performance Comparison (10-Shot or N/A Shots)', fontsize=16, y=1.02)

metrics_to_plot = ['F1', 'Recall', 'Precision']
titles = ['F1 Score (95% CI)', 'Recall (95% CI)', 'Precision (95% CI)']

for i, (metric, title) in enumerate(zip(metrics_to_plot, titles)):
    ax = axes[i]
    means = df_filtered[f'{metric}_mean'].values
    # Need to format error correctly for bar plot yerr: should be shape (2, N)
    errors = df_filtered[f'{metric}_error'].tolist()
    # Handle potential NaNs in errors before transposing
    errors_formatted = [[e[0] if pd.notna(e[0]) else 0 for e in errors],
                        [e[1] if pd.notna(e[1]) else 0 for e in errors]]

    # bars = ax.bars(x, means, width,
    #               label=metric,
    #               color=[colors(j) for j in range(n_models)],
    #               yerr=errors_formatted,
    #               capsize=5, # Add caps to error bars
    #               alpha=0.85) # Slight transparency
    
    bars = ax.scatter(x, means,
                label=metric,
                color=[colors(j) for j in range(n_models)]
    )

    # Add error bars manually
    ax.errorbar(x, means, yerr=errors_formatted, fmt='none', capsize=5, alpha=0.85)

    ax.set_ylabel(metric)
    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_ylim(0, 1.05) # Scores are between 0 and 1
    ax.grid(axis='y', linestyle='--', alpha=0.7)


# Set x-axis labels only on the bottom plot
axes[-1].set_xticklabels(df_filtered['Plot Label'], rotation=45, ha='right')

# Improve layout
plt.tight_layout(rect=[0, 0, 1, 1]) # Adjust layout to make space for title and labels
plt.show()

# Display the filtered data table (optional)
print("Filtered Data Used for Plotting:")
print(df_filtered[['Plot Label', 'Number of shots', 'Precision_mean', 'Recall_mean', 'F1_mean']].round(3))