In [2]:
import os
import shutil
import argparse
from tqdm.auto import tqdm
import torch
from torch.nn.utils import clip_grad_norm_
# import torch_geometric
# assert not torch_geometric.__version__.startswith('2'), 'Please use torch_geometric lower than version 2.0.0'
from torch_geometric.loader import DataLoader

from models.surfgen import SurfGen
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
from utils.train import *
from utils.datasets.surfdata import SurfGenDataset
from time import time
from utils.train import get_model_loss


In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='/home/haotian/molecules_confs/Protein_test/Pocket2Mol-main/configs/train.yml')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--logdir', type=str, default='/home/haotian/molecules_confs/Protein_test/Pocket2Mol-main/logs')
args = parser.parse_args([])
base_path = '/home/haotian/Molecule_Generation/SurfGen'
args.config = os.path.join(base_path, 'configs/train_surf.yml')
args.logdir = os.path.join(base_path, 'logs')
config = load_config(args.config)
config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
seed_all(config.train.seed)
config.dataset.path = os.path.join(base_path, 'data/crossdocked_pocket10')
config.dataset.split = os.path.join(base_path, 'data/split_by_name.pt')
log_dir = get_new_log_dir(args.logdir, prefix=config_name)
ckpt_dir = os.path.join(log_dir, 'checkpoints')
os.makedirs(ckpt_dir, exist_ok=True)
logger = get_logger('train', log_dir)
logger.info(args)
logger.info(config)

