# Setup



## Get files / dependencies

In [None]:
%tensorflow_version 1

In [None]:
!git clone https://github.com/google-research/proteinfer 

%cd proteinfer

!pip3 install -qr  requirements.txt

import pandas as pd
import tensorflow
import inference
import parenthood_lib
import baseline_utils,subprocess
import shlex
import tqdm 
import sklearn
import numpy as np
import utils
import colab_evaluation
import plotly.express as px

from plotnine import ggplot, geom_point, geom_point, geom_line, aes, stat_smooth, facet_wrap, xlim,coord_cartesian,theme_bw,labs,ggsave


In [None]:
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/models/zipped_models/noxpnd_cnn_swissprot_ec_random_swiss-cnn_for_swissprot_ec_random-13685140.tar.gz
!tar xzf noxpnd_cnn_swissprot_ec_random_swiss-cnn_for_swissprot_ec_random-13685140.tar.gz
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/colab_support/parenthood.json.gz
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/blast_baseline/fasta_files/SWISSPROT_RANDOM_EC/eval_test.fasta

## Load vocabulary and parenthood information

In [None]:
vocab = inference.Inferrer(
    'noxpnd_cnn_swissprot_ec_random_swiss-cnn_for_swissprot_ec_random-13685140'
).get_variable('label_vocab:0').astype(str)
label_normalizer = parenthood_lib.get_applicable_label_dict(
    'parenthood.json.gz')

## Define a helper function to download inference results

In [None]:
def download_inference_results(run_name):
  file_shard_names = ['-{:05d}-of-00064.predictions.gz'.format(i) for i in range(64)]
  subprocess.check_output(shlex.split(f'mkdir -p ./inference_results/{run_name}/'))

  for shard_name in tqdm.tqdm(file_shard_names, position=0,desc="Downloading"):
    subprocess.check_output(shlex.split(f'wget https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/swissprot_inference_results/{run_name}/{shard_name} -O ./inference_results/{run_name}/{shard_name}'))
  return 

## Downloading predictions and getting them ready for analysis

In [None]:
min_decision_threshold = 1e-10
download_inference_results(f"ec_random_test")
predictions_df = colab_evaluation.get_normalized_inference_results("inference_results/ec_random_test",vocab,label_normalizer,min_decision_threshold=min_decision_threshold)

In [None]:
test_ground_truth = baseline_utils.load_ground_truth('eval_test.fasta')
ground_truth_df = colab_evaluation.make_tidy_df_from_ground_truth(test_ground_truth)
del test_ground_truth

# Analysis

Now we can get some statistics about our predictions. Let's start with a simple calculation of precision, recall and F1 for the whole dataset at a threshold of 0.5. 

What happens in different EC classes - is there differential performance?

In [None]:
def get_first_level_of_ec_hierarchy(ec):
    ec_group_names = {
        "EC:1": "Oxidoreductases",
        "EC:2": "Transferases",
        "EC:3": "Hydrolases",
        "EC:4": "Lyases",
        "EC:5": "Isomerases",
        "EC:6": "Ligases",
        "EC:7": "Translocases"
    }
    return ec_group_names[ec.split(".")[0]]

top_level_ec_grouping = {x: get_first_level_of_ec_hierarchy(x) for x in vocab}

colab_evaluation.apply_threshold_and_return_stats(
    predictions_df, ground_truth_df, grouping=top_level_ec_grouping)

And what about at different levels of the EC hierarchy?

In [None]:
bootstrapped_merge = pd.concat([
    bootstrapped_data_blast, bootstrapped_data, bootstrapped_data_ens,
    bootstrapped_data_combo
],
                               ignore_index=True)
    num_of_dashes = ec.count("-")
bootstrapped_merge['count_cut_str'] = bootstrapped_merge['count_cut'].astype(
    str)
fig = px.box(bootstrapped_merge,
             width=700,
             color="type",
             x="count_cut_str",
             y="f1",
             labels={
                 "count_cut_str": "Number of training examples per label",
                 "f1": "F1"
             },
             template="simple_white")

level_of_hierachy_grouping = {x: get_level_of_hierarchy(x) for x in vocab}

level_data = colab_evaluation.apply_threshold_and_return_stats(
    predictions_df, ground_truth_df, grouping=level_of_hierachy_grouping)
ggplot(level_data, aes(x="group", y="f1")) + geom_point() + geom_point(
) + geom_line() + theme_bw() + labs(
    x="Level of hierarchy", y="F1 score") + coord_cartesian(ylim=[0.95, 0.99])

