## Comparing ESM-based models and RNASamba models for predicting coding and noncoding transcripts
__Keith Cheveralls__<br>
__March 2024__<br>

This notebook documents the visualizations that were used to compare the performance of ESM-based models and RNASamba models trained to predict whether transcripts are coding or noncoding. This was motivated by developing an approach that used ESM embeddings to identifying sORFs for the [peptigate pipeline](https://github.com/Arcadia-Science/peptigate).

The predictions from ESM-based models and RNASamba models on which this notebook depends were generated outside of this notebook. Predictions from ESM-based models were generated using the commands namespaced under the `plmutils orf-classification` CLI. Predictions from RNASamba models were generated using the script found in the `/scripts/rnasamba` subdirectory of this repo. The CLI commands that were used are briefly documented in the sections below. 

In [None]:
import io
import pathlib

import arcadia_pycolor as apc
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from plmutils.models import calc_metrics

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

### Dataset metadata
The metadata associated with the 16 species used for these comparisons is included below for completeness. Note that the plots in this notebook label species using the `species_id` defined in this metadata (rather than the full species name).

In [None]:
metadata_csv_content = """
species_id	species_common_name	root_url	genome_name	cdna_endpoint	ncrna_endpoint	genome_abbreviation
hsap	human	https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/	Homo_sapiens.GRCh38	cdna/Homo_sapiens.GRCh38.cdna.all.fa.gz	ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz	GRCh38
scer	yeast	https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi/release-58/fasta/saccharomyces_cerevisiae/	Saccharomyces_cerevisiae.R64-1-1	cdna/Saccharomyces_cerevisiae.R64-1-1.cdna.all.fa.gz	ncrna/Saccharomyces_cerevisiae.R64-1-1.ncrna.fa.gz	R64-1-1
cele	worm	https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/caenorhabditis_elegans/	Caenorhabditis_elegans.WBcel235	cdna/Caenorhabditis_elegans.WBcel235.cdna.all.fa.gz	ncrna/Caenorhabditis_elegans.WBcel235.ncrna.fa.gz	WBcel235
atha	arabadopsis	https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/arabidopsis_thaliana/	Arabidopsis_thaliana.TAIR10	cdna/Arabidopsis_thaliana.TAIR10.cdna.all.fa.gz	ncrna/Arabidopsis_thaliana.TAIR10.ncrna.fa.gz	TAIR10
dmel	drosophila	https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/drosophila_melanogaster/	Drosophila_melanogaster.BDGP6.46	cdna/Drosophila_melanogaster.BDGP6.46.cdna.all.fa.gz	ncrna/Drosophila_melanogaster.BDGP6.46.ncrna.fa.gz	BDGP6.46
ddis	dictyostelium_discoideum	https://ftp.ensemblgenomes.ebi.ac.uk/pub/protists/release-58/fasta/dictyostelium_discoideum/	Dictyostelium_discoideum.dicty_2.7	cdna/Dictyostelium_discoideum.dicty_2.7.cdna.all.fa.gz	ncrna/Dictyostelium_discoideum.dicty_2.7.ncrna.fa.gz	dicty_2.7
mmus	mouse	https://ftp.ensembl.org/pub/release-111/fasta/mus_musculus/	Mus_musculus.GRCm39	cdna/Mus_musculus.GRCm39.cdna.all.fa.gz	ncrna/Mus_musculus.GRCm39.ncrna.fa.gz	GRCm39
drer	zebrafish	https://ftp.ensembl.org/pub/release-111/fasta/danio_rerio/	Danio_rerio.GRCz11	cdna/Danio_rerio.GRCz11.cdna.all.fa.gz	ncrna/Danio_rerio.GRCz11.ncrna.fa.gz	GRCz11
ggal	chicken	https://ftp.ensembl.org/pub/release-111/fasta/gallus_gallus/	Gallus_gallus.bGalGal1.mat.broiler.GRCg7b	cdna/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.cdna.all.fa.gz	ncrna/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.ncrna.fa.gz	bGalGal1.mat.broiler.GRCg7b
oind	rice	https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/oryza_indica/	Oryza_indica.ASM465v1	cdna/Oryza_indica.ASM465v1.cdna.all.fa.gz	ncrna/Oryza_indica.ASM465v1.ncrna.fa.gz	ASM465v1
zmay	maize	https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/zea_mays/	Zea_mays.Zm-B73-REFERENCE-NAM-5.0	cdna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.cdna.all.fa.gz	ncrna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.ncrna.fa.gz	Zm-B73-REFERENCE-NAM-5.0
xtro	frog	https://ftp.ensembl.org/pub/release-111/fasta/xenopus_tropicalis/	Xenopus_tropicalis.UCB_Xtro_10.0	cdna/Xenopus_tropicalis.UCB_Xtro_10.0.cdna.all.fa.gz	ncrna/Xenopus_tropicalis.UCB_Xtro_10.0.ncrna.fa.gz	UCB_Xtro_10.0
rnor	rat	https://ftp.ensembl.org/pub/release-111/fasta/rattus_norvegicus/	Rattus_norvegicus.mRatBN7.2	cdna/Rattus_norvegicus.mRatBN7.2.cdna.all.fa.gz	ncrna/Rattus_norvegicus.mRatBN7.2.ncrna.fa.gz	mRatBN7
amel	honeybee	https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/apis_mellifera/	Apis_mellifera.Amel_HAv3.1	cdna/Apis_mellifera.Amel_HAv3.1.cdna.all.fa.gz	ncrna/Apis_mellifera.Amel_HAv3.1.ncrna.fa.gz	Amel_HAv3.1
spom	fission_yeast	https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi/release-58/fasta/schizosaccharomyces_pombe/	Schizosaccharomyces_pombe.ASM294v2	cdna/Schizosaccharomyces_pombe.ASM294v2.cdna.all.fa.gz	ncrna/Schizosaccharomyces_pombe.ASM294v2.ncrna.fa.gz	ASM294v2
tthe	tetrahymena	https://ftp.ensemblgenomes.ebi.ac.uk/pub/protists/release-58/fasta/tetrahymena_thermophila/	Tetrahymena_thermophila.JCVI-TTA1-2.2	cdna/Tetrahymena_thermophila.JCVI-TTA1-2.2.cdna.all.fa.gz	ncrna/Tetrahymena_thermophila.JCVI-TTA1-2.2.ncrna.fa.gz	JCVI-TTA1-2.2
"""  # noqa: E501

metadata = pd.read_csv(io.StringIO(metadata_csv_content), sep="\t")
metadata.head()

### Heatmap plotting functions
These are functions used later in the notebook to generate heatmap visualizations of the matrices of model performance metrics for all pairs of training and test species. 

In [None]:
apc.mpl._load_fonts("../Fonts")
apc.mpl._load_styles()

In [None]:
def plot_heatmap(
    df,
    column="accuracy",
    model_name="unknown",
    ax=None,
    colormap_type="sequential",
    **heatmap_kwargs,
):
    """
    Plot the values in the given column as a square heatmap of training vs test species
    (with training species on the x-axis and test species on the y-axis).

    Note: "training species" is the species used to train the model and "test species"
    is the species used to test each trained model.
    """
    if colormap_type == "sequential":
        gradient = apc.gradients.teals
    elif colormap_type == "diverging":
        gradient = apc.gradients.aster_canary
    else:
        raise ValueError(f"Unknown colormap type: {colormap_type}")

    # reverse the gradient so that the highest values are darkest.
    colors = [
        (1 - value, color.hex_code)
        for value, color in zip(gradient.values[::-1], gradient.colors[::-1], strict=True)
    ]
    colormap = matplotlib.colors.LinearSegmentedColormap.from_list("", colors=colors)

    df = df.pivot(index="test_species_id", columns="training_species_id", values=column)

    if ax is None:
        plt.figure(figsize=(8, 6))
        ax = plt.gca()

    sns.heatmap(
        df,
        cmap=colormap,
        annot=False,
        annot_kws={"size": 6},
        fmt=".1f",
        square=True,
        ax=ax,
        cbar_kws={"shrink": 0.7},
        **heatmap_kwargs,
    )

    name = column.replace("_", " ")
    if name.lower() == "mcc":
        name = name.upper()
    else:
        name = name[0].upper() + name[1:]

    ax.set_xlabel("Training species")
    ax.set_ylabel("Test species")
    ax.set_title(f"{name} | {model_name}", fontdict={"family": "Suisse Int'l"})
    ax.tick_params(axis="both", which="both", pad=5, size=0)

    apc.mpl.monospace_ticklabels(font="Suisse Int'l", axis=ax)
    apc.mpl.autostyle(axis=ax, cbar=True, cat=None)

In [None]:
def plot_heatmaps(df_left, df_right, column, model_names):
    """
    Plot a row of three heatmaps: one for the left dataframe, one for the right dataframe,
    and the third (the rightmost) for the difference between the two (right minus left).
    """
    fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(42, 14))

    df_merged = pd.merge(df_left, df_right, on=("training_species_id", "test_species_id"))
    df_merged[column] = df_merged[f"{column}_y"] - df_merged[f"{column}_x"]

    plot_heatmap(df_left, column=column, model_name=model_names[0], ax=axs[0], vmin=0, vmax=1)
    plot_heatmap(df_right, column=column, model_name=model_names[1], ax=axs[1], vmin=0, vmax=1)
    plot_heatmap(
        df_merged,
        column=column,
        model_name="difference",
        ax=axs[2],
        vmin=-1,
        vmax=1,
        colormap_type="diverging",
    )