[2023-03-15 08:13:32,270::train::INFO] Namespace(config='/home/haotian/Molecule_Generation/SurfGen/configs/train_surf.yml', device='cpu', logdir='/home/haotian/Molecule_Generation/SurfGen/logs')
[2023-03-15 08:13:32,270::train::INFO] {'model': {'vn': 'vn', 'hidden_channels': 256, 'hidden_channels_vec': 64, 'encoder': {'name': 'cftfm', 'hidden_channels': 256, 'hidden_channels_vec': 64, 'edge_channels': 64, 'key_channels': 128, 'num_heads': 4, 'num_interactions': 6, 'cutoff': 10.0, 'knn': 24}, 'field': {'name': 'classifier', 'num_filters': 128, 'num_filters_vec': 32, 'edge_channels': 64, 'num_heads': 4, 'cutoff': 10.0, 'knn': 24}, 'position': {'num_filters': 128, 'n_component': 3}}, 'train': {'save': True, 'seed': 2021, 'use_apex': False, 'batch_size': 16, 'num_workers': 16, 'pin_memory': True, 'max_iters': 500000, 'val_freq': 5000, 'pos_noise_std': 0.1, 'max_grad_norm': 100.0, 'resume_train': False, 'ckpt_name': 'None', 'start_epoch': 'None', 'optimizer': {'type': 'adam', 'lr': 0.0002, 

In [4]:
protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()
masking = get_mask(config.train.transform.mask)
composer = AtomComposer(protein_featurizer.feature_dim, ligand_featurizer.feature_dim, config.model.encoder.knn)


In [5]:
edge_sampler = EdgeSample(config.train.transform.edgesampler)
cfg_ctr = config.train.transform.contrastive
contrastive_sampler = ContrastiveSample(cfg_ctr.num_real, cfg_ctr.num_fake, cfg_ctr.pos_real_std, cfg_ctr.pos_fake_std, config.model.field.knn)
transform = Compose([
    RefineData(),
    LigandCountNeighbors(),
    protein_featurizer,
    ligand_featurizer,
    masking,
    composer,

    FocalBuilder(),
    edge_sampler,
    contrastive_sampler,
])

dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)
dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)

In [12]:
import networkx as nx
from scipy.spatial import distance_matrix
from utils.protein_ligand import parse_sdf_file
from utils.surface import geodesic_matrix, read_ply_geom, dst2knn_graph
from torch_geometric.transforms import FaceToEdge
from torch_geometric.utils import geodesic_distance
from torch_geometric.data import Data

In [172]:
sdf_file = './covid_19/3cl_ligand.sdf'
surf_file = './covid_19/3cl_pocket_8.0_res_1.5.ply'
data = read_ply_geom(surf_file,read_face=True)
data = FaceToEdge()(data)
gds_mat = geodesic_matrix(data.pos,data.edge_index)
data.gds_mat = gds_mat

In [212]:
def dst2knnedge(dst_mat, num_knn=24, self_loop=False):
    knn_edge_index_src = []
    knn_edge_index_tgt = []
    knn_edge_dist = []
    num_nodes = dst_mat.shape[0]
    for node_idx in range(num_nodes):
        knn_edge_index_src.extend([node_idx]*num_knn)
        
        if self_loop:
            knn_edge_index_tgt.extend(np.argsort(dst_mat[node_idx])[:num_knn])
            knn_edge_dist.extend(np.sort(dst_mat[node_idx])[:num_knn])
        else:
            knn_edge_index_tgt.extend(np.argsort(dst_mat[node_idx])[1:num_knn+1])
            knn_edge_dist.extend(np.sort(dst_mat[node_idx])[1:num_knn+1])

    return torch.tensor(np.array([knn_edge_index_src,knn_edge_index_tgt])), torch.tensor(np.array(knn_edge_dist,dtype=np.float32))

In [213]:
gds_knn_edge_index, gds_knn_edge_dist = dst2knn_graph(data.gds_mat, num_knn=12)

In [214]:
tri_edge_index = data.edge_index

In [253]:
num_nodes = data.x.shape[0]

In [268]:
def gds_edge_process(tri_edge_index,gds_knn_edge_index,num_nodes):
    id_tri_edge = tri_edge_index[0] * num_nodes + tri_edge_index[1]
    id_gds_knn_edge = gds_knn_edge_index[0] * num_nodes + gds_knn_edge_index[1]
    idx_edge = [torch.nonzero(id_gds_knn_edge == id_) for id_ in id_tri_edge]
    idx_edge = torch.tensor([a.squeeze() if len(a) > 0 else torch.tensor(-1) for a in idx_edge], dtype=torch.long)
    compose_gds_edge_type = torch.zeros(len(gds_knn_edge_index[0]), dtype=torch.long) 
    compose_gds_edge_type[idx_edge[idx_edge>=0]] = torch.ones_like(idx_edge[idx_edge>=0])
    gds_edge_sca = F.one_hot(compose_gds_edge_type)
    return gds_edge_sca

In [269]:
gds_edge_sca = gds_edge_process(tri_edge_index, gds_knn_edge_index, num_nodes=num_nodes)

In [270]:
gds_edge_sca

tensor([[0, 1],
        [0, 1],
        [0, 1],
        ...,
        [1, 0],
        [1, 0],
        [1, 0]])

In [133]:
import torch.nn as nn

In [275]:
from models.common import GaussianSmearing

In [304]:
edge_channels = 6
edge_hidden = 16
num_edge_types = 2
cutoff = 10.
distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels - num_edge_types)

In [284]:
gds_edge_sca = gds_edge_sca.float()

In [309]:
geo_edge_sca = torch.cat([distance_expansion(gds_knn_edge_dist), gds_edge_sca], dim=-1)

In [310]:
geo_edge_sca.shape

torch.Size([2388, 6])

In [315]:
edge_SL = nn.Linear(in_features=edge_channels,out_features=edge_hidden)
edge_sca_feat = edge_SL(geo_edge_sca)

### vector mapping

In [329]:
edge_vec = (data.pos[gds_knn_edge_index[0]] -data.pos[gds_knn_edge_index[1]]).T

In [403]:
from models.invariant import VNLeakyReLU, VNLinear, GVPerceptronVN

In [327]:
VN_node = VNLinear(in_channels=4, out_channels=8)
VN_edge = VNLinear(in_channels=3, out_channels=6)

In [331]:
VN_edge(edge_vec)

torch.Size([6, 2388])

In [332]:
from torch.nn import Module