Now let's try varying the threshold to generate a precision-recall curve.

In [None]:
cnn_pr_data = colab_evaluation.get_pr_curve_df(predictions_df,ground_truth_df)

In [None]:
cnn_pr_data.drop(index=0)

In [None]:
ggplot(cnn_pr_data.drop(index=0), aes(x="recall", y="precision", color="f1")) + geom_line(
) + geom_line() + coord_cartesian(xlim=(0.96, 1)) + theme_bw() + labs(
    x="Recall", y="Precision", color="F1 Score")

What decision threshold maximises F1 score?

In [None]:
cnn_pr_data.sort_values('f1',ascending=False)[:3]

Now let's have a look at PR curves for each different top level group.

# Load CNN ensemble predictions

In [None]:
min_decision_threshold = 1e-10
download_inference_results(f"ec_random_test_ens")
ens_predictions_df = colab_evaluation.get_normalized_inference_results("inference_results/ec_random_test_ens",vocab,label_normalizer,min_decision_threshold=min_decision_threshold)

In [None]:
ens_cnn_pr_data = colab_evaluation.get_pr_curve_df(ens_predictions_df,ground_truth_df)

In [None]:
ens_cnn_pr_data.sort_values('f1',ascending=False)[0:3]

In [None]:
cnn_pr_data['method'] = "CNN"
ens_cnn_pr_data['method'] = "CNN Ensemble"

In [None]:
method_comparison = pd.concat([cnn_pr_data, ens_cnn_pr_data], ignore_index=True)
ggplot(method_comparison, aes(x="recall", y="precision",
                              color="method",linetype="method")) + geom_line() + coord_cartesian(
                                  xlim=(0.91, 1),
                                  ylim=(0.91, 1)) + theme_bw()+ labs(x="Recall",y="Precision",color="Method")


# Blast comparison

Let's do the same sort of analysis for a BLAST baseline.

In [None]:
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/blast_baseline/blast_output/random/blast_out_test.tsv
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/blast_baseline/fasta_files/SWISSPROT_RANDOM_EC/eval_test.fasta
!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/blast_baseline/fasta_files/SWISSPROT_RANDOM_EC/train.fasta
train_ground_truth = colab_evaluation.make_tidy_df_from_ground_truth(baseline_utils.load_ground_truth('train.fasta')).rename(columns={"up_id":"train_seq_id"}).drop(columns=["gt"])

In [None]:
blast_out = colab_evaluation.read_blast_table("blast_out_test.tsv")
blast_df = blast_out.merge(train_ground_truth,left_on="target",right_on="train_seq_id")
blast_df.rename(columns={'bit_score':'value',"query":"up_id"}, inplace=True)

In [None]:
min_decision_threshold=0
blast_pr_data = colab_evaluation.get_pr_curve_df(
    blast_df, ground_truth_df)
blast_pr_data['method'] = 'BLAST'

In [None]:
cnn_pr_data['method'] = 'CNN'
ens_cnn_pr_data['method'] = 'Ensembled CNN'
method_comparison = pd.concat([cnn_pr_data.drop(index=0), ens_cnn_pr_data.drop(index=0), blast_pr_data.drop(index=0)], ignore_index=True)
ggplot(method_comparison, aes(x="recall", y="precision",
                              color="method")) + geom_line() + coord_cartesian(
                                  xlim=(0.90, 1),
                                  ylim=(0.90, 1)) + theme_bw() + labs(x="Recall",y="Precision",color="Method")


In [None]:
method_comparison.groupby("method")[['f1']].agg(max)

In [None]:
method_comparison.sort_values('f1', ascending=False).drop_duplicates(['method'])

Let's investigate what's going on at the left hand side of the graph where the CNN and ensemble achieve greater precision than BLAST.

In [None]:
def get_x_where_y_is_closest_to_z(df, x, y, z):
    return df.iloc[(df[y] - z).abs().argsort()[:1]][x]


cnn_threshold = float(
    get_x_where_y_is_closest_to_z(
        cnn_pr_data, x="threshold", y="recall", z=0.96))
blast_threshold = float(
    get_x_where_y_is_closest_to_z(
        blast_pr_data, x="threshold", y="recall", z=0.96))

cnn_results = colab_evaluation.assign_tp_fp_fn(ens_predictions_df, ground_truth_df, cnn_threshold)

blast_results = colab_evaluation.assign_tp_fp_fn(blast_df, ground_truth_df, blast_threshold)

