# SARS-CoV2 - ACE2

## mutSpike-ACE2: 

* 6M0J: WT
* 7EKF: Alpha(N501Y)
* 7EKG: Beta(K417N+E484K+N501Y)
* 7WBQ: Delta(L452R+T478K)
* ETA: Eta(E484K)

* 7WBP: Omicron 


### clean PDBs

In [26]:
import os
from Bio.PDB import PDBParser, PDBIO, PPBuilder
from Bio.PDB import Structure, Model, Chain, Residue, Atom

original_pdb_path = "./datasets/mutSpike_ACE2/PDBs/"
cleaned_pdb_path = "./datasets/mutSpike_ACE2/cleaned_PDBs/"
os.makedirs(cleaned_pdb_path, exist_ok=True)

pdb_name_list = ["6m0j", "7ekf", "7ekg", "7wbq", "eta"]
for pdb_name in pdb_name_list:
    pdb_filepath = f"{original_pdb_path}/{pdb_name}.pdb"

    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_name, pdb_filepath)
    new_structure = Structure.Structure(pdb_name)

    model = structure[0]
    new_model = Model.Model(model.id)

    for chain in model:
        print(chain)
        new_chain = Chain.Chain(chain.id)

        chain = [res for res in chain.get_residues() if res.id[0] == ' ']  # exclude HOH...
        for residue in chain:
            # print(residue)
            if set(['N', 'CA', 'C', 'O']).issubset(set(residue.child_dict.keys())):
                new_chain.add(residue)
            else:
                if not (residue.id[1] == 1 or residue.id[1] == len(chain)):
                    print(f"WARNING: delete in the middle {residue.full_id} of total len {len(chain)}")

        if len(new_chain) > 0:
            new_model.add(new_chain)
    if len(new_model) > 0:
        new_structure.add(new_model)

    io = PDBIO()
    io.set_structure(new_structure)
    io.save(f"{cleaned_pdb_path}/{pdb_name.upper()}.pdb")


<Chain id=A>
<Chain id=E>


### clean csv

In [27]:
import pandas as pd

df = pd.read_csv("./datasets/mutSpike_ACE2/mutSpike-ACE2.csv")
df
# 201 * 20 * 5 = 20100

Unnamed: 0,target,wildtype,position,mutant,mutation,bind,delta_bind,n_bc_bind,n_libs_bind,bind_rep1,bind_rep2,bind_rep3,expr,delta_expr,n_bc_expr,n_libs_expr,expr_rep1,expr_rep2
0,Beta,N,331,A,N331A,9.41007,0.10297,6,3,9.31807,9.49961,9.41253,9.72294,-0.27844,3,2,9.63034,9.81554
1,Beta,N,331,C,N331C,9.11229,-0.19481,27,3,9.06925,9.09584,9.17180,9.36532,-0.63606,17,2,9.18919,9.54145
2,Beta,N,331,D,N331D,9.31717,0.01007,23,3,9.18921,9.38256,9.37976,9.77842,-0.22296,15,2,9.70504,9.85180
3,Beta,N,331,E,N331E,9.36899,0.06189,17,3,9.36041,9.42835,9.31821,10.06567,0.06430,10,2,9.93181,10.19954
4,Beta,N,331,F,N331F,9.17224,-0.13487,22,3,9.07594,9.17972,9.26104,9.42578,-0.57560,14,2,9.30185,9.54970
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20095,Delta,T,531,S,T531S,9.03510,-0.00015,24,2,9.00314,9.06706,,9.81783,0.02757,25,2,9.81554,9.82012
20096,Delta,T,531,T,T531T,9.03525,0.00000,36189,2,9.00711,9.06338,,9.79026,0.00000,37693,2,9.75058,9.82995
20097,Delta,T,531,V,T531V,9.07490,0.03965,27,2,9.07401,9.07580,,9.70583,-0.08443,27,2,9.61395,9.79772
20098,Delta,T,531,W,T531W,9.06536,0.03012,24,2,9.01741,9.11332,,9.77484,-0.01542,26,2,9.84511,9.70457


In [28]:

df_wt = df[df["target"] == "Wuhan-Hu-1"]
df_alpha = df[df["target"] == "N501Y"]
df_beta = df[df["target"] == "Beta"]
df_delta = df[df["target"] == "Delta"]
df_eta = df[df["target"] == "E484K"]

In [29]:
df_wt = df_wt[["target", "wildtype", "position", "mutant", "delta_bind"]]
df_alpha = df_alpha[["target", "wildtype", "position", "mutant", "delta_bind"]]
df_beta = df_beta[["target", "wildtype", "position", "mutant", "delta_bind"]]
df_delta = df_delta[["target", "wildtype", "position", "mutant", "delta_bind"]]
df_eta = df_eta[["target", "wildtype", "position", "mutant", "delta_bind"]]

