<a href="https://colab.research.google.com/github/areias/viral-escape/blob/main/FluBERTa_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Escape prediction validation

We obtained experimentally validated causal escape mutations to HA H1 WSN33 from Doud et al. (1)

* Make, in silico, all possible single-residue mutations to H1 WSN33
* For each of these mutations, compute semantic change and grammaticality and combine these scores using the CSCS rank-based acquisition function 
* rank all possible mutants usng the value of the CSCS acquisition function 
* to assess enrichment of acquired escape mutants, we constructed a curve that plotted the top n CSCS-acquired mutants on the x-axis and the corresponding number of these mutants that were also causal escape mutations on the y-axis; the area under this curve, normalized to the total possible area, resulted in our normalized AUC metric for evaluating escape enrichment. The AUC is normalized to be between 0 and 1, where a value of 0.5 indicates random guessing and higher values indicate greater enrichment.

## 1. Check GPU and RAM specifications

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Apr  6 09:50:27 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 27.3 gigabytes of available RAM

You are using a high-RAM runtime!


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# install dependencies
! pip install anndata 

Collecting anndata
  Downloading anndata-0.8.0-py3-none-any.whl (96 kB)
[?25l[K     |███▍                            | 10 kB 36.3 MB/s eta 0:00:01[K     |██████▉                         | 20 kB 8.6 MB/s eta 0:00:01[K     |██████████▎                     | 30 kB 7.7 MB/s eta 0:00:01[K     |█████████████▋                  | 40 kB 3.5 MB/s eta 0:00:01[K     |█████████████████               | 51 kB 3.6 MB/s eta 0:00:01[K     |████████████████████▌           | 61 kB 4.3 MB/s eta 0:00:01[K     |███████████████████████▉        | 71 kB 4.5 MB/s eta 0:00:01[K     |███████████████████████████▎    | 81 kB 4.8 MB/s eta 0:00:01[K     |██████████████████████████████▊ | 92 kB 5.4 MB/s eta 0:00:01[K     |████████████████████████████████| 96 kB 3.2 MB/s 
Installing collected packages: anndata
Successfully installed anndata-0.8.0


In [5]:
! pip install scanpy

Collecting scanpy
  Downloading scanpy-1.9.1-py3-none-any.whl (2.0 MB)
[K     |████████████████████████████████| 2.0 MB 4.3 MB/s 
[?25hCollecting matplotlib>=3.4
  Downloading matplotlib-3.5.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.2 MB)
[K     |████████████████████████████████| 11.2 MB 62.9 MB/s 
Collecting umap-learn>=0.3.10
  Downloading umap-learn-0.5.2.tar.gz (86 kB)
[K     |████████████████████████████████| 86 kB 7.3 MB/s 
Collecting session-info
  Downloading session_info-1.0.0.tar.gz (24 kB)
Collecting fonttools>=4.22.0
  Downloading fonttools-4.31.2-py3-none-any.whl (899 kB)
[K     |████████████████████████████████| 899 kB 59.2 MB/s 
Collecting pynndescent>=0.5
  Downloading pynndescent-0.5.6.tar.gz (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 64.9 MB/s 
[?25hCollecting stdlib_list
  Downloading stdlib_list-0.8.0-py3-none-any.whl (63 kB)
[K     |████████████████████████████████| 63 kB 2.5 MB/s 
[?25hBuilding wheels for collected package

In [6]:
! pip install bio


Collecting bio
  Downloading bio-1.3.6-py3-none-any.whl (273 kB)
[?25l[K     |█▏                              | 10 kB 39.6 MB/s eta 0:00:01[K     |██▍                             | 20 kB 8.3 MB/s eta 0:00:01[K     |███▋                            | 30 kB 7.5 MB/s eta 0:00:01[K     |████▉                           | 40 kB 3.5 MB/s eta 0:00:01[K     |██████                          | 51 kB 3.6 MB/s eta 0:00:01[K     |███████▏                        | 61 kB 4.2 MB/s eta 0:00:01[K     |████████▍                       | 71 kB 4.5 MB/s eta 0:00:01[K     |█████████▋                      | 81 kB 4.8 MB/s eta 0:00:01[K     |██████████▉                     | 92 kB 5.3 MB/s eta 0:00:01[K     |████████████                    | 102 kB 4.2 MB/s eta 0:00:01[K     |█████████████▏                  | 112 kB 4.2 MB/s eta 0:00:01[K     |██████████████▍                 | 122 kB 4.2 MB/s eta 0:00:01[K     |███████████████▋                | 133 kB 4.2 MB/s eta 0:00:01[K     |████

In [7]:
AAs = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H',
        'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W',
        'Y', 'V', 'X', 'Z', 'J', 'U', 'B', 'Z'
    ]
    
vocabulary = { aa: idx + 1 for idx, aa in enumerate(sorted(AAs)) }



In [8]:
len(vocabulary)

25

## 3. Load RoBERTa model

In [9]:
! pip install transformers

Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 4.1 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.0-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 7.8 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 75.5 MB/s 
Collecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 81.2 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 77.9 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Fo

In [10]:
! pip install tokenizers



In [11]:
# Check that PyTorch sees it
import torch
torch.cuda.is_available()

True

In [12]:
from transformers import PreTrainedTokenizerFast

fast_tokenizer = PreTrainedTokenizerFast(tokenizer_file="drive/MyDrive/FluBERTa/tokenizer/tokenizer-flu.json",
                                         pad_token='<pad>',
                                         bos_token='<s>',
                                         eos_token='</s>',
                                         mask_token='<mask>',
                                         unk_token='<unk>',
                                         max_len=512,
                                         padding='max_length')


In [13]:
# load model
from transformers import RobertaForMaskedLM

model = RobertaForMaskedLM.from_pretrained("drive/MyDrive/FluBERTa/checkpoint-11940")
model.eval()

RobertaForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(10000, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNor

In [14]:
import pandas as pd 
import numpy as np

### Original HA H1 WSN33 sequence

In [15]:
# add mutation to path
import sys
sys.path.append('drive/MyDrive/viral-mutation/bin')

from escape import load_doud2018, load_lee2019

seq_to_mutate, escape_seqs = load_doud2018()


In [16]:
seq_to_mutate

Seq('MKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSH...ICI')

In [17]:
len(seq_to_mutate)

565

In [18]:
inputs = fast_tokenizer(str(seq_to_mutate), return_tensors='pt',truncation=True, max_length=512, pad_to_max_length=True)



In [19]:
with torch.no_grad():
    preds=model(inputs.input_ids,output_hidden_states=True)
    token_logits = preds.logits
    original_embedding = [x.detach().numpy().mean(axis=0)  for x in preds[1][0]]

In [20]:
original_embedding[0].shape

(768,)

### Compute mutational probabilities of base sequence

In [41]:
len(seq_to_mutate)

565

In [42]:
# completed in 2:58hrs 
from collections import defaultdict
from scipy.spatial.distance import cityblock

for i in range(len(seq_to_mutate)):

    # intialize empty dicts
    seq_prob = defaultdict(dict)
    seq_change = defaultdict(dict)
 
    # mask sequence
    masked_seq=str(seq_to_mutate)[0:i] + fast_tokenizer.mask_token + str(seq_to_mutate)[i+1:]

    # get predicted probabilties 
    inputs = fast_tokenizer(masked_seq, return_tensors='pt',truncation=True, max_length=512, pad_to_max_length=True)
    with torch.no_grad():
        token_logits=model(inputs.input_ids).logits

    # limit to probability of masked token
    mask_token_index = torch.where(inputs.input_ids == fast_tokenizer.mask_token_id)[1]
    logits = token_logits[0, mask_token_index, :].squeeze()
    prob = logits.softmax(dim=0)

    # probability of single mutations
    values, indices = prob.topk(k=1000, dim=0)
    for index, token in enumerate(fast_tokenizer.convert_ids_to_tokens(indices)):
        if len(token)==1:
            seq_prob[masked_seq.replace(fast_tokenizer.mask_token, token)] = values[index].item()


    # get embeddings for mutations
    X_batch=list(seq_prob.keys())
    inputs = fast_tokenizer(X_batch, return_tensors='pt',truncation=True, max_length=512, pad_to_max_length=True)
    with torch.no_grad():
        outputs=model(inputs.input_ids, output_hidden_states=True)
        sequence_embeddings=[x.detach().numpy().mean(axis=0)  for x in outputs[1][0]]

    # get l1 distance from original embedding 
    for index, embedding in enumerate(sequence_embeddings):
        seq_change[X_batch[index]] = cityblock(original_embedding[0],embedding)

    # save to file
    prob_df = pd.DataFrame.from_dict(seq_prob, orient='index').reset_index()
    prob_df.to_csv("drive/MyDrive/FluBERTa/sequence_probabilities.csv", mode='a', header=False, index=False)

    change_df = pd.DataFrame.from_dict(seq_change, orient='index').reset_index()
    change_df.to_csv("drive/MyDrive/FluBERTa/sequence_change.csv", mode='a', header=False, index=False)

    if i%10==0:
        print(i)



0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560


In [48]:
prob_df = pd.read_csv("drive/MyDrive/FluBERTa/sequence_probabilities.csv", header=None)


In [49]:
len(prob_df)

12361

In [61]:
import matplotlib.pyplot as plt
prob_df.head()

Unnamed: 0,0,1
0,MKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,0.210636
1,CKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,0.001616
2,TKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,0.000836
3,AKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,0.00067
4,KKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,0.00065


In [62]:
prob_df.hist()

array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7faa1ded6990>]],
      dtype=object)

ImportError: ignored

<Figure size 432x288 with 1 Axes>

In [51]:
change_df = pd.read_csv("drive/MyDrive/FluBERTa/sequence_change.csv",header=None)
len(change_df)

12361

In [52]:
change_df.head()

Unnamed: 0,0,1
0,MKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,0.0
1,CKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,4.041759
2,TKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,2.813788
3,AKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,2.920809
4,KKAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHS...,2.830333


In [None]:
masked_seq=str(seq_to_mutate)[0:i] + fast_tokenizer.mask_token + str(seq_to_mutate)[i+1:]
masked_seq

'<mask>KAKLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAINGITNKVNSVIEKMNTQFTAVGKEFNNLEKRMENLNKKVDDGFLDIWTYNAELLVLLENERTLDFHDLNVKNLYEKVKSQLKNNAKEIGNGCFEFYHKCDNECMESVRNGTYDYPKYSEESKLNREKIDGVKLESMGVYQILAIYSTVASSLVLLVSLGAISFWMCSNGSLQCRICI'

In [None]:
inputs = fast_tokenizer(masked_seq, return_tensors='pt',truncation=True, max_length=512, pad_to_max_length=True)
with torch.no_grad():
    token_logits=model(inputs.input_ids).logits



In [None]:
# probability of single letter mutations

In [None]:
mask_token_index = torch.where(inputs.input_ids == fast_tokenizer.mask_token_id)[1]
mask_token_index

tensor([3])

In [None]:
logits = token_logits[0, mask_token_index, :].squeeze()
logits

tensor([ -2.2443, -12.3807,  -9.6227,  ...,  -3.0062,  -2.7548,  -1.2283])

In [None]:
prob = logits.softmax(dim=0)
prob

tensor([4.6908e-06, 1.8581e-10, 2.9298e-09,  ..., 2.1895e-06, 2.8153e-06,
        1.2956e-05])

In [None]:
values, indices = prob.topk(k=1000, dim=0)

In [None]:
from collections import defaultdict
seq_probs=defaultdict(dict)

for index, token in enumerate(fast_tokenizer.convert_ids_to_tokens(indices)):
    if len(token)==1:
        seq_probs[str.replace(masked_seq,"<mask>", token)] = values[index].item()


In [None]:
seq_probs

defaultdict(dict,
            {'MKAALLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAINGITNKVNSVIEKMNTQFTAVGKEFNNLEKRMENLNKKVDDGFLDIWTYNAELLVLLENERTLDFHDLNVKNLYEKVKSQLKNNAKEIGNGCFEFYHKCDNECMESVRNGTYDYPKYSEESKLNREKIDGVKLESMGVYQILAIYSTVASSLVLLVSLGAISFWMCSNGSLQCRICI': 0.02782510779798031,
             'MKABLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWY

In [None]:
X_batch=list(seq_probs.keys())

In [None]:
%%time
inputs = fast_tokenizer(X_batch, return_tensors='pt',truncation=True, max_length=512, pad_to_max_length=True)




CPU times: user 10.8 ms, sys: 882 µs, total: 11.7 ms
Wall time: 8.35 ms




In [None]:
with torch.no_grad():
    outputs=model(inputs.input_ids, output_hidden_states=True)
    sequence_embeddings=[x.detach().numpy().mean(axis=0)  for x in outputs[1][0]]


In [None]:
sequence_embeddings[0].shape

(768,)

In [None]:
# l1 distance between original and mutants
# equal to manhattan distance 
from scipy.spatial.distance import cityblock


In [None]:
l1_norm=defaultdict(dict)
for index, embedding in enumerate(sequence_embeddings):
    l1_norm[X_batch[index]] = cityblock(sequence_embedding[0],embedding)

In [None]:
l1_norm

defaultdict(dict,
            {'MKAALLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAINGITNKVNSVIEKMNTQFTAVGKEFNNLEKRMENLNKKVDDGFLDIWTYNAELLVLLENERTLDFHDLNVKNLYEKVKSQLKNNAKEIGNGCFEFYHKCDNECMESVRNGTYDYPKYSEESKLNREKIDGVKLESMGVYQILAIYSTVASSLVLLVSLGAISFWMCSNGSLQCRICI': 2.7880669,
             'MKABLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGS

In [None]:
seq_probs

defaultdict(dict,
            {'MKAALLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAINGITNKVNSVIEKMNTQFTAVGKEFNNLEKRMENLNKKVDDGFLDIWTYNAELLVLLENERTLDFHDLNVKNLYEKVKSQLKNNAKEIGNGCFEFYHKCDNECMESVRNGTYDYPKYSEESKLNREKIDGVKLESMGVYQILAIYSTVASSLVLLVSLGAISFWMCSNGSLQCRICI': 0.02782510779798031,
             'MKABLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWY

In [None]:
# rank by highest probability 
prob_rank={key: rank for rank, key in enumerate(sorted(seq_probs, key=seq_probs.get, reverse=True), 1)}

In [None]:
# rank by lowest l1 distance
l1_rank={key: rank for rank, key in enumerate(sorted(l1_norm, key=l1_norm.get, reverse=False), 1)}

In [None]:
prob_seqs.keys(l1_rank[list(l1_rank.keys())[0]] + prob_rank[l1_rank[list(l1_rank.keys())[0]] +]

1

In [None]:
rank=defaultdict(dict)
for key in list(seq_probs.keys()):
    rank[key]=l1_rank[key] + prob_rank[key]

In [None]:
# top k
k=20
topk=sorted(rank, key=rank.get, reverse=False)[:k]
topk

['MKALLLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAINGITNKVNSVIEKMNTQFTAVGKEFNNLEKRMENLNKKVDDGFLDIWTYNAELLVLLENERTLDFHDLNVKNLYEKVKSQLKNNAKEIGNGCFEFYHKCDNECMESVRNGTYDYPKYSEESKLNREKIDGVKLESMGVYQILAIYSTVASSLVLLVSLGAISFWMCSNGSLQCRICI',
 'MKAALLVLLYAFVATDADTICIGYHANNSTDTVDTILEKNVAVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNITGWLLGNPECDSLLPARSWSYIVETPNSENGACYPGDLIDYEELREQLSSVSSLERFEIFPKESSWPNHTFNGVTVSCSHRGKSSFYRNLLWLTKKGDSYPKLTNSYVNNKGKEVLVLWGVHHPSSSDEQQSLYSNGNAYVSVASSNYNRRFTPEIAARPKVRDQHGRMNYYWTLLEPGDTIIFEATGNLIAPWYAFALSRGFESGIITSNASMHECNTKCQTPQGAINSNLPFQNIHPVTIGECPKYVRSTKLRMVTGLRNIPSIQYRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAINGITNKVNSVIEKMNTQFTAVGKEFNNLEKRMENLNKKVD

In [None]:
for k in topk:
   is_escape = k in escape_seqs
   is_viable = sum([ m['significant']for m in escape_seqs[topk[1]] ]) > 0
   print(is_escape, is_viable)

True False
True False
True False
True False
True False
True False
True False
True False
True False
True False
False False
True False
True False
False False
True False
True False
True False
False False
True False
True False


In [None]:
probs, changes = [], []
with open(cache_fname, 'w') as of:
    fields = [ 'pos', 'wt', 'mut', 'prob', 'change',
                'is_viable', 'is_escape' ]
    of.write('\t'.join(fields) + '\n')
    for seq in seqs:
        prob = seq_prob[seq]
        change = seq_change[seq]
        mut = prob_seqs[seq][0]['word']
        pos = prob_seqs[seq][0]['pos']
        orig = seq_to_mutate[pos]
        is_viable = seq in escape_seqs
        is_escape = ((seq in escape_seqs) and
                        (sum([ m['significant']
                            for m in escape_seqs[seq] ]) > 0))
        fields = [ pos, orig, mut, prob, change, is_viable, is_escape ]
        of.write('\t'.join([ str(field) for field in fields ]) + '\n')
        probs.append(prob)
        changes.append(change)


In [None]:
beta=1.
if plot_acquisition:
    from cached_semantics import cached_escape



In [None]:
cache_fname

'target/flu/semantics/cache/analyze_semantics_flu_bilstm_512.txt'

In [None]:
if plot_acquisition:
    from cached_semantics import cached_escape
    cached_escape(cache_fname, beta,
                plot=plot_acquisition,
                namespace=plot_namespace)

In [None]:
# from cached_semantics.py
prob_cutoff=0.
beta=1.
plot_acquisition=True,
plot_namespace='flu_h1'
cutoff=None

from escape import load_doud2018
if cutoff is None:
    wt_seq, seqs_escape = load_doud2018()


In [None]:
prob, change, escape_idx, viable_idx = [], [], [], []
with open(cache_fname) as f:
    f.readline()
    for line in f:
        fields = line.rstrip().split('\t')
        pos = int(fields[0])
        if 'rbd' in cache_fname:
            if pos < 330 or pos > 530:
                continue
        if fields[2] in { 'U', 'B', 'J', 'X', 'Z' }:
            continue
        aa_wt = fields[1]
        aa_mut = fields[2]
        assert(wt_seq[pos] == aa_wt)
        mut_seq = wt_seq[:pos] + aa_mut + wt_seq[pos+1:]
        if mut_seq not in seqs_escape:
            continue
        prob.append(float(fields[3]))
        change.append(float(fields[4]))
        viable_idx.append(fields[5] == 'True')
        escape_idx.append(
            (mut_seq in seqs_escape) and
            (sum([ m['significant']
                    for m in seqs_escape[mut_seq] ]) > 0)
        )

In [None]:
from sklearn.metrics import auc
from cached_semantics import compute_p

prob, orig_prob = np.array(prob), np.array(prob)
change, orig_change  = np.array(change), np.array(change)
escape_idx = np.array(escape_idx)
viable_idx = np.array(viable_idx)

acquisition = ss.rankdata(change) + (beta * ss.rankdata(prob))

pos_change_idx = change > 0

pos_change_escape_idx = np.logical_and(pos_change_idx, escape_idx)
escape_prob = prob[pos_change_escape_idx]
escape_change = change[pos_change_escape_idx]
prob = prob[pos_change_idx]
change = change[pos_change_idx]

log_prob, log_change = np.log10(prob), np.log10(change)
log_escape_prob, log_escape_change = (np.log10(escape_prob),
                                        np.log10(escape_change))

plot=True
namespace='flu_h1'
if plot:
    mkdir_p('figures')

    plt.figure()
    plt.scatter(log_prob, log_change, c=acquisition[pos_change_idx],
                cmap='viridis', alpha=0.3)
    plt.scatter(log_escape_prob, log_escape_change, c='red',
                alpha=0.5, marker='x')
    plt.xlabel(r'$ \log_{10}(\hat{p}(x_i | \mathbf{x}_{[N] ∖ \{i\} })) $')
    plt.ylabel(r'$ \log_{10}(\Delta \mathbf{\hat{z}}) $')
    plt.savefig('figures/{}_acquisition.png'
                .format(namespace), dpi=300)
    plt.close()

    rand_idx = np.random.choice(len(prob), len(escape_prob))
    plt.figure()
    plt.scatter(log_prob, log_change, c=acquisition[pos_change_idx],
                cmap='viridis', alpha=0.3)
    plt.scatter(log_prob[rand_idx], log_change[rand_idx], c='red',
                alpha=0.5, marker='x')
    plt.xlabel(r'$ \log_{10}(\hat{p}(x_i | \mathbf{x}_{[N] ∖ \{i\} })) $')
    plt.ylabel(r'$ \log_{10}(\Delta \mathbf{\hat{z}}) $')
    plt.savefig('figures/{}_acquisition_rand.png'
                .format(namespace), dpi=300)
    plt.close()

if len(escape_prob) == 0:
    print('No escape mutations found.')
    #return

acq_argsort = ss.rankdata(-acquisition)
escape_rank_dist = acq_argsort[escape_idx]

size = len(prob)
print('Number of escape seqs: {} / {}'
        .format(len(escape_rank_dist), sum(escape_idx)))
print('Mean rank: {} / {}'.format(np.mean(escape_rank_dist), size))
print('Median rank: {} / {}'.format(np.median(escape_rank_dist), size))
print('Min rank: {} / {}'.format(np.min(escape_rank_dist), size))
print('Max rank: {} / {}'.format(np.max(escape_rank_dist), size))
print('Rank stdev: {} / {}'.format(np.std(escape_rank_dist), size))

max_consider = len(prob)
n_consider = np.array([ i + 1 for i in range(max_consider) ])

n_escape = np.array([ sum(escape_rank_dist <= i + 1)
                        for i in range(max_consider) ])
norm = max(n_consider) * max(n_escape)
norm_auc = auc(n_consider, n_escape) / norm

escape_rank_prob = ss.rankdata(-orig_prob)[escape_idx]
n_escape_prob = np.array([ sum(escape_rank_prob <= i + 1)
                            for i in range(max_consider) ])
norm_auc_prob = auc(n_consider, n_escape_prob) / norm

escape_rank_change = ss.rankdata(-orig_change)[escape_idx]
n_escape_change = np.array([ sum(escape_rank_change <= i + 1)
                                for i in range(max_consider) ])
norm_auc_change = auc(n_consider, n_escape_change) / norm

if plot:
    plt.figure()
    plt.plot(n_consider, n_escape)
    plt.plot(n_consider, n_escape_change, c='C0', linestyle='-.')
    plt.plot(n_consider, n_escape_prob, c='C0', linestyle=':')
    plt.plot(n_consider, n_consider * (len(escape_prob) / len(prob)),
                c='gray', linestyle='--')

    plt.xlabel(r'$ \log_{10}() $')
    plt.ylabel(r'$ \log_{10}(\Delta \mathbf{\hat{z}}) $')

    plt.legend([
        r'$ \Delta \mathbf{\hat{z}} + ' +
        r'\beta \hat{p}(x_i | \mathbf{x}_{[N] ∖ \{i\} }) $,' +
        (' AUC = {:.3f}'.format(norm_auc)),
        r'$  \Delta \mathbf{\hat{z}} $ only,' +
        (' AUC = {:.3f}'.format(norm_auc_change)),
        r'$ \hat{p}(x_i | \mathbf{x}_{[N] ∖ \{i\} }) $ only,' +
        (' AUC = {:.3f}'.format(norm_auc_prob)),
        'Random guessing, AUC = 0.500'
    ])
    plt.xlabel('Top N')
    plt.ylabel('Number of escape mutations in top N')
    plt.savefig('figures/{}_consider_escape.png'
                .format(namespace), dpi=300)
    plt.close()


print('Escape semantics, beta = {} [{}]'
        .format(beta, namespace))

norm_auc_p = compute_p(norm_auc, sum(escape_idx), len(escape_idx))

print('AUC (CSCS): {}, P = {}'.format(norm_auc, norm_auc_p))
print('AUC (semantic change only): {}'.format(norm_auc_change))
print('AUC (grammaticality only): {}'.format(norm_auc_prob))

print('{:.4g} (mean log prob), {:.4g} (mean log prob escape), '
        '{:.4g} (p-value)'
        .format(log_prob.mean(), log_escape_prob.mean(),
                ss.mannwhitneyu(log_prob, log_escape_prob,
                                alternative='two-sided')[1]))
print('{:.4g} (mean log change), {:.4g} (mean log change escape), '
        '{:.4g} (p-value)'
        .format(change.mean(), escape_change.mean(),
                ss.mannwhitneyu(change, escape_change,
                                alternative='two-sided')[1]))


Number of escape seqs: 170 / 170
Mean rank: 3281.8382352941176 / 10735
Median rank: 2551.75 / 10735
Min rank: 9.0 / 10735
Max rank: 10421.0 / 10735
Rank stdev: 2796.440226958164 / 10735
Escape semantics, beta = 1.0 [flu_h1]
AUC (CSCS): 0.6943165566179895, P = 0.0
AUC (semantic change only): 0.5365467547056084
AUC (grammaticality only): 0.7185742075125345
-4.292 (mean log prob), -3.139 (mean log prob escape), 1.189e-22 (p-value)
2026 (mean log change), 2228 (mean log change escape), 0.1015 (p-value)


#### Flu escape prediction graphs

* X-axis is Grammaticality
* Y-axis is Semantic change 

Red x's represent escape mutants, they are mostly localted in the high semantic change and high grammaticallity quadrant (upper right)

![](https://raw.githubusercontent.com/areias/viral-escape/main/figures/flu_h1_acquisition.png) 


The figure below shows our model is learning the grammaticallity of a sequence (the output probability of the model) better than it is the semantic representation.

> Internally, the language model constructs a semantic representa-
tion, or an “embedding,” for a given sequence
(6), and the output of a language model en-
codes how well a particular token fits within
the rules of the language, which we call “gram-
maticality”



![](https://raw.githubusercontent.com/areias/viral-escape/main/figures/flu_h1_consider_escape.png) 