merged = cnn_results.merge(
    blast_results,
    how="outer",
    suffixes=("_ens_cnn", "_blast"),
    left_on=["label", "up_id", "gt"],
    right_on=["label", "up_id", "gt"])


In [None]:
blast_info = blast_out[['up_id','target','pc_identity']]

Let's list some of the BLAST false-positives in case we want to investigate what's going on.

In [None]:
merged.query("fp_blast==True and fp_ens_cnn==False").head()

# An ensemble of BLAST and ensembled-CNNs

We've seen that the CNN-ensemble and BLAST have different strengths - at lower recalls the CNN appears to have greater precision than BLAST at lower recalls, but BLAST has better recall at lower precisions. Can we combine these approaches to get a predictor with the best of both worlds?

In [None]:

blast_and_cnn_ensemble = ens_predictions_df.merge(
    blast_df,
    how="outer",
    suffixes=("_ens_cnn", "_blast"),
    left_on=["label", "up_id"],
    right_on=["label", "up_id"])

In [None]:
blast_and_cnn_ensemble=blast_and_cnn_ensemble.fillna(False)

We will create a simple ensemble where the value of the predictor is simply the multiple of the probability assigned by the ensemble of neural networks and the bit-score linking this sequence to to an example with this label by BLAST.

In [None]:
blast_and_cnn_ensemble['value']=blast_and_cnn_ensemble['value_ens_cnn']*blast_and_cnn_ensemble['value_blast']

In [None]:
blast_and_cnn_ensemble_pr = colab_evaluation.get_pr_curve_df(
    blast_and_cnn_ensemble, ground_truth_df)

In [None]:
blast_and_cnn_ensemble_pr.f1.max()

In [None]:
blast_and_cnn_ensemble_pr['method']='Ensemble of BLAST  with  Ensembled-CNN'

In [None]:
cnn_pr_data['method'] = 'CNN'
ens_cnn_pr_data['method'] = 'Ensembled CNN'
method_comparison = pd.concat([cnn_pr_data.drop(index=0), ens_cnn_pr_data.drop(index=0), blast_pr_data.drop(index=0),blast_and_cnn_ensemble_pr.drop(index=0)], ignore_index=True)
ggplot(method_comparison, aes(x="recall", y="precision",
                              color="method")) + geom_line() + coord_cartesian(
                                  xlim=(0.93, 1),
                                  ylim=(0.93, 1)) + theme_bw() + labs(x="Recall",y="Precision",color="Method")


In [None]:
method_comparison=method_comparison.query("recall!=1.0")
fig = px.line(method_comparison, 
    x="recall", y="precision", color="method")
fig.update_layout(template="plotly_white", title="Precision-recall by method")
fig.update_xaxes(range=(0.95, 1))
fig.update_yaxes(range=(0.95, 1))
fig.show()
json=fig.to_json(pretty=True)
with open("method.json","w") as f:
  f.write(json)

In [None]:
method_comparison.groupby("method")[['f1']].agg(max)

# Bootstrapping

## Defining functions

In [None]:
import collections
def get_bootstrapped_pr_curves(predictions_df,ground_truth_df, grouping=None, n =100, method_label = None,sample_with_replacement = True):
 
  joined = predictions_df[predictions_df.value > 1e-10].merge(ground_truth_df, on=['up_id', 'label'], how='outer')
  unique_up_ids = joined['up_id'].unique()


  pr_samples = []
  for _ in tqdm.tqdm(range(n)):
    sampled_up_ids =np.random.choice(unique_up_ids, len(unique_up_ids),sample_with_replacement)
    
    

    count_by_sample = collections.Counter(sampled_up_ids)
    count_by_sample_ordered = [count_by_sample[x] for x in joined.up_id]
    joined_sampled = pd.DataFrame(np.repeat(joined.values, count_by_sample_ordered, axis=0), columns=joined.columns)
    unique_suffixes_counter = collections.defaultdict(lambda: 0)
    unique_suffixes = []
    for row in joined_sampled.values:
      lookup_key = (row[0], row[1])
      unique_suffixes.append(unique_suffixes_counter[lookup_key])
      unique_suffixes_counter[lookup_key] += 1

    joined_sampled['up_id'] = [f'{x}-{y}' for x, y in zip(joined_sampled.up_id, unique_suffixes)]
    
    pred = joined_sampled[joined_sampled['value'].notna()][['up_id', 'label', 'value']]
    gt = joined_sampled[joined_sampled['gt'].notna()][['up_id', 'label', 'gt']]

    pr_curves = colab_evaluation.get_pr_curve_df(
      pred, gt ,grouping = grouping)
    pr_curves.loc[pr_curves['threshold']==0.0,'precision']=0
    pr_curves.loc[pr_curves['threshold']==0.0,'f1']=0
    pr_curves['type'] = method_label
    pr_samples.append(pr_curves)
  return pr_samples