In [None]:
# 初始node_sca dim=4, edge_sca_dim=6
# node_vec_dim = 3, edge_vec_dim = 3 
class Geodesic_Message(Module):
    def __init__(self, node_sca_dim, node_vec_dim, edge_sca_dim, edge_vec_dim, out_sca, out_vec, cutoff=10.):
        super().__init__()
        edge_sca_sca = nn.Linear(edge_sca_dim, out_sca)
        
        edge_sca_vec = nn.Linear(edge_sca_dim, out_sca)
        node_sca_vec = nn.Linear(node_sca_dim, out_sca)
        edge_vec_vec = VNLinear(edge_vec_dim, out_vec)
        node_vec_vec = VNLinear(node_vec_dim, out_vec)
        
        encoder = GVPerceptronVN()
    
    def forward(node_feats, edge_feats, edge_index):
        

In [390]:
node_sca_dim = 4
node_vec_dim = 3
edge_sca_dim = 6
edge_vec_dim = 3
out_sca = 16
out_vec = 16
node_dim = 4

edge_sca_sca = nn.Linear(edge_sca_dim, node_dim)

edge_sca_vec = nn.Linear(edge_sca_dim, out_sca)
node_sca_vec = nn.Linear(node_sca_dim, out_sca)
edge_vec_vec = VNLinear(edge_vec_dim, out_vec)
node_vec_vec = VNLinear(node_vec_dim, out_vec)

In [422]:
encoder = GVPerceptronVN(node_dim,out_vec,out_sca,out_sca)

In [385]:
node_feats = (data.x, data.pos.reshape(-1,3,1)) 
edge_feats = (geo_edge_sca, edge_vec.reshape(-1,3,1))

In [340]:
edge_index = gds_knn_edge_index

In [356]:
edge_index_raw = edge_index[0]

In [414]:
msg_sca_emb = node_feats[0][edge_index_raw] * edge_sca_sca(edge_feats[0])

In [396]:
edge_sca_vec(edge_feats[0]).unsqueeze(-1).shape

torch.Size([2388, 16, 1])

In [415]:
msg_vec_emb1 = node_vec_vec(node_feats[1])[edge_index_raw] * edge_sca_vec(edge_feats[0]).unsqueeze(-1)
msg_vec_emb2 = node_sca_vec(node_feats[0])[edge_index_raw].unsqueeze(-1) * edge_vec_vec(edge_feats[1])
msg_vec_emb = msg_vec_emb1 + msg_vec_emb2

In [406]:
from torch_scatter import scatter_sum

In [416]:
num_nodes = node_feats[0].shape[0]
aggr_msg_sca = scatter_sum(msg_sca_emb, edge_index_raw, dim=0, dim_size=num_nodes)

In [417]:
aggr_msg_vec = scatter_sum(msg_vec_emb, edge_index_raw, dim=0, dim_size=num_nodes)

In [420]:
aggr_msg_sca.shape

torch.Size([199, 4])

In [419]:
aggr_msg_vec.shape

torch.Size([199, 16, 1])

In [424]:
encoder((aggr_msg_sca,aggr_msg_vec))[0]

tensor([[26.2894, -0.3607, 55.4214,  ..., 10.8450, 29.5320, 44.2554],
        [29.4692, -0.4008, 47.9377,  ...,  8.7002, 22.9243, 34.3272],
        [23.5407, -0.4631, 45.5106,  ..., 16.0130, 26.1938, 48.2630],
        ...,
        [26.4618, -0.3660, 59.3455,  ...,  6.3887, 29.1326, 34.0420],
        [28.3861, -0.3764, 59.3625,  ...,  7.4963, 30.8525, 35.3171],
        [26.9196, -0.3668, 59.2715,  ...,  6.5670, 29.5021, 34.1910]],
       grad_fn=<LeakyReluBackward0>)

In [409]:
aggr_msg_sca.shape

torch.Size([199, 16, 1])

In [410]:
aggr_msg_vec.shape

torch.Size([199, 16, 1])

In [387]:
node_vec_vec(node_feats[1])[edge_index_raw].shape

torch.Size([2388, 16, 1])

In [389]:
edge_sca_vec(edge_feats[0]).shape

