# Compare predictions

Here, we'll compare our predictions with those by other authors. 
We'll treat the combination of validation and testing datasets as ground truth and check how well other authors predicted TISs across these data.

## Prerequisites

This notebook requires:
- [hg38.fa]()
- [Our predictions](); either download or go through `predict_5UTR.ipynb`
- "uORF_annotation_hg38.csv" from [Scholtz et. al](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0222459)
- "Supplemental_Data_Tables_.xlsx" from [McGillivray et. al](https://academic.oup.com/nar/article/46/7/3326/4942470)
- "elife-08890-supp1-v2.xlsx" from [Ji et. al](https://elifesciences.org/articles/08890)

Tables from the papers above can be obtained via [this link](https://drive.google.com/file/d/1o1YhRuF4Dp122NWSmcehiLTwf68jwLnK/view?usp=sharing). Unpack the files and place them into the `data` directory (or provide paths manually below).
```
data
|____hg38.fa
|____predictions_5UTR_v4.7.csv
|____others
| |____elife-08890-supp1-v2.xlsx
| |____uORF_annotation_hg38.csv
| |____Supplemental_Data_Tables_.xlsx
```

In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
from more_itertools import unzip
from pyliftover import LiftOver
from sklearn.metrics import f1_score, recall_score, precision_score, balanced_accuracy_score
from tqdm.auto import tqdm
from uBERTa.base import VALID_START
from uBERTa.utils import Ref, reverse_complement

In [2]:
BASE = Path('../data')
BASE.mkdir(exist_ok=True)

REF = BASE / 'hg38.fa'
SCHOLTZ = BASE / 'others' / 'uORF_annotation_hg38.csv'
JI = BASE / 'others' / 'elife-08890-supp1-v2.xlsx'
GIL = BASE / 'others' / 'Supplemental_Data_Tables_.xlsx'
PRED = BASE / 'XGB' / 'predictions_5UTR_v4.7.csv'

VALID_START = VALID_START

In [3]:
ref = Ref(REF)

In [29]:
def fetch_codons(df, ref):
    """
    Given a table with at columns "Chrom", "Start", "Strand", fetch sequences
        of regions [Start, Start + 3] for a each row.
    For "-" strand sequences, take reverse complements.
    """
    handle_seq = lambda seq, strand: (
        seq.upper() if strand == '+' else reverse_complement(seq).upper())
    return [
        handle_seq(
            ref.fetch(chrom, start, start + 3), strand) 
        for chrom, start, strand in 
        tqdm(
            df[['Chrom', 'Start', 'Strand']].itertuples(index=False), 
            total=len(df), desc='Fetching')
    ]

def filter_codons(df, valid=VALID_START):
    """
    When "Codon" column is present, filter to rows where manually fetched 
        codon matches the expected codon.
    Filter to rows where codon is among the `valid` sequence of codons.
    """
    df = df.copy()
    if 'Codon' in df.columns and 'CodonFetched' in df.columns:
        idx = df.Codon != df.CodonFetched
        print(f'Non-matching codons: {idx.sum()}')
        df = df[~idx]
        df = df.drop(columns='CodonFetched')
    idx = ~df.Codon.isin(valid)
    print(f'Invalid start codons: {idx.sum()}')
    df = df[~idx]
    return df

def lift_starts(df):
    """
    Lift "Start" coordinates from hg19 to hg38.
    """
    def convert(*args):
        conv_list = lo.convert_coordinate(*args)
        if conv_list:
            return conv_list[0][1]
        return np.nan
    
    lo = LiftOver('hg19', 'hg38')
    return [
        convert(*x[1:]) for x in 
        df[['Chrom', 'Start', 'Strand']].itertuples()]

def offset_starts(df, tot_offset=-1, neg_offset=-2):
    """
    Offset start sites depending on a strand so they match our convention, 
    i.e., 0-based coordinates with "Start" pointing at the first nucleotide of a start codon.
    The latter is inverted along with the sequence when the strand is negative, 
    e.g., CAT 1,2,3 -> ATG 3,2,1 => Start=3
    """
    return [
        (end + neg_offset if strand == '-' else start) + tot_offset 
        for _, start, end, strand in 
        df[['Start', 'End', 'Strand']].itertuples()]

def lift_and_fetch(df, ref):
    df = df.copy()
    df['Start'] = lift_starts(df)
    df = df[~df.Start.isna()]
    df['Start'] = df['Start'].astype(int)
    df['CodonFetched'] = fetch_codons(df, ref)
    return df

def parse_ji(path, ref):
    df = pd.read_excel(
        path, sheet_name='uORF', usecols=[0, 2, 3, 4, 5], 
        names=['Gene', 'Chrom', 'Strand', 'Start', 'End']
    )
    df = df[df.Chrom.apply(lambda x: '_' not in x)]
    # Offset for the negative strand is -3, total offset is zero
    df['Start'] = offset_starts(df, 0, -3)
    df = lift_and_fetch(df, ref)
    df = df.rename(columns={'CodonFetched': 'Codon'}).drop(columns='End')
    df = filter_codons(df)
    return df

def parse_scholtz(path, ref):
    df = pd.read_csv(
        path, skiprows=1,
        usecols=[2, 3, 4, 5, 7, 8],
        names=['Chrom', 'Start', 'End', 'GeneID', 'Strand', 'Codon']
    )
    # Offset for the negative strand is -2, total offset is -1, no lifting is needed.
    df['Start'] = offset_starts(df)
    df['CodonFetched'] = fetch_codons(df, ref)
    df = filter_codons(df)
    df['GeneID'] = df['GeneID'].apply(lambda x: x.split('.')[0])
    df = df.drop(columns='End')
    return df

def parse_gil(path, ref):
    df = pd.read_excel(
        path, sheet_name='Supplemental_Table_4', 
        usecols=[0, 2, 3, 4, 5, 6], skiprows=3,
        names=['ID', 'Codon', 'Chrom', 'Strand', 'Start', 'End']
    )
    df['TranscriptID'] = df['ID'].apply(lambda x: x.split('.')[0])
    df = df.drop(columns='ID')
    idx = df.Strand == '-'
    # For + strand, offset by -1, for - strand, offset by -3
    df.loc[~idx, 'Start'] = df.loc[~idx, 'Start'] - 1
    df.loc[idx, 'Start'] = df.loc[idx, 'Start'] - 3
    # Lift coordinates and filter
    df = lift_and_fetch(df, ref)
    df = filter_codons(df)
    df = df.drop(columns='End')
    return df

def parse_pred(path):
    """
    Read the dataset with predictions and filter to Test and Val datasets.
    
    Offset start coordinates. 
    In our convention, we used 0-based coordinates with 
        "Start" pointing at the first nucleotide of a start codon.
    The latter was inverted along with the sequence when the strand is negative, 
        e.g., CAT 1,2,3 -> ATG 3,2,1 => Start=3
    Now, we offset (back) the start by -2 so that it always points 
        to the first nucleotide of the + strand (so Start=1 in the above example).
    """
    df = pd.read_csv(path).rename(columns={'SeqEnum': 'Start', 'Start': 'Codon'})
    df = df[df.Dataset == 'Test']
    idx = df.Strand == '-'
    df.loc[idx, 'Start'] = df.loc[idx, 'Start'] - 2
    return df

def annotate_predictions(df):
    df = df.copy()
    df['PredictionType'] = 'TP'
    df.loc[(df.y_true == 1) & (df.y_pred == 0), 'PredictionType'] = 'FN'
    df.loc[(df.y_true == 0) & (df.y_pred == 1), 'PredictionType'] = 'FP'
    df.loc[(df.y_true == 0) & (df.y_pred == 0), 'PredictionType'] = 'TN'
    return df

def calc_pred_scores(df):
    y_pred = df['y_pred'].values
    y_true = df['y_true'].values
    fn, fp, tn, tp = map(
        lambda x: len(df[df.PredictionType == x]), 
        ['FN', 'FP', 'TN', 'TP'])
    return {
        'f1': f1_score(y_true, y_pred, zero_division=0), 
        'prc': precision_score(y_true, y_pred, zero_division=0), 
        'rec': recall_score(y_true, y_pred, zero_division=0),
        'bac': balanced_accuracy_score(y_true, y_pred),
        # 'roc_auc': roc_auc_score(y_true, y_prob),
        'FN': fn, 'FP': fp, 'TN': tn, 'TP': tp,
    }

def merge_and_score(
    df_pred, df_comp, df_comp_name, 
    on=['Chrom', 'Strand', 'Start', 'Codon']
):
    df_comp = df_comp.copy()
    df_pred = df_pred.copy()
    df_comp['Dataset'] = df_comp_name
    df = df_pred.merge(
        df_comp, how='left', on=on, suffixes=['_pred', '_comp'])
    print(f'Merged size: {len(df)}')
    comp_codons = set(df_comp.Codon)
    df = df[df.Codon.isin(comp_codons)]
    print(f'Filtered to {comp_codons} start codons: {len(df)}')
    df['y_pred'] = 1
    df.loc[df.Dataset_comp.isna(), 'y_pred'] = 0
    df = annotate_predictions(df)
    df = df.drop_duplicates(['Chrom', 'Strand', 'Codon', 'Start'])
    # idx_of_codons = ((codon, df.Codon == codon) for codon in comp_codons)
    scores = {codon: calc_pred_scores(df[df.Codon == codon]) for codon in comp_codons}
    return df, scores

def unravel_scores(scores):
    for ds_name, ds_vs in scores.items():
        for codon_name, codon_scores in ds_vs.items():
            for score_name, score_val in codon_scores.items():
                yield ds_name, codon_name, score_name, score_val

In [5]:
pred = parse_pred(PRED)

In [18]:
scholtz = parse_scholtz(SCHOLTZ, ref)
gil = parse_gil(GIL, ref)
ji = parse_ji(JI, ref)

Fetching:   0%|          | 0/1933 [00:00<?, ?it/s]

Non-matching codons: 0
Invalid start codons: 0


Fetching:   0%|          | 0/188787 [00:00<?, ?it/s]

Non-matching codons: 792
Invalid start codons: 0


Fetching:   0%|          | 0/6614 [00:00<?, ?it/s]

Invalid start codons: 36


In [30]:
ds_names = ['Scholtz', 'McGillivray', 'Ji']
merged_dfs, scores = map(
    list,
    unzip(merge_and_score(pred, ds, ds_name) for ds, ds_name in 
     zip([scholtz, gil, ji], ds_names)))
scores = {ds_name: s for ds_name, s in zip(ds_names, scores)}

Merged size: 30856
Filtered to {'ATG'} start codons: 2116
Merged size: 30960
Filtered to {'GTG', 'ATC', 'ACG', 'CTG', 'ATT', 'TTG', 'ATG', 'ATA'} start codons: 22177
Merged size: 30871
Filtered to {'GTG', 'ATC', 'CTG', 'TTG', 'ATG'} start codons: 16936


In [31]:
df_scores = pd.DataFrame(
    unravel_scores(scores), 
    columns=['Dataset', 'Codon', 'ScoreType', 'ScoreVal']
).round(2).pivot(
    index=['Dataset', 'Codon'], columns='ScoreType', values='ScoreVal'
)

for c in ['FN', 'FP', 'TN', 'TP']:
    df_scores[c] = df_scores[c].astype(int)

df_scores['P'] = df_scores['TP'] + df_scores['FN']
df_scores['Total'] = (
    df_scores['TP'] + df_scores['FN'] + df_scores['FP'] + df_scores['TN']
)

df_scores = df_scores.reset_index().sort_values(
    ['Dataset', 'P'], ascending=[True, False]
).set_index(['Dataset', 'Codon'])[[
    'f1', 'prc', 'rec', 'bac', 'TN', 'FN', 'FP', 'TP', 'P', 'Total'
]]

df_scores

Unnamed: 0_level_0,ScoreType,f1,prc,rec,bac,TN,FN,FP,TP,P,Total
Dataset,Codon,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Ji,CTG,0.08,0.29,0.05,0.52,5653,185,22,9,194,5869
Ji,ATG,0.45,0.78,0.32,0.65,1931,116,15,54,170,2116
Ji,GTG,0.12,0.35,0.07,0.53,3787,79,11,6,85,3883
Ji,TTG,0.07,0.15,0.05,0.52,3050,42,11,2,44,3105
Ji,ATC,0.0,0.0,0.0,0.5,1923,25,0,0,25,1948
McGillivray,CTG,0.17,0.1,0.51,0.68,4839,96,836,98,194,5869
McGillivray,ATG,0.37,0.33,0.43,0.68,1796,97,150,73,170,2116
McGillivray,GTG,0.12,0.07,0.51,0.68,3225,42,573,43,85,3883
McGillivray,ACG,0.14,0.08,0.55,0.68,1150,20,267,24,44,1461
McGillivray,TTG,0.11,0.06,0.57,0.72,2654,19,407,25,44,3105


In [36]:
for n in ['McGillivray', 'Ji', 'Scholtz']:
    
    print(df_scores.loc[n].drop(columns=['P', 'Total']).to_latex())

\begin{tabular}{lrrrrrrrr}
\toprule
ScoreType &    f1 &   prc &   rec &   bac &    TN &  FN &   FP &  TP \\
Codon &       &       &       &       &       &     &      &     \\
\midrule
CTG   &  0.17 &  0.10 &  0.51 &  0.68 &  4839 &  96 &  836 &  98 \\
ATG   &  0.37 &  0.33 &  0.43 &  0.68 &  1796 &  97 &  150 &  73 \\
GTG   &  0.12 &  0.07 &  0.51 &  0.68 &  3225 &  42 &  573 &  43 \\
ACG   &  0.14 &  0.08 &  0.55 &  0.68 &  1150 &  20 &  267 &  24 \\
TTG   &  0.11 &  0.06 &  0.57 &  0.72 &  2654 &  19 &  407 &  25 \\
ATC   &  0.10 &  0.05 &  0.48 &  0.69 &  1713 &  13 &  210 &  12 \\
ATT   &  0.05 &  0.03 &  0.47 &  0.68 &  2113 &   8 &  248 &   7 \\
ATA   &  0.04 &  0.02 &  0.50 &  0.71 &  1206 &   2 &  105 &   2 \\
\bottomrule
\end{tabular}

\begin{tabular}{lrrrrrrrr}
\toprule
ScoreType &    f1 &   prc &   rec &   bac &    TN &   FN &  FP &  TP \\
Codon &       &       &       &       &       &      &     &     \\
\midrule
CTG   &  0.08 &  0.29 &  0.05 &  0.52 &  5653 &  185 &  22 

  print(df_scores.loc[n].drop(columns=['P', 'Total']).to_latex())
  print(df_scores.loc[n].drop(columns=['P', 'Total']).to_latex())
  print(df_scores.loc[n].drop(columns=['P', 'Total']).to_latex())
