# PBkcat

## 背景介绍：
催化常数kcat（catalytic number）:是在底物浓度处于饱和状态下，一个酶（或一个酶活性位点）单位时间内转化的底物分子数，其定义了反应的最大化学转化率，值为Vmax/E。

查尔姆斯理工大学（Chalmers University of Technology）的研究团队提出了深度学习方法 DLKcat来预测所有代谢酶与其底物的 kcat值，只需要底物 SMILES 信息和酶的蛋白质序列作为输入，从而为任何物种产生通用的 kcat预测工具。文章中使用pearson系数对模型进行评估（代表预测值与实际值之间的相关性系数），表现还不错，但是在实际测试中该方法预测出来的kcat值误差非常大。

为此，我们希望通过建立PBkcat模型，该模型使用proteinBERT提取蛋白质氨基酸序列特征、使用GNN提取底物分子图特征，再使用attention机制对kcat值进行大规模预测，希望能进一步探究蛋白质和底物分子与kcat值之间的联系，能够更加准确的预测kcat值。

## 相关算法介绍：

### ProteinBERT
2018年，Devlin等提出基于深度双向Transformer的预训练模型ProteinBERT，旨在以一种自然的方式捕获蛋白质的局部和全局表示。ProteinBERT在涵盖不同蛋白质属性（包括蛋白质结构、翻译后修饰和生物物理属性）的多个基准上获得了最先进的性能。

### Substructure-based graph neural network (sub-GNN)
对于底物的处理，我们打算通过GNN来提取底物的分子图特征。分子图以原子为节点，键为边的图形表示，节点存储信息（标签），例如原子类型、电荷、多重性和质量，而边存储键合顺序。每个都可以具有关于芳族和立体异构的信息。

----------

In [None]:
#@title clone and download dependencies
! git clone https://github.com/950288/PBkcat_test.git
%cd PBkcat_test
! pip install  -r dependencies.txt
! wget -nc -P preprocess ftp://ftp.cs.huji.ac.il/users/nadavb/protein_bert/epoch_92400_sample_23500000.pkl 

In [None]:
#@title convert substrate smiles to fingerprints for training
# ! python preprocess/substrate.py

import json
from rdkit import Chem
from collections import defaultdict
import numpy as np
import pickle
import tqdm

atom_dict = defaultdict(lambda: len(atom_dict))
bond_dict = defaultdict(lambda: len(bond_dict))
fingerprint_dict = defaultdict(lambda: len(fingerprint_dict))
edge_dict = defaultdict(lambda: len(edge_dict))

radius = 2
ngram = 3

def create_atoms(mol):
    """Create a list of atom (e.g., hydrogen and oxygen) IDs
    considering the aromaticity."""
    atoms = [a.GetSymbol() for a in mol.GetAtoms()]
    for a in mol.GetAromaticAtoms():
        i = a.GetIdx()
        atoms[i] = (atoms[i], 'aromatic')
    atoms = [atom_dict[a] for a in atoms]
    return np.array(atoms)

def create_ijbonddict(mol):
    """Create a dictionary, which each key is a node ID
    and each value is the tuples of its neighboring node
    and bond (e.g., single and double) IDs."""
    # bond_dict = defaultdict(lambda: len(bond_dict))
    i_jbond_dict = defaultdict(lambda: [])
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        bond = bond_dict[str(b.GetBondType())]
        i_jbond_dict[i].append((j, bond))
        i_jbond_dict[j].append((i, bond))
    return i_jbond_dict



def extract_fingerprints(atoms, i_jbond_dict, radius):
    """Extract the r-radius subgraphs (i.e., fingerprints)
    from a molecular graph using Weisfeiler-Lehman algorithm."""

    # edge_dict = defaultdict(lambda: len(edge_dict))

    if (len(atoms) == 1) or (radius == 0):
        fingerprints = [fingerprint_dict[a] for a in atoms]
    else:
        nodes = atoms
        i_jedge_dict = i_jbond_dict

        for _ in range(radius):
            """Update each node ID considering its neighboring nodes and edges
            (i.e., r-radius subgraphs or fingerprints)."""
            fingerprints = []
            for i in range(len(nodes)):
                neighbors = i_jedge_dict.get(i, []) 
                if not neighbors:
                    fingerprint = (nodes[i], ())  # empty tuple
                else:
                    fingerprint = (nodes[i], tuple([(nodes[j], bond) for j, bond in neighbors]))
                fingerprints.append(fingerprint_dict[fingerprint])
            nodes = fingerprints

            """Also update each edge ID considering two nodes on its both sides."""
            _i_jedge_dict = defaultdict(lambda: [])
            for i, j_edge in enumerate(i_jedge_dict.values()):
                for j, edge in j_edge:
                    both_side = tuple(sorted((nodes[i], nodes[j])))
                    edge = edge_dict[(both_side, edge)]
                    _i_jedge_dict[i].append((j, edge))
            i_jedge_dict = _i_jedge_dict                 

    return np.array(fingerprints)

def create_adjacency(mol):
    adjacency = Chem.GetAdjacencyMatrix(mol)
    return adjacency

def dump_dictionary(dictionary, filename):
    with open(filename, 'wb') as file:
        pickle.dump(dict(dictionary), file)

def save_array(array, filename):
    with open(filename, 'wb') as file:
        pickle.dump(array, file)

