# Expose AminoBERT's last encoder layer hidden states to the RGN

Goal is to do this outside of the tf.Estimator API.

We do this in cell 10.

The rest of the notebook is validating that this new implementation outputs the exact same results as before by numerically checking previously generated outputs for Round 6 CASP14 sequences.

In [1]:
import sys
import os
import time
import copy
import subprocess
import shutil
import pickle
import random
import glob

from Bio import SeqIO
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sys.path.append('../../')
import modeling
import tokenization
import optimization
import run_finetuning_and_prediction

%reload_ext autoreload
%autoreload 2

## Inference with tf.Estimator interface

The guts of this are not important. Including here as this is what I've been using to inference seqs for CASP. We'll validate our Estimator-free reimplementation produces the same output below.

In [2]:
def run_prediction(seqs, qfunc, checkpoint_file, wt_log_prob_mat=None, 
                   return_seq_log_probs=True, return_seq_output=True, 
                   clip_seq_level_outputs=True):
    
    start = time.time()
    
    MAX_SEQ_LENGTH = 1024
    output_dir = '../../data/test/'
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    tokenizer=tokenization.FullTokenizer(k=1, token_to_replace_with_mask='X')

    result = run_finetuning_and_prediction.run_model(
        input_seqs=list(seqs),
        labels=qfunc,
        max_seq_length=MAX_SEQ_LENGTH,
        tokenizer=tokenizer,
        bert_config_file='AminoBERT_config_v2.json',
        output_dir=output_dir,
        init_checkpoint=checkpoint_file,
        do_training=False, # No fine-tuning
        do_evaluation=False,
        do_prediction=True, # Prediction only.
        num_train_epochs=3,
        learning_rate=5e-5,
        warmup_proportion=0.1,
        train_batch_size=16,
        eval_batch_size=32,
        predict_batch_size=32,
        use_tpu=False,
        return_seq_log_probs=return_seq_log_probs,
        return_seq_output=return_seq_output, # encoder_layers[-1]
        encoding_layer_for_seq_rep=[[0,1,2,3], [4,5,6,7], [8,9,10,11]],
        wt_log_prob_mat=wt_log_prob_mat,
        clip_seq_level_outputs=clip_seq_level_outputs
    )
    
    end = time.time()    
    result['compute_time'] = end - start

    return result

In [3]:
PREPEND_M = True
DATA_DIR = 'round_6/'

CHECKPOINT = os.path.join('checkpoint',
        'AminoBERT_runs_v2_uniparc_dataset_v2_5-1024_fresh_start_model.ckpt-1100000')

## Load and process sequences

In [4]:
def fasta_read(fasta_file):
    headers = []
    seqs = []
    for seq_record in SeqIO.parse(fasta_file, 'fasta'):
        headers.append(seq_record.id)
        seqs.append(str(seq_record.seq))
    
    return headers, seqs

In [5]:
# Sequences to predict structures for. 1 sequence per fasta.
fastas = glob.glob(os.path.join(DATA_DIR, '*.fa'))
print(len(fastas))
display(fastas)

# Read in sequences.
headers, seqs = zip(*[fasta_read(f) for f in fastas])

# Add a stop char to each sequence to be consistent
# with how the model was trained.
headers = [h[0] for h in headers]
seqs = [s[0] + '*' for s in seqs]

# Prepend an M. Again reflective of how the model
# was trained.
if PREPEND_M:
    for i in range(len(seqs)):
        if seqs[i][0] != 'M':
            seqs[i] = 'M' + seqs[i]
            
# Remove sequences that are too long for the model
mask = np.array([len(s) for s in seqs]) <= 1023
print('Sequences being removed due to length:', np.sum(~mask))
print('Sequences being removed:', np.array(headers)[~mask], np.array(seqs)[~mask])

seqs = list(np.array(seqs)[mask])
headers = list(np.array(headers)[mask])
fastas = list(np.array(fastas)[mask])

# Take a look at the seqs
display(seqs)
print([len(s) for s in seqs])

20