torch.Size([2388, 32])

In [375]:
msg_sca_emb.shape

torch.Size([2388, 4])

In [376]:
node_feats[1].shape

torch.Size([199, 1, 3])

In [377]:
node_vec_vec

VNLinear(
  (map_to_feat): Linear(in_features=3, out_features=16, bias=True)
)

In [348]:
edge_sca_sca(edge_feats[0]).shape

torch.Size([2388, 32])

In [350]:
node_feats[0][edge_index[0]].shape

torch.Size([2388, 4])

In [352]:
node_feats[0][edge_index[0]]  edge_sca_sca(edge_feats[0])

RuntimeError: The size of tensor a (4) must match the size of tensor b (32) at non-singleton dimension 1

In [None]:
for batch in val_loader:
    break

In [None]:
surf_feature = batch.protein_surf_feature
surf_pos = batch.protein_pos

In [None]:
model = SurfGen(
    config.model, 
    num_classes = contrastive_sampler.num_elements,
    num_bond_types = edge_sampler.num_bond_types,
    protein_atom_feature_dim = protein_featurizer.feature_dim,
    ligand_atom_feature_dim = ligand_featurizer.feature_dim,
).to(args.device)
print('Num of parameters is {0:.4}M'.format(np.sum([p.numel() for p in model.parameters()]) /100000 ))


Num of parameters is 37.06M


In [None]:
def embed_compose(compose_feature, compose_pos, idx_ligand, idx_protein,
                                      ligand_atom_emb, protein_atom_emb,
                                      emb_dim):

    h_ligand = ligand_atom_emb(compose_feature[idx_ligand], compose_pos[idx_ligand])
    h_protein = protein_atom_emb(compose_feature[idx_protein], compose_pos[idx_protein])
    
    h_sca = torch.zeros([len(compose_pos), emb_dim[0]],).to(h_ligand[0])
    h_vec = torch.zeros([len(compose_pos), emb_dim[1], 3],).to(h_ligand[1])
    h_sca[idx_ligand], h_sca[idx_protein] = h_ligand[0], h_protein[0]
    h_vec[idx_ligand], h_vec[idx_protein] = h_ligand[1], h_protein[1]
    return [h_sca, h_vec]



In [None]:
compose_feature = batch.compose_feature.float()
compose_pos = batch.compose_pos
idx_ligand = batch.idx_ligand_ctx_in_compose
idx_protein = batch.idx_protein_in_compose

h_compose = embed_compose(compose_feature, compose_pos, idx_ligand, idx_protein,
                                model.ligand_atom_emb, model.protein_atom_emb, model.emb_dim)

In [None]:
compose_knn_edge_index = batch.compose_knn_edge_index
compose_knn_edge_feature = batch.compose_knn_edge_feature

In [None]:
compose_knn_edge_index

tensor([[   0,    0,    0,  ..., 3923, 3923, 3923],
        [   2,  240,    1,  ..., 3800, 3785, 3872]])

In [None]:
data = train_set[100]

In [None]:
pos = data.compose_pos
dis_mat = distance_matrix(pos, pos)
src, tgt = dst2knn_graph(dis_mat)

In [None]:
torch.tensor(tgt) == data.compose_knn_edge_index[1]

tensor([True, True, True,  ..., True, True, True])

In [None]:
data.compose_knn_edge_index

tensor([[  0,   0,   0,  ..., 188, 188, 188],
        [  6,  56,  10,  ..., 157,  54,   8]])

In [None]:
data.compose_knn_edge_feature

tensor([[0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 1, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]])

In [None]:
compose_knn_edge_feature = torch.cat([
    torch.ones([len(data.compose_knn_edge_index[0]), 1], dtype=torch.long),
    torch.zeros([len(data.compose_knn_edge_index[0]), 3], dtype=torch.long),
], dim=-1) 

In [None]:
compose_knn_edge_feature

tensor([[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]])

In [None]:
knn=24
len_ligand_ctx = len(data.ligand_context_pos)
len_compose = len_ligand_ctx + len(data.protein_pos)

In [None]:
len_ligand_ctx