df_alpha["target"] = "Alpha"
df_eta["target"] = "Eta"
df_alpha

Unnamed: 0,target,wildtype,position,mutant,delta_bind
8040,Alpha,N,331,A,
8041,Alpha,N,331,C,-0.36219
8042,Alpha,N,331,D,-0.08671
8043,Alpha,N,331,E,-0.02208
8044,Alpha,N,331,F,-0.36872
...,...,...,...,...,...
12055,Alpha,T,531,S,
12056,Alpha,T,531,T,0.00000
12057,Alpha,T,531,V,0.53498
12058,Alpha,T,531,W,-0.02847


In [31]:
df_wt["pdb"] = "6M0J"
df_alpha["pdb"] = "7EKF"
df_beta["pdb"] = "7EKG"
df_delta["pdb"] = "7WBQ"
df_eta["pdb"] = "ETA"
df_wt

Unnamed: 0,target,wildtype,position,mutant,delta_bind,pdb
12060,Wuhan-Hu-1,N,331,A,0.06027,6M0J
12061,Wuhan-Hu-1,N,331,C,-0.15567,6M0J
12062,Wuhan-Hu-1,N,331,D,-0.01751,6M0J
12063,Wuhan-Hu-1,N,331,E,0.15400,6M0J
12064,Wuhan-Hu-1,N,331,F,-0.11470,6M0J
...,...,...,...,...,...,...
16075,Wuhan-Hu-1,T,531,S,,6M0J
16076,Wuhan-Hu-1,T,531,T,0.00000,6M0J
16077,Wuhan-Hu-1,T,531,V,-0.01541,6M0J
16078,Wuhan-Hu-1,T,531,W,-0.00779,6M0J


In [32]:
# def insert_chain_id(mutation, chain_id):
#     return mutation[:1] + chain_id + mutation[1:]

# df_wt["mutation"] = df_wt["mutation"].apply(lambda x: insert_chain_id(x, "E"))
# df_alpha["mutation"] = df_alpha["mutation"].apply(lambda x: insert_chain_id(x, "B"))
# df_beta["mutation"] = df_beta["mutation"].apply(lambda x: insert_chain_id(x, "B"))
# df_delta["mutation"] = df_delta["mutation"].apply(lambda x: insert_chain_id(x, "B"))
# df_wt

df_wt["mutation"] = df_wt["wildtype"] + "E" + df_wt["position"].astype(str) + df_wt["mutant"]
df_alpha["mutation"] = df_alpha["wildtype"] + "B" + df_alpha["position"].astype(str) + df_alpha["mutant"]
df_beta["mutation"] = df_beta["wildtype"] + "B" + df_beta["position"].astype(str) + df_beta["mutant"]
df_delta["mutation"] = df_delta["wildtype"] + "B" + df_delta["position"].astype(str) + df_delta["mutant"]
df_eta["mutation"] = df_eta["wildtype"] + "E" + df_eta["position"].astype(str) + df_eta["mutant"]
df_eta

Unnamed: 0,target,wildtype,position,mutant,delta_bind,pdb,mutation
4020,Eta,N,331,A,-0.05404,ETA,NE331A
4021,Eta,N,331,C,-0.50425,ETA,NE331C
4022,Eta,N,331,D,-0.07810,ETA,NE331D
4023,Eta,N,331,E,0.06027,ETA,NE331E
4024,Eta,N,331,F,-0.23645,ETA,NE331F
...,...,...,...,...,...,...,...
8035,Eta,T,531,S,0.06807,ETA,TE531S
8036,Eta,T,531,T,0.00000,ETA,TE531T
8037,Eta,T,531,V,,ETA,TE531V
8038,Eta,T,531,W,0.00958,ETA,TE531W


In [33]:
info = pd.concat([df_wt, df_alpha, df_beta, df_delta, df_eta])
info


Unnamed: 0,target,wildtype,position,mutant,delta_bind,pdb,mutation
12060,Wuhan-Hu-1,N,331,A,0.06027,6M0J,NE331A
12061,Wuhan-Hu-1,N,331,C,-0.15567,6M0J,NE331C
12062,Wuhan-Hu-1,N,331,D,-0.01751,6M0J,NE331D
12063,Wuhan-Hu-1,N,331,E,0.15400,6M0J,NE331E
12064,Wuhan-Hu-1,N,331,F,-0.11470,6M0J,NE331F
...,...,...,...,...,...,...,...
8035,Eta,T,531,S,0.06807,ETA,TE531S
8036,Eta,T,531,T,0.00000,ETA,TE531T
8037,Eta,T,531,V,,ETA,TE531V
8038,Eta,T,531,W,0.00958,ETA,TE531W


