In [None]:
import polars as pl
import matplotlib.pyplot as plt
import numpy as np
plt.set_loglevel("info")

In [None]:
# Input files
df_file = "results/evaluation_all.tsv"
out_folder = "results"

# Set to None if you don't want to use it. Results will not be grouped/filtered by team
names_file = None

# Cumulate the last column of the cols variable, e.g. "pr" --> precision, so that the curves are monotonic as in CAFA
cumulate = True

# Add extreme points to the precision-recall curves (0, 1) and (1, 0)
add_extreme_points = True

# Methods with coverage below this threshold will not be plotted
coverage_threshold = 0.3

# Select a metric
# metric, cols = ('f', ['rc', 'pr'])
metric, cols =  ('f_w', ['rc_w', 'pr_w'])
# metric, cols =  ('f_micro', ['rc_micro', 'pr_micro'])
# metric, cols =  ('f_micro_w', ['rc_micro_w', 'pr_micro_w'])
# metric, cols = ('s_w', ['ru_w', 'mi_w'])

In [None]:
# Map column names to full names (for axis labels)
axis_title_dict = {'pr': 'Precision', 'rc': 'Recall', 'f': 'F-score', 'pr_w': 'Weighted Precision', 'rc_w': 'Weighted Recall', 'f_w': 'Weighted F-score', 'mi': 'Misinformation (Unweighted)', 'ru': 'Remaining Uncertainty (Unweighted)', 'mi_w': 'Misinformation', 'ru_w': 'Remaining Uncertainty', 's': 'S-score', 'pr_micro': 'Precision (Micro)', 'rc_micro': 'Recall (Micro)', 'f_micro': 'F-score (Micro)', 'pr_micro_w': 'Weighted Precision (Micro)', 'rc_micro_w': 'Weighted Recall (Micro)', 'f_micro_w': 'Weighted F-score (Micro)'}

# Map ontology namespaces to full names (for plot titles)
ontology_dict = {'biological_process': 'BPO', 'molecular_function': 'MFO', 'cellular_component': 'CCO'}

In [None]:
df = pl.read_csv(df_file, separator="\t")
df

In [None]:
# Set method information (optional)
if names_file is None:
    df = df.with_columns([
        pl.col('filename').alias('group'),
        pl.col('filename').alias('label'),
        pl.lit(False).alias('is_baseline'),
    ])
else:
    methods = pl.read_csv(names_file, separator=" ")
    df = df.join(methods, on='filename', how='left')
    df = df.with_columns([
        pl.col('group').fill_null(pl.col('filename')),
        pl.col('label').fill_null(pl.col('filename')),
    ])
    if 'is_baseline' not in df.columns:
        df = df.with_columns(pl.lit(False).alias('is_baseline'))
    else:
        df = df.with_columns(pl.col('is_baseline').fill_null(False))

df = df.drop('filename')
df

In [None]:
# Filter by coverage
df = df.filter(pl.col('cov') >= coverage_threshold)
df

In [None]:
# Assign colors based on group
cmap = plt.get_cmap('tab20')
# Create a mapping from unique groups to color indices
unique_groups = df['group'].unique().to_list()
group_to_color_idx = {g: i for i, g in enumerate(unique_groups)}
# Map groups to colors
color_indices = df['group'].map_elements(lambda x: group_to_color_idx[x], return_dtype=pl.Int64)
colors = color_indices.map_elements(lambda x: cmap.colors[x % len(cmap.colors)], return_dtype=pl.Object)
df = df.with_columns(colors.alias('colors'))
df

In [None]:
# Identify the best methods and thresholds
if metric in ['f', 'f_w', 'f_micro', 'f_micro_w']:
    index_best = df.group_by(['group', 'ns']).agg(pl.all().sort_by(metric).last())
else:
    index_best = df.group_by(['group', 'ns']).agg(pl.all().sort_by(metric).first())
index_best

In [None]:
# Filter the dataframe for the best methods
# Get rows that match the best label/ns combinations
best_labels = index_best.select(['group', 'label', 'ns'])
df_methods = df.join(best_labels, on=['group', 'label', 'ns'], how='inner')
df_methods = df_methods.select(['group', 'label', 'ns', 'tau', 'cov', 'colors'] + cols + [metric])

# Makes the curves monotonic. Cumulative max on the last column of the cols variable, e.g. "pr" --> precision
if cumulate:
    if metric in ['f', 'f_w', 'f_micro', 'f_micro_w']:
        df_methods = df_methods.with_columns(
            pl.col(cols[-1]).cum_max().over(['label', 'ns']).alias(cols[-1])
        )
    else:
        df_methods = df_methods.with_columns(
            pl.col(cols[-1]).cum_min().over(['label', 'ns']).alias(cols[-1])
        )

# Save to file
df_methods.drop('colors').write_csv('{}/fig_{}.tsv'.format(out_folder, metric), separator="\t", float_precision=3)
df_methods

