In [1]:
import os
import shutil
import argparse
from tqdm.auto import tqdm
import torch
from torch.nn.utils import clip_grad_norm_
# import torch_geometric
# assert not torch_geometric.__version__.startswith('2'), 'Please use torch_geometric lower than version 2.0.0'
from torch_geometric.loader import DataLoader

# from models.surfgen import SurfGen
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
from utils.train import *
# from utils.datasets.surfdata import SurfGenDataset

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='/home/haotian/Molecule_Generation/SurfGen_release/configs/train.yml')
parser.add_argument('--device', type=str, default='gpu')
parser.add_argument('--logdir', type=str, default='/home/haotian/Molecule_Generation/SurfGen_release/logs')
args = parser.parse_args([])

config = load_config(args.config)
config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
seed_all(config.train.seed)
log_dir = get_new_log_dir(args.logdir, prefix=config_name)
ckpt_dir = os.path.join(log_dir, 'checkpoints')
logger = get_logger('train', log_dir)
# logger.info(args)
# logger.info(config)

In [3]:
protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()
masking = get_mask(config.train.transform.mask)
composer = AtomComposer(protein_featurizer.feature_dim, ligand_featurizer.feature_dim, config.model.encoder.knn)
edge_sampler = EdgeSample(config.train.transform.edgesampler)
cfg_ctr = config.train.transform.contrastive
contrastive_sampler = ContrastiveSample(cfg_ctr.num_real, cfg_ctr.num_fake, cfg_ctr.pos_real_std, cfg_ctr.pos_fake_std, config.model.field.knn)
transform = Compose([
    RefineData(),
    LigandCountNeighbors(),
    protein_featurizer,
    ligand_featurizer,
    masking,
    composer,

    FocalBuilder(),
    edge_sampler,
    contrastive_sampler,
])

dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)

### dataloading

In [4]:
subsets['train']

<torch.utils.data.dataset.Subset at 0x7fdd55c887c0>

In [5]:
from utils.surface import geodesic_matrix, dst2knnedge, read_ply_geom
from torch_geometric.transforms import FaceToEdge

In [6]:
surf_base = '/home/haotian/Molecule_Generation/SurfGen/data/crossdock2020_surface_8'

In [7]:
data = dataset[0]

In [8]:
surf_file = osp.join(surf_base,data.protein_filename[:-6]+'_8.0_res_1.5.ply')
data_tmp = read_ply_geom(surf_file,read_face=True)
data_tmp = FaceToEdge()(data_tmp)
gds_mat = geodesic_matrix(data_tmp.pos,data_tmp.edge_index)

data.gds_mat = gds_mat
gds_knn_edge_index, gds_knn_edge_dist = dst2knnedge(data.gds_mat, num_knn=24)

In [9]:
def gds_edge_process(tri_edge_index,gds_knn_edge_index,num_nodes):
    id_tri_edge = tri_edge_index[0] * num_nodes + tri_edge_index[1]
    id_gds_knn_edge = gds_knn_edge_index[0] * num_nodes + gds_knn_edge_index[1]
    idx_edge = [torch.nonzero(id_gds_knn_edge == id_) for id_ in id_tri_edge]
    idx_edge = torch.tensor([a.squeeze() if len(a) > 0 else torch.tensor(-1) for a in idx_edge], dtype=torch.long)
    compose_gds_edge_type = torch.zeros(len(gds_knn_edge_index[0]), dtype=torch.long) 
    compose_gds_edge_type[idx_edge[idx_edge>=0]] = torch.ones_like(idx_edge[idx_edge>=0])
    gds_edge_sca = F.one_hot(compose_gds_edge_type)
    return gds_edge_sca

In [10]:
data