### ESM-based model predictions

These predictions were generated using the `plmutils orf-prediction` CLI. 

First, download the Ensembl datasets listed in the user-provided metadata CSV file (see above for the file used with this notebook):
```
plmutils orf-prediction download-data \
    output/data/ensembl-dataset-metadata.tsv \
    output/data/
```

Next, construct deduplicated sets of coding and noncoding transcripts. Deduplication is achieved by clustering transcripts by sequence identity and retaining only one representative sequence from each cluster.
```
plmutils orf-prediction construct-data \
    output/data/ensembl-dataset-metadata.tsv \
    output/data/ \
    --subsample-factor 1
```

Next, find putative ORFs from coding and noncoding transcripts, retain only the longest putative ORF from each transcript, and generate the embedding of the protein sequence for which it codes:
```
plmutils orf-prediction translate-and-embed \
    output/data/processed/final/coding-dedup-ssx1/transcripts

plmutils orf-prediction translate-and-embed \
    output/data/processed/final/noncoding-dedup-ssx1/transcripts   
```

Finally, train models using these embeddings to predict whether a given ORF orginated from a coding or noncoding transcript. Separate models are trained on, and used to make predictions for, each species. This results in a matrix of model performance metrics for all pairs of species (one used to train the model, the other to evaluate it). The `--output-dirpath` in the command below corresponds to the directories passed to the `calc_metrics_from_smallesm_results` function defined below. (This command was run manually with and without `--max-length 100` to train models on all ORFs and only sORFs, respectively).
```
plmutils orf-prediction train-and-evaluate \
    --coding-dirpath output/data/processed/final/coding-dedup-ssx1/embeddings/esm2_t6_8M_UR50D \
    --noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/embeddings/esm2_t6_8M_UR50D \
    --output-dirpath output/data/esm-model-results-ssx1-all
```