In [None]:
# Add first last points to precision and recall curves to improve APS calculation
def add_points(df_group):
    # Get first and last rows
    first_row = df_group.head(1).with_columns([
        pl.lit(0.0).alias('tau'),
        pl.lit(1.0).alias(cols[0]),
        pl.lit(0.0).alias(cols[1]),
    ])
    last_row = df_group.tail(1).with_columns([
        pl.lit(1.1).alias('tau'),
        pl.lit(0.0).alias(cols[0]),
        pl.lit(1.0).alias(cols[1]),
    ])
    return pl.concat([first_row, df_group, last_row])

if metric.startswith('f') and add_extreme_points:
    df_methods = df_methods.group_by(['group', 'label', 'ns'], maintain_order=True).map_groups(add_points)
df_methods

In [None]:
# Filter the dataframe for the best method and threshold
df_best = index_best.select(['group', 'label', 'ns', 'tau', 'cov', 'colors'] + cols + [metric])
df_best

In [None]:
# Calculate average precision score 
if metric.startswith('f'):
    # Calculate APS per group
    aps_df = df_methods.sort(['group', 'label', 'ns', 'tau']).group_by(['group', 'label', 'ns']).agg(
        ((pl.col(cols[0]).diff(-1).shift(1)) * pl.col(cols[1])).sum().alias('aps')
    )
    df_best = df_best.join(aps_df, on=['group', 'label', 'ns'], how='left')
df_best

In [None]:
# Calculate the max coverage across all thresholds
max_cov_df = df_methods.group_by(['group', 'label', 'ns']).agg(
    pl.col('cov').max().alias('max_cov')
)
df_best = df_best.join(max_cov_df, on=['group', 'label', 'ns'], how='left')
df_best

In [None]:
# Set a label column for the plot legend
if 'aps' not in df_best.columns:
    df_best = df_best.with_columns(
        (pl.col('label') + ' (' + metric.upper() + '=' + pl.col(metric).round(3).cast(pl.Utf8) + ' C=' + pl.col('max_cov').round(3).cast(pl.Utf8) + ')').alias('plot_label')
    )
else:
    df_best = df_best.with_columns(
        (pl.col('label') + ' (' + metric.upper() + '=' + pl.col(metric).round(3).cast(pl.Utf8) + ' APS=' + pl.col('aps').round(3).cast(pl.Utf8) + ' C=' + pl.col('max_cov').round(3).cast(pl.Utf8) + ')').alias('plot_label')
    )
df_best

In [None]:
# Generate the figures
plt.rcParams.update({'font.size': 22, 'legend.fontsize': 18})

# F-score contour lines
x = np.arange(0.01, 1, 0.01)
y = np.arange(0.01, 1, 0.01)
X, Y = np.meshgrid(x, y)
Z = 2 * X * Y / (X + Y)

for ns in df_best['ns'].unique().to_list():
    df_g = df_best.filter(pl.col('ns') == ns)
    fig, ax = plt.subplots(figsize=(15, 15))

    # Contour lines. At the moment they are provided only for the F-score
    if metric.startswith('f'):
        CS = ax.contour(X, Y, Z, np.arange(0.1, 1.0, 0.1), colors='gray')
        ax.clabel(CS, inline=True)

    # Sort by metric and max_cov
    if metric.startswith('f'):
        df_g = df_g.sort([metric, 'max_cov'], descending=[True, True])
    else:
        df_g = df_g.sort([metric, 'max_cov'], descending=[False, True])

    # Iterate methods
    for i, row in enumerate(df_g.iter_rows(named=True)):
        group = row['group']
        label = row['label']
        
        data = df_methods.filter(
            (pl.col('group') == group) & 
            (pl.col('label') == label) & 
            (pl.col('ns') == ns)
        )
        
        # Precision-recall or mi-ru curves
        ax.plot(data[cols[0]].to_numpy(), data[cols[1]].to_numpy(), 
                color=row['colors'], label=row['plot_label'], lw=2, zorder=500-i)
        
        # F-max or S-min dots
        ax.plot(row[cols[0]], row[cols[1]], color=row['colors'], marker='o', markersize=12, mfc='none', zorder=1000-i)
        ax.plot(row[cols[0]], row[cols[1]], color=row['colors'], marker='o', markersize=6, zorder=1000-i)

    # Set axes limit
    if metric.startswith('f'):
        plt.xlim(0, 1)
        plt.ylim(0, 1)

    # Set titles
    ax.set_title(ontology_dict.get(ns, ns), pad=20)
    ax.set_xlabel(axis_title_dict[cols[0]], labelpad=20)
    ax.set_ylabel(axis_title_dict[cols[1]], labelpad=20)
    
    # Legend
    leg = ax.legend(markerscale=6)
    for legobj in leg.get_lines():
        legobj.set_linewidth(10.0)

    # Save figure on disk
    plt.savefig("{}/fig_{}_{}.png".format(out_folder, metric, ns), bbox_inches='tight', dpi=300, transparent=True)