In [34]:

info = info[(info["position"] >= 333) & (info["position"] <= 526)]
info

Unnamed: 0,target,wildtype,position,mutant,delta_bind,pdb,mutation
12100,Wuhan-Hu-1,T,333,A,-0.02454,6M0J,TE333A
12101,Wuhan-Hu-1,T,333,C,-0.33095,6M0J,TE333C
12102,Wuhan-Hu-1,T,333,D,0.05865,6M0J,TE333D
12103,Wuhan-Hu-1,T,333,E,0.05119,6M0J,TE333E
12104,Wuhan-Hu-1,T,333,F,-0.37864,6M0J,TE333F
...,...,...,...,...,...,...,...
7935,Eta,G,526,S,-0.23774,ETA,GE526S
7936,Eta,G,526,T,-0.15366,ETA,GE526T
7937,Eta,G,526,V,-0.20412,ETA,GE526V
7938,Eta,G,526,W,-0.28501,ETA,GE526W


In [35]:
import numpy as np

# compute ddg
R = float(8.314/4184)
T = 300

info["ddg"] = R * T * (np.log(10) * -info["delta_bind"])
info

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  info["ddg"] = R * T * (np.log(10) * -info["delta_bind"])


Unnamed: 0,target,wildtype,position,mutant,delta_bind,pdb,mutation,ddg
12100,Wuhan-Hu-1,T,333,A,-0.02454,6M0J,TE333A,0.033684
12101,Wuhan-Hu-1,T,333,C,-0.33095,6M0J,TE333C,0.454274
12102,Wuhan-Hu-1,T,333,D,0.05865,6M0J,TE333D,-0.080505
12103,Wuhan-Hu-1,T,333,E,0.05119,6M0J,TE333E,-0.070265
12104,Wuhan-Hu-1,T,333,F,-0.37864,6M0J,TE333F,0.519735
...,...,...,...,...,...,...,...,...
7935,Eta,G,526,S,-0.23774,ETA,GE526S,0.326330
7936,Eta,G,526,T,-0.15366,ETA,GE526T,0.210919
7937,Eta,G,526,V,-0.20412,ETA,GE526V,0.280182
7938,Eta,G,526,W,-0.28501,ETA,GE526W,0.391215


In [37]:
col_order = ["target", "pdb", "mutation", "delta_bind", "ddg"]
info = info[col_order]

info.dropna(inplace=True)
info.reset_index(drop=True, inplace=True)
info

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  info.dropna(inplace=True)


Unnamed: 0,target,pdb,mutation,delta_bind,ddg
0,Wuhan-Hu-1,6M0J,TE333A,-0.02454,0.033684
1,Wuhan-Hu-1,6M0J,TE333C,-0.33095,0.454274
2,Wuhan-Hu-1,6M0J,TE333D,0.05865,-0.080505
3,Wuhan-Hu-1,6M0J,TE333E,0.05119,-0.070265
4,Wuhan-Hu-1,6M0J,TE333F,-0.37864,0.519735
...,...,...,...,...,...
19217,Eta,ETA,GE526S,-0.23774,0.326330
19218,Eta,ETA,GE526T,-0.15366,0.210919
19219,Eta,ETA,GE526V,-0.20412,0.280182
19220,Eta,ETA,GE526W,-0.28501,0.391215


In [38]:

info_wt = info[info["target"] == "Wuhan-Hu-1"]
info_alpha = info[info["target"] == "Alpha"]
info_beta = info[info["target"] == "Beta"]
info_delta = info[info["target"] == "Delta"]
info_eta = info[info["target"] == "Eta"]

info_wt.reset_index(drop=True, inplace=True)
info_alpha.reset_index(drop=True, inplace=True)
info_beta.reset_index(drop=True, inplace=True)
info_delta.reset_index(drop=True, inplace=True)
info_eta.reset_index(drop=True, inplace=True)
info_eta


Unnamed: 0,target,pdb,mutation,delta_bind,ddg
0,Eta,ETA,TE333A,-0.13683,0.187818
1,Eta,ETA,TE333C,-0.27489,0.377324
2,Eta,ETA,TE333E,-0.04957,0.068042
3,Eta,ETA,TE333F,0.16668,-0.228791
4,Eta,ETA,TE333G,0.04187,-0.057472
...,...,...,...,...,...
3804,Eta,ETA,GE526S,-0.23774,0.326330
3805,Eta,ETA,GE526T,-0.15366,0.210919
3806,Eta,ETA,GE526V,-0.20412,0.280182
3807,Eta,ETA,GE526W,-0.28501,0.391215