## Perform calculations

In [None]:
n=100
non_ensembled_prs = get_bootstrapped_pr_curves(predictions_df, ground_truth_df, n=n,method_label="CNN")
ensembled_prs = get_bootstrapped_pr_curves(ens_predictions_df, ground_truth_df,n=n, method_label="Ensemble")
blast_prs = get_bootstrapped_pr_curves(blast_df, ground_truth_df, n=n,method_label="Blast")
blast_and_cnn_ensemble_prs = get_bootstrapped_pr_curves(blast_and_cnn_ensemble, ground_truth_df, n=n,method_label="Blast/CNN-ensemble")



## Interpolate curves

In [None]:
from scipy.interpolate import interp1d


def create_interpolated_df(single_curve):
  interp_recall_fn = interp1d(single_curve.recall, single_curve.precision,bounds_error=False)
  recall = np.linspace(0.95, 1, 5001)
  interpolated_precisions = interp_recall_fn(recall)
  return pd.DataFrame({"type":single_curve.type.to_list()[0],"group":single_curve.group.to_list()[0], "precision":interpolated_precisions, "recall":recall })

In [None]:

curves=[ensembled_prs,non_ensembled_prs,blast_and_cnn_ensemble_prs,blast_prs]
dfs = []

for curve_set in curves:
  for c2 in curve_set:
    for group_name, df_group in c2.groupby("group"):
      dfs.append(create_interpolated_df(df_group))
all=pd.concat(dfs)


In [None]:

curves=[ensembled_prs,non_ensembled_prs,blast_and_cnn_ensemble_prs,blast_prs]
dfs = []


def create_f1(single_curve):

  return pd.DataFrame({"type":single_curve.type.to_list()[0],"group":single_curve.group.to_list()[0], "f1":single_curve.f1.max() },index=[0])


for curve_set in curves:
  for c2 in curve_set:
    for group_name, df_group in c2.groupby("group"):
      dfs.append(create_f1(df_group))
f1=pd.concat(dfs)

def lower_func(x):
    return x.quantile(0.025)

def upper_func(x):
    return x.quantile(0.975)

f1_data = f1.groupby(['type','group']).agg( lower=("f1",lower_func),upper=("f1",upper_func)).reset_index()
f1_data

In [None]:
f1

In [None]:
def lower_func(x):
    return x.quantile(0.025)

def upper_func(x):
    return x.quantile(0.975)

for_graph = all.groupby(['type','group','recall']).agg( lower=("precision",lower_func),upper=("precision",upper_func)).reset_index()


In [None]:
a = get_bootstrapped_pr_curves(predictions_df, ground_truth_df, n=1,sample_with_replacement=False, method_label="CNN")[0]
b = get_bootstrapped_pr_curves(ens_predictions_df, ground_truth_df, n=1,sample_with_replacement=False, method_label="Ensemble")[0]
c = get_bootstrapped_pr_curves(blast_df, ground_truth_df, n=1,sample_with_replacement=False, method_label = "Blast")[0]
d = get_bootstrapped_pr_curves(blast_and_cnn_ensemble, ground_truth_df, n=1,sample_with_replacement=False, method_label="Blast/CNN-ensemble")[0]
all_single = pd.concat([a,b,c,d])


## Plot bootstrap curves

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

def get_color(index, transparent):
  colors = {'CNN':[150,0,0],'Ensemble':[0,125,125],'Blast/CNN-ensemble':[0,200,0],'Blast':[125,0,255]}
  transparency = 0.2 if transparent else 1
  return f"rgba({colors[index][0]}, {colors[index][1]}, {colors[index][2]}, {transparency})"



colors = {'CNN':'green','Ensemble':'red', 'Blast/CNN-ensemble' : 'blue', 'Blast':'orange'}
for the_type, new in for_graph.groupby('type'):
  fig.add_trace(go.Scatter(x=new['recall'], y=new['upper'],
      mode='lines',
 showlegend=False,
      line=dict(width=0.0, color=get_color(the_type,False)),
      name="",
      hoverinfo='skip',
      ))
  fig.add_trace(go.Scatter(
      x=new['recall'],
      y=new['lower'],
      name=the_type,
      hoverinfo='skip',
      showlegend=False,
     line=dict(width=0.0, color=get_color(the_type,False)),
      fill='tonexty',fillcolor=get_color(the_type,True),))
  