In [None]:
def calc_metrics_from_smallesm_results(results_dirpath, max_length=None):
    """
    Calculate classification metrics from ESM-based model results.
    """
    all_metrics = []
    prediction_filepaths = pathlib.Path(results_dirpath).glob("*.csv")
    for prediction_filepath in prediction_filepaths:
        df = pd.read_csv(prediction_filepath)

        if max_length is not None:
            df = df.loc[df.sequence_length < max_length]

        metrics = calc_metrics(
            y_true=(df.true_label == "coding"),
            y_pred_proba=df.predicted_probability.values,
        )
        metrics["training_species_id"] = df.iloc[0].training_species_id
        metrics["test_species_id"] = df.iloc[0].testing_species_id
        metrics["num_coding"] = (df.true_label == "coding").sum()
        metrics["num_noncoding"] = (df.true_label != "coding").sum()

        all_metrics.append(metrics)
    df = pd.DataFrame(all_metrics)
    df["true_negative_rate"] = df.num_true_negative / df.num_noncoding
    return df

In [None]:
metrics_esm_trained_all_eval_all = calc_metrics_from_smallesm_results(
    "../output/results/2024-03-01-esm-model-results-ssx1-all/",
    max_length=None,
)
metrics_esm_trained_all_eval_short = calc_metrics_from_smallesm_results(
    "../output/results/2024-03-01-esm-model-results-ssx1-all/",
    max_length=100,
)
metrics_esm_trained_short_eval_all = calc_metrics_from_smallesm_results(
    "../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/",
    max_length=None,
)
metrics_esm_trained_short_eval_short = calc_metrics_from_smallesm_results(
    "../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/",
    max_length=100,
)

In [None]:
metrics_esm_trained_all_eval_all.head()