### EvoEF2 generate mutant .pdb

In [43]:
import os
from tqdm import tqdm
import subprocess

EvoEF2_toolpath = "./tools/EvoEF2/EvoEF2"
cleaned_pdb_path = "./datasets/mutSpike_ACE2/cleaned_PDBs/"
mut_pdb_path = "./datasets/mutSpike_ACE2/mut_PDBs/"
os.makedirs(mut_pdb_path, exist_ok=True)

for i in tqdm(range(len(info_eta))):
    wt = info_eta.loc[i, 'pdb']
    mut_site = info_eta.loc[i, 'mutation']
    name = wt.split('_')[0]
    mut_name = f"{wt}_{mut_site}"
    # print(f"{name} -> {mut_name}")

    with open(f"{mut_pdb_path}/{mut_name}.txt", 'w') as f:
        f.write(mut_site + ';')
    
    cmd = f"{EvoEF2_toolpath} --command=BuildMutant \
                --pdb={cleaned_pdb_path}/{name}.pdb --mutant_file={mut_pdb_path}/{mut_name}.txt && \
            mv {name}_Model_0001.pdb {mut_pdb_path}/{mut_name}.pdb && \
            rm {mut_pdb_path}/{mut_name}.txt"
    # print(cmd)
    subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL)

# besides buffer overflow, 1KBH would also fail because the illegal original .pdb and its' mutation sites.

  0%|          | 0/3809 [00:00<?, ?it/s]

100%|██████████| 3809/3809 [35:29<00:00,  1.79it/s]


In [44]:
info

Unnamed: 0,target,pdb,mutation,delta_bind,ddg
0,Wuhan-Hu-1,6M0J,TE333A,-0.02454,0.033684
1,Wuhan-Hu-1,6M0J,TE333C,-0.33095,0.454274
2,Wuhan-Hu-1,6M0J,TE333D,0.05865,-0.080505
3,Wuhan-Hu-1,6M0J,TE333E,0.05119,-0.070265
4,Wuhan-Hu-1,6M0J,TE333F,-0.37864,0.519735
...,...,...,...,...,...
19217,Eta,ETA,GE526S,-0.23774,0.326330
19218,Eta,ETA,GE526T,-0.15366,0.210919
19219,Eta,ETA,GE526V,-0.20412,0.280182
19220,Eta,ETA,GE526W,-0.28501,0.391215


### save processed dataset

In [48]:

import torch
os.makedirs("./data/mutSpike_ACE2", exist_ok=True)

info.to_csv("./data/mutSpike_ACE2/info.csv",index=False)
torch.save(info, "./data/mutSpike_ACE2/dataset.pt")


# prepare feature

In [49]:
import os

import numpy as np
import torch

from tqdm import tqdm

### load processed dataset

In [71]:
info = torch.load("./data/mutSpike_ACE2/dataset.pt")
info

Unnamed: 0,target,pdb,mutation,delta_bind,ddg
0,Wuhan-Hu-1,6M0J,TE333A,-0.02454,0.033684
1,Wuhan-Hu-1,6M0J,TE333C,-0.33095,0.454274
2,Wuhan-Hu-1,6M0J,TE333D,0.05865,-0.080505
3,Wuhan-Hu-1,6M0J,TE333E,0.05119,-0.070265
4,Wuhan-Hu-1,6M0J,TE333F,-0.37864,0.519735
...,...,...,...,...,...
19217,Eta,ETA,GE526S,-0.23774,0.326330
19218,Eta,ETA,GE526T,-0.15366,0.210919
19219,Eta,ETA,GE526V,-0.20412,0.280182
19220,Eta,ETA,GE526W,-0.28501,0.391215


### collect .pdb files

In [53]:
import subprocess

data_pdb_path = "./data/mutSpike_ACE2/pdb/"
os.makedirs(data_pdb_path, exist_ok=True)

os.system(f"cp ./datasets/mutSpike_ACE2/cleaned_PDBs/*.pdb {data_pdb_path}")
os.system(f"cp ./datasets/mutSpike_ACE2/mut_PDBs/*.pdb {data_pdb_path}")

print(f"Total .pdb file: {subprocess.check_output(f'ls {data_pdb_path} | wc -w', shell=True)}")


Total .pdb file: b'19227\n'


