In [305]:
import os
import subprocess

import pandas as pd
import numpy as np
import torch

from tqdm import tqdm


# clean original skempi_v2: csv & pdb

In [306]:
import pandas as pd

raw_df = pd.read_csv("./datasets/skempi_v2/skempi_v2.csv", sep=';')
raw_df = raw_df[['#Pdb', 'Mutation(s)_cleaned', 'Affinity_wt_parsed', 'Affinity_mut_parsed', 'Temperature']]
col_mapping = {'#Pdb': 'pdb',
               'Mutation(s)_cleaned': 'mutation', 
               'Affinity_wt_parsed': 'affinity_wt', 
               'Affinity_mut_parsed': 'affinity_mut',
               "Temperature": "temp"}
raw_df.rename(columns=col_mapping, inplace=True)
raw_df["pdb"] = raw_df["pdb"].apply(lambda s: s.split('_')[0])
raw_df

Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp
0,1CSE,LI38G,1.120000e-12,5.260000e-11,294
1,1CSE,LI38S,1.120000e-12,8.330000e-12,294
2,1CSE,LI38P,1.120000e-12,1.020000e-07,294
3,1CSE,LI38I,1.120000e-12,1.720000e-10,294
4,1CSE,LI38D,1.120000e-12,1.920000e-09,294
...,...,...,...,...,...
7080,3QIB,KP8R,5.500000e-06,2.400000e-04,298
7081,3QIB,TP11A,5.500000e-06,1.100000e-03,298
7082,3QIB,TP11S,5.500000e-06,3.380000e-05,298
7083,3QIB,TP11N,5.500000e-06,4.340000e-05,298


### compute ddg & daffinity

In [309]:

import numpy as np
df = raw_df
df['temp'] = df['temp'].map(lambda temp: float(temp[:3]) if isinstance(temp, str) else temp)

R = float(8.314/4184)
df['ddg'] = R * df['temp'] * (np.log(df['affinity_mut']) - np.log(df['affinity_wt']))

df['daffinity'] = df['affinity_mut'] - df['affinity_wt']

# df = df[['pdb', 'mutation', 'ddg', 'daffinity', 'temp']]
df

Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
0,1CSE,LI38G,1.120000e-12,5.260000e-11,294.0,2.248833,5.148000e-11
1,1CSE,LI38S,1.120000e-12,8.330000e-12,294.0,1.172229,7.210000e-12
2,1CSE,LI38P,1.120000e-12,1.020000e-07,294.0,6.671276,1.019989e-07
3,1CSE,LI38I,1.120000e-12,1.720000e-10,294.0,2.940988,1.708800e-10
4,1CSE,LI38D,1.120000e-12,1.920000e-09,294.0,4.350434,1.918880e-09
...,...,...,...,...,...,...,...
7080,3QIB,KP8R,5.500000e-06,2.400000e-04,298.0,2.235909,2.345000e-04
7081,3QIB,TP11A,5.500000e-06,1.100000e-03,298.0,3.137419,1.094500e-03
7082,3QIB,TP11S,5.500000e-06,3.380000e-05,298.0,1.075181,2.830000e-05
7083,3QIB,TP11N,5.500000e-06,4.340000e-05,298.0,1.223219,3.790000e-05


### remove repeated wt-mut data

In [310]:

dataset = {}
for i in range(len(df)):
    wt = df.loc[i, 'pdb']
    mut = df.loc[i, 'mutation']
    if (wt, mut) not in dataset:
        dataset[(wt, mut)] = []
    dataset[(wt, mut)].append(i)
print(f"Num of wt-mut pair: {len(dataset)}")

drop_list = []
for wt_mut_pair in dataset:
    repeat_idxs = dataset[wt_mut_pair]
    if len(repeat_idxs) == 1:
        continue
    idx_298 = []
    for idx in repeat_idxs:
        if df.loc[idx, 'temp'] == 298:
            idx_298.append(idx)
    if len(idx_298) == 0:
        idx_298 = repeat_idxs
    
    ddg = np.mean([df.loc[idx, 'ddg'] for idx in idx_298])
    mean_daffinity = np.mean([df.loc[idx, 'daffinity'] for idx in idx_298])
    df.loc[repeat_idxs[0], 'ddg'] = ddg
    df.loc[repeat_idxs[0], 'daffinity'] = mean_daffinity
    
    drop_list += repeat_idxs[1:]

print(f"Num of redundant data: {len(drop_list)}")
df = df.drop(index=drop_list)
df.reset_index(drop=True, inplace=True)
# df = df[['pdb', 'mutation', 'ddg', 'daffinity']]
df

Num of wt-mut pair: 6185
Num of redundant data: 900


Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
0,1CSE,LI38G,1.120000e-12,5.260000e-11,294.0,2.248833,5.148000e-11
1,1CSE,LI38S,1.120000e-12,8.330000e-12,294.0,1.172229,7.210000e-12
2,1CSE,LI38P,1.120000e-12,1.020000e-07,294.0,6.671276,1.019989e-07
3,1CSE,LI38I,1.120000e-12,1.720000e-10,294.0,2.940988,1.708800e-10
4,1CSE,LI38D,1.120000e-12,1.920000e-09,294.0,4.350434,1.918880e-09
...,...,...,...,...,...,...,...
6180,3QIB,KP8R,5.500000e-06,2.400000e-04,298.0,2.235909,2.345000e-04
6181,3QIB,TP11A,5.500000e-06,1.100000e-03,298.0,3.137419,1.094500e-03
6182,3QIB,TP11S,5.500000e-06,3.380000e-05,298.0,1.075181,2.830000e-05
6183,3QIB,TP11N,5.500000e-06,4.340000e-05,298.0,1.223219,3.790000e-05


### clean original .pdb file

In [250]:
import os
from tqdm import tqdm
import glob
from Bio.PDB import PDBParser, PDBIO, PPBuilder
from Bio.PDB import Structure, Model, Chain, Residue, Atom

original_pdb_path = "./datasets/skempi_v2/PDBs/"
cleaned_pdb_path = "./datasets/skempi_v2/cleaned_PDBs/"
os.makedirs(cleaned_pdb_path, exist_ok=True)

illegal_original_pdb = {}
original_pdb_files = glob.glob(f"{original_pdb_path}/*.pdb")
for pdb_filepath in tqdm(original_pdb_files):
    pdb_name = os.path.basename(pdb_filepath).split('.')[0]

    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_name, pdb_filepath)
    new_structure = Structure.Structure(pdb_name)

    model = structure[0]  # only keep one model 
    new_model = Model.Model(model.id)

    for chain in model:
        new_chain = Chain.Chain(chain.id)

        chain = [res for res in chain.get_residues() if res.id[0] == ' ']  # exclude HOH...
        for residue in chain:
            if set(['N', 'CA', 'C', 'O']).issubset(set(residue.child_dict.keys())):
                new_chain.add(residue)
            else:
                if not illegal_original_pdb.get(pdb_filepath):
                    illegal_original_pdb[pdb_filepath] = []
                illegal_original_pdb[pdb_filepath].append(residue.full_id)

                if not (residue.id[1] == 1 or residue.id[1] == len(chain)):
                    print(f"WARNING: delete in the middle {residue.full_id} of total len {len(chain)}")

        if len(new_chain) > 0:
            new_model.add(new_chain)
    if len(new_model) > 0:
        new_structure.add(new_model)

    io = PDBIO()
    io.set_structure(new_structure)
    io.save(f"{cleaned_pdb_path}/{pdb_name}.pdb")

len(illegal_original_pdb)

  0%|          | 0/345 [00:00<?, ?it/s]

  9%|▉         | 32/345 [00:02<00:25, 12.41it/s]



 32%|███▏      | 111/345 [00:08<00:15, 14.66it/s]



 51%|█████     | 176/345 [00:14<00:13, 12.27it/s]



 63%|██████▎   | 217/345 [00:19<00:13,  9.53it/s]



 78%|███████▊  | 270/345 [00:23<00:04, 16.36it/s]



 82%|████████▏ | 284/345 [00:24<00:05, 10.79it/s]



 94%|█████████▍| 324/345 [00:28<00:01, 14.66it/s]



100%|██████████| 345/345 [00:30<00:00, 11.45it/s]


24

### EvoEF2 generate mutant .pdb

In [9]:
import os
from tqdm import tqdm
import subprocess

EvoEF2_toolpath = "./tools/EvoEF2/EvoEF2"
cleaned_pdb_path = "./datasets/skempi_v2/cleaned_PDBs/"
mut_pdb_path = "./datasets/skempi_v2/mut_PDBs/"
os.makedirs(mut_pdb_path, exist_ok=True)