#### Compare ESM-based models trained on all ORFs and only sORFs

In [None]:
# models trained on either all ORFs or only sORFs and evaluated on only sORFs.
plot_heatmaps(
    metrics_esm_trained_all_eval_short,
    metrics_esm_trained_short_eval_short,
    column="mcc",
    model_names=("ESM-based (trained all, eval short)", "ESM-based (trained short, eval short)"),
)

In [None]:
# models trained only on sORFs and evaluated on all or only sORFs.
plot_heatmaps(
    metrics_esm_trained_short_eval_all,
    metrics_esm_trained_short_eval_short,
    column="mcc",
    model_names=("ESM-based (trained short, eval all)", "ESM-based (trained short, eval short)"),
)

In [None]:
# models trained on all ORFs or only sORFs, but evaluated on all sequences.
plot_heatmaps(
    metrics_esm_trained_all_eval_all,
    metrics_esm_trained_short_eval_all,
    column="mcc",
    model_names=("ESM-based (trained all, eval all)", "ESM-based (trained short, eval all)"),
)

### RNASamba predictions

These predictions were generated by the script `plm-utils/scripts/rnasamba/train_and_evaluate.py` using the same datasets of deduplicated coding and noncoding transcripts generated by the `plmutils orf-prediction construct-data` command describe above. 

To train RNASamba models on all sequences:
```
python scripts/rnasamba-comparison/train_and_evaluate.py \
--coding-dirpath output/data/processed/final/coding-dedup-ssx1/transcripts \
--noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/transcripts \
--output-dirpath 2024-02-28-rnasamba-results-ssx1-all \
```

To train RNASamba models on transcripts corresponding to sORFs:
```
python scripts/rnasamba-comparison/train_and_evaluate.py \
--coding-dirpath output/data/processed/final/coding-dedup-ssx1/transcripts \
--noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/transcripts \
--output-dirpath output/data/2024-02-28-rnasamba-results-ssx1-min-peptide-length-100 \
--max-length 100
```
The `--output-dirpath` above corresponds to the directory passed to the `calc_metrics_from_rnasamba_results` function below.

In [None]:
def calc_metrics_from_rnasamba_results(rnasamba_results_dirpath):
    """
    Aggregate the results from RNASamba models trained in the script
    `scripts/rnasamba-comparison/train_and_evaluate.py`.
    """
    all_metrics = []
    dirpaths = [p for p in rnasamba_results_dirpath.glob("trained-on*") if p.is_dir()]
    for dirpath in dirpaths:
        # dirnames are of the form 'trained-on-{species_id}-filtered'.
        training_species_id = dirpath.stem.split("-")[2]

        prediction_filepaths = dirpath.glob("*.tsv")
        for prediction_filepath in prediction_filepaths:
            # filenames are of the form '{species_id}-preds.csv'.
            test_species_id = prediction_filepath.stem.split("-")[0]

            df = pd.read_csv(prediction_filepath, sep=",")
            metrics = calc_metrics(
                y_true=(df.true_label == "coding"), y_pred_proba=df.coding_score.values
            )
            metrics["training_species_id"] = training_species_id
            metrics["test_species_id"] = test_species_id
            metrics["num_coding"] = (df.true_label == "coding").sum()
            metrics["num_noncoding"] = (df.true_label != "coding").sum()

            all_metrics.append(metrics)

    df = pd.DataFrame(all_metrics)
    df["true_negative_rate"] = df.num_true_negative / df.num_noncoding
    return df

In [None]:
# models trained and tested on all transcripts.
rnasamba_results_dirpath_all = pathlib.Path(
    "../output/results/2024-02-23-rnasamba-models-clustered-ssx3/"
)

# models trained and tested only on transcripts whose longest ORFs are sORFs.
rnasamba_results_dirpath_short = pathlib.Path(
    "../output/results/2024-02-28-rnasamba-results-ssx1-max-peptide-length-100/"
)

metrics_rs_trained_all_eval_all = calc_metrics_from_rnasamba_results(rnasamba_results_dirpath_all)
metrics_rs_trained_short_eval_short = calc_metrics_from_rnasamba_results(
    rnasamba_results_dirpath_short
)

#### Compare RNASamba models trained on all or only sORFs

In [None]:
plot_heatmaps(
    metrics_rs_trained_all_eval_all,
    metrics_rs_trained_short_eval_short,
    column="mcc",
    model_names=("RNASamba (all)", "RNASamba (short)"),
)