In [74]:
wt_name_list = [f"{pdb.split('_')[0]}" for pdb in set(info['pdb'])]
mut_name_list = [f"{mut}" for mut in set(info['pdb'] +'_'+ info['mutation'])]

### extract sequence & coordinate from .pdb file

In [62]:
from utils import *

from Bio.PDB import PDBParser

data_seq_path = "./data/mutSpike_ACE2/seq/"
data_coord_path = "./data/mutSpike_ACE2/coord/"
os.makedirs(data_seq_path, exist_ok=True)
os.makedirs(data_coord_path, exist_ok=True)

for name in tqdm(wt_name_list + mut_name_list):
    
    pdb_filepath = f"{data_pdb_path}/{name}.pdb"
    parser = PDBParser(QUIET=True)
    struct = parser.get_structure(name, pdb_filepath)
    res_list = get_clean_res_list(struct.get_residues(), verbose=False, ensure_ca_exist=True)
    # ensure all res contains N, CA, C and O
    res_list = [res for res in res_list if (('N' in res) and ('CA' in res) and ('C' in res) and ('O' in res))]

    # extract sequence
    seq = "".join([three_to_one.get(res.resname) for res in res_list])

    # extract coordinate
    coord = []
    for res in res_list:
        res_coord = []
        R_group = []
        for atom in res:
            if atom.get_name() in ['N', 'CA', 'C', 'O']:
                res_coord.append(atom.get_coord())
            else:
                R_group.append(atom.get_coord())

        if len(R_group) == 0:
            R_group.append(res['CA'].get_coord())
        R_group = np.array(R_group).mean(axis=0)
        res_coord.append(R_group)
        coord.append(res_coord)
    coord = np.array(coord)  # convert list directly to tensor would be rather slow, suggest use ndarray as transition
    coord = torch.tensor(coord, dtype=torch.float32)

    # save to file
    seq_to_file = f"{data_seq_path}/{name}.txt"
    coord_to_file = f"{data_coord_path}/{name}.pt"
    with open(seq_to_file, "w") as seq_file:
        seq_file.write(seq)
    torch.save(coord, coord_to_file)


100%|██████████| 3810/3810 [09:02<00:00,  7.02it/s]


### extract ProtTrans feature from sequence

In [63]:
import gc
from tqdm import tqdm
import torch
from transformers import T5Tokenizer, T5EncoderModel

In [64]:
data_seq_path = "./data/mutSpike_ACE2/seq/"
ProtTrans_toolpath = "./tools/Prot-T5-XL-U50/"
gpu = '0'

# Load the vocabulary and ProtT5-XL-UniRef50 Model
tokenizer = T5Tokenizer.from_pretrained(ProtTrans_toolpath, do_lower_case=False)
model = T5EncoderModel.from_pretrained(ProtTrans_toolpath)
gc.collect()

# Load the model into the GPU if avilabile and switch to inference mode
device = torch.device('cuda:' + gpu if torch.cuda.is_available() and gpu else 'cpu')
model = model.to(device)
model = model.eval()


