# Compare predictions

Here, we'll compare our predictions with those by other authors. 
We'll treat the combination of validation and testing datasets as ground truth and check how well other authors predicted TISs across these data.

## Prerequisites

This notebook requires:
- [hg38.fa]()
- [Our predictions](); either download or go through `predict_5UTR.ipynb`
- "uORF_annotation_hg38.csv" from [Scholtz et. al](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0222459)
- "Supplemental_Data_Tables_.xlsx" from [McGillivray et. al](https://academic.oup.com/nar/article/46/7/3326/4942470)
- "elife-08890-supp1-v2.xlsx" from [Ji et. al](https://elifesciences.org/articles/08890)

Tables from the papers above can be obtained via [this link](https://drive.google.com/file/d/1o1YhRuF4Dp122NWSmcehiLTwf68jwLnK/view?usp=sharing). Unpack the files and place them into the `data` directory (or provide paths manually below).
```
data
|____hg38.fa
|____predictions_5UTR.tsv
|____others
| |____elife-08890-supp1-v2.xlsx
| |____uORF_annotation_hg38.csv
| |____Supplemental_Data_Tables_.xlsx
```

In [30]:
from pathlib import Path

import numpy as np
import pandas as pd
from more_itertools import unzip
from pyliftover import LiftOver
from sklearn.metrics import f1_score, recall_score, precision_score
from tqdm.auto import tqdm
from uBERTa.utils import Ref, reverse_complement

In [2]:
BASE = Path('../data')
BASE.mkdir(exist_ok=True)

REF = BASE / 'hg38.fa'
SCHOLTZ = BASE / 'others' / 'uORF_annotation_hg38.csv'
JI = BASE / 'others' / 'elife-08890-supp1-v2.xlsx'
GIL = BASE / 'others' / 'Supplemental_Data_Tables_.xlsx'
PRED = BASE / 'predictions.tsv'

VALID_START = ('ACG', 'ATC', 'ATG', 'ATT', 'CTG', 'GTG')

In [3]:
ref = Ref(REF)

In [67]:
def fetch_codons(df, ref):
    """
    Given a table with at columns "Chrom", "Start", "Strand", fetch sequences
        of regions [Start, Start + 3] for a each row.
    For "-" strand sequences, take reverse complements.
    """
    handle_seq = lambda seq, strand: (
        seq.upper() if strand == '+' else reverse_complement(seq).upper())
    return [
        handle_seq(
            ref.fetch(chrom, start, start + 3), strand) 
        for _, chrom, start, strand in 
        tqdm(
            df[['Chrom', 'Start', 'Strand']].itertuples(), 
            total=len(df), desc='Fetching')
    ]

def lift_starts(df):
    """
    Lift "Start" coordinates from hg19 to hg38.
    """
    def convert(*args):
        conv_list = lo.convert_coordinate(*args)
        if conv_list:
            return conv_list[0][1]
        return np.nan
    
    lo = LiftOver('hg19', 'hg38')
    return [
        convert(*x[1:]) for x in 
        df[['Chrom', 'Start', 'Strand']].itertuples()]

def offset_starts(df, tot_offset=-1, neg_offset=-2):
    """
    Offset start sites depending on a strand so they match our convention, 
    i.e., 0-based coordinates with "Start" pointing at the first nucleotide of a start codon.
    The latter is inverted along with the sequence when the strand is negative, 
    e.g., CAT 1,2,3 -> ATG 3,2,1 => Start=3
    """
    return [
        (end + neg_offset if strand == '-' else start) + tot_offset 
        for _, start, end, strand in 
        df[['Start', 'End', 'Strand']].itertuples()]

def lift_and_fetch(df, ref):
    df = df.copy()
    df['Start'] = lift_starts(df)
    df = df[~df.Start.isna()]
    df['Start'] = df['Start'].astype(int)
    df['CodonFetched'] = fetch_codons(df, ref)
    return df

def filter_codons(df, valid=VALID_START):
    """
    When "Codon" column is present, filter to rows where manually fetched 
        codon matches the expected codon.
    Filter to rows where codon is among the `valid` sequence of codons.
    """
    df = df.copy()
    if 'Codon' in df.columns and 'CodonFetched' in df.columns:
        idx = df.Codon != df.CodonFetched
        print(f'Non-matching codons: {idx.sum()}')
        df = df[~idx]
        df = df.drop(columns='CodonFetched')
    idx = ~df.Codon.isin(valid)
    print(f'Invalid start codons: {idx.sum()}')
    df = df[~idx]
    return df

def calc_pred_scores(y_true, y_pred):
    return {
        'f1': f1_score(y_true, y_pred), 
        'prc': precision_score(y_true, y_pred), 
        'rec': recall_score(y_true, y_pred)
    }

def parse_ji(path, ref):
    df = pd.read_excel(
        path, sheet_name='uORF', usecols=[0, 2, 3, 4, 5], 
        names=['Gene', 'Chrom', 'Strand', 'Start', 'End']
    )
    df = df[df.Chrom.apply(lambda x: '_' not in x)]
    # Offset for the negative strand is -3, total offset is zero
    df['Start'] = offset_starts(df, 0, -3)
    df = lift_and_fetch(df, ref)
    df = df.rename(columns={'CodonFetched': 'Codon'}).drop(columns='End')
    df = filter_codons(df)
    return df

def parse_scholtz(path, ref):
    df = pd.read_csv(
        path, skiprows=1,
        usecols=[2, 3, 4, 5, 7, 8],
        names=['Chrom', 'Start', 'End', 'GeneID', 'Strand', 'Codon']
    )
    # Offset for the negative strand is -2, total offset is -1, no lifting is needed.
    df['Start'] = offset_starts(df)
    df['CodonFetched'] = fetch_codons(df, ref)
    df = filter_codons(df)
    df['GeneID'] = df['GeneID'].apply(lambda x: x.split('.')[0])
    df = df.drop(columns='End')
    return df

def parse_gil(path, ref):
    df = pd.read_excel(
        path, sheet_name='Supplemental_Table_4', 
        usecols=[0, 2, 3, 4, 5, 6], skiprows=3,
        names=['ID', 'Codon', 'Chrom', 'Strand', 'Start', 'End']
    )
    df['TranscriptID'] = df['ID'].apply(lambda x: x.split('.')[0])
    df = df.drop(columns='ID')
    idx = df.Strand == '-'
    # For + strand, offset by -1, for - strand, offset by -3
    df.loc[~idx, 'Start'] = df.loc[~idx, 'Start'] - 1
    df.loc[idx, 'Start'] = df.loc[idx, 'Start'] - 3
    # Lift coordinates and filter
    df = lift_and_fetch(df, ref)
    df = filter_codons(df)
    df = df.drop(columns='End')
    return df

def parse_pred(path):
    """
    Read the dataset with predictions and filter to Test and Val datasets.
    
    Offset start coordinates. 
    In our convention, we used 0-based coordinates with 
        "Start" pointing at the first nucleotide of a start codon.
    The latter was inverted along with the sequence when the strand is negative, 
        e.g., CAT 1,2,3 -> ATG 3,2,1 => Start=3
    Now, we offset (back) the start by -2 so that it always points 
        to the first nucleotide of the + strand (so Start=1 in the above example).
    """
    df = pd.read_csv(path, sep='\t')
    df = df[df.Dataset.isin(['Test', 'Val'])]
    idx = df.Strand == '-'
    df.loc[idx, 'Start'] = df.loc[idx, 'Start'] - 2
    return df

def annotate_predictions(df):
    df = df.copy()
    df['PredictionType'] = 'TP'
    df.loc[(df.y_true == 1) & (df.y_pred == 0), 'PredictionType'] = 'FN'
    df.loc[(df.y_true == 0) & (df.y_pred == 1), 'PredictionType'] = 'FP'
    df.loc[(df.y_true == 0) & (df.y_pred == 0), 'PredictionType'] = 'TN'
    return df

def merge_and_score(
    df_pred, df_comp, df_comp_name, 
    on=['Chrom', 'Strand', 'Start', 'Codon']
):
    df_comp = df_comp.copy()
    df_pred = df_pred.copy()
    df_comp['Dataset'] = df_comp_name
    df = df_pred.merge(
        df_comp, how='left', on=on, suffixes=['_pred', '_comp'])
    print(f'Merged size: {len(df)}')
    comp_codons = set(df_comp.Codon)
    df = df[df.Codon.isin(comp_codons)]
    print(f'Filtered to {comp_codons} start codons: {len(df)}')
    df['y_pred'] = 1
    df.loc[df.Dataset_comp.isna(), 'y_pred'] = 0
    df = annotate_predictions(df)
    idx_of_codons = ((codon, df.Codon == codon) for codon in comp_codons)
    scores = {codon: calc_pred_scores(df[idx].y_true, df[idx].y_pred) 
              for codon, idx in idx_of_codons}
    return df, scores

def unravel_scores(scores):
    for ds_name, ds_vs in scores.items():
        for codon_name, codon_scores in ds_vs.items():
            for score_name, score_val in codon_scores.items():
                yield ds_name, codon_name, score_name, score_val

In [5]:
scholtz = parse_scholtz(SCHOLTZ, ref)
gil = parse_gil(GIL, ref)
ji = parse_ji(JI, ref)
pred = parse_pred(PRED)

Fetching:   0%|          | 0/1933 [00:00<?, ?it/s]

Non-matching codons: 0
Invalid start codons: 0


Fetching:   0%|          | 0/188787 [00:00<?, ?it/s]

Non-matching codons: 792
Invalid start codons: 33166


Fetching:   0%|          | 0/6614 [00:00<?, ?it/s]

Invalid start codons: 563


In [69]:
ds_names = ['Scholtz', 'McGillivray', 'Ji']
merged_dfs, scores = map(
    list,
    unzip(merge_and_score(pred, ds, ds_name) for ds, ds_name in 
     zip([scholtz, gil, ji], ds_names)))
scores = {ds_name: s for ds_name, s in zip(ds_names, scores)}

Merged size: 36183
Filtered to {'ATG'} start codons: 4337
Merged size: 36359
Filtered to {'GTG', 'ACG', 'ATC', 'ATG', 'CTG', 'ATT'} start codons: 36359
Merged size: 36187
Filtered to {'GTG', 'CTG', 'ATC', 'ATG'} start codons: 28482


  _warn_prf(average, modifier, msg_start, len(result))


In [70]:
df_scores = pd.DataFrame(
    unravel_scores(scores), 
    columns=['Dataset', 'Codon', 'Score', 'Value']
).sort_values(
    ['Dataset', 'Codon']
)
df_scores['Value'] = df_scores['Value'].round(2)
df_scores = df_scores.pivot(['Dataset', 'Codon'], ['Score'], ['Value'])

In [71]:
df_scores

Unnamed: 0_level_0,Unnamed: 1_level_0,Value,Value,Value
Unnamed: 0_level_1,Score,f1,prc,rec
Dataset,Codon,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
Ji,ATC,0.0,0.0,0.0
Ji,ATG,0.4,0.62,0.29
Ji,CTG,0.06,0.19,0.04
Ji,GTG,0.09,0.22,0.06
McGillivray,ACG,0.12,0.07,0.53
McGillivray,ATC,0.09,0.05,0.55
McGillivray,ATG,0.33,0.29,0.4
McGillivray,ATT,0.08,0.05,0.54
McGillivray,CTG,0.16,0.09,0.49
McGillivray,GTG,0.11,0.06,0.5


In [72]:
print(df_scores.to_latex())

\begin{tabular}{llrrr}
\toprule
        & {} & \multicolumn{3}{l}{Value} \\
        & Score &    f1 &   prc &   rec \\
Dataset & Codon &       &       &       \\
\midrule
Ji & ATC &  0.00 &  0.00 &  0.00 \\
        & ATG &  0.40 &  0.62 &  0.29 \\
        & CTG &  0.06 &  0.19 &  0.04 \\
        & GTG &  0.09 &  0.22 &  0.06 \\
McGillivray & ACG &  0.12 &  0.07 &  0.53 \\
        & ATC &  0.09 &  0.05 &  0.55 \\
        & ATG &  0.33 &  0.29 &  0.40 \\
        & ATT &  0.08 &  0.05 &  0.54 \\
        & CTG &  0.16 &  0.09 &  0.49 \\
        & GTG &  0.11 &  0.06 &  0.50 \\
Scholtz & ATG &  0.28 &  0.64 &  0.18 \\
\bottomrule
\end{tabular}



  print(df_scores.to_latex())


In [75]:
for name, _df in zip(ds_names, merged_dfs):
    counts = _df[
        ['Codon', 'PredictionType', 'Start']
    ].sort_values(
        'Codon'
    ).groupby(
        ['Codon', 'PredictionType'], as_index=False
    ).count().pivot(['Codon'], ['PredictionType'], ['Start'])
    print(name, counts, sep='\n', end='\n\n')

Scholtz
               Start              
PredictionType    FN  FP    TN  TP
Codon                             
ATG              267  33  3978  59

McGillivray
               Start                 
PredictionType    FN    FP    TN   TP
Codon                                
ACG               37   552  2271   42
ATC               20   462  3597   24
ATG              197   331  3689  132
ATT               18   438  4358   21
CTG              191  1745  9967  181
GTG               78  1184  6747   77

Ji
                Start                     
PredictionType     FN    FP       TN    TP
Codon                                     
ATC              40.0   NaN   4049.0   NaN
ATG             231.0  57.0   3954.0  95.0
CTG             352.0  54.0  11592.0  13.0
GTG             141.0  31.0   7864.0   9.0