['round_6/T1088.fa',
 'round_6/T1075.fa',
 'round_6/T1072s1.fa',
 'round_6/T1082.fa',
 'round_6/T1085.fa',
 'round_6/T1089.fa',
 'round_6/T1073.fa',
 'round_6/T1070.fa',
 'round_6/T1079.fa',
 'round_6/T1083.fa',
 'round_6/T1078.fa',
 'round_6/T1077.fa',
 'round_6/T1087.fa',
 'round_6/T1084.fa',
 'round_6/T1086.fa',
 'round_6/T1074.fa',
 'round_6/T1071.fa',
 'round_6/T1080.fa',
 'round_6/T1076.fa',
 'round_6/T1090.fa']

Sequences being removed due to length: 0
Sequences being removed: [] []


['MDGKFTLGAGVGVVEHPYKQYDADVYPVPVISYESENFWFHGLGGGYYLWNDTNDKLSITAYWSPMYFKPGDSDSEQMRRLDKRKSTVMAGLSYVHNTPYGFLRTTIAGDTLDNSNGINWDLAWLYRYTNGNLTLTPGIGVEWNSDNQNEYYYGVSRHESRRSGMRSYDPDSSWNPYLELSANYRFLGDWSVYGVARYTRLSDEITDSPMVDKSWSGLISTGITYTF*',
 'METEQPEETFPNTETNGEFGKRPAEDMEEEQAFKRSRNTDEMVELRILLQSKNAGAVIGKGGKNIKALRTDYNASVSVPDSSGPERILSISADIETIGEILKKIIPTLEEGLQLPSPTATSQLPLESDAVECLNYQHYKGSDFDCELRLLIHQSLAGGIIGVKGAKIKELRENTQTTIKLFQECCPHSTDRVVLIGGKPDRVVECIKIILDLISESPIKGRAQPYDPNFYDETYDYGGFTMMFDDRRGRPVGFPMRGRGGFDRMPPGRGGRPMPPSRRDYDDMSPRRGPPPPPPGRGGRGGSRARNLPLPPPPPPRGGDLMAYDRRGRPGDRYDGMVGFSADETWDSAIDTWSPSEWQMAYEPQGGSGYDYSYAGGRGSYGDLGGPIITTQVTIPKDLAGSIIGKGGQRIKQIRHESGASIKIDEPLEGSEDRIITITGTQDQIQNAQYLLQNSVKQYSGKFF*',
 'MGSMGLYFSSLDSSIDILQKRAQELIENINKSRQKDHALMTNFRNSLKTKVSDLTEKLEERIYQIYNDHNKIIQEKLQEFTQKMAKISHLETELKQVCHSVE*',
 'MKKFIFATIFALASCAAQPAMAGYDKDLCEWSMTADQTEVETQIEADIMNIVKRDRPEMKAEVQKQLKSGGVMQYNYVLYCDKNFNNKNIIAEVVGE*',
 'MVILAKYNVEKDAAAKKSQAEKLEKNLLLGIENAEKSKDASLLVAAQDDYLQNSVLKQKSFDVQYQKTYAIYQKGDYAVAADQLK

[228, 464, 103, 98, 590, 405, 256, 336, 506, 100, 139, 160, 95, 74, 409, 203, 489, 923, 554, 195]


## Inference.

In [6]:
qfunc = np.random.randn(len(seqs)) # dummy labels. Ignore this.
inf_result = run_prediction(seqs, qfunc, CHECKPOINT) 

Featurizing input
INFO:tensorflow:Using config: {'_experimental_distribute': None, '_keep_checkpoint_every_n_hours': 10000, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f466378bb70>, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_is_chief': True, '_model_dir': '../../data/test/', '_num_worker_replicas': 1, '_master': '', '_service': None, '_keep_checkpoint_max': 5, '_device_fn': None, '_save_checkpoints_steps': 3, '_log_step_count_steps': None, '_task_type': 'worker', '_num_ps_replicas': 0, '_tpu_config': TPUConfig(iterations_per_loop=200, num_shards=8, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_save_checkpoints_secs': None, '_tf_random_seed': None, '_cluster': None, '_evaluation_master': '', '_global_id_in_cluster': 0, '_task_id': 0, '_save_summary_steps': 100, '_eval_distrib

INFO:tensorflow:  name = bert/encoder/layer_3/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_3/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_3/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_3/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_3/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_3/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_3/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_3/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_3/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflo

INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKP

## Ordinarily, save the contextual embeddings

But we won't do that here just to double check we can reproduce the results I generated in my other repo and handed off to Ratul & Mohammed.

```
print('Writing numpy arrays')
for j in range(len(seqs)):
    assert inf_result['predict']['seq_output'][j].shape[0] == len(seqs[j])
    assert headers[j] in fastas[j], (headers[j], fastas[j])

    outfile = fastas[j] + '.npy'
    np.save(outfile, inf_result['predict']['seq_output'][j])
```

The contextual embeddings are contained in `inf_result['predict']['seq_output']`

**IMPORTANT:** Instead of returning the full (max_seq_len=1024, 768) matrix of encoder outputs, we clip the encoder output to include only those token embeddings that correspond to the protein sequence.

So, the CLS output is clipped off, as well as all outputs after the sequence stop char.

Hence ...

In [7]:
for j in range(len(seqs)):
    assert inf_result['predict']['seq_output'][j].shape[0] == len(seqs[j])

### Double check this tf.Estimator inference matches the previous tf.Estimator  inference I did in my other repo.

In [8]:
for i in range(len(seqs)):
    prev_result = np.load(fastas[i] + '.npy')
    
    assert np.allclose(inf_result['predict']['seq_output'][i], prev_result)
    print(np.amax(inf_result['predict']['seq_output'][i] - prev_result))

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0


Cool moving code over to this new repo didn't affect anything.

# Re-implementation without tf.Estimator

In [9]:
MAX_SEQ_LENGTH = 1024
BERT_CONFIG_FILE = 'AminoBERT_config_v2.json'
USE_ONE_HOT_EMBEDDINGS = False
TOKENIZER = tokenization.FullTokenizer(k=1)

bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG_FILE)

### Tensorflow graph

In [10]:
tf.reset_default_graph()

input_ids = tf.placeholder(tf.int32, shape=(None, MAX_SEQ_LENGTH), name='input_ids')
input_mask = tf.placeholder(tf.int32, shape=(None, MAX_SEQ_LENGTH), name='input_mask')
token_type_ids = tf.placeholder(tf.int32, shape=(None, MAX_SEQ_LENGTH), name='token_type_ids')
y_true = tf.placeholder(tf.float32, [None])

model = modeling.BertModel(
    config=bert_config,
    is_training=False,
    input_ids=input_ids,
    input_mask=input_mask,
    token_type_ids=token_type_ids,
    use_one_hot_embeddings=USE_ONE_HOT_EMBEDDINGS)

### Add RGN here ###
# e.g. rgn(model.sequence_output)

## Initialize from checkpoint
(assignment_map, initialized_variable_names) = (
        modeling.get_assignment_map_from_checkpoint(
                tf.trainable_variables(), CHECKPOINT)
)

# Replaces variable initializers such that the initialize from the checkpoint
# when we call tf.global_variables_initializer()
tf.train.init_from_checkpoint(CHECKPOINT, assignment_map)

### Prep sequence input

In [11]:
seqs = run_finetuning_and_prediction.check_seqs(
        seqs, max_seq_length=MAX_SEQ_LENGTH) 

input_dict = run_finetuning_and_prediction.generate_input_features_from_seq_list(
        seqs, None, TOKENIZER, pad_to=MAX_SEQ_LENGTH)

### Run graph

In [12]:
with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
    # Initialize variables. Do random initialization or init from checkpoint.
    sess.run(tf.global_variables_initializer())

    feed_dict = {
        input_ids: np.array(input_dict['input_ids']),
        input_mask: np.array(input_dict['input_mask']),
        token_type_ids: np.array(input_dict['segment_ids']),
    }
    
    seq_out_new = sess.run(model.sequence_output, feed_dict=feed_dict)

In [13]:
print(seq_out_new.shape)

(20, 1024, 768)


As we can see, the sequence level output (hidden states from last encoder layer) contains hidden state vectors for all 1024 token positions. In order to compare to the results generated above, we'll need to trim to keep only vectors corresponding to residues in our seq.

In [14]:
for i in range(seq_out_new.shape[0]):
    
    so_clip = seq_out_new[i][1:(len(seqs[i])+1),:] # Exclude CLS token
    
    assert np.allclose(so_clip, inf_result['predict']['seq_output'][i])
    print(np.amax(so_clip - inf_result['predict']['seq_output'][i]))

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0


Looks like our manually exposed sequence outputs match the ones generated when we run the inference through the tf.Estimator API.