for i in tqdm(range(len(df))):
    wt = df.loc[i, 'pdb']
    mut_site = df.loc[i, 'mutation']
    name = wt.split('_')[0]
    mut_name = f"{wt}_{mut_site}"
    # print(f"{name} -> {mut_name}")

    with open(f"{mut_pdb_path}/{mut_name}.txt", 'w') as f:
        f.write(mut_site + ';')
    
    cmd = f"{EvoEF2_toolpath} --command=BuildMutant \
                --pdb={cleaned_pdb_path}/{name}.pdb --mutant_file={mut_pdb_path}/{mut_name}.txt && \
            mv {name}_Model_0001.pdb {mut_pdb_path}/{mut_name}.pdb && \
            rm {mut_pdb_path}/{mut_name}.txt"
    # print(cmd)
    subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL)

# besides buffer overflow, 1KBH would also fail because the illegal original .pdb and its' mutation sites.

  7%|▋         | 403/6193 [04:03<54:48,  1.76it/s]  *** buffer overflow detected ***: terminated
Aborted (core dumped)
  7%|▋         | 404/6193 [04:03<45:44,  2.11it/s]*** buffer overflow detected ***: terminated
Aborted (core dumped)
  7%|▋         | 405/6193 [04:03<37:11,  2.59it/s]*** buffer overflow detected ***: terminated
Aborted (core dumped)
  7%|▋         | 407/6193 [04:04<38:28,  2.51it/s]*** buffer overflow detected ***: terminated
Aborted (core dumped)
  7%|▋         | 408/6193 [04:04<31:57,  3.02it/s]*** buffer overflow detected ***: terminated
Aborted (core dumped)
  7%|▋         | 409/6193 [04:04<27:15,  3.54it/s]*** buffer overflow detected ***: terminated
Aborted (core dumped)
  7%|▋         | 411/6193 [04:05<38:47,  2.48it/s]*** buffer overflow detected ***: terminated
Aborted (core dumped)
  7%|▋         | 412/6193 [04:05<32:36,  2.95it/s]*** buffer overflow detected ***: terminated
Aborted (core dumped)
  7%|▋         | 413/6193 [04:05<27:57,  3.45it/s]*** buffer o

### remove data that EvoEF2 failed to generate

In [311]:

import os

mut_pdb_path = "./datasets/skempi_v2/mut_PDBs/"

drop_list = []
for idx in range(len(df)):
    wt = df.loc[idx, 'pdb']
    mut = df.loc[idx, 'mutation']
    mut_pdb_filename = wt + '_' + mut + '.pdb'
    if not os.path.exists(mut_pdb_path + mut_pdb_filename):
        drop_list.append(idx)
        
print(f"Num of EvoEF2-fail-to-generate data: {len(drop_list)}") 
df = df.drop(index=drop_list)
df.reset_index(drop=True, inplace=True)
df

Num of EvoEF2-fail-to-generate data: 166


Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
0,1CSE,LI38G,1.120000e-12,5.260000e-11,294.0,2.248833,5.148000e-11
1,1CSE,LI38S,1.120000e-12,8.330000e-12,294.0,1.172229,7.210000e-12
2,1CSE,LI38P,1.120000e-12,1.020000e-07,294.0,6.671276,1.019989e-07
3,1CSE,LI38I,1.120000e-12,1.720000e-10,294.0,2.940988,1.708800e-10
4,1CSE,LI38D,1.120000e-12,1.920000e-09,294.0,4.350434,1.918880e-09
...,...,...,...,...,...,...,...
6014,3QIB,KP8R,5.500000e-06,2.400000e-04,298.0,2.235909,2.345000e-04
6015,3QIB,TP11A,5.500000e-06,1.100000e-03,298.0,3.137419,1.094500e-03
6016,3QIB,TP11S,5.500000e-06,3.380000e-05,298.0,1.075181,2.830000e-05
6017,3QIB,TP11N,5.500000e-06,4.340000e-05,298.0,1.223219,3.790000e-05


### remove NaN data

In [312]:

df = df.dropna()
df.reset_index(drop=True, inplace=True)
df

Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
0,1CSE,LI38G,1.120000e-12,5.260000e-11,294.0,2.248833,5.148000e-11
1,1CSE,LI38S,1.120000e-12,8.330000e-12,294.0,1.172229,7.210000e-12
2,1CSE,LI38P,1.120000e-12,1.020000e-07,294.0,6.671276,1.019989e-07
3,1CSE,LI38I,1.120000e-12,1.720000e-10,294.0,2.940988,1.708800e-10
4,1CSE,LI38D,1.120000e-12,1.920000e-09,294.0,4.350434,1.918880e-09
...,...,...,...,...,...,...,...
5748,3QIB,KP8R,5.500000e-06,2.400000e-04,298.0,2.235909,2.345000e-04
5749,3QIB,TP11A,5.500000e-06,1.100000e-03,298.0,3.137419,1.094500e-03
5750,3QIB,TP11S,5.500000e-06,3.380000e-05,298.0,1.075181,2.830000e-05
5751,3QIB,TP11N,5.500000e-06,4.340000e-05,298.0,1.223219,3.790000e-05


In [313]:
print(f"daffinity range: [{df['daffinity'].min()}, {df['daffinity'].max()}]")
print(f"ddg range: [{df['ddg'].min()}, {df['ddg'].max()}]")


daffinity range: [-0.000111, 0.04049885]
ddg range: [-12.221723141911879, 12.221723141911879]


In [None]:
df.iloc[[df['daffinity'].idxmin()]]

In [None]:
df.iloc[[df['daffinity'].idxmax()]]

In [254]:
df.iloc[[df['ddg'].idxmin()]]

Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
2676,3BTG,GI13K,6.7e-05,5.88e-14,295.0,-12.221723,-6.7e-05


In [255]:
df.iloc[[df['ddg'].idxmax()]]

Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
1178,2FTL,KI15G,5.88e-14,6.7e-05,295.0,12.221723,6.7e-05


In [256]:
df.iloc[[df['ddg'].abs().idxmin()]]

Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
52,2SIC,MI67K,1.8e-11,1.8e-11,298.0,0.0,0.0


### save processed dataset

In [314]:

import torch

df.to_csv("./data/skempi_v2/info.csv",index=False)
torch.save(df, "./data/skempi_v2/dataset.pt")


# prepare feature

In [315]:
import os

import numpy as np
import torch

from tqdm import tqdm

### load processed dataset

In [316]:
info = torch.load("./data/skempi_v2/dataset.pt")
info


Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
0,1CSE,LI38G,1.120000e-12,5.260000e-11,294.0,2.248833,5.148000e-11
1,1CSE,LI38S,1.120000e-12,8.330000e-12,294.0,1.172229,7.210000e-12
2,1CSE,LI38P,1.120000e-12,1.020000e-07,294.0,6.671276,1.019989e-07
3,1CSE,LI38I,1.120000e-12,1.720000e-10,294.0,2.940988,1.708800e-10
4,1CSE,LI38D,1.120000e-12,1.920000e-09,294.0,4.350434,1.918880e-09
...,...,...,...,...,...,...,...
5748,3QIB,KP8R,5.500000e-06,2.400000e-04,298.0,2.235909,2.345000e-04
5749,3QIB,TP11A,5.500000e-06,1.100000e-03,298.0,3.137419,1.094500e-03
5750,3QIB,TP11S,5.500000e-06,3.380000e-05,298.0,1.075181,2.830000e-05
5751,3QIB,TP11N,5.500000e-06,4.340000e-05,298.0,1.223219,3.790000e-05


### collect .pdb file

In [317]:

import subprocess

data_pdb_path = "./data/skempi_v2/pdb/"
os.makedirs(data_pdb_path, exist_ok=True)

os.system(f"cp datasets/skempi_v2/cleaned_PDBs/*.pdb {data_pdb_path}")
os.system(f"cp datasets/skempi_v2/mut_PDBs/*.pdb {data_pdb_path}")

print(f"Total .pdb file: {subprocess.check_output(f'ls {data_pdb_path} | wc -w', shell=True)}")


Total .pdb file: b'6364\n'


In [318]:

wt_name_list = [f"{pdb.split('_')[0]}" for pdb in set(info['pdb'])]
mut_name_list = [f"{mut}" for mut in set(info['pdb'] +'_'+ info['mutation'])]
duplicate_names = {name for name in wt_name_list if wt_name_list.count(name) > 1}
print(f"duplicate pdb: {duplicate_names}")

duplicate pdb: set()


### extract sequence & coordinate from .pdb file

In [18]:
from utils import *

from Bio.PDB import PDBParser

data_seq_path = "./data/skempi_v2/seq/"
data_coord_path = "./data/skempi_v2/coord/"
os.makedirs(data_seq_path, exist_ok=True)
os.makedirs(data_coord_path, exist_ok=True)