Some weights of the model checkpoint at ./tools/Prot-T5-XL-U50/ were not used when initializing T5EncoderModel: ['decoder.block.6.layer.2.DenseReluDense.wi.weight', 'decoder.block.8.layer.0.SelfAttention.v.weight', 'decoder.block.16.layer.1.EncDecAttention.v.weight', 'decoder.block.22.layer.0.SelfAttention.o.weight', 'decoder.block.13.layer.1.EncDecAttention.q.weight', 'decoder.block.23.layer.2.DenseReluDense.wo.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.18.layer.0.layer_norm.weight', 'decoder.block.11.layer.1.EncDecAttention.q.weight', 'decoder.block.3.layer.2.DenseReluDense.wo.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decoder.block.20.layer.1.EncDecAttention.v.weight', 'decoder.block.14.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.1.layer_norm.weight', 'decoder.block.21.layer.0.SelfAttention.q.weight', 'decoder.block.10.layer.1.EncDecAttention.q.weight', 'decoder.block.17.layer.0.SelfAttention.o.weight', 'decoder.block.19.la

In [65]:
data_ProtTrans_raw_path = "./data/mutSpike_ACE2/ProtTrans_raw/"
os.makedirs(data_ProtTrans_raw_path, exist_ok=True)

for name in tqdm(wt_name_list + mut_name_list):
    with open(f"{data_seq_path}/{name}.txt") as seq_file:
        seq = seq_file.readline()
    batch_name_list = [name]
    batch_seq_list = [" ".join(list(seq))]
    # print(len(seq))
    # print(batch_name_list)
    # print(batch_seq_list)

    # Tokenize, encode sequences and load it into the GPU if possibile
    ids = tokenizer.batch_encode_plus(batch_seq_list, add_special_tokens=True, padding=True)
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    # Extracting sequences' features and load it into the CPU if needed
    with torch.no_grad():
        embedding = model(input_ids=input_ids,attention_mask=attention_mask)
    embedding = embedding.last_hidden_state.cpu()

    # Remove padding (\<pad>) and special tokens (\</s>) that is added by ProtT5-XL-UniRef50 model
    for seq_num in range(len(embedding)):
        seq_len = (attention_mask[seq_num] == 1).sum()
        seq_emb = embedding[seq_num][:seq_len-1]
        # print(f"truncate padding: {embedding[seq_num].shape} -> {seq_emb.shape}")
        ProtTrans_to_file = f"{data_ProtTrans_raw_path}/{batch_name_list[seq_num]}.npy"
        np.save(ProtTrans_to_file, seq_emb)


100%|██████████| 3810/3810 [10:33<00:00,  6.01it/s]


#### normalize raw ProtTrans

In [76]:
Max_protrans = []
Min_protrans = []
for name in tqdm(wt_name_list + mut_name_list):
    raw_protrans = np.load(f"{data_ProtTrans_raw_path}/{name}.npy")
    Max_protrans.append(np.max(raw_protrans, axis = 0))
    Min_protrans.append(np.min(raw_protrans, axis = 0))

Min_protrans = np.min(np.array(Min_protrans), axis = 0)
Max_protrans = np.max(np.array(Max_protrans), axis = 0)
print(Min_protrans)
print(Max_protrans)

np.save("./data/mutSpike_ACE2/Max_ProtTrans_repr.npy", Max_protrans)
np.save("./data/mutSpike_ACE2/Min_ProtTrans_repr.npy", Min_protrans)

100%|██████████| 19227/19227 [11:25<00:00, 28.04it/s]


[-0.7307469  -0.44854504 -0.61657476 ... -0.85375977 -0.701555
 -1.0113736 ]
[0.64754593 0.5798953  0.7366022  ... 0.6467273  0.691037   0.72902286]


In [79]:
Max_protrans = np.load("./data/mutSpike_ACE2/Max_ProtTrans_repr.npy")
Min_protrans = np.load("./data/mutSpike_ACE2/Min_ProtTrans_repr.npy")

data_ProtTrans_path = "./data/mutSpike_ACE2/ProtTrans/"
os.makedirs(data_ProtTrans_path, exist_ok=True)

for name in tqdm(wt_name_list + mut_name_list):
    raw_protrans = np.load(f"{data_ProtTrans_raw_path}/{name}.npy")
    protrans = (raw_protrans - Min_protrans) / (Max_protrans - Min_protrans)
    ProtTrans_to_file = f"{data_ProtTrans_path}/{name}.pt"
    torch.save(torch.tensor(protrans, dtype = torch.float32), ProtTrans_to_file)


  0%|          | 0/19227 [00:00<?, ?it/s]

100%|██████████| 19227/19227 [10:10<00:00, 31.48it/s]  


### extract DSSP feature from .pdb file

#### correct format of mut.pdb: col of Occupancy(55 - 60) should be "{:.2f}"

In [66]:
data_pdb_path = "./data/mutSpike_ACE2/pdb/"

for name in tqdm(mut_name_list):
    pdb_filepath = f"{data_pdb_path}/{name}.pdb"

    with open(pdb_filepath, "r") as f:
        lines = f.readlines()

    for i in range(len(lines)):
        if lines[i].split()[0] == "REMARK":
            continue
        lines[i] = lines[i][:57] + '.00' + lines[i][60:]

    with open(pdb_filepath, "w") as f:
        f.writelines(lines)


100%|██████████| 3809/3809 [01:23<00:00, 45.87it/s]


In [67]:
from utils import *

data_pdb_path = "./data/mutSpike_ACE2/pdb/"
data_seq_path = "./data/mutSpike_ACE2/seq/"
dssp_toolpath = "./tools/mkdssp"

data_DSSP_path = "./data/mutSpike_ACE2/DSSP"
os.makedirs(data_DSSP_path, exist_ok=True)

for name in tqdm(wt_name_list + mut_name_list):
    pdb_filepath = f"{data_pdb_path}/{name}.pdb"
    with open(f"{data_seq_path}/{name}.txt") as seq_file:
        seq = seq_file.readline()

    DSSP_to_file = f"{data_DSSP_path}/{name}.dssp"
    dssp_cmd = f"{dssp_toolpath} -i {pdb_filepath} -o {DSSP_to_file}"
    os.system(dssp_cmd)

    try:
        dssp_seq, dssp_matrix = process_dssp(DSSP_to_file)
        # dssp_seq: likely equal to original sequence
        # dssp_matrix (list<ndarray>): list of (1, 9) vector, length of dssp_seq
        if dssp_seq != seq:
            dssp_matrix = match_dssp(dssp_seq, dssp_matrix, seq)
        
        DSSP_to_file = f"{data_DSSP_path}/{name}.pt"
        torch.save(torch.tensor(np.array(dssp_matrix), dtype = torch.float32), DSSP_to_file)
        # shape(AA_len, 9)
        # os.system("rm {DSSP_to_file}"")
    except:
        print(f"Wrong entry prompt: $ {dssp_cmd}")
        continue


100%|██████████| 3810/3810 [06:51<00:00,  9.25it/s]


# validate wt & mut feature match

In [80]:
from utils import match_wt2mut

match_wt2mut("6M0J", "6M0J_AE344H", "./data/mutSpike_ACE2/")
match_wt2mut("7EKF", "7EKF_SB459H", "./data/mutSpike_ACE2/")
match_wt2mut("7EKG", "7EKG_SB349T", "./data/mutSpike_ACE2/")
match_wt2mut("7WBQ", "7WBQ_YB508M", "./data/mutSpike_ACE2/")


verify sequence whether match: 
wt: 791 - mut: 791
609: A -> H
verify coordinate whether match: 
wt: torch.Size([791, 5, 3]) - mut: torch.Size([791, 5, 3])
verify ProtTrans feature whether match: 
wt: torch.Size([791, 1024]) - mut: torch.Size([791, 1024])
verify DSSP feature whether match: 
wt: torch.Size([791, 9]) - mut: torch.Size([791, 9])
verify sequence whether match: 
wt: 791 - mut: 791
723: S -> H
verify coordinate whether match: 
wt: torch.Size([791, 5, 3]) - mut: torch.Size([791, 5, 3])
verify ProtTrans feature whether match: 
wt: torch.Size([791, 1024]) - mut: torch.Size([791, 1024])
verify DSSP feature whether match: 
wt: torch.Size([791, 9]) - mut: torch.Size([791, 9])
verify sequence whether match: 
wt: 791 - mut: 791
613: S -> T
verify coordinate whether match: 
wt: torch.Size([791, 5, 3]) - mut: torch.Size([791, 5, 3])
verify ProtTrans feature whether match: 
wt: torch.Size([791, 1024]) - mut: torch.Size([791, 1024])
verify DSSP feature whether match: 
wt: torch.Size([79

# prepare dataset

In [10]:
import random
import numpy as np
import torch

info = torch.load("./data/mutSpike_ACE2/dataset.pt")
info


Unnamed: 0,target,pdb,mutation,delta_bind,ddg
0,Wuhan-Hu-1,6M0J,TE333A,-0.02454,0.033684
1,Wuhan-Hu-1,6M0J,TE333C,-0.33095,0.454274
2,Wuhan-Hu-1,6M0J,TE333D,0.05865,-0.080505
3,Wuhan-Hu-1,6M0J,TE333E,0.05119,-0.070265
4,Wuhan-Hu-1,6M0J,TE333F,-0.37864,0.519735
...,...,...,...,...,...
19217,Eta,ETA,GE526S,-0.23774,0.326330
19218,Eta,ETA,GE526T,-0.15366,0.210919
19219,Eta,ETA,GE526V,-0.20412,0.280182
19220,Eta,ETA,GE526W,-0.28501,0.391215


In [11]:
# change Dataframe column name
info.rename(columns={'target': 'variant'}, inplace=True)
info.rename(columns={'pdb': 'wt_name'}, inplace=True)
info.rename(columns={'mutation': 'mut_name'}, inplace=True)
info["mut_name"] = info["wt_name"] + "_" + info["mut_name"]

info["target"] = info["ddg"]
# info.drop(columns=['variant', 'delta_bind'], inplace=True)
torch.save(info, "./data/mutSpike_ACE2/dataset_unshuffled.pt")
info

Unnamed: 0,variant,wt_name,mut_name,delta_bind,ddg,target
0,Wuhan-Hu-1,6M0J,6M0J_TE333A,-0.02454,0.033684,0.033684
1,Wuhan-Hu-1,6M0J,6M0J_TE333C,-0.33095,0.454274,0.454274
2,Wuhan-Hu-1,6M0J,6M0J_TE333D,0.05865,-0.080505,-0.080505
3,Wuhan-Hu-1,6M0J,6M0J_TE333E,0.05119,-0.070265,-0.070265
4,Wuhan-Hu-1,6M0J,6M0J_TE333F,-0.37864,0.519735,0.519735
...,...,...,...,...,...,...
19217,Eta,ETA,ETA_GE526S,-0.23774,0.326330,0.326330
19218,Eta,ETA,ETA_GE526T,-0.15366,0.210919,0.210919
19219,Eta,ETA,ETA_GE526V,-0.20412,0.280182,0.280182
19220,Eta,ETA,ETA_GE526W,-0.28501,0.391215,0.391215


In [8]:
# shuffle
info_shuffled = info.sample(frac=1, random_state=42).reset_index(drop=True)

# split folds
folds_num = 10
index = list(range(len(info_shuffled)))
index_k_split = np.array_split(index, folds_num)
index_k_split = [np.full(len(index_k_split[i]), i) for i in range(folds_num)]
index_k_split = np.concatenate(index_k_split)
info_shuffled["split"] = index_k_split

torch.save(info_shuffled, "./data/mutSpike_ACE2/dataset_processed.pt")
info_shuffled

Unnamed: 0,variant,wt_name,mut_name,delta_bind,ddg,target,split
0,Wuhan-Hu-1,6M0J,6M0J_LE513D,-3.60047,4.942134,-3.60047,0
1,Beta,7EKG,7EKG_RB403F,-3.85938,5.297523,-3.85938,0
2,Alpha,7EKF,7EKF_NB460G,-0.22761,0.312426,-0.22761,0
3,Delta,7WBQ,7WBQ_YB473Q,-2.06549,2.835165,-2.06549,0
4,Delta,7WBQ,7WBQ_EB471R,-0.05830,0.080025,-0.05830,0
...,...,...,...,...,...,...,...
19217,Beta,7EKG,7EKG_SB514M,-0.17101,0.234734,-0.17101,9
19218,Delta,7WBQ,7WBQ_NB354M,-0.13728,0.188435,-0.13728,9
19219,Alpha,7EKF,7EKF_AB411K,-0.51635,0.708760,-0.51635,9
19220,Wuhan-Hu-1,6M0J,6M0J_TE376C,-0.34863,0.478542,-0.34863,9


In [4]:

info_wt = info[info["wt_name"] == "6M0J"]
info_alpha = info[info["wt_name"] == "7EKF"]
info_beta = info[info["wt_name"] == "7EKG"]
info_delta = info[info["wt_name"] == "7WBQ"]
info_eta = info[info["wt_name"] == "ETA"]

info_wt_shuffled = info_wt.sample(frac=1, random_state=42).reset_index(drop=True)
info_alpha_shuffled = info_alpha.sample(frac=1, random_state=42).reset_index(drop=True)
info_beta_shuffled = info_beta.sample(frac=1, random_state=42).reset_index(drop=True)
info_delta_shuffled = info_delta.sample(frac=1, random_state=42).reset_index(drop=True)
info_eta_shuffled = info_eta.sample(frac=1, random_state=42).reset_index(drop=True)

info_eta_shuffled

Unnamed: 0,variant,wt_name,mut_name,delta_bind,ddg,target
0,Eta,ETA,ETA_TE385F,-0.06822,0.093641,0.093641
1,Eta,ETA,ETA_LE513Q,-2.62926,3.609016,3.609016
2,Eta,ETA,ETA_AE520T,-0.04349,0.059696,0.059696
3,Eta,ETA,ETA_CE379L,-1.92026,2.635817,2.635817
4,Eta,ETA,ETA_QE498E,-1.30736,1.794529,1.794529
...,...,...,...,...,...,...
3804,Eta,ETA,ETA_CE391I,-0.24490,0.336158,0.336158
3805,Eta,ETA,ETA_SE399N,-1.04003,1.427582,1.427582
3806,Eta,ETA,ETA_FE377P,-1.88035,2.581036,2.581036
3807,Eta,ETA,ETA_VE511H,-2.92224,4.011171,4.011171


In [5]:
torch.save(info_wt_shuffled, "./data/mutSpike_ACE2/dataset_wt.pt")
torch.save(info_alpha_shuffled, "./data/mutSpike_ACE2/dataset_alpha.pt")
torch.save(info_beta_shuffled, "./data/mutSpike_ACE2/dataset_beta.pt")
torch.save(info_delta_shuffled, "./data/mutSpike_ACE2/dataset_delta.pt")
torch.save(info_eta_shuffled, "./data/mutSpike_ACE2/dataset_eta.pt")