if __name__ == "__main__":

    with open('./data/Kcat_combination_0918.json', 'r') as infile :
        Kcat_data = json.load(infile)
        
    compound_fingerprints = list()
    adjacencies = list()
    Kcats = list()

    for data in tqdm.tqdm(Kcat_data) :
        smiles = data['Smiles']
        Kcats.append(float(data['Value']))
        mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) # Add hydrogens
        atoms = create_atoms(mol) # Get atom features

        i_jbond_dict = create_ijbonddict(mol) # Get graph structure
        fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius) # Extract fingerprints
        compound_fingerprints.append(fingerprints)

        adjacency = create_adjacency(mol)
        adjacencies.append(adjacency)

    save_array(Kcats, './data/Kcats.pickle')
    save_array(compound_fingerprints, './data/compound_fingerprints.pickle')
    save_array(adjacencies, './data/adjacencies.pickle')
    dump_dictionary(atom_dict, './data/atom_dict.pickle')
    dump_dictionary(bond_dict, './data/bond_dict.pickle')
    dump_dictionary(fingerprint_dict, './data/fingerprint_dict.pickle')
    dump_dictionary(edge_dict, './data/edge_dict.pickle')

    print('compound_fingerprints, adjacencies, atom_dict, bond_dict, fingerprint_dict, edge_dict saved successfully!')

In [None]:
#@title preprocess protein
# ! python preprocess/proteinBERT_local_rep.py

from proteinbert import load_pretrained_model
import numpy as np
import pickle
import json

pretrained_model_generator, input_encoder = load_pretrained_model(local_model_dump_dir = "./preprocess" , local_model_dump_file_name = 'epoch_92400_sample_23500000.pkl')

local_representations = []

with open('./data/Kcat_combination_0918.json', 'r') as infile :
    Kcat_data = json.load(infile)

def dump_dictionary(dictionary, filename):
    with open(filename, 'wb') as file:
        pickle.dump(dict(dictionary), file)

def save_array(array, filename):
    with open(filename, 'wb') as file:
        pickle.dump(array, file)

sequences = []
max_len = 0
for i , data in enumerate(Kcat_data) :
    sequences.append(str(data['Sequence']))
    max_len = max(max_len , len(data['Sequence']))

max_len += 2

model = pretrained_model_generator.create_model(max_len)

step = 256
for i in range(0, len(sequences), step):
    print(i, '/' , len(sequences))
    sequences_ = sequences[i:i+step]
    input_ids = input_encoder.encode_X(sequences_, max_len)
    local_representations_, _ = model.predict(input_ids, batch_size=16)
    if len(local_representations) != 0:
        local_representations = np.concatenate((local_representations, local_representations_), axis=0)
    else:
        local_representations = local_representations_

print(local_representations.shape)
save_array(local_representations, './data/local_representations.pickle')
# save_array(local_representations, './data/local_representations.pickle')

print('saved successfully!')

In [None]:
#@title train model
import model.model as model
import torch
import random
import timeit
import json

if __name__ == "__main__":

    model_name = 'Kcat'

    args = {
        "dim" : 10,
        "layer_output" : 3,
        "layer_gnn" : 3,
        "layer_dnn" : 3,
        "lr" : 1e-3,
        "weight_decay": 1e-6,
        "epoch" : 100
    }

    file_model = './model/output/' + model_name
    file_MAEs  = './model/output/' + model_name + '-MAEs.csv'
    file_args  = './model/output/' + model_name + '-args.json'

    dir_input = './data/'
    compound_fingerprints = model.load_pickle(dir_input + 'compound_fingerprints.pickle')
    adjacencies = model.load_pickle(dir_input + 'adjacencies.pickle')
    proteins_local = model.load_pickle(dir_input + 'local_representations.pickle')
    # proteins_global = model.load_pickle(dir_input + 'global_representations.pickle')
    fingerprint_dict = model.load_pickle(dir_input + 'fingerprint_dict.pickle')
    args['len_fingerprint'] = len(fingerprint_dict)
    Kcat = model.load_pickle(dir_input + 'Kcats.pickle')
    Kcat = torch.LongTensor(Kcat)

    if not (len(compound_fingerprints) == len(adjacencies) == len(proteins_local) == len(Kcat)):
        print('The length of compound_fingerprints, adjacencies and proteins are not equal !!!')
        exit()

    dataset = list(zip(compound_fingerprints, adjacencies, proteins_local, Kcat))
    random.shuffle(dataset)
    dataset_train, dataset_ = model.split_dataset(dataset, 0.8)
    dataset_dev, dataset_test = model.split_dataset(dataset_, 0.5)

    """CPU or GPU."""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print('The code uses GPU !!!')
    else:
        device = torch.device('cpu')
        print('The code uses CPU !!!')

    torch.manual_seed(random.randint(1, 10000))
    Kcatpredictor = model.KcatPrediction(args, device).to(device)
    trainer = model.Trainer(Kcatpredictor)
    tester = model.Tester(Kcatpredictor)

    """Output files."""
    with open(file_args, 'w') as f:
        f.write(str(json.dumps(args)) + '\n')

    """Start training."""
    print('Training...')
    MAEs = []
    start = timeit.default_timer()
    for epoch in range(0, args["epoch"]):
        print('Epoch: %d / %d' % (epoch + 1, args["epoch"]))
        LOSS_train, RMSE_train, R2_train = trainer.train(dataset_train)
        LOSS_test, RMSE_test, R2_test = tester.test(dataset_dev)
        end = timeit.default_timer()
        time = end - start
        MAE = [epoch+1, time, LOSS_train, RMSE_train, R2_train, 
                            LOSS_test,  RMSE_test,  R2_test]
        MAEs.append(MAE)

    """Save the trained model."""
    torch.save(Kcatpredictor.state_dict(), file_model + ".pth")
    print('Model saved to %s' % file_model)

    """save MAEs as csv file"""
    with open(file_MAEs, 'w') as f:
        f.write('epoch, time, LOSS_train, RMSE_train, R2_train, LOSS_test, RMSE_test, R2_test\n')
        for MAE in MAEs:
            f.write(str(MAE)[1:-1] + '\n')
    print('MAEs saved to %s' % file_MAEs)