for name in tqdm(wt_name_list + mut_name_list):
    
    pdb_filepath = f"{data_pdb_path}/{name}.pdb"
    parser = PDBParser(QUIET=True)
    struct = parser.get_structure(name, pdb_filepath)
    res_list = get_clean_res_list(struct.get_residues(), verbose=False, ensure_ca_exist=True)
    # ensure all res contains N, CA, C and O
    res_list = [res for res in res_list if (('N' in res) and ('CA' in res) and ('C' in res) and ('O' in res))]

    # extract sequence
    seq = "".join([three_to_one.get(res.resname) for res in res_list])

    # extract coordinate
    coord = []
    for res in res_list:
        res_coord = []
        R_group = []
        for atom in res:
            if atom.get_name() in ['N', 'CA', 'C', 'O']:
                res_coord.append(atom.get_coord())
            else:
                R_group.append(atom.get_coord())

        if len(R_group) == 0:
            R_group.append(res['CA'].get_coord())
        R_group = np.array(R_group).mean(axis=0)
        res_coord.append(R_group)
        coord.append(res_coord)
    coord = np.array(coord)  # convert list directly to tensor would be rather slow, suggest use ndarray as transition
    coord = torch.tensor(coord, dtype=torch.float32)

    # save to file
    seq_to_file = f"{data_seq_path}/{name}.txt"
    coord_to_file = f"{data_coord_path}/{name}.pt"
    with open(seq_to_file, "w") as seq_file:
        seq_file.write(seq)
    torch.save(coord, coord_to_file)


100%|██████████| 6191/6191 [09:12<00:00, 11.21it/s]


### extract ProtTrans feature from sequence

In [7]:
import gc
from tqdm import tqdm
import torch
from transformers import T5Tokenizer, T5EncoderModel

In [8]:
data_seq_path = "./data/skempi_v2/seq/"
ProtTrans_toolpath = "./tools/Prot-T5-XL-U50/"
gpu = '3'

# Load the vocabulary and ProtT5-XL-UniRef50 Model
tokenizer = T5Tokenizer.from_pretrained(ProtTrans_toolpath, do_lower_case=False)
model = T5EncoderModel.from_pretrained(ProtTrans_toolpath)
gc.collect()

# Load the model into the GPU if avilabile and switch to inference mode
device = torch.device('cuda:' + gpu if torch.cuda.is_available() and gpu else 'cpu')
model = model.to(device)
model = model.eval()