for the_type, new in all_single.groupby('type'):
    fig.add_trace(go.Scatter(
      x=new['recall'],
      y=new['precision'],
      name=the_type,
      
     line=dict(width=1, color=get_color(the_type,False))))


fig.update_xaxes(title="Recall", range=[0.95,1])
fig.update_yaxes(title="Precision", range=[0.95,1])
fig.update_layout(template="plotly_white")
fig.update_layout(legend_title_text='Method')

fig.update_layout(
    title="Precision and recall by method",
   
)

fig.show()




In [None]:

curves=[ensembled_prs,non_ensembled_prs,blast_and_cnn_ensemble_prs,blast_prs]
dfs = []


def create_f1(single_curve):

  return pd.DataFrame({"type":single_curve.type.to_list()[0],"group":single_curve.group.to_list()[0], "f1":single_curve.f1.max() },index=[0])


for curve_set in curves:
  for c2 in curve_set:
    for group_name, df_group in c2.groupby("group"):
      dfs.append(create_f1(df_group))
f1=pd.concat(dfs)

def lower_func(x):
    return x.quantile(0.025)

def upper_func(x):
    return x.quantile(0.975)

f1_data = f1.groupby(['type','group']).agg( lower=("f1",lower_func),upper=("f1",upper_func)).reset_index()
f1_data

## Examine effect of number of training examples on performance



In [None]:
def resample_with_replacement(df):
  indices = np.random.randint(0,df.shape[0],df.shape[0])
  return df.iloc[indices,:]


def bootstrap(df, n=100):
  resampled_results = []
  for x in tqdm.tqdm(range(n),position=0):
    resampled = resample_with_replacement(df)
    data = colab_evaluation.stats_by_group(resampled.groupby('count_cut'))
    resampled_results.append(data)
  return pd.concat(resampled_results)




In [None]:
train_counts  = train_ground_truth.groupby("label", as_index=False).count().rename(columns={"train_seq_id":"count"})

In [None]:
both = colab_evaluation.assign_tp_fp_fn(predictions_df,ground_truth_df,0.625205)
both = both.merge(train_counts,left_on="label",right_on="label",how="outer")
both.fillna(0)
both['count_cut'] = pd.cut(both['count'],bins = (0,5,10,20,40,100,1000,500000)) 
bootstrapped_data = bootstrap(both,n=5)
bootstrapped_data['count_cut_str'] = bootstrapped_data['count_cut'].astype(str)

In [None]:
bootstrapped_data['type'] = "CNN"

In [None]:
both = colab_evaluation.assign_tp_fp_fn(blast_df,ground_truth_df,60.5)
both = both.merge(train_counts,left_on="label",right_on="label",how="outer")
both.fillna(0)
both['count_cut'] = pd.cut(both['count'],bins = (0,5,10,20,40,100,1000,500000)) 

bootstrapped_data_blast = bootstrap(both,n=100)
bootstrapped_data_blast['type'] = "BLAST"

In [None]:
both = colab_evaluation.assign_tp_fp_fn(ens_predictions_df,ground_truth_df,0.25)
both = both.merge(train_counts,left_on="label",right_on="label",how="outer")
both.fillna(0)
both['count_cut'] = pd.cut(both['count'],bins = (0,5,10,20,40,100,1000,500000)) 

bootstrapped_data_ens = bootstrap(both,n=100)
bootstrapped_data_ens['type'] = "Ensembled CNNs"

In [None]:
both = colab_evaluation.assign_tp_fp_fn(blast_and_cnn_ensemble,ground_truth_df,0.17)
both = both.merge(train_counts,left_on="label",right_on="label",how="outer")
both.fillna(0)
both['count_cut'] = pd.cut(both['count'],bins = (0,5,10,20,40,100,1000,500000)) 

bootstrapped_data_combo = bootstrap(both,n=100)
bootstrapped_data_combo['type'] = "Ensembled CNNs with BLAST"


In [None]:
bootstrapped_merge = pd.concat([bootstrapped_data_blast, bootstrapped_data,bootstrapped_data_ens,bootstrapped_data_combo], ignore_index=True)

bootstrapped_merge['count_cut_str'] = bootstrapped_merge['count_cut'].astype(str)
fig = px.box(bootstrapped_merge,width=700,color="type", x="count_cut_str",y="f1",labels={"count_cut_str":"Number of training examples per label","f1":"F1"},template="simple_white")
fig.show()