In [1]:
%load_ext autoreload
%autoreload 2
%cd "python/LatentEvolution"
%ls

/home/hew/python/LatentEvolution
[0m[01;34mcache[0m/  [01;34mdata[0m/  [01;34mfigure[0m/  [01;34mframework[0m/  main.py  [01;34mscript[0m/  [01;34mtemp[0m/


In [2]:
import numpy as np
import pandas as pd
import torch
from scipy.stats import pearsonr, spearmanr

from framework.utils.lightning.trainer_utils import LitInference
from script.task_02_ProteinVAE.ProteinVAE.sequence_data_module import SequenceDataModule
from script.task_02_ProteinVAE.ProteinVAE.sequence_lightning_module import SequenceLightningModule

root_path: /home/hew/python/LatentEvolution
framework_path: /home/hew/python/LatentEvolution/framework


## load train data

In [3]:
version = 2
diffusion_model_train_data = f'./script/task_04_LatentRegression/version_{version}_train_data.pt'
train_data = torch.load(diffusion_model_train_data)
embeddings = train_data['z']
ddG = train_data['ddG']
dS = train_data['dS']

In [4]:
min(ddG), max(dS)

(-5.5497, 0.078)

## model predict

In [5]:
log_dir = "./script/task_02_ProteinVAE/lightning_logs/"
version = 2
epoch = '95, loss=35.447, ce=18.697, reg=10.702, mse=6.049, ddG=0.537, dS=0.861.ckpt'
inferencer = LitInference(SequenceLightningModule, SequenceDataModule, log_dir, version, epoch)
tokenizer = inferencer.pl_data_module.tokenizer
device = torch.device('cuda:0')
model = inferencer.model
model = model.eval()
model

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 42




ProteinVAE(
  (encoder_transformer): ESMTransformer(
    (embed_tokens): Embedding(33, 128, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=128, out_features=128, bias=True)
          (v_proj): Linear(in_features=128, out_features=128, bias=True)
          (q_proj): Linear(in_features=128, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (final_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=128, out_features=128, bias=True)

In [6]:
# input latent vectors, output generated sequences
def decode_z(z):
    logits = model.decode(z)[:, 1:-1]
    tokens = torch.argmax(logits, dim=-1)
    sequences = tokenizer.decode(tokens)
    return sequences


# input amino acid sequences, output attributes
def model_forward(seqs):
    inferencer.set_batch_size(128)
    predictions = inferencer.predict(seqs)
    true_tokens = torch.concat([batch['tokens'] for batch in predictions])
    recon_tokens = torch.concat([batch['recon_tokens'] for batch in predictions])
    predict_z = torch.concat([batch['z'] for batch in predictions])
    predict_mean = torch.concat([batch['mean'] for batch in predictions])
    predict_logvar = torch.concat([batch['logvar'] for batch in predictions])
    predict_ddG = torch.concat([batch['pred_ddG'] for batch in predictions])
    predict_dS = torch.concat([batch['pred_dS'] for batch in predictions])
    true_ddG = torch.concat([batch['ddG'] for batch in predictions])
    true_dS = torch.concat([batch['dS'] for batch in predictions])
    return predict_ddG, predict_dS, true_ddG, true_dS

In [7]:
z = embeddings[:2].to(device)
gen_seqs = decode_z(z)
gen_seqs

100%|██████████| 2/2 [00:00<00:00, 790.41it/s]


['STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWLAFLKEQSTLAQMYPLEEIQNLTVKLQLQALQ',
 'STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLDEIQNLTVKLQLQALQ']

In [8]:
seqs = [
    'STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDQWSAFLKEQSTIAQMYPLQEIKNLTVKLVLQALQ',
    'STEEEQAKTFLDKFNHEAEDYFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQLRALL',
    'STIEEQAKIFLDKFNHEAEDLFYQSSLRSFNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQLQALQ'
]
predict_ddG, predict_dS, true_ddG, true_dS = model_forward(seqs)
predict_ddG, predict_dS, true_ddG, true_dS

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


prepare_data predict_data kwargs {'predict_data': ['STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDQWSAFLKEQSTIAQMYPLQEIKNLTVKLVLQALQ', 'STEEEQAKTFLDKFNHEAEDYFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQLRALL', 'STIEEQAKIFLDKFNHEAEDLFYQSSLRSFNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQLQALQ']}
len(self.predict_index) 3
[len self.predict_dataset] 3


Predicting: 0it [00:00, ?it/s]

(tensor([0.9730, 0.9791, 0.8270]),
 tensor([-0.9108, -0.8529, -0.9006]),
 tensor([nan, nan, nan]),
 tensor([nan, nan, nan]))

In [9]:
df = pd.read_csv('./data/ACE2_variants_2k/cooked/ACE2_variants_2k.csv')
df

Unnamed: 0,index,name,partition,length,sequence,structure,graph,ddG,dS
0,0,0,,83,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,,,-1.0838,0.017
1,1,1,,83,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,,,-0.0154,0.017
2,10,10,,83,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,,,-0.8987,0.017
3,100,100,,83,STIEEQAKTFLDKFNHDAEDLFYQSFLASWNYNTNITEENVQNMNN...,,,-1.1936,0.017
4,1000,1000,,83,SDIEEQAKTFLDKFNHEAEDLFYQSSLAYWNYNTNITEENVQNMGN...,,,-2.5357,0.017
...,...,...,...,...,...,...,...,...,...
2399,995,995,,83,STIEEQAKTFLDKFNHEAEDLFYQSDLARWNYNTNITEENVQNMNN...,,,-0.4275,0.017
2400,996,996,,83,STIEEQAKTFLDKFNHEAEDLFYQSSLASWWYNTNITEENVQNMNN...,,,-1.6738,0.017
2401,997,997,,83,STIEEQAKTFLDKFNHEAEDLFYQMSLASWNYNTNITEENVQNMNN...,,,-1.9267,0.017
2402,998,998,,83,SDIEEQAKMFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,,,-1.3416,0.017


In [10]:
seqs = df.sequence.tolist()
true_ddG = df.ddG.tolist()
true_dS = df.dS.tolist()
predict_ddG, predict_dS, _, _ = model_forward(seqs)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


prepare_data predict_data kwargs {'predict_data': ['STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWLAFLKEQSTLAQMYPLEEIQNLTVKLQLQALQ', 'STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLDEIQNLTVKLQLQALQ', 'STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAIDVWEAFLKEQSTLAQMYPLQEIQNLTVKLQLWALQ', 'STIEEQAKTFLDKFNHDAEDLFYQSFLASWNYNTNITEENVQNMNNARDKWSAFLKEQSTLAQMYPLQEIDNLTVKLQLQALQ', 'SDIEEQAKTFLDKFNHEAEDLFYQSSLAYWNYNTNITEENVQNMGNAGDKWSAFLKEQSTLAQMYPLQEIQNLTPKLQLQALQ', 'STIEEQAKTFMDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWEAFLKEQSTLAQMYPLQEIQNLTVKLQLQALQ', 'STIEEQAETFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLWEMSTLAQMYPLQEIQNLTVKLALQALQ', 'FTIEEQAKTFLDKFNHEAEDLFYQSSLAEWNYNTNITEENVQNMNNAGDKWSAFLIEQSTLAQMYPLQEIQNLTVKLQLQALQ', 'SEIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNAAGDKWSAFLKEQSTLAQMYPLQEIQNLPVKLQLQALQ', 'STIEEQAKTFLDMFNEEAEDLFYQSSLAIWNYNTNITEENVQNMNNAGDKWSAFLKEISTYAQMYPLQEIQNLTVKLQLQALQ', 'STIEEQAKTFLDKFNHEAEDLFYQSSLARWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQ

Predicting: 0it [00:00, ?it/s]

In [11]:
predict_ddG

tensor([0.7948, 0.9767, 0.7347,  ..., 0.6900, 0.5558, 0.4590])

In [12]:
pearsonr(predict_ddG, true_ddG)[0], pearsonr(predict_dS, true_dS)[0]

(0.8245244368982572, 0.8985297411531991)

In [13]:
spearmanr(predict_ddG, true_ddG)[0], spearmanr(predict_dS, true_dS)[0]

(0.7472712859713769, 0.5157004220257524)

In [14]:
sorted_true_ddG_indices = np.argsort(true_ddG)
sorted_true_ddG = np.array(true_ddG)[sorted_true_ddG_indices]
corresponding_predict_ddG = np.array(predict_ddG)[sorted_true_ddG_indices]
np.concatenate([sorted_true_ddG.reshape(-1, 1), corresponding_predict_ddG.reshape(-1, 1)], axis=1)

array([[-5.54970000e+00, -8.95914555e-01],
       [-5.37360000e+00, -7.41741717e-01],
       [-4.86450000e+00, -1.74640000e-01],
       ...,
       [-4.00000000e-03,  9.90709305e-01],
       [-2.60000000e-03,  9.69277203e-01],
       [-7.00000000e-04,  9.33099389e-01]])

In [15]:
sorted_true_dS_indices = np.argsort(true_dS)
sorted_true_dS = np.array(true_dS)[sorted_true_dS_indices]
corresponding_predict_dS = np.array(predict_dS)[sorted_true_dS_indices]
np.concatenate([sorted_true_dS.reshape(-1, 1), corresponding_predict_dS.reshape(-1, 1)], axis=1)

array([[ 0.005     , -0.97757703],
       [ 0.005     , -0.94977134],
       [ 0.005     , -0.97581601],
       ...,
       [ 0.078     ,  0.93465209],
       [ 0.078     ,  0.93038553],
       [ 0.078     ,  0.83474034]])

## train

## generation

min ddG: -6.0, lower
max dS: 0.086, higher


1. train diffusion model
2. modify ddG, dS, gen z
3. model.decode(z) -> GenSet (100000)
4. model(GenSet) -> pred_ddG, pred_dS
5. rank GenSet by pred_ddG, pred_dS
6. select top 1000
7. evaluate using FoldX. Protein-Sol