### Compare RNASamba and ESM-based models

These are the most important plots in this notebook. They compare the performance of ESM-based models to RNASamba models by plotting the heatmap of performance metrics side by side.

#### Models trained and evaluated on all transcripts (for RNASamba) or ORFs (for ESM-based)

In [None]:
# overall performance (MCC metric)
plot_heatmaps(
    metrics_rs_trained_all_eval_all,
    metrics_esm_trained_all_eval_all,
    column="mcc",
    model_names=("RNASamba (all)", "plm-utils (all)"),
)
plt.savefig(
    "figures/2024-05-31-mcc-rnasamba-vs-plmutils-all-transcripts.pdf", dpi=72, bbox_inches="tight"
)

In [None]:
# recall (also the true positive rate, or num_true_positive / num_coding)
plot_heatmaps(
    metrics_rs_trained_all_eval_all,
    metrics_esm_trained_all_eval_all,
    column="recall",
    model_names=("RNASamba (all)", "ESM-based (all)"),
)

In [None]:
# the true negative rate.
plot_heatmaps(
    metrics_rs_trained_all_eval_all,
    metrics_esm_trained_all_eval_all,
    column="true_negative_rate",
    model_names=("RNASamba (all)", "ESM-based (all)"),
)

#### Models trained only on short sequences (< 100aa)

For RNASamba, this means the models were trained only on transcripts whose longest ORF was an sORF (less than 100aa long). 

Note that the class imbalance in this case is severe (most species do not have many coding transcripts whose longest ORF is an sORF) and this likely at least partly explains why the RNASamba models perform so poorly, as we do not compensate for the class imbalance during training (while we do compensate for it when training the ESM-based models). 

In [None]:
plot_heatmaps(
    metrics_rs_trained_short_eval_short,
    metrics_esm_trained_short_eval_short,
    column="mcc",
    model_names=("RNASamba (short)", "plm-utils (short)"),
)
plt.savefig("figures/2024-05-31-mcc-rnasamba-vs-plmutils-short.pdf", dpi=72, bbox_inches="tight")

In [None]:
plot_heatmaps(
    metrics_rs_trained_short_eval_short,
    metrics_esm_trained_short_eval_short,
    column="recall",
    model_names=("RNASamba (short)", "ESM-based (short)"),
)

In [None]:
plot_heatmaps(
    metrics_rs_trained_short_eval_short,
    metrics_esm_trained_short_eval_short,
    column="true_negative_rate",
    model_names=("RNASamba (short)", "ESM-based (short)"),
)

### Aside: blasting against peptipedia

We were curious whether some of the false positives from ESM-based models represented genuine sORFs from lncRNAs (which are annotated as noncoding). As a way to examine this, we blasted all of the putative ORFs against peptipedia, and plotted the distribution of max evalues from putative sORFs for which the ESM-based model made either true and false positive predictions. If the model correctly identifies genuine sORFs from lncRNAs, we'd expect to see an enrichment of low evalues among the false positives.

The command `plmutils orf-classification blast-peptipedia` was used to generate the directory of blast results that are loaded and concatenated by `concat_smallesm_results` function below. 

In [None]:
def concat_smallesm_results(results_dirpath):
    """
    Load and concatenate the predictions from esm-based models.
    """
    dfs = []
    prediction_filepaths = pathlib.Path(results_dirpath).glob("*.csv")
    for prediction_filepath in prediction_filepaths:
        dfs.append(pd.read_csv(prediction_filepath))

    return pd.concat(dfs)

In [None]:
# predictions from models trained on all putative ORFs.
esm_trained_all_preds = concat_smallesm_results(
    "../output/results/2024-03-01-esm-model-results-ssx1-all/"
)

In [None]:
# predictions from models trained on short peptides (< 100aa).
esm_trained_short_preds = concat_smallesm_results(
    "../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/"
)

In [None]:
esm_trained_all_preds.shape, esm_trained_short_preds.shape

In [None]:
esm_trained_short_preds.head()

In [None]:
# count the number of peptides from coding and noncoding transcripts to make sure
# that the class imbalance between coding and noncoding is not too severe.
# (we only need to look at preds from one model, since each model is tested with all species).
hsap_preds = esm_trained_all_preds.loc[esm_trained_all_preds.training_species_id == "hsap"].copy()
pd.merge(
    hsap_preds.groupby(["testing_species_id", "true_label"]).count().sequence_id,
    (
        hsap_preds.loc[hsap_preds.sequence_length < 100]
        .groupby(["testing_species_id", "true_label"])
        .count()
        .sequence_id
    ),
    left_index=True,
    right_index=True,
    suffixes=("_all", "_short"),
)

