In [7]:
import polars as pl
import sklearn

from udonpred_benchmarking.constants import DATA_DIR
from udonpred_benchmarking.plots import METHODS, QUARTILE_BOUNDARIES

In [20]:
TEST_SET = "caid3"
#TEST_SET = "trizod"

residue_df = pl.read_csv(DATA_DIR / f"{TEST_SET}/per_residue_predictions.csv")
BINARY_SCORE_NAME = "pscores_binary" if TEST_SET == "trizod" else "CAID3_binary"
CONTINUOUS_SCORE_NAME = "pscores" if TEST_SET == "trizod" else "CAID3_binary"
PROTEIN_COLUMN_NAME = "ID" if TEST_SET == "trizod" else "protein"

spearman_scores = []
aps_scores = []
rmse_scores = []

#METHODS.remove("ODiNPred")
for method in METHODS:
    spearman_scores.append(
        residue_df
        .group_by("protein")
        .agg(pl.corr(CONTINUOUS_SCORE_NAME, f"{method}_continuous", method="spearman"))
        .rename({CONTINUOUS_SCORE_NAME: "spearman"})
        .with_columns(pl.lit(method).alias("method"))
        .fill_nan(0)
    )

    aps_scores.append(
        residue_df
        .group_by("protein").agg(
            pl.map_groups(
                exprs=[BINARY_SCORE_NAME, f"{method}_continuous"],
                function=lambda series: sklearn.metrics.average_precision_score(series[0], series[1])
            ).alias("aps")
        )
        .with_columns(
            pl.col("aps").list.explode(),
            pl.lit(method).alias("method")
        )
    )

    rmse_scores.append(
        residue_df
        .group_by("protein").agg(
            pl.map_groups(
                exprs=[CONTINUOUS_SCORE_NAME, f"{method}_continuous"],
                function=lambda series: sklearn.metrics.root_mean_squared_error(series[0], series[1])
            ).alias("rmse")
        )
        .with_columns(
            pl.col("rmse").list.explode(),
            pl.lit(method).alias("method")
        ),
    )

spearman_scores = pl.concat(spearman_scores)
aps_scores = pl.concat(aps_scores)
rmse_scores = pl.concat(rmse_scores)

In [21]:
protein_stats = (
    spearman_scores
    .select(
        pl.col(PROTEIN_COLUMN_NAME),
        pl.col("spearman").median().over(PROTEIN_COLUMN_NAME).alias("spearman_median"),
        pl.col("spearman").min().over(PROTEIN_COLUMN_NAME).alias("spearman_min"),
        pl.col("spearman").max().over(PROTEIN_COLUMN_NAME).alias("spearman_max"),
    )
    .unique()
    .join(
        residue_df
        .group_by(PROTEIN_COLUMN_NAME).mean().select([PROTEIN_COLUMN_NAME, CONTINUOUS_SCORE_NAME]),
        on=PROTEIN_COLUMN_NAME
    )
    .join(
        rmse_scores
        .select(
            pl.col(PROTEIN_COLUMN_NAME),
            pl.col("rmse").median().over(PROTEIN_COLUMN_NAME).alias("rmse_median"),
            pl.col("rmse").min().over(PROTEIN_COLUMN_NAME).alias("rmse_min"),
            pl.col("rmse").max().over(PROTEIN_COLUMN_NAME).alias("rmse_max"),
        )
        .unique(),
        on=PROTEIN_COLUMN_NAME
    )
    .join(
        aps_scores
        .select(
            pl.col(PROTEIN_COLUMN_NAME),
            pl.col("aps").median().over(PROTEIN_COLUMN_NAME).alias("aps_median"),
            pl.col("aps").min().over(PROTEIN_COLUMN_NAME).alias("aps_min"),
            pl.col("aps").max().over(PROTEIN_COLUMN_NAME).alias("aps_max"),
        )
        .unique(),
        on=PROTEIN_COLUMN_NAME
    )
    .rename({CONTINUOUS_SCORE_NAME: f"{CONTINUOUS_SCORE_NAME}_mean"})
)
protein_stats

