In [75]:
import os
import shutil
import argparse
from tqdm.auto import tqdm
import torch
from torch.nn.utils import clip_grad_norm_
# import torch_geometric
# assert not torch_geometric.__version__.startswith('2'), 'Please use torch_geometric lower than version 2.0.0'
from torch_geometric.loader import DataLoader

from models.surfgen import SurfGen
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
from utils.train import *
from utils.datasets.surfdata import SurfGenDataset
from time import time
from utils.train import get_model_loss

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='/home/haotian/molecules_confs/Protein_test/Pocket2Mol-main/configs/train.yml')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--logdir', type=str, default='/home/haotian/molecules_confs/Protein_test/Pocket2Mol-main/logs')
args = parser.parse_args([])
base_path = '/home/haotian/molecules_confs/Protein_test/SurfGen'
args.config = os.path.join(base_path, 'configs/train_surf.yml')
args.logdir = os.path.join(base_path, 'logs')
config = load_config(args.config)
config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
seed_all(config.train.seed)
config.dataset.path = os.path.join(base_path, 'data/crossdocked_pocket10')
config.dataset.split = os.path.join(base_path, 'data/split_by_name.pt')
log_dir = get_new_log_dir(args.logdir, prefix=config_name)
ckpt_dir = os.path.join(log_dir, 'checkpoints')

In [76]:
protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()
masking = get_mask(config.train.transform.mask)
composer = AtomComposer(protein_featurizer.feature_dim, ligand_featurizer.feature_dim, config.model.encoder.knn)
edge_sampler = EdgeSample(config.train.transform.edgesampler)
cfg_ctr = config.train.transform.contrastive
contrastive_sampler = ContrastiveSample(cfg_ctr.num_real, cfg_ctr.num_fake, cfg_ctr.pos_real_std, cfg_ctr.pos_fake_std, config.model.field.knn)
transform = Compose([
    RefineData(),
    LigandCountNeighbors(),
    protein_featurizer,
    ligand_featurizer,
    masking,
    composer,

    FocalBuilder(),
    edge_sampler,
    contrastive_sampler,
])

dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)
dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)

In [77]:
dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)
train_set, val_set = subsets['train'], subsets['test']
follow_batch = []
collate_exclude_keys = ['ligand_nbh_list']
train_iterator = inf_iterator(DataLoader(
    train_set, 
    batch_size = config.train.batch_size, 
    shuffle = True,
    num_workers = config.train.num_workers,
    pin_memory = config.train.pin_memory,
    follow_batch = follow_batch,
    exclude_keys = collate_exclude_keys,
))
val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False, follow_batch=follow_batch, exclude_keys = collate_exclude_keys,)
train_loader = DataLoader(train_set, config.train.batch_size, shuffle=False,  exclude_keys = collate_exclude_keys)

In [78]:
data = val_set[0]

In [79]:
import torch.nn.functional as F
from models.invariant import VNLinear,GVPerceptronVN

In [95]:
pos = data.ligand_pos
nbh_list = data.ligand_nbh_list 
node_feat = data.ligand_atom_feature_full.float()
edge_index = data.ligand_bond_index
edge_type = data.ligand_bond_type
edge_feat = F.one_hot(edge_type-1, num_classes=3).float()
edge_vec = (pos[edge_index[0]] - pos[edge_index[1]]).unsqueeze(1)
num_atoms = pos.shape[0]
max_dim = len(max(list(nbh_list.values()),key=len))
local_coords = 0.1 * torch.ones([num_atoms,max_dim, 3])
masker = torch.zeros([num_atoms,max_dim,3])
for atom_idx in range(num_atoms):
    local_coord = pos[atom_idx] - pos[nbh_list[atom_idx]]
    local_coords[atom_idx,:local_coord.shape[0],:] = local_coord
    masker[atom_idx,:local_coord.shape[0],:] = torch.ones([local_coord.shape[0],3])

In [104]:
mol = Chem.MolFromMol2File('2z3h_A_rec_1wn6_bst_lig_tt_docked_3.mol2')

In [139]:
from utils.chem import fragmentize_mol, remove_dummys_mol
from torch_scatter import scatter_add

In [106]:
fragmentations = fragmentize_mol(mol)
fragmentation = fragmentations[0]

In [125]:
combine_mols = remove_dummys_mol(fragmentation[1])[0]
frags = Chem.GetMolFrags(combine_mols,asMols=True)

In [135]:
frag = remove_dummys_mol(frags[1])[0]

In [177]:
motif_idx1 = torch.tensor(mol.GetSubstructMatch(frags[0]))
motif_idx2 = torch.tensor(mol.GetSubstructMatch(frag))

In [99]:
input_node_vec_dim = max_dim
input_node_sca_dim = 13
input_edge_vec_dim = 1
input_edge_sca_dim = 3
out_dim = 16

node_vec_net = VNLinear(input_node_vec_dim,out_dim)
node_sca_net = nn.Linear(input_node_sca_dim, out_dim)
edge_vec_net = VNLinear(input_edge_vec_dim, out_dim)
edge_sca_net = nn.Linear(input_edge_sca_dim, out_dim)
node_net = nn.Linear(input_node_sca_dim, out_dim)
edge_net = nn.Linear(input_edge_sca_dim, out_dim)
edge_mapper = GVPerceptronVN(out_dim,out_dim,out_dim,out_dim)
node_mapper = GVPerceptronVN(out_dim,out_dim,out_dim,out_dim)

In [98]:
edge_raw = edge_index[0]
node_vec_hid = node_vec_net(local_coords) 
node_sca_hid = node_sca_net(node_feat)
edge_vec_hid = edge_vec_net(edge_vec)
edge_sca_hid = edge_sca_net(edge_feat)

msg_sca_j = node_sca_hid[edge_raw]


In [147]:
motif_idx

tensor([13, 10, 14, 11, 12,  6,  7,  8,  9])

In [148]:
node_feat.shape

torch.Size([31, 13])

In [163]:
type(motif_idx[0])

torch.Tensor

In [161]:
type(torch.ones_like(motif_idx)[0])

torch.Tensor

In [176]:
### input motif index

In [178]:
motifs = [motif_idx1,motif_idx2]
motif_index = torch.zeros(num_atoms).long()
for idx,motif in enumerate(motifs):
    motif_index[motif] = torch.ones_like(motif) * idx

In [180]:
### motif pooling techniques
scatter_add(node_feat, motif_index, dim=0)

tensor([[13.,  5.,  4.,  0.,  0.,  0.,  0., 22., 45., 51., 39.,  6.,  0.],
        [ 4.,  3.,  2.,  0.,  0.,  0.,  0.,  9., 19., 23., 15.,  4.,  0.]])