14

In [None]:
len_compose

188

In [None]:
data.compose_knn_edge_index.shape

torch.Size([2, 4512])

In [None]:
id_ligand_ctx_edge = data.ligand_context_bond_index[0] * len_compose + data.ligand_context_bond_index[1]

In [None]:
id_ligand_ctx_edge

tensor([ 564,    3,    2,    1,  376,  382, 1130, 1136, 1510, 1509,  948,  949,
         941, 1697, 1702, 1703, 1889, 1887, 1326, 1320,  759,  753,  188,  193,
         192, 2077, 2081, 2080, 2455, 2267])

In [None]:
data.compose_knn_edge_index[0, :len_ligand_ctx*knn] * len_compose + data.compose_knn_edge_index[1, :len_ligand_ctx*knn]
#前半段是ligand node(knn个)

tensor([   3,    2,    1,    6,    4,    5,   88,   74,   67,  182,    8,  150,
         163,  128,   81,  118,  101,  160,  176,   91,  183,  159,    7,  153,
         192,  193,  188,  191,  195,  190,  196,  316,  197,  194,  198,  347,
         306,  341,  262,  365,  331,  338,  255,  348,  364,  269,  351,  251,
         376,  382,  477,  379,  464,  384,  467,  377,  443,  403,  381,  552,
         558,  397,  531,  559,  504,  450,  526,  539,  380,  495,  431,  457,
         564,  727,  746,  645,  566,  565,  714,  652,  724,  568,  638,  682,
         631,  570,  569,  717,  655,  692,  665,  572,  747,  723,  571,  707,
         753,  759,  905,  895,  762,  757,  752,  911,  878,  870,  755,  761,
         912,  929,  880,  790,  826,  833,  902,  754,  904,  760,  915,  805,
         948,  941,  949, 1068,  946,  950,  944,  940,  951,  942,  947,  995,
         975,  952, 1116,  953, 1117, 1099, 1067, 1003,  943, 1094,  977, 1014,
        1130, 1136, 1149, 1155, 1304, 11

In [None]:
id_compose_edge = data.compose_knn_edge_index[0, :len_ligand_ctx*knn] * len_compose + data.compose_knn_edge_index[1, :len_ligand_ctx*knn]
id_ligand_ctx_edge = data.ligand_context_bond_index[0] * len_compose + data.ligand_context_bond_index[1]

In [None]:
data.ligand_context_bond_index[0]

tensor([ 3,  0,  0,  0,  2,  2,  6,  6,  8,  8,  5,  5,  5,  9,  9,  9, 10, 10,
         7,  7,  4,  4,  1,  1,  1, 11, 11, 11, 13, 12])

In [None]:
id_compose_edge = data.compose_knn_edge_index[0, :len_ligand_ctx*knn] * len_compose + data.compose_knn_edge_index[1, :len_ligand_ctx*knn]
id_ligand_ctx_edge = data.ligand_context_bond_index[0] * len_compose + data.ligand_context_bond_index[1]
idx_edge = [torch.nonzero(id_compose_edge == id_) for id_ in id_ligand_ctx_edge]

In [None]:
torch.nonzero(id_ligand_ctx_edge[0] == id_compose_edge)

tensor([[72]])

In [None]:
idx_edge = torch.tensor([a.squeeze() if len(a) > 0 else torch.tensor(-1) for a in idx_edge], dtype=torch.long)

In [None]:
idx_edge

tensor([ 72,   0,   1,   2,  48,  49, 144, 145, 193, 192, 120, 122, 121, 217,
        216, 218, 241, 240, 169, 168,  97,  96,  26,  25,  24, 266, 265, 264,
        312, 288])

In [None]:
data.ligand_context_bond_type

tensor([2, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1, 2,
        1, 1, 2, 2, 2, 2])

In [None]:
data.compose_knn_edge_feature[idx_edge[idx_edge>=0]] = F.one_hot(data.ligand_context_bond_type[idx_edge>=0], num_classes=4) 

In [None]:
data.compose_knn_edge_feature

tensor([[0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 1, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]])