protein,spearman_median,spearman_min,spearman_max,CAID3_binary_mean,rmse_median,rmse_min,rmse_max,aps_median,aps_min,aps_max
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""DP04091""",0.41828,0.385178,0.432311,0.06867,0.208937,0.175383,0.269408,0.675382,0.400426,0.879889
"""DP04064""",0.0,0.0,0.0,1.0,0.823573,0.664326,0.904342,1.0,1.0,1.0
"""DP04159""",0.0,0.0,0.0,1.0,0.315075,0.116184,0.694807,1.0,1.0,1.0
"""DP04118""",0.623831,0.498085,0.693141,0.292994,0.384841,0.380714,0.408817,0.799554,0.64972,0.876482
"""DP04273""",0.0,0.0,0.0,1.0,0.408454,0.29084,0.719202,1.0,1.0,1.0
…,…,…,…,…,…,…,…,…,…,…
"""DP03815""",0.406703,0.381153,0.409843,0.059459,0.179259,0.101379,0.260844,0.901868,0.71505,0.994545
"""DP04169""",0.18522,0.159687,0.307074,0.090508,0.419087,0.315358,0.521214,0.13028,0.120146,0.415783
"""DP03938""",0.0,0.0,0.0,1.0,0.628215,0.593296,0.672007,1.0,1.0,1.0
"""DP04371""",0.485631,0.4511,0.547039,0.113636,0.283165,0.211496,0.302886,0.579965,0.453046,0.980909


In [22]:
protein_df = (
    spearman_scores
    .join(rmse_scores, on=[PROTEIN_COLUMN_NAME, "method"])
    .join(aps_scores, on=[PROTEIN_COLUMN_NAME, "method"], how="left")
    .unpivot(index=[PROTEIN_COLUMN_NAME, "method"], variable_name="metric", value_name="value")
)
protein_df

protein,method,metric,value
str,str,str,f64
"""DP04378""","""UdonPred""","""spearman""",0.379525
"""DP04064""","""UdonPred""","""spearman""",0.0
"""DP03804""","""UdonPred""","""spearman""",0.559341
"""DP04232""","""UdonPred""","""spearman""",0.484382
"""DP04122""","""UdonPred""","""spearman""",0.456177
…,…,…,…
"""DP04237""","""AlphaFold3-pLDDT""","""aps""",0.733073
"""DP04203""","""AlphaFold3-pLDDT""","""aps""",0.065223
"""DP03804""","""AlphaFold3-pLDDT""","""aps""",0.997486
"""DP03909""","""AlphaFold3-pLDDT""","""aps""",0.834199


In [23]:
#quartile_boundaries = [protein_stats["pscore_mean"].quantile(x, interpolation="midpoint") for x in [.25, .5, .75, 1]]

protein_df = (
    protein_df
    .join(protein_stats.select([PROTEIN_COLUMN_NAME, f"{CONTINUOUS_SCORE_NAME}_mean"]), on=PROTEIN_COLUMN_NAME)
    .join(
        residue_df.group_by(PROTEIN_COLUMN_NAME).len(),
        on=PROTEIN_COLUMN_NAME
    )
)
if TEST_SET == "trizod":
    protein_df = (
        protein_df
        .with_columns(
            pl.when(pl.col("pscore_mean") < .33).then(pl.lit("[0, .33)"))
            .when(pl.col("pscore_mean") <= .67).then(pl.lit("[.33, .67]"))
            .otherwise(pl.lit("(.67, 1]"))
            .alias("pscore_mean_category_thirds"),
            pl.when(pl.col("pscore_mean") < QUARTILE_BOUNDARIES[0]).then(1)
            .when(pl.col("pscore_mean") < QUARTILE_BOUNDARIES[1]).then(2)
            .when(pl.col("pscore_mean") < QUARTILE_BOUNDARIES[2]).then(3)
            .otherwise(4)
            .alias("pscore_mean_quartile")
        )
    )
protein_df

protein,method,metric,value,CAID3_binary_mean,len
str,str,str,f64,f64,u32
"""DP04378""","""UdonPred""","""spearman""",0.379525,0.287736,848
"""DP04064""","""UdonPred""","""spearman""",0.0,1.0,15
"""DP03804""","""UdonPred""","""spearman""",0.559341,0.792899,338
"""DP04232""","""UdonPred""","""spearman""",0.484382,0.516667,60
"""DP04122""","""UdonPred""","""spearman""",0.456177,0.076923,260
…,…,…,…,…,…
"""DP04237""","""AlphaFold3-pLDDT""","""aps""",0.733073,0.202532,79
"""DP04203""","""AlphaFold3-pLDDT""","""aps""",0.065223,0.105263,95
"""DP03804""","""AlphaFold3-pLDDT""","""aps""",0.997486,0.792899,338
"""DP03909""","""AlphaFold3-pLDDT""","""aps""",0.834199,0.054054,370


In [19]:
protein_df.write_csv(DATA_DIR / TEST_SET / "per_protein_performance.csv")
protein_stats.write_csv(DATA_DIR / TEST_SET / "per_protein_stats.csv")