ProteinLigandData(
  protein_feature=[176, 4],
  protein_pos=[176, 3],
  ligand_element=[21],
  ligand_pos=[21, 3],
  ligand_bond_index=[2, 48],
  ligand_bond_type=[48],
  ligand_center_of_mass=[3],
  ligand_atom_feature=[21, 8],
  ligand_nbh_list={
    0=[3],
    1=[2],
    2=[3],
    3=[2],
    4=[2],
    5=[3],
    6=[2],
    7=[1],
    8=[3],
    9=[3],
    10=[3],
    11=[2],
    12=[2],
    13=[3],
    14=[2],
    15=[1],
    16=[2],
    17=[2],
    18=[2],
    19=[3],
    20=[2]
  },
  protein_filename='1B57_HUMAN_25_300_0/5u98_D_rec_5u98_1kx_lig_tt_min_0_pocket10.pdb',
  ligand_filename='1B57_HUMAN_25_300_0/5u98_D_rec_5u98_1kx_lig_tt_min_0.sdf',
  id=0,
  ligand_num_neighbors=[21],
  ligand_atom_valence=[21],
  ligand_atom_num_bonds=[21, 3],
  protein_surf_feature=[176, 5],
  ligand_atom_feature_full=[21, 13],
  context_idx=[8],
  masked_idx=[13],
  ligand_masked_element=[13],
  ligand_masked_pos=[13, 3],
  ligand_context_element=[8],
  ligand_context_feature_full=[8, 13],
  li

In [11]:
# basic -> besic+gds
num_nodes = data.protein_feature.shape[0]
dlny_edge_index = data_tmp.edge_index
gds_edge_sca = gds_edge_process(dlny_edge_index, gds_knn_edge_index, num_nodes=num_nodes)

data.dlny_edge_index = dlny_edge_index
data.gds_edge_sca = gds_edge_sca 
data.gds_knn_edge_index = gds_knn_edge_index
data.gds_dist = gds_knn_edge_dist

## embedding atoms

In [12]:
from models.embedding import  AtomEmbedding
protein_atom_feature_dim = protein_featurizer.feature_dim   # 5
ligand_atom_feature_dim = ligand_featurizer.feature_dim   # 13
emb_dim = [256, 64]#[config.hidden_channels, config.hidden_channels_vec]
protein_atom_emb = AtomEmbedding(protein_atom_feature_dim, 1, *emb_dim)
ligand_atom_emb = AtomEmbedding(ligand_atom_feature_dim, 1, *emb_dim)

In [13]:
from models.surfgen import embed_compose

compose_feature = data.compose_feature
compose_pos = data.compose_pos
idx_ligand = data.idx_ligand_ctx_in_compose
idx_protein = data.idx_protein_in_compose

h_compose = embed_compose(compose_feature, compose_pos, idx_ligand, idx_protein,
                                ligand_atom_emb, protein_atom_emb, emb_dim)
# h_compose[0].shape = torch.Size([182, 256])
# h_compose[1].shape = torch.Size([182, 64, 3])

In [14]:
class EdgeMapping(nn.Module):
    def __init__(self, edge_channels):
        super().__init__()
        self.nn = nn.Linear(in_features=1, out_features=edge_channels, bias=False)
    
    def forward(self, edge_vector):
        edge_vector = edge_vector / (torch.norm(edge_vector, p=2, dim=1, keepdim=True)+1e-7)
        expansion = self.nn(edge_vector.unsqueeze(-1)).transpose(1, -1)
        return expansion

In [15]:
from models.model_utils import  GaussianSmearing
from models.invariant import GVPerceptronVN, GVLinear, VNLinear
from math import  pi
from torch_scatter import scatter_sum
from models.interaction.geodesic import Geodesic_GNN
from models.interaction.geoattn import Geoattn_GNN

In [16]:
edge_index = data.gds_knn_edge_index
pos = data.protein_pos
edge_vector = pos[edge_index[0]] - pos[edge_index[1]]
edge_feature = data.gds_edge_sca
node_feats = [h_compose[0][data.idx_protein_in_compose], h_compose[1][data.idx_protein_in_compose]]
gds_dist = data.gds_dist

In [17]:
# detailed of representation learning 
layer1 = Geodesic_GNN()
layer2 = Geodesic_GNN()
out_sca, out_vec = layer1(node_feats, edge_feature, edge_vector, edge_index, gds_dist)
out_sca, out_vec = layer2([out_sca, out_vec], edge_feature, edge_vector, edge_index, gds_dist)

In [18]:
edge_dim = 64
num_edge_types= 2
cutoff = 10
node_sca_dim=256
node_vec_dim=64
hid_dim = 128
out_sca_dim = 256
out_vec_dim = 64

edge_expansion = EdgeMapping(edge_dim)
distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_dim - num_edge_types)

node_mapper = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim)
edge_mapper = GVLinear(edge_dim,edge_dim,node_sca_dim,node_vec_dim)

edge_sca_sca = nn.Linear(node_sca_dim, hid_dim)
node_sca_sca = nn.Linear(node_sca_dim, hid_dim)

edge_sca_vec = nn.Linear(node_sca_dim, hid_dim)
node_sca_vec = nn.Linear(node_sca_dim, hid_dim)
edge_vec_vec = VNLinear(node_vec_dim, hid_dim)
node_vec_vec = VNLinear(node_vec_dim, hid_dim)

