In [None]:
#https://colab.research.google.com/github/ml6team/quick-tips/blob/main/nlp/2021_04_22_shap_for_huggingface_transformers/explainable_transformers_using_shap.ipynb#scrollTo=7XgsySD-u_rN

import shap
import transformers

from transformers import (AutoTokenizer, 
                          AutoModelForSequenceClassification, 
                          TextClassificationPipeline)

import os
import pandas as pd
import numpy as np
import helpers

from sklearn.model_selection import train_test_split


  def _pt_shuffle_rec(i, indexes, index_mask, partition_tree, M, pos):
  def delta_minimization_order(all_masks, max_swap_size=100, num_passes=2):
  def _reverse_window(order, start, length):
  def _reverse_window_score_gain(masks, order, start, length):
  def _mask_delta_score(m1, m2):
  def identity(x):
  def _identity_inverse(x):
  def logit(x):
  def _logit_inverse(x):
  def _build_fixed_single_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights):
  def _build_fixed_multi_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights):
  def _init_masks(cluster_matrix, M, indices_row_pos, indptr):
  def _rec_fill_masks(cluster_matrix, indices_row_pos, indptr, indices, M, ind):
  def _single_delta_mask(dind, masked_inputs, last_mask, data, x, noop_code):
  def _delta_masking(masks, x, curr_delta_inds, varying_rows_out,


In [None]:
df = pd.read_csv('./data/cub.csv', index_col=0)

df_train, df_test = train_test_split(df, test_size=0.2, random_state=1234)

In [None]:
shap_df = df_test[df_test['species'] == 'yeasts288c']

In [None]:
import helpers

shap_df = helpers.add_codons_to_df(shap_df, 'Sequence')

In [None]:
lengths = [len(seq) for seq in shap_df['codons_cleaned']]
shap_df['lengths'] = lengths

In [None]:
shap_df = shap_df[shap_df['lengths'] < 1064]

In [None]:
# import matplotlib.pyplot as plt

# fig, ax = plt.subplots()
# N, bins, patches = ax.hist(np.log(df['median_exp']), edgecolor='white', linewidth=1, bins=30)

# for i in range(0,12):
#     patches[i].set_facecolor('#068cf9')
# # for i in range(7,14):    
# #     patches[i].set_facecolor('#a41daa')
# for i in range(12, len(patches)):
#     patches[i].set_facecolor('#ff0051')

# plt.title('Distribution of Median Expression')
# plt.xlabel('Log of Median Expression')
# plt.ylabel('Bin Count')
# plt.show()

In [None]:
pos = max(df['median_exp'])
neutral = np.median(df['median_exp'])
neg = min(df['median_exp'])

In [None]:
pos

In [None]:
neutral

In [None]:
neg

In [None]:
pos_seq = ''.join(df.loc[df['median_exp'] == pos, 'codons_cleaned'].tolist())
pos_seq

In [None]:
neutral_seq = ''.join(df.loc[df['median_exp'] == 14.0, 'codons_cleaned'].tolist())
neutral_seq

In [None]:
neg_seq = ''.join(df.loc[df['median_exp'] == neg, 'codons_cleaned'].tolist())
neg_seq

In [None]:
tokenizer_name = "./tokenizers/codonBERT"
model_name = "./models/codonBERT-binary-large_1/checkpoint-127330"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token_type_ids=False)
model = AutoModelForSequenceClassification.from_pretrained(model_name).cpu()

pipe_binary = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)

In [None]:
tokenizer_name = "./tokenizers/codonBERT"
model_name = "./models/codonBERT-multi-large_1/checkpoint-127330"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token_type_ids=False)
model = AutoModelForSequenceClassification.from_pretrained(model_name).cpu()

pipe_multi = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)

In [None]:
def score_and_visualize(texts, pipe):
    #prediction = pipe(texts)
    #print(prediction[0])

    explainer = shap.Explainer(pipe)
    shap_values = explainer(texts)

    #shap.plots.text(shap_values)
    shap.plots.beeswarm(shap_values)

In [None]:
score_and_visualize(shap_df['codons_cleaned'], pipe_binary)

In [None]:
score_and_visualize(shap_df['codons_cleaned'], pipe_multi)

In [None]:
def score_and_visualize(texts, pipe):
    prediction = pipe([texts])
    print(prediction[0])

    explainer = shap.Explainer(pipe)
    shap_values = explainer([texts])

    #shap.plots.text(shap_values)
    #shap.plots.beeswarm(shap_values)
    shap.plots.bar(shap_values[0])

In [None]:
#positive
score_and_visualize('ATG ACA CGC GTT CAA TTT AAA CAC CAC CAT CAT CAC CAT CAT CCT GAC TAG')

In [None]:
#negative
score_and_visualize('ATG CCA GTC AAC AGC GTC CTT TGC CAT TTT TCT TCC GAC TTT TCA TTG GGC CTC ATA TTG CAA GAT ATC TAA')

In [None]:
#positive
score_and_visualize('ATG ACA CGC GTT CAA TTT AAA CAC CAC CAT CAT CAC CAT CAT CCT GAC TAG')

In [None]:
#neutral low
score_and_visualize('ATG AAC GCG GCG ATA TTG AAA ATT CGC TTC AGT TTC CAA GGA TTT CTT GAA AGA GAT TAGATG TGC CGT CGA GAG AGG TGG TTG GTA CGG TAG')

In [None]:
#neutral med
score_and_visualize('ATG GTT GAT CCA TAC TGG ATG GCA AAC TTC AAC TGT TAG')

In [None]:
#neutral high
score_and_visualize('ATG TTT GAA CTG TTT GTT ACA AAT GAC ATC TCT CAC TCC TGA')

In [None]:
#negative
score_and_visualize('ATG CCA GTC AAC AGC GTC CTT TGC CAT TTT TCT TCC GAC TTT TCA TTG GGC CTC ATA TTG CAA GAT ATC TAA')