Some weights of the model checkpoint at ./tools/Prot-T5-XL-U50/ were not used when initializing T5EncoderModel: ['decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.14.layer.1.layer_norm.weight', 'decoder.block.22.layer.0.layer_norm.weight', 'decoder.block.20.layer.0.SelfAttention.v.weight', 'decoder.block.9.layer.1.EncDecAttention.v.weight', 'decoder.block.18.layer.0.SelfAttention.q.weight', 'decoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.23.layer.0.SelfAttention.k.weight', 'decoder.block.11.layer.2.DenseReluDense.wo.weight', 'decoder.block.19.layer.1.EncDecAttention.o.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.16.layer.0.SelfAttention.q.weight', 'decoder.block.14.layer.0.layer_norm.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'decoder.block.12.layer.1.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttentio

In [9]:
data_ProtTrans_raw_path = "./data/skempi_v2/ProtTrans_raw/"
os.makedirs(data_ProtTrans_raw_path, exist_ok=True)

for name in tqdm(wt_name_list + mut_name_list):
    with open(f"{data_seq_path}/{name}.txt") as seq_file:
        seq = seq_file.readline()
    batch_name_list = [name]
    batch_seq_list = [" ".join(list(seq))]
    # print(len(seq))
    # print(batch_name_list)
    # print(batch_seq_list)

    # Tokenize, encode sequences and load it into the GPU if possibile
    ids = tokenizer.batch_encode_plus(batch_seq_list, add_special_tokens=True, padding=True)
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    # Extracting sequences' features and load it into the CPU if needed
    with torch.no_grad():
        embedding = model(input_ids=input_ids,attention_mask=attention_mask)
    embedding = embedding.last_hidden_state.cpu()

    # Remove padding (\<pad>) and special tokens (\</s>) that is added by ProtT5-XL-UniRef50 model
    for seq_num in range(len(embedding)):
        seq_len = (attention_mask[seq_num] == 1).sum()
        seq_emb = embedding[seq_num][:seq_len-1]
        # print(f"truncate padding: {embedding[seq_num].shape} -> {seq_emb.shape}")
        ProtTrans_to_file = f"{data_ProtTrans_raw_path}/{batch_name_list[seq_num]}.npy"
        np.save(ProtTrans_to_file, seq_emb)


100%|██████████| 6191/6191 [28:21<00:00,  3.64it/s]  


#### normalize raw ProtTrans

In [10]:
Max_protrans = []
Min_protrans = []
for name in tqdm(wt_name_list + mut_name_list):
    raw_protrans = np.load(f"{data_ProtTrans_raw_path}/{name}.npy")
    Max_protrans.append(np.max(raw_protrans, axis = 0))
    Min_protrans.append(np.min(raw_protrans, axis = 0))

Min_protrans = np.min(np.array(Min_protrans), axis = 0)
Max_protrans = np.max(np.array(Max_protrans), axis = 0)
print(Min_protrans)
print(Max_protrans)

np.save("./data/skempi_v2/Max_ProtTrans_repr.npy", Max_protrans)
np.save("./data/skempi_v2/Min_ProtTrans_repr.npy", Min_protrans)

100%|██████████| 6191/6191 [05:15<00:00, 19.62it/s]  


[-1.1695224  -0.67864954 -0.9903961  ... -0.9388567  -0.9177465
 -0.93719786]
[1.1407886 0.7556648 1.0245636 ... 0.9903357 1.1121918 0.9022764]


In [11]:
Max_protrans = np.load("./data/skempi_v2/Max_ProtTrans_repr.npy")
Min_protrans = np.load("./data/skempi_v2/Min_ProtTrans_repr.npy")

data_ProtTrans_path = "./data/skempi_v2/ProtTrans/"
os.makedirs(data_ProtTrans_path, exist_ok=True)

for name in tqdm(wt_name_list + mut_name_list):
    raw_protrans = np.load(f"{data_ProtTrans_raw_path}/{name}.npy")
    protrans = (raw_protrans - Min_protrans) / (Max_protrans - Min_protrans)
    ProtTrans_to_file = f"{data_ProtTrans_path}/{name}.pt"
    torch.save(torch.tensor(protrans, dtype = torch.float32), ProtTrans_to_file)


100%|██████████| 6191/6191 [07:43<00:00, 13.35it/s]  


### extract DSSP feature from .pdb file

#### correct format of mut.pdb: col of Occupancy(55 - 60) should be "{:.2f}"

In [13]:
data_pdb_path = "./data/skempi_v2/pdb/"

for name in tqdm(mut_name_list):
    pdb_filepath = f"{data_pdb_path}/{name}.pdb"

    with open(pdb_filepath, "r") as f:
        lines = f.readlines()

    for i in range(len(lines)):
        if lines[i].split()[0] == "REMARK":
            continue
        lines[i] = lines[i][:57] + '.00' + lines[i][60:]

    with open(pdb_filepath, "w") as f:
        f.writelines(lines)


100%|██████████| 5847/5847 [06:25<00:00, 15.17it/s]


In [14]:
from utils import *

data_pdb_path = "./data/skempi_v2/pdb/"
data_seq_path = "./data/skempi_v2/seq/"
dssp_toolpath = "./tools/mkdssp"

data_DSSP_path = "./data/skempi_v2/DSSP"
os.makedirs(data_DSSP_path, exist_ok=True)

for name in tqdm(wt_name_list + mut_name_list):
    pdb_filepath = f"{data_pdb_path}/{name}.pdb"
    with open(f"{data_seq_path}/{name}.txt") as seq_file:
        seq = seq_file.readline()

    DSSP_to_file = f"{data_DSSP_path}/{name}.dssp"
    dssp_cmd = f"{dssp_toolpath} -i {pdb_filepath} -o {DSSP_to_file}"
    os.system(dssp_cmd)

    try:
        dssp_seq, dssp_matrix = process_dssp(DSSP_to_file)
        # dssp_seq: likely equal to original sequence
        # dssp_matrix (list<ndarray>): list of (1, 9) vector, length of dssp_seq
        if dssp_seq != seq:
            dssp_matrix = match_dssp(dssp_seq, dssp_matrix, seq)
        
        DSSP_to_file = f"{data_DSSP_path}/{name}.pt"
        torch.save(torch.tensor(np.array(dssp_matrix), dtype = torch.float32), DSSP_to_file)
        # shape(AA_len, 9)
        # os.system("rm {DSSP_to_file}"")
    except:
        print(f"Wrong entry prompt: $ {dssp_cmd}")
        continue


100%|██████████| 6191/6191 [16:33<00:00,  6.23it/s]  


# validate wt & mut feature match

In [319]:
from utils import match_wt2mut
# 1Y3B @282
# 1Y4A @275
# 2FTL @223

match_wt2mut("1Y3B", "1Y3B_SI41E", "./data/skempi_v2/")
match_wt2mut("1Y4A", "1Y4A_SI40E,RI39M", "./data/skempi_v2/")
match_wt2mut("2FTL", "2FTL_PI13A", "./data/skempi_v2/")

verify sequence whether match: 
wt: 344 - mut: 344
321: S -> E
verify coordinate whether match: 
wt: torch.Size([344, 5, 3]) - mut: torch.Size([344, 5, 3])
verify ProtTrans feature whether match: 
wt: torch.Size([344, 1024]) - mut: torch.Size([344, 1024])
verify DSSP feature whether match: 
wt: torch.Size([344, 9]) - mut: torch.Size([344, 9])
verify sequence whether match: 
wt: 337 - mut: 337
313: R -> M
314: S -> E
verify coordinate whether match: 
wt: torch.Size([337, 5, 3]) - mut: torch.Size([337, 5, 3])
verify ProtTrans feature whether match: 
wt: torch.Size([337, 1024]) - mut: torch.Size([337, 1024])
verify DSSP feature whether match: 
wt: torch.Size([337, 9]) - mut: torch.Size([337, 9])
verify sequence whether match: 
wt: 279 - mut: 279
234: P -> A
verify coordinate whether match: 
wt: torch.Size([279, 5, 3]) - mut: torch.Size([279, 5, 3])
verify ProtTrans feature whether match: 
wt: torch.Size([279, 1024]) - mut: torch.Size([279, 1024])
verify DSSP feature whether match: 
wt: to

# split dataset

In [18]:
import random
import numpy as np
import torch

info = torch.load("./data/skempi_v2/dataset.pt")
info


Unnamed: 0,pdb,mutation,affinity_wt,affinity_mut,temp,ddg,daffinity
0,1CSE,LI38G,1.120000e-12,5.260000e-11,294.0,2.248833,5.148000e-11
1,1CSE,LI38S,1.120000e-12,8.330000e-12,294.0,1.172229,7.210000e-12
2,1CSE,LI38P,1.120000e-12,1.020000e-07,294.0,6.671276,1.019989e-07
3,1CSE,LI38I,1.120000e-12,1.720000e-10,294.0,2.940988,1.708800e-10
4,1CSE,LI38D,1.120000e-12,1.920000e-09,294.0,4.350434,1.918880e-09
...,...,...,...,...,...,...,...
5748,3QIB,KP8R,5.500000e-06,2.400000e-04,298.0,2.235909,2.345000e-04
5749,3QIB,TP11A,5.500000e-06,1.100000e-03,298.0,3.137419,1.094500e-03
5750,3QIB,TP11S,5.500000e-06,3.380000e-05,298.0,1.075181,2.830000e-05
5751,3QIB,TP11N,5.500000e-06,4.340000e-05,298.0,1.223219,3.790000e-05


In [321]:
# change Dataframe column name
info.rename(columns={'pdb': 'wt_name'}, inplace=True)
info.rename(columns={'mutation': 'mut_name'}, inplace=True)
info["mut_name"] = info["wt_name"] + "_" + info["mut_name"]

info.drop(columns=['affinity_wt'], inplace=True)
info.drop(columns=['affinity_mut'], inplace=True)
info.drop(columns=['temp'], inplace=True)

info["target"] = info["ddg"]
torch.save(info, "./data/skempi_v2/dataset_unshuffled.pt")
info

Unnamed: 0,wt_name,mut_name,ddg,daffinity,target
0,1CSE,1CSE_LI38G,2.248833,5.148000e-11,2.248833
1,1CSE,1CSE_LI38S,1.172229,7.210000e-12,1.172229
2,1CSE,1CSE_LI38P,6.671276,1.019989e-07,6.671276
3,1CSE,1CSE_LI38I,2.940988,1.708800e-10,2.940988
4,1CSE,1CSE_LI38D,4.350434,1.918880e-09,4.350434
...,...,...,...,...,...
5748,3QIB,3QIB_KP8R,2.235909,2.345000e-04,2.235909
5749,3QIB,3QIB_TP11A,3.137419,1.094500e-03,3.137419
5750,3QIB,3QIB_TP11S,1.075181,2.830000e-05,1.075181
5751,3QIB,3QIB_TP11N,1.223219,3.790000e-05,1.223219


### mutation-level split

In [322]:
# shuffle
info_shuffled = info.sample(frac=1, random_state=42).reset_index(drop=True)

# split folds
folds_num = 10
index = list(range(len(info_shuffled)))
index_k_split = np.array_split(index, folds_num)
index_k_split = [np.full(len(index_k_split[i]), i) for i in range(folds_num)]
index_k_split = np.concatenate(index_k_split)
info_shuffled["split"] = index_k_split

torch.save(info_shuffled, "./data/skempi_v2/dataset_processed_mutation_level.pt")
info_shuffled

Unnamed: 0,wt_name,mut_name,ddg,daffinity,target,split
0,1DAN,"1DAN_TT55A,FU50A",4.496988,3.847400e-09,4.496988,0
1,3F1S,3F1S_DA36A,2.184384,8.190000e-08,2.184384,0
2,1R0R,"1R0R_EI5D,AI10V,TI12S,RI16M,NI23S,GI27N,NI31D,...",0.478587,3.730000e-11,0.478587,0
3,2B2X,"2B2X_VH50T,EH64K,QL28S,YL52N,SL91R",3.770126,8.485400e-07,3.770126,0
4,3KBH,3KBH_DA20A,0.872968,2.290000e-07,0.872968,0
...,...,...,...,...,...,...
5748,3BT1,3BT1_VU222A,-0.484935,-2.600000e-10,-0.484935,9
5749,4RS1,"4RS1_EA25K,DA28K",4.120828,1.998100e-06,4.120828,9
5750,4NKQ,"4NKQ_DC99K,RB80D",0.176364,1.700000e-10,0.176364,9
5751,1K8R,1K8R_RA41A,1.226585,1.891000e-06,1.226585,9


### structure-level split

In [None]:
print(len(set(info["wt_name"])))

info["split"] = -1
info

In [325]:
# split folds
folds_num = 10

i = 0
step = 1
pdb_cnt = info["wt_name"].value_counts()
for pdb, cnt in pdb_cnt.items():
    if i == folds_num or i == -1:
        step = -step
        i+=step

    info.loc[info["wt_name"] == pdb, "split"] = i
    # print(f"{pdb}: {cnt}")
    i+=step

torch.save(info, "./data/skempi_v2/dataset_processed_structure_level.pt")
info

Unnamed: 0,wt_name,mut_name,ddg,daffinity,target,split
0,1CSE,1CSE_LI38G,2.248833,5.148000e-11,2.248833,1
1,1CSE,1CSE_LI38S,1.172229,7.210000e-12,1.172229,1
2,1CSE,1CSE_LI38P,6.671276,1.019989e-07,6.671276,1
3,1CSE,1CSE_LI38I,2.940988,1.708800e-10,2.940988,1
4,1CSE,1CSE_LI38D,4.350434,1.918880e-09,4.350434,1
...,...,...,...,...,...,...
5748,3QIB,3QIB_KP8R,2.235909,2.345000e-04,2.235909,5
5749,3QIB,3QIB_TP11A,3.137419,1.094500e-03,3.137419,5
5750,3QIB,3QIB_TP11S,1.075181,2.830000e-05,1.075181,5
5751,3QIB,3QIB_TP11N,1.223219,3.790000e-05,1.223219,5


In [326]:

grouped = info.groupby('split')

# 创建一个包含所有组的列表
groups = [group for _, group in grouped]
for g in groups:
    print(f'Split {g.iloc[0]["split"]} size: {len(g)}')

Split 0 size: 634
Split 1 size: 632
Split 2 size: 624
Split 3 size: 606
Split 4 size: 607
Split 5 size: 566
Split 6 size: 545
Split 7 size: 513
Split 8 size: 517
Split 9 size: 509


# process subset

### s4169

In [3]:
import pandas as pd
import numpy as np
import torch

s4169 = pd.read_csv("./datasets/skempi_v2/S4169.csv")
s4169


Unnamed: 0,protein,Partners(A_B),pa,pb,mutation,DDG,mode
0,1E50,A_B,x,x,A:D8A,0.402,forward
1,1E50,A_B,x,x,A:K86M,-0.320,forward
2,1E50,A_B,x,x,A:M48A,-1.320,forward
3,1E50,A_B,x,x,A:N11A,-0.678,forward
4,1E50,A_B,x,x,A:N51A,-2.433,forward
...,...,...,...,...,...,...,...
4164,5XCO,A_B,x,x,B:P6A,-1.079,forward
4165,5XCO,A_B,x,x,B:S10A,-1.286,forward
4166,5XCO,A_B,x,x,B:V14A,0.000,forward
4167,5XCO,A_B,x,x,B:Y11A,-1.900,forward


In [4]:
def tidy_mutation(mutation):
    chain = mutation.split(':')[0]
    transition = mutation.split(':')[1]

    mutation = transition[0] + chain + transition[1:]
    return mutation

s = "A:K86M"
tidy_mutation(s)

'KA86M'

In [5]:

s4169.rename(columns={'protein': 'wt_name'}, inplace=True)
s4169.rename(columns={'mutation': 'mut_name'}, inplace=True)
# s4169["wt_name"] = s4169["wt_name"] + '_' + s4169["Partners(A_B)"]
s4169["mut_name"] = s4169["mut_name"].apply(tidy_mutation)
s4169["mut_name"] = s4169["wt_name"] + '_' + s4169["mut_name"]

s4169.drop(columns=['Partners(A_B)', 'pa','pb', 'mode'], inplace=True)
s4169["target"] = s4169["DDG"]
s4169

Unnamed: 0,wt_name,mut_name,DDG,target
0,1E50,1E50_DA8A,0.402,0.402
1,1E50,1E50_KA86M,-0.320,-0.320
2,1E50,1E50_MA48A,-1.320,-1.320
3,1E50,1E50_NA11A,-0.678,-0.678
4,1E50,1E50_NA51A,-2.433,-2.433
...,...,...,...,...
4164,5XCO,5XCO_PB6A,-1.079,-1.079
4165,5XCO,5XCO_SB10A,-1.286,-1.286
4166,5XCO,5XCO_VB14A,0.000,0.000
4167,5XCO,5XCO_YB11A,-1.900,-1.900


In [6]:
# drop invalid items
import os

drop_list = []
for idx in range(len(s4169)):
    # wt = s4169.loc[idx, 'pdb']
    mut_name = s4169.loc[idx, 'mut_name']
    mut_pdb_filename = mut_name + '.txt'
    if not os.path.exists(f"./data/skempi_v2/seq/{mut_pdb_filename}"):
        drop_list.append(idx)
print(f"items to drop: {s4169.loc[drop_list]}")
s4169 = s4169.drop(index=drop_list)
s4169.reset_index(drop=True, inplace=True)
s4169

items to drop:      wt_name      mut_name    DDG  target
311     1AO7     1AO7_YC8A -3.168  -3.168
343     1BD2     1BD2_YC5A -3.408  -3.408
433     1C1Y    1C1Y_DA38A -2.620  -2.620
1142    1JCK    1JCK_NB23A -2.077  -2.077
1144    1JCK   1JCK_QB210A -2.077  -2.077
1307    1KBH   1KBH_LA930W -0.851  -0.851
1308    1KBH   1KBH_QA896W -0.138  -0.138
1309    1KBH  1KBH_YB1172W -0.596  -0.596
1359    1LFD    1LFD_DB38A -3.747  -3.747
1388    1LFD    1LFD_YB32W  0.211   0.211
2157    2AJF   2AJF_TE159S -1.824  -1.824
3731    4G0N    4G0N_DA38A -3.194  -3.194


Unnamed: 0,wt_name,mut_name,DDG,target
0,1E50,1E50_DA8A,0.402,0.402
1,1E50,1E50_KA86M,-0.320,-0.320
2,1E50,1E50_MA48A,-1.320,-1.320
3,1E50,1E50_NA11A,-0.678,-0.678
4,1E50,1E50_NA51A,-2.433,-2.433
...,...,...,...,...
4152,5XCO,5XCO_PB6A,-1.079,-1.079
4153,5XCO,5XCO_SB10A,-1.286,-1.286
4154,5XCO,5XCO_VB14A,0.000,0.000
4155,5XCO,5XCO_YB11A,-1.900,-1.900


In [50]:
subset = s4169

# mutation-level split
# shuffle
subset_shuffled = subset.sample(frac=1, random_state=42).reset_index(drop=True)

# split folds
folds_num = 10
index = list(range(len(subset_shuffled)))
index_k_split = np.array_split(index, folds_num)
index_k_split = [np.full(len(index_k_split[i]), i) for i in range(folds_num)]
index_k_split = np.concatenate(index_k_split)
subset_shuffled["split"] = index_k_split

torch.save(subset_shuffled, "./data/skempi_v2/dataset_s4169_mutation_level.pt")
subset_shuffled

Unnamed: 0,wt_name,mut_name,DDG,target,split
0,2SIC,2SIC_MI67E,-0.796,-0.796,0
1,1AO7,1AO7_YC5A,-1.533,-1.533,0
2,2NYY,2NYY_KA895A,0.141,0.141,0
3,3BT1,3BT1_MU4A,-0.037,-0.037,0
4,1C4Z,1C4Z_KD6E,-1.555,-1.555,0
...,...,...,...,...,...
4152,3SGB,3SGB_AI9M,-0.089,-0.089,9
4153,1C4Z,1C4Z_KD97A,-1.302,-1.302,9
4154,3NGB,3NGB_RH54A,0.255,0.255,9
4155,4I77,4I77_WH52A,-2.115,-2.115,9


In [51]:
subset = s4169

# structure-level split
print(len(set(subset["wt_name"])))

subset["split"] = -1

# split folds
folds_num = 10

i = 0
step = 1
pdb_cnt = subset["wt_name"].value_counts()
for pdb, cnt in pdb_cnt.items():
    if i == folds_num or i == -1:
        step = -step
        i+=step

    subset.loc[subset["wt_name"] == pdb, "split"] = i
    # print(f"{pdb}: {cnt}")
    i+=step

torch.save(subset, "./data/skempi_v2/dataset_s4169_structure_level.pt")
subset

318


Unnamed: 0,wt_name,mut_name,DDG,target,split
0,1E50,1E50_DA8A,0.402,0.402,7
1,1E50,1E50_KA86M,-0.320,-0.320,7
2,1E50,1E50_MA48A,-1.320,-1.320,7
3,1E50,1E50_NA11A,-0.678,-0.678,7
4,1E50,1E50_NA51A,-2.433,-2.433,7
...,...,...,...,...,...
4152,5XCO,5XCO_PB6A,-1.079,-1.079,7
4153,5XCO,5XCO_SB10A,-1.286,-1.286,7
4154,5XCO,5XCO_VB14A,0.000,0.000,7
4155,5XCO,5XCO_YB11A,-1.900,-1.900,7


In [52]:

grouped = subset.groupby('split')

# 创建一个包含所有组的列表
groups = [group for _, group in grouped]
for g in groups:
    print(f'Split {g.iloc[0]["split"]} size: {len(g)}')

Split 0 size: 504
Split 1 size: 456
Split 2 size: 455
Split 3 size: 456
Split 4 size: 458
Split 5 size: 410
Split 6 size: 357
Split 7 size: 356
Split 8 size: 356
Split 9 size: 349


### s8338

In [53]:
# structure-level split
s8338 = subset.copy()

s8338["wt_name"], s8338["mut_name"] = s8338["mut_name"], s8338["wt_name"]
s8338["DDG"] = -s8338["DDG"]
s8338["target"] = -s8338["target"]
s8338 = pd.concat([subset, s8338], ignore_index=True)

torch.save(s8338, "./data/skempi_v2/dataset_s8338_structure_level.pt")
s8338

Unnamed: 0,wt_name,mut_name,DDG,target,split
0,1E50,1E50_DA8A,0.402,0.402,7
1,1E50,1E50_KA86M,-0.320,-0.320,7
2,1E50,1E50_MA48A,-1.320,-1.320,7
3,1E50,1E50_NA11A,-0.678,-0.678,7
4,1E50,1E50_NA51A,-2.433,-2.433,7
...,...,...,...,...,...
8309,5XCO_PB6A,5XCO,1.079,1.079,7
8310,5XCO_SB10A,5XCO,1.286,1.286,7
8311,5XCO_VB14A,5XCO,-0.000,-0.000,7
8312,5XCO_YB11A,5XCO,1.900,1.900,7


In [54]:

grouped = s8338.groupby('split')

# 创建一个包含所有组的列表
groups = [group for _, group in grouped]
for g in groups:
    print(f'Split {g.iloc[0]["split"]} size: {len(g)}')

Split 0 size: 1008
Split 1 size: 912
Split 2 size: 910
Split 3 size: 912
Split 4 size: 916
Split 5 size: 820
Split 6 size: 714
Split 7 size: 712
Split 8 size: 712
Split 9 size: 698


In [7]:
subset = s4169

# mutation-level split
# shuffle
subset_shuffled = subset.sample(frac=1, random_state=4200).reset_index(drop=True)

# split folds
folds_num = 10
index = list(range(len(subset_shuffled)))
index_k_split = np.array_split(index, folds_num)
index_k_split = [np.full(len(index_k_split[i]), i) for i in range(folds_num)]
index_k_split = np.concatenate(index_k_split)
subset_shuffled["split"] = index_k_split

s4169_reversed = subset_shuffled.copy()
s4169_reversed["wt_name"], s4169_reversed["mut_name"] = s4169_reversed["mut_name"], s4169_reversed["wt_name"]
s4169_reversed["DDG"] = -s4169_reversed["DDG"]
s4169_reversed["target"] = -s4169_reversed["target"]
s8338 = pd.concat([subset_shuffled, s4169_reversed], ignore_index=True)

torch.save(s8338, "./data/skempi_v2/dataset_s8338_mutation_level.pt")
# s8338.to_csv("./data/skempi_v2/dataset_s8338_mutation_level.csv", index=False)
s8338


Unnamed: 0,wt_name,mut_name,DDG,target,split
0,3SGB,3SGB_EI13K,-0.169,-0.169,0
1,2FTL,2FTL_TI11A,-4.093,-4.093,0
2,1A22,1A22_FA165A,-0.411,-0.411,0
3,1A22,1A22_RB166A,-0.056,-0.056,0
4,1DAN,1DAN_WU68F,-0.123,-0.123,0
...,...,...,...,...,...
8309,3BT1_DU22A,3BT1,0.232,0.232,9
8310,1K8R_RA41A,1K8R,1.180,1.180,9
8311,1VFB_SB28E,1VFB,-0.104,-0.104,9
8312,3SGB_NI30E,3SGB,0.963,0.963,9


### m1707

In [29]:
m1707 = pd.read_csv("./datasets/skempi_v2/M1707.csv")
m1707

Unnamed: 0,PDB id,Partner1,DDE_vdw,Mutation(s)_PDB,Mutation(s)_cleaned,DDGexp,Label,MutaBind2,DDE_vdw.1,DDG_solv,DDG_fold,SA_com_wt,SA_part_wt,CS,dE_vdw_wt
0,1A22,A_B,0.9647,"DA11A,MA14V,HA18Q,RA19H,FA25A,QA29K,EA33R","DA11A,MA14V,HA18Q,RA19H,FA25A,QA29K,EA33R",0.75,forward,0.97,0.9647,-0.5058,0.4267,-0.5577,-0.1895,0.6154,-0.6276
1,1A22,A_B,0.6049,"NA12H,FA25L","NA12H,FA25L",0.84,forward,0.81,0.6049,0.0131,-1.3867,-0.2188,0.0067,0.9308,0.0119
2,1A22,A_B,-0.1449,"NA12R,MA14V,LA15V,RA16L,RA19Y","NA12R,MA14V,LA15V,RA16L,RA19Y",1.68,forward,1.67,-0.1449,0.5161,0.8117,-0.7707,-0.0641,0.9119,-0.4371
3,1A22,A_B,1.0421,"NA12R,MA14V,LA15V,RA16L,RA19Y,FA25S,DA26E,QA29...","NA12R,MA14V,LA15V,RA16L,RA19Y,FA25S,DA26E,QA29...",1.40,forward,1.62,1.0421,-0.3728,0.8401,-0.4965,-0.1383,0.5519,-0.6493
4,1A22,A_B,0.7230,"QA22N,FA25S,DA26E,QA29S,EA30Q,EA33K","QA22N,FA25S,DA26E,QA29S,EA30Q,EA33K",-0.09,forward,0.73,0.7230,-0.5699,0.0625,-0.4992,-0.0637,0.7109,-0.4647
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1702,4G0N,A_B,-0.5875,"FB61W,NB71R,VB88I","FB8W,NB18R,VB35I",-0.36,reverse,-0.75,-0.5875,0.5925,0.5858,0.0315,0.2279,-1.6484,-0.7917
1703,4G0N,A_B,0.6227,"FB61W,RB67L,NB71R,VB88I","FB8W,RB14L,NB18R,VB35I",-0.60,reverse,-0.78,0.6227,-0.0514,1.0229,-0.0750,0.2502,-2.3153,-1.0766
1704,4G0N,A_B,0.3294,"FB61W,RB67L,VB69E,NB71R,KB84R,VB88I","FB8W,RB14L,VB16E,NB18R,KB31R,VB35I",-3.60,reverse,-1.02,0.3294,0.0614,0.5816,0.3131,0.1210,-2.3634,-0.9024
1705,4G0N,A_B,0.2245,"FB61W,VB69E,NB71R,VB88I","FB8W,VB16E,NB18R,VB35I",-1.82,reverse,-1.22,0.2245,0.4388,-0.8488,0.3134,0.4395,-2.1040,-0.5284


In [30]:

m1707.rename(columns={'PDB id': 'wt_name'}, inplace=True)
m1707.rename(columns={'Mutation(s)_cleaned': 'mut_name'}, inplace=True)
m1707["mut_name"] = m1707["wt_name"] + '_' + m1707["mut_name"]

m1707 = m1707.loc[:, ['wt_name', 'mut_name', 'DDGexp', 'Label']]
m1707["target"] = m1707["DDGexp"]
m1707

Unnamed: 0,wt_name,mut_name,DDGexp,Label,target
0,1A22,"1A22_DA11A,MA14V,HA18Q,RA19H,FA25A,QA29K,EA33R",0.75,forward,0.75
1,1A22,"1A22_NA12H,FA25L",0.84,forward,0.84
2,1A22,"1A22_NA12R,MA14V,LA15V,RA16L,RA19Y",1.68,forward,1.68
3,1A22,"1A22_NA12R,MA14V,LA15V,RA16L,RA19Y,FA25S,DA26E...",1.40,forward,1.40
4,1A22,"1A22_QA22N,FA25S,DA26E,QA29S,EA30Q,EA33K",-0.09,forward,-0.09
...,...,...,...,...,...
1702,4G0N,"4G0N_FB8W,NB18R,VB35I",-0.36,reverse,-0.36
1703,4G0N,"4G0N_FB8W,RB14L,NB18R,VB35I",-0.60,reverse,-0.60
1704,4G0N,"4G0N_FB8W,RB14L,VB16E,NB18R,KB31R,VB35I",-3.60,reverse,-3.60
1705,4G0N,"4G0N_FB8W,VB16E,NB18R,VB35I",-1.82,reverse,-1.82


In [31]:
# drop invalid items
import os

drop_list = []
for idx in range(len(m1707)):
    mut_name = m1707.loc[idx, 'mut_name']
    mut_pdb_filename = mut_name + '.txt'
    if not os.path.exists(f"./data/skempi_v2/seq/{mut_pdb_filename}"):
        drop_list.append(idx)
print(f"items to drop: {m1707.loc[drop_list]}")
m1707 = m1707.drop(index=drop_list)
m1707.reset_index(drop=True, inplace=True)
m1707

items to drop:        wt_name                                           mut_name  DDGexp  \
3         1A22  1A22_NA12R,MA14V,LA15V,RA16L,RA19Y,FA25S,DA26E...    1.40   
10        1A22  1A22_SA57T,TA60A,SA62T,NA63G,RA64K,EA65D,TA67A...    1.68   
79        1BP3  1BP3_SA57T,TA60A,SA62T,NA63G,RA64K,EA65D,TA67A...    2.30   
171   1.00E+50                               1.00E+50_DA52A,EA53A    0.39   
354       1KBH                               1KBH_YB1172W,IB1126V    0.81   
...        ...                                                ...     ...   
1560      1A22  1A22_NA12R,MA14V,LA15V,RA16L,RA19Y,FA25S,DA26E...   -1.40   
1566      1A22  1A22_SA57T,TA60A,SA62T,NA63G,RA64K,EA65D,TA67A...   -1.68   
1593      1BP3  1BP3_SA57T,TA60A,SA62T,NA63G,RA64K,EA65D,TA67A...   -2.30   
1632      1PPF  1PPF_EI10D,AI15V,TI17S,EI19D,RI21M,KI29T,GI32N...   -4.53   
1638      1PPF  1PPF_EI10D,AI15V,YI20D,RI21M,KI29I,GI32H,NI36Y...   -4.86   

        Label  target  
3     forward    1.40  
10    forwar

Unnamed: 0,wt_name,mut_name,DDGexp,Label,target
0,1A22,"1A22_DA11A,MA14V,HA18Q,RA19H,FA25A,QA29K,EA33R",0.75,forward,0.75
1,1A22,"1A22_NA12H,FA25L",0.84,forward,0.84
2,1A22,"1A22_NA12R,MA14V,LA15V,RA16L,RA19Y",1.68,forward,1.68
3,1A22,"1A22_QA22N,FA25S,DA26E,QA29S,EA30Q,EA33K",-0.09,forward,-0.09
4,1A22,"1A22_FA25A,YA42A,QA46A",0.20,forward,0.20
...,...,...,...,...,...
1597,4G0N,"4G0N_FB8W,NB18R,VB35I",-0.36,reverse,-0.36
1598,4G0N,"4G0N_FB8W,RB14L,NB18R,VB35I",-0.60,reverse,-0.60
1599,4G0N,"4G0N_FB8W,RB14L,VB16E,NB18R,KB31R,VB35I",-3.60,reverse,-3.60
1600,4G0N,"4G0N_FB8W,VB16E,NB18R,VB35I",-1.82,reverse,-1.82


In [32]:
# mutation-level split

forward = m1707[m1707["Label"] == "forward"].reset_index(drop=True)
reverse = m1707[m1707["Label"] == "reverse"].reset_index(drop=True)
print(f"forward: {len(forward)}; revers: {len(reverse)}")


# shuffle
subset = forward
subset_shuffled = subset.sample(frac=1, random_state=42).reset_index(drop=True)

# split folds
folds_num = 10
index = list(range(len(subset_shuffled)))
index_k_split = np.array_split(index, folds_num)
index_k_split = [np.full(len(index_k_split[i]), i) for i in range(folds_num)]
index_k_split = np.concatenate(index_k_split)
subset_shuffled["split"] = index_k_split
forward = subset_shuffled
forward


forward: 1237; revers: 365


Unnamed: 0,wt_name,mut_name,DDGexp,Label,target,split
0,1YCS,"1YCS_MA37L,VA107A,NA143Y,NA172D,RA153S,HA72R",0.50,forward,0.50,0
1,1EMV,"1EMV_FB86A,YA53A",8.26,forward,8.26,0
2,3KUD,"3KUD_KB30A,FB6W",2.17,forward,2.17,0
3,3NGB,"3NGB_AH57G,VH58T,PH63K,VH74T",1.67,forward,1.67,0
4,3S9D,"3S9D_EB66A,RA120A",3.63,forward,3.63,0
...,...,...,...,...,...,...
1232,3SGB,"3SGB_KI7T,PI8E,AI9Y",-3.82,forward,-3.82,9
1233,3VR6,"3VR6_LD386A,LE386A,LF389A",1.14,forward,1.14,9
1234,4G0N,"4G0N_VB16A,EA37A",2.80,forward,2.80,9
1235,3IDX,"3IDX_WG13M,MG247A,CG14W,CG129V",-0.65,forward,-0.65,9


In [33]:

reverse["split"] = -1
for row in range(len(reverse)):
    mut_name = reverse.loc[row, "mut_name"]
    split = forward[forward["mut_name"] == mut_name]["split"].item()
    reverse.loc[row, "split"] = split

reverse["wt_name"], reverse["mut_name"] = reverse["mut_name"], reverse["wt_name"]
reverse

Unnamed: 0,wt_name,mut_name,DDGexp,Label,target,split
0,"1A4Y_KB40G,YA434A,DA435A",1A4Y,-6.21,reverse,-6.21,7
1,"1A4Y_KB40G,YA434A,YA437A",1A4Y,-9.12,reverse,-9.12,7
2,"1A4Y_KB40G,YA434F",1A4Y,-6.66,reverse,-6.66,2
3,"1A4Y_RB5A,YA434A",1A4Y,-6.98,reverse,-6.98,9
4,"1A4Y_RB5A,YA434A,DA435A",1A4Y,-10.10,reverse,-10.10,5
...,...,...,...,...,...,...
360,"4G0N_FB8W,NB18R,VB35I",4G0N,-0.36,reverse,-0.36,2
361,"4G0N_FB8W,RB14L,NB18R,VB35I",4G0N,-0.60,reverse,-0.60,2
362,"4G0N_FB8W,RB14L,VB16E,NB18R,KB31R,VB35I",4G0N,-3.60,reverse,-3.60,0
363,"4G0N_FB8W,VB16E,NB18R,VB35I",4G0N,-1.82,reverse,-1.82,2


In [45]:
m1707_mutation = pd.concat([forward, reverse], ignore_index=True)
m1707_mutation.sort_values(by=['split', 'Label'], inplace=True)
m1707_mutation.drop(columns=['Label'], inplace=True)

torch.save(m1707_mutation, "./data/skempi_v2/dataset_m1707_mutation_level.pt")
# m1707_mutation.to_csv("./data/skempi_v2/dataset_m1707_mutation_level.csv", index=False)
m1707_mutation

Unnamed: 0,wt_name,mut_name,DDGexp,target,split
0,1YCS,"1YCS_MA37L,VA107A,NA143Y,NA172D,RA153S,HA72R",0.50,0.50,0
1,1EMV,"1EMV_FB86A,YA53A",8.26,8.26,0
2,3KUD,"3KUD_KB30A,FB6W",2.17,2.17,0
3,3NGB,"3NGB_AH57G,VH58T,PH63K,VH74T",1.67,1.67,0
4,3S9D,"3S9D_EB66A,RA120A",3.63,3.63,0
...,...,...,...,...,...
1572,"2VN5_SB42A,TB43F,AB14S,LB15T",2VN5,-0.36,-0.36,9
1578,"3KUD_KB30A,FB6W,NB16R",3KUD,-1.20,-1.20,9
1579,"3KUD_KB30A,FB6W,NB16R,VB33I",3KUD,-1.95,-1.95,9
1590,"3MZW_NB18T,AB49S",3MZW,-0.09,-0.09,9


In [36]:
# structure-level split
print(len(set(m1707["wt_name"])))

m1707["split"] = -1

# split folds
folds_num = 10

i = 0
step = 1
pdb_cnt = m1707["wt_name"].value_counts()
for pdb, cnt in pdb_cnt.items():
    if i == folds_num or i == -1:
        step = -step
        i+=step

    m1707.loc[m1707["wt_name"] == pdb, "split"] = i
    # print(f"{pdb}: {cnt}")
    i+=step

m1707

117


Unnamed: 0,wt_name,mut_name,DDGexp,Label,target,split
0,1A22,"1A22_DA11A,MA14V,HA18Q,RA19H,FA25A,QA29K,EA33R",0.75,forward,0.75,8
1,1A22,"1A22_NA12H,FA25L",0.84,forward,0.84,8
2,1A22,"1A22_NA12R,MA14V,LA15V,RA16L,RA19Y",1.68,forward,1.68,8
3,1A22,"1A22_QA22N,FA25S,DA26E,QA29S,EA30Q,EA33K",-0.09,forward,-0.09,8
4,1A22,"1A22_FA25A,YA42A,QA46A",0.20,forward,0.20,8
...,...,...,...,...,...,...
1597,4G0N,"4G0N_FB8W,NB18R,VB35I",-0.36,reverse,-0.36,5
1598,4G0N,"4G0N_FB8W,RB14L,NB18R,VB35I",-0.60,reverse,-0.60,5
1599,4G0N,"4G0N_FB8W,RB14L,VB16E,NB18R,KB31R,VB35I",-3.60,reverse,-3.60,5
1600,4G0N,"4G0N_FB8W,VB16E,NB18R,VB35I",-1.82,reverse,-1.82,5


In [37]:
reverse_rows = m1707["Label"] == "reverse"
m1707.loc[reverse_rows,['wt_name', 'mut_name']] = m1707.loc[reverse_rows,['mut_name', 'wt_name']].values
m1707.drop(columns=['Label'], inplace=True)

torch.save(m1707, "./data/skempi_v2/dataset_m1707_structure_level.pt")
m1707

Unnamed: 0,wt_name,mut_name,DDGexp,target,split
0,1A22,"1A22_DA11A,MA14V,HA18Q,RA19H,FA25A,QA29K,EA33R",0.75,0.75,8
1,1A22,"1A22_NA12H,FA25L",0.84,0.84,8
2,1A22,"1A22_NA12R,MA14V,LA15V,RA16L,RA19Y",1.68,1.68,8
3,1A22,"1A22_QA22N,FA25S,DA26E,QA29S,EA30Q,EA33K",-0.09,-0.09,8
4,1A22,"1A22_FA25A,YA42A,QA46A",0.20,0.20,8
...,...,...,...,...,...
1597,"4G0N_FB8W,NB18R,VB35I",4G0N,-0.36,-0.36,5
1598,"4G0N_FB8W,RB14L,NB18R,VB35I",4G0N,-0.60,-0.60,5
1599,"4G0N_FB8W,RB14L,VB16E,NB18R,KB31R,VB35I",4G0N,-3.60,-3.60,5
1600,"4G0N_FB8W,VB16E,NB18R,VB35I",4G0N,-1.82,-1.82,5


In [38]:

grouped = m1707.groupby('split')

# 创建一个包含所有组的列表
groups = [group for _, group in grouped]
for g in groups:
    print(f'Split {g.iloc[0]["split"]} size: {len(g)}')

Split 0 size: 215
Split 1 size: 178
Split 2 size: 171
Split 3 size: 163
Split 4 size: 148
Split 5 size: 147
Split 6 size: 145
Split 7 size: 145
Split 8 size: 143
Split 9 size: 147


In [8]:
import torch
import pandas as pd

s4169_mutation = torch.load("./data/skempi_v2/dataset_s4169_mutation_level.pt")
s4169_structure = torch.load("./data/skempi_v2/dataset_s4169_structure_level.pt")
s8338_mutation = torch.load("./data/skempi_v2/dataset_s8338_mutation_level.pt")
s8338_structure = torch.load("./data/skempi_v2/dataset_s8338_structure_level.pt")
m1707_mutation = torch.load("./data/skempi_v2/dataset_m1707_mutation_level.pt")
m1707_structure = torch.load("./data/skempi_v2/dataset_m1707_structure_level.pt")


In [47]:

s4169_mutation.to_csv("./data/skempi_v2/s4169_mutation.csv",index=False)
s4169_mutation

Unnamed: 0,wt_name,mut_name,DDG,target,split
0,2SIC,2SIC_MI67E,-0.796,-0.796,0
1,1AO7,1AO7_YC5A,-1.533,-1.533,0
2,2NYY,2NYY_KA895A,0.141,0.141,0
3,3BT1,3BT1_MU4A,-0.037,-0.037,0
4,1C4Z,1C4Z_KD6E,-1.555,-1.555,0
...,...,...,...,...,...
4152,3SGB,3SGB_AI9M,-0.089,-0.089,9
4153,1C4Z,1C4Z_KD97A,-1.302,-1.302,9
4154,3NGB,3NGB_RH54A,0.255,0.255,9
4155,4I77,4I77_WH52A,-2.115,-2.115,9


In [48]:
s4169_structure.to_csv("./data/skempi_v2/s4169_structure.csv",index=False)
s4169_structure


Unnamed: 0,wt_name,mut_name,DDG,target,split
0,1E50,1E50_DA8A,0.402,0.402,7
1,1E50,1E50_KA86M,-0.320,-0.320,7
2,1E50,1E50_MA48A,-1.320,-1.320,7
3,1E50,1E50_NA11A,-0.678,-0.678,7
4,1E50,1E50_NA51A,-2.433,-2.433,7
...,...,...,...,...,...
4152,5XCO,5XCO_PB6A,-1.079,-1.079,7
4153,5XCO,5XCO_SB10A,-1.286,-1.286,7
4154,5XCO,5XCO_VB14A,0.000,0.000,7
4155,5XCO,5XCO_YB11A,-1.900,-1.900,7


In [9]:

s8338_mutation.to_csv("./data/skempi_v2/s8338_mutation.csv",index=False)
s8338_mutation


Unnamed: 0,wt_name,mut_name,DDG,target,split
0,3SGB,3SGB_EI13K,-0.169,-0.169,0
1,2FTL,2FTL_TI11A,-4.093,-4.093,0
2,1A22,1A22_FA165A,-0.411,-0.411,0
3,1A22,1A22_RB166A,-0.056,-0.056,0
4,1DAN,1DAN_WU68F,-0.123,-0.123,0
...,...,...,...,...,...
8309,3BT1_DU22A,3BT1,0.232,0.232,9
8310,1K8R_RA41A,1K8R,1.180,1.180,9
8311,1VFB_SB28E,1VFB,-0.104,-0.104,9
8312,3SGB_NI30E,3SGB,0.963,0.963,9


In [50]:

s8338_structure.to_csv("./data/skempi_v2/s8338_structure.csv",index=False)
s8338_structure


Unnamed: 0,wt_name,mut_name,DDG,target,split
0,1E50,1E50_DA8A,0.402,0.402,7
1,1E50,1E50_KA86M,-0.320,-0.320,7
2,1E50,1E50_MA48A,-1.320,-1.320,7
3,1E50,1E50_NA11A,-0.678,-0.678,7
4,1E50,1E50_NA51A,-2.433,-2.433,7
...,...,...,...,...,...
8309,5XCO_PB6A,5XCO,1.079,1.079,7
8310,5XCO_SB10A,5XCO,1.286,1.286,7
8311,5XCO_VB14A,5XCO,-0.000,-0.000,7
8312,5XCO_YB11A,5XCO,1.900,1.900,7


In [51]:

m1707_mutation.to_csv("./data/skempi_v2/m1707_mutation.csv",index=False)
m1707_mutation


Unnamed: 0,wt_name,mut_name,DDGexp,target,split
0,1YCS,"1YCS_MA37L,VA107A,NA143Y,NA172D,RA153S,HA72R",0.50,0.50,0
1,1EMV,"1EMV_FB86A,YA53A",8.26,8.26,0
2,3KUD,"3KUD_KB30A,FB6W",2.17,2.17,0
3,3NGB,"3NGB_AH57G,VH58T,PH63K,VH74T",1.67,1.67,0
4,3S9D,"3S9D_EB66A,RA120A",3.63,3.63,0
...,...,...,...,...,...
1572,"2VN5_SB42A,TB43F,AB14S,LB15T",2VN5,-0.36,-0.36,9
1578,"3KUD_KB30A,FB6W,NB16R",3KUD,-1.20,-1.20,9
1579,"3KUD_KB30A,FB6W,NB16R,VB33I",3KUD,-1.95,-1.95,9
1590,"3MZW_NB18T,AB49S",3MZW,-0.09,-0.09,9


In [52]:

m1707_structure.to_csv("./data/skempi_v2/m1707_structure.csv",index=False)
m1707_structure

Unnamed: 0,wt_name,mut_name,DDGexp,target,split
0,1A22,"1A22_DA11A,MA14V,HA18Q,RA19H,FA25A,QA29K,EA33R",0.75,0.75,8
1,1A22,"1A22_NA12H,FA25L",0.84,0.84,8
2,1A22,"1A22_NA12R,MA14V,LA15V,RA16L,RA19Y",1.68,1.68,8
3,1A22,"1A22_QA22N,FA25S,DA26E,QA29S,EA30Q,EA33K",-0.09,-0.09,8
4,1A22,"1A22_FA25A,YA42A,QA46A",0.20,0.20,8
...,...,...,...,...,...
1597,"4G0N_FB8W,NB18R,VB35I",4G0N,-0.36,-0.36,5
1598,"4G0N_FB8W,RB14L,NB18R,VB35I",4G0N,-0.60,-0.60,5
1599,"4G0N_FB8W,RB14L,VB16E,NB18R,KB31R,VB35I",4G0N,-3.60,-3.60,5
1600,"4G0N_FB8W,VB16E,NB18R,VB35I",4G0N,-1.82,-1.82,5


In [17]:
m1707_structure[m1707_structure["split"] == 0]

for i in range(10):
    split = m1707_structure[m1707_structure["split"] == i]
    print(f"Split {i} size: {len(split)} pdb_cnt: {(set(split['wt_name'].str.split('_').str[0]))}")

Split 0 size: 215 pdb_cnt: {'1HE8', '1B41', '4JEU', '4YH7', '1DQJ', '4RA0', '4KRL', '5CYK', '4RS1', '1GC1', '1JTG'}
Split 1 size: 178 pdb_cnt: {'4UYP', '3BP8', '1VFB', '1MQ8', '1C1Y', '1A4Y', '2QJA', '5UFE', '3AAA', '1PPF', '1YQV'}
Split 2 size: 171 pdb_cnt: {'1FCC', '3KUD', '4UYQ', '1AHW', '2QJ9', '3HFM', '4KRP', '1BRS', '1YY9', '3BE1', '3LZF'}
Split 3 size: 163 pdb_cnt: {'5M2O', '1MHP', '1BP3', '4E6K', '4JPK', '1AK4', '3M63', '3SGB', '4J2L', '4MYW', '2QJB', '1DVF'}
Split 4 size: 148 pdb_cnt: {'4K71', '2G2U', '1Z7X', '1B2S', '3G6D', '1K8R', '1R0R', '4B0M', '1CZ8', '1WQJ', '3MZG', '1N8Z'}
Split 5 size: 147 pdb_cnt: {'4GU0', '4NKQ', '2PCC', '3NGB', '3MZW', '2O3B', '1FSS', '1GL1', '3BDY', '4G0N', '1QAB', '1B2U'}
Split 6 size: 145 pdb_cnt: {'5UFQ', '1MAH', '2B2X', '3L5X', '1MLC', '2KSO', '3VR6', '1GL0', '2NZ9', '4GNK', '1XXM', '1B3S'}
Split 7 size: 145 pdb_cnt: {'1JRH', '1TM1', '1GUA', '2NY7', '4X4M', '1Y4A', '2NOJ', '3S9D', '2NYY', '1C4Z', '1REW', '2J0T'}
Split 8 size: 143 pdb_cnt: {'1A2

In [44]:
len(m1707_structure[m1707_structure['wt_name'].str.split('_').str[0] == '1YQV'])

2