msg_out = GVLinear(hid_dim, hid_dim, out_sca_dim, out_vec_dim)

resi_connecter = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim)
aggr_out = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim)

### Geodesic GNN is performed on the protein nodes --2 layers
### Geoattnn GNN is performed on the protein-lignad graphs -- 4layer

In [21]:
# edge_fetures expansion (from original)
edge_dist = torch.norm(edge_vector, dim=-1, p=2)
edge_sca_feat = torch.cat([distance_expansion(edge_dist), edge_feature], dim=-1)
edge_vec_feat = edge_expansion(edge_vector) 

In [22]:
## message passing
# first mapping (message)
edge_index_row = edge_index[0]
node_sca_feats, node_vec_feats = node_mapper(node_feats)
edge_sca_feat, edge_vec_feat = edge_mapper([edge_sca_feat, edge_vec_feat])
node_sca_feats, node_vec_feats = node_sca_feats[edge_index_row], node_vec_feats[edge_index_row]
# vec interacte with sca, edge interact with node
coeff = 0.5 * (torch.cos(edge_dist * pi / cutoff) + 1.0)
coeff = coeff * (edge_dist <= cutoff) * (edge_dist >= 0.0)
# compute the scalar message
msg_sca_emb = node_sca_sca(node_sca_feats) * edge_sca_sca(edge_sca_feat)
msg_sca_emb = msg_sca_emb * coeff.view(-1,1)

# compute the vector message
msg_vec_emb1 = node_vec_vec(node_vec_feats) * edge_sca_vec(edge_sca_feat).unsqueeze(-1)
msg_vec_emb2 = node_sca_vec(node_sca_feats).unsqueeze(-1) * edge_vec_vec(edge_vec_feat)
msg_vec_emb = msg_vec_emb1 + msg_vec_emb2
msg_vec_emb = msg_vec_emb * coeff.view(-1,1,1)

msg_sca_emb, msg_vec_emb = msg_out([msg_sca_emb, msg_vec_emb])

In [23]:
## aggregation
aggr_msg_sca = scatter_sum(msg_sca_emb, edge_index_row, dim=0, dim_size=num_nodes)
aggr_msg_vec = scatter_sum(msg_vec_emb, edge_index_row, dim=0, dim_size=num_nodes)
resi_sca, resi_vec = resi_connecter(node_feats)
out_sca = resi_sca + aggr_msg_sca
out_vec = resi_vec + aggr_msg_vec
out_sca, out_vec = aggr_out([out_sca, out_vec])

## Geoattn GNN

### prepare mapped features

In [24]:
from torch_scatter import scatter_softmax

In [25]:
edge_dim = 64
num_edge_types= 4
cutoff = 10
node_sca_dim=256
node_vec_dim=64
hid_dim = 128
out_sca_dim = 256
out_vec_dim = 64

edge_expansion = EdgeMapping(edge_dim)
distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_dim - num_edge_types)

node_mapper = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim)
edge_mapper = GVLinear(edge_dim,edge_dim,node_sca_dim,node_vec_dim)

edge_net = nn.Linear(node_sca_dim, hid_dim)
node_net = nn.Linear(node_sca_dim, hid_dim)

edge_sca_net = nn.Linear(node_sca_dim, hid_dim)
node_sca_net = nn.Linear(node_sca_dim, hid_dim)
edge_vec_net = VNLinear(node_vec_dim, hid_dim)
node_vec_net = VNLinear(node_vec_dim, hid_dim)


sca_attn_net = nn.Linear(node_sca_dim*2+1, hid_dim)
vec_attn_net = VNLinear(node_vec_dim, hid_dim)
softmax = scatter_softmax  
sigmoid = nn.Sigmoid()

msg_out = GVLinear(hid_dim, hid_dim, out_sca_dim, out_vec_dim)

resi_connecter = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim)
aggr_out = GVLinear(node_sca_dim,node_vec_dim,node_sca_dim,node_vec_dim)


In [26]:
h_ligpkt_sca = torch.cat([h_compose[0][data.idx_ligand_ctx_in_compose], out_sca], dim=0)
h_ligpkt_vec = torch.cat([h_compose[1][data.idx_ligand_ctx_in_compose], out_vec], dim=0)
node_feats = [h_ligpkt_sca,h_ligpkt_vec]

In [27]:
h_ligpkt_sca = torch.cat([h_compose[0][data.idx_ligand_ctx_in_compose], out_sca], dim=0)
h_ligpkt_vec = torch.cat([h_compose[1][data.idx_ligand_ctx_in_compose], out_vec], dim=0)
node_feats = [h_ligpkt_sca,h_ligpkt_vec]
pos = data.compose_pos
edge_index = data.compose_knn_edge_index
edge_feature = data.compose_knn_edge_feature
edge_vector = pos[edge_index[0]] - pos[edge_index[1]]