In [None]:
def concat_blast_results(dirpaths):
    """
    Aggregate the blast results generated by `plmutils orf-classification blast-peptipedia`.
    """
    blast_results_columns = (
        "qseqid sseqid full_sseq pident length qlen slen mismatch gapopen qstart qend sstart send evalue bitscore"  # noqa: E501
    ).split(" ")

    dfs = []
    for dirpath in dirpaths:
        filepaths = pathlib.Path(dirpath).glob("*.tsv")
        for filepath in filepaths:
            try:
                df = pd.read_csv(filepath, sep="\t")
            except Exception:
                continue
            df.columns = blast_results_columns
            dfs.append(df)
    return pd.concat(dfs)

In [None]:
blast_results = concat_blast_results(
    [
        "../output/data/processed/final/coding-dedup-ssx1/blast-peptipedia-results/",
        "../output/data/processed/final/noncoding-dedup-ssx1/blast-peptipedia-results/",
    ]
)

In [None]:
# use the log of the evalue for readability.
blast_results["evalue"] = np.log(blast_results.evalue)

# we only need to examine the minimum evalue for all hits to each peptide.
min_evalues = blast_results.groupby("qseqid").evalue.min().reset_index()

In [None]:
# merge the minimum evalues with the model predictions.
esm_trained_short_preds_w_evalues = pd.merge(
    esm_trained_short_preds, min_evalues, left_on="sequence_id", right_on="qseqid", how="inner"
)

esm_trained_all_preds_w_evalues = pd.merge(
    esm_trained_all_preds, min_evalues, left_on="sequence_id", right_on="qseqid", how="inner"
)

In [None]:
esm_trained_short_preds_w_evalues_short_only = esm_trained_short_preds_w_evalues.loc[
    esm_trained_short_preds_w_evalues.sequence_length < 100
].copy()

In [None]:
# sanity-check: count the number of peptides that had hits in peptipedia.
(
    esm_trained_short_preds_w_evalues_short_only
    # we only need to look at one model
    .loc[esm_trained_short_preds_w_evalues_short_only.training_species_id == "hsap"]
    .groupby(["testing_species_id", "true_label"])
    .count()[["sequence_id"]]
)

#### Histograms of evalues for coding and noncoding transcripts

This was to determine whether the false positives were enriched for peptides that had hits in peptipedia, which would suggest that they correspond to genuine sORFs from lncRNAs (and are therefore not actually false positives).

In [None]:
# we only look at preds for short peptides from the human dataset
# because it is one of the only that has a decent number of short peptides
# with peptipedia hits and are from noncoding transcripts.
preds = esm_trained_all_preds_w_evalues.loc[
    (esm_trained_all_preds_w_evalues.training_species_id == "hsap")
    & (esm_trained_all_preds_w_evalues.testing_species_id == "hsap")
    & (esm_trained_all_preds_w_evalues.sequence_length < 100)
]

fig, axs = plt.subplots(1, 2, figsize=(16, 6))

min_min_evalue = -150
bins = np.arange(min_min_evalue, 0, -min_min_evalue / 30)
kwargs = dict(bins=bins, density=False, alpha=0.5)

# left axis: coding transcripts
ax = axs[0]
ax.hist(
    preds[(preds.true_label == "coding") & (preds.predicted_probability > 0.5)].evalue,
    label="True positives",
    color="blue",
    **kwargs,
)
ax.hist(
    preds[(preds.true_label == "coding") & (preds.predicted_probability < 0.5)].evalue,
    label="False negatives",
    color="red",
    **kwargs,
)
ax.legend()
ax.set_xlabel("Minimum log evalue")
ax.set_ylabel("Density")
ax.set_title("Coding transcripts")

# right axis: noncoding transcripts
ax = axs[1]
ax.hist(
    preds[(preds.true_label == "noncoding") & (preds.predicted_probability < 0.5)].evalue,
    label="True negatives",
    color="blue",
    **kwargs,
)
_ = ax.hist(
    preds[(preds.true_label == "noncoding") & (preds.predicted_probability > 0.5)].evalue,
    label="False positives",
    color="red",
    **kwargs,
)
ax.legend()
ax.set_title("Noncoding transcripts")