layer4 = Geoattn_GNN()
layer3 = Geoattn_GNN()
out_sca, out_evc = layer3(node_feats, edge_feature, edge_vector, edge_index)
out_sca, out_evc = layer4([out_sca, out_evc], edge_feature, edge_vector, edge_index)

In [28]:
# edge_fetures expansion (from original)
edge_dist = torch.norm(edge_vector, dim=-1, p=2)
edge_sca_feat = torch.cat([distance_expansion(edge_dist), edge_feature], dim=-1)
edge_vec_feat = edge_expansion(edge_vector) 

In [29]:
## message passing
# first mapping (message)
edge_index_row = edge_index[0]
node_sca_feats, node_vec_feats = node_mapper(node_feats)
edge_sca_feat, edge_vec_feat = edge_mapper([edge_sca_feat, edge_vec_feat])
node_sca_feats, node_vec_feats = node_sca_feats[edge_index_row], node_vec_feats[edge_index_row]

In [30]:
# compute the attention score \alpha_ij and A_ij
alpha_sca = torch.cat([node_sca_feats[edge_index[0]], node_sca_feats[edge_index[1]], edge_dist.unsqueeze(-1)], dim=-1)
alpha_sca = sca_attn_net(alpha_sca)
alpha_sca = softmax(alpha_sca,edge_index_row,dim=0)

alpha_vec_hid = vec_attn_net(node_vec_feats)
alpha_vec = (alpha_vec_hid[edge_index[0]] * alpha_vec_hid[edge_index[1]]).sum(-1).sum(-1)
alpha_vec = sigmoid(alpha_vec)

In [31]:
alpha_sca.shape

torch.Size([8832, 128])

In [34]:
# the scalar feats
node_sca_feat =  node_net(node_sca_feats)[edge_index_row] * edge_net(edge_sca_feat) 
# the equivariant interaction between node feature and edge feature
node_sca_hid = node_sca_net(node_sca_feats)[edge_index_row].unsqueeze(-1)
edge_vec_hid = edge_vec_net(edge_vec_feat)
node_vec_hid = node_vec_net(node_vec_feats)[edge_index_row]
edge_sca_hid =  edge_sca_net(edge_sca_feat).unsqueeze(-1)

In [35]:
coeff = 0.5 * (torch.cos(edge_dist * pi / cutoff) + 1.0)
coeff = coeff * (edge_dist <= cutoff) * (edge_dist >= 0.0)

In [36]:
node_sca_feat.shape

torch.Size([8832, 128])

In [37]:
alpha_sca.shape

torch.Size([8832, 128])

In [40]:
msg_sca = node_sca_feat * alpha_sca 
msg_vec = (node_sca_hid * edge_vec_hid + node_vec_hid*edge_sca_hid)*alpha_vec.unsqueeze(-1).unsqueeze(-1)

In [41]:
msg_sca,msg_vec = msg_out([msg_sca,msg_vec])

In [42]:

from models.interaction.geoattn import  Geoattn_GNN

In [45]:

# vec interacte with sca, edge interact with node
coeff = 0.5 * (torch.cos(edge_dist * pi / cutoff) + 1.0)
coeff = coeff * (edge_dist <= cutoff) * (edge_dist >= 0.0)
# compute the scalar message
msg_sca_emb = node_sca_sca(node_sca_feats) * edge_sca_sca(edge_sca_feat)
msg_sca_emb = msg_sca_emb * coeff.view(-1,1)

# compute the vector message
msg_vec_emb1 = node_vec_vec(node_vec_feats) * edge_sca_vec(edge_sca_feat).unsqueeze(-1)
msg_vec_emb2 = node_sca_vec(node_sca_feats).unsqueeze(-1) * edge_vec_vec(edge_vec_feat)
msg_vec_emb = msg_vec_emb1 + msg_vec_emb2
msg_vec_emb = msg_vec_emb * coeff.view(-1,1,1)

msg_sca_emb, msg_vec_emb = msg_out([msg_sca_emb, msg_vec_emb])

In [46]:
## aggregation
aggr_msg_sca = scatter_sum(msg_sca, edge_index_row, dim=0, dim_size=num_nodes)
aggr_msg_vec = scatter_sum(msg_vec, edge_index_row, dim=0, dim_size=num_nodes)
resi_sca, resi_vec = resi_connecter(node_feats)
out_sca = resi_sca + aggr_msg_sca
out_vec = resi_vec + aggr_msg_vec
out_sca, out_vec = aggr_out([out_sca, out_vec])

RuntimeError: index 176 is out of bounds for dimension 0 with size 176

## Interaction Module

In [29]:
class InteractionModule(nn.Module):
    def __init__(self, node_sca_dim=256, node_vec_dim=64, edge_dim=64,hid_dim=128,num_geodesic=2, \
        num_geoattn=4, k=24, cutoff=10.):

        super().__init__()

        self.node_sca_dim = node_sca_dim
        self.node_vec_dim = node_vec_dim
        self.edge_dim = edge_dim 
        self.hid_dim = hid_dim 
        self.num_geodesic = num_geodesic
        self.num_geoattn = num_geoattn
        self.k = k
        self.cutoff = cutoff

        self.interactions = ModuleList()
        for _ in range(num_geodesic):
            block = Geodesic_GNN(
                node_sca_dim=node_sca_dim,
                node_vec_dim=node_vec_dim,
                hid_dim = hid_dim,
                edge_dim = edge_dim,
                num_edge_types=2, 
                out_sca_dim=node_sca_dim,
                out_vec_dim=node_vec_dim,
                cutoff=cutoff
            )
            self.interactions.append(block)

        for _ in range(num_geoattn):
            block = Geoattn_GNN(
                node_sca_dim=node_sca_dim,
                node_vec_dim=node_vec_dim,
                hid_dim = hid_dim,
                edge_dim = edge_dim,
                num_edge_types=4, 
                out_sca_dim=node_sca_dim,
                out_vec_dim=node_vec_dim,
                cutoff=cutoff
            )
            self.interactions.append(block)

    @property
    def out_sca(self):
        return self.hidden_channels[0]
    
    @property
    def out_vec(self):
        return self.hidden_channels[1]

    def forward(self, node_attr, pos, idx_ligand, idx_surface, gds_edge_index, gds_edge_feature, gds_dis, geom_edge_index, geom_edge_feature):
        
        h_surface_sca = node_attr[0][idx_surface]
        h_surface_vec = node_attr[1][idx_surface]
        gds_edge_vec = pos[idx_protein][gds_knn_edge_index[0]]-pos[idx_protein][gds_knn_edge_index[1]]

        for geodesic_block in self.interactions[:self.num_geodesic]:
            delta_h = geodesic_block([h_surface_sca,h_surface_vec], gds_edge_feature, gds_edge_vec, gds_edge_index, gds_dis)
            h_surface_sca = h_surface_sca + delta_h[0]
            h_surface_vec = h_surface_vec + delta_h[1]

        h_ligpkt_sca = torch.cat([node_attr[0][data.idx_ligand], h_surface_sca], dim=0)
        h_ligpkt_vec = torch.cat([node_attr[1][data.idx_ligand], h_surface_vec], dim=0)
        geom_edge_vec = pos[geom_edge_index[0]] - pos[geom_edge_index[1]]

        for geoattn_block in self.interactions[self.num_geoattn:]:
            delta_h = geoattn_block([h_ligpkt_sca,h_ligpkt_vec], geom_edge_feature, geom_edge_vec, geom_edge_index)
            h_ligpkt_sca = h_ligpkt_sca + delta_h[0]
            h_ligpkt_vec = h_ligpkt_vec + delta_h[1]

        return [h_ligpkt_sca, h_ligpkt_vec]

In [None]:
h_ligpkt_sca = torch.cat([h_compose[0][data.idx_ligand_ctx_in_compose], out_sca], dim=0)
h_ligpkt_vec = torch.cat([h_compose[1][data.idx_ligand_ctx_in_compose], out_vec], dim=0)
node_feats = [h_ligpkt_sca,h_ligpkt_vec]
pos = data.compose_pos
edge_index = data.compose_knn_edge_index
edge_feature = data.compose_knn_edge_feature
edge_vector = pos[edge_index[0]] - pos[edge_index[1]]

layer4 = Geoattn_GNN()
layer3 = Geoattn_GNN()
out_sca, out_evc = layer3(node_feats, edge_feature, edge_vector, edge_index)
out_sca, out_evc = layer4([out_sca, out_evc], edge_feature, edge_vector, edge_index)

In [None]:
out_sca, out_vec = layer1(node_feats, edge_feature, edge_vector, edge_index, gds_dist)
out_sca, out_vec = layer2([out_sca, out_vec], edge_feature, edge_vector, edge_index, gds_dist)