In [3]:
import os
import shutil
import argparse
from tqdm.auto import tqdm
import torch
from torch.nn.utils import clip_grad_norm_
# import torch_geometric
# assert not torch_geometric.__version__.startswith('2'), 'Please use torch_geometric lower than version 2.0.0'
from torch_geometric.loader import DataLoader

from models.surfgen import SurfGen
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
from utils.train import *
from utils.datasets.surfdata import SurfGenDataset
from time import time
from utils.train import get_model_loss

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='/home/haotian/molecules_confs/Protein_test/Pocket2Mol-main/configs/train.yml')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--logdir', type=str, default='/home/haotian/molecules_confs/Protein_test/Pocket2Mol-main/logs')
args = parser.parse_args([])
base_path = '/home/haotian/Molecule_Generation/SurfGen'
args.config = os.path.join(base_path, 'configs/train_surf.yml')
args.logdir = os.path.join(base_path, 'logs')
config = load_config(args.config)
config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
seed_all(config.train.seed)
config.dataset.path = os.path.join(base_path, 'data/crossdocked_pocket10')
config.dataset.split = os.path.join(base_path, 'data/split_by_name.pt')
log_dir = get_new_log_dir(args.logdir, prefix=config_name)
ckpt_dir = os.path.join(log_dir, 'checkpoints')

In [4]:
protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()
masking = get_mask(config.train.transform.mask)
composer = AtomComposer(protein_featurizer.feature_dim, ligand_featurizer.feature_dim, config.model.encoder.knn)
edge_sampler = EdgeSample(config.train.transform.edgesampler)
cfg_ctr = config.train.transform.contrastive
contrastive_sampler = ContrastiveSample(cfg_ctr.num_real, cfg_ctr.num_fake, cfg_ctr.pos_real_std, cfg_ctr.pos_fake_std, config.model.field.knn)
transform = Compose([
    RefineData(),
    LigandCountNeighbors(),
    protein_featurizer,
    ligand_featurizer,
    masking,
    composer,

    FocalBuilder(),
    edge_sampler,
    contrastive_sampler,
])

dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)
dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)

In [5]:
dataset, subsets = get_dataset(
    config = config.dataset,
    transform = transform,
)
train_set, val_set = subsets['train'], subsets['test']
follow_batch = []
collate_exclude_keys = ['ligand_nbh_list']
train_iterator = inf_iterator(DataLoader(
    train_set, 
    batch_size = config.train.batch_size, 
    shuffle = True,
    num_workers = config.train.num_workers,
    pin_memory = config.train.pin_memory,
    follow_batch = follow_batch,
    exclude_keys = collate_exclude_keys,
))
val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False, follow_batch=follow_batch, exclude_keys = collate_exclude_keys,)
train_loader = DataLoader(train_set, config.train.batch_size, shuffle=False,  exclude_keys = collate_exclude_keys)

In [6]:
data = val_set[0]

In [7]:
from torch_scatter import scatter_softmax

In [15]:
from models.invariant import VNLinear

In [21]:
vec_feature = torch.cat([pos.unsqueeze(1),pos.unsqueeze(1)], dim=1)

In [42]:
vec_feature = vet_attn_net(vec_feature)

In [45]:
alpha_vec = (vec_feature[edge_index[0]] * vec_feature[edge_index[1]]).sum(-1).sum(-1)

In [49]:
alpha_vec = sigmoid(alpha_vec)

In [54]:
node_sca.shape

torch.Size([297, 13])

In [62]:
edge_col = edge_index[0]

In [None]:
node_attn_input = 27
hidden_sca = 16
sca_attn_net = nn.Linear(node_attn_input, hidden_sca)

node_sca = data.compose_feature


In [169]:
from torch.nn import Sigmoid
sigmoid = Sigmoid()
input_node_vec_dim = 2
input_node_sca_dim = 13
input_edge_vec_dim = 1
input_edge_sca_dim = 4
out_dim = 16

node_vec_net = VNLinear(input_node_vec_dim,out_dim)
node_sca_net = nn.Linear(input_node_sca_dim, out_dim)
edge_vec_net = VNLinear(input_edge_vec_dim, out_dim)
edge_sca_net = nn.Linear(input_edge_sca_dim, out_dim)
sca_attn_net = nn.Linear(input_node_sca_dim*2+1, out_dim)
vec_attn_net = VNLinear(input_node_vec_dim, out_dim)
mapper = GVPerceptronVN(out_dim,out_dim,out_dim,out_dim)

In [131]:
edge_index = data.compose_knn_edge_index
pos = data.compose_pos
node_sca = data.compose_feature
edge_sca = data.compose_knn_edge_feature.float()
edge_dist = torch.norm(pos[edge_index[0]]-pos[edge_index[1]], dim=-1)
edge_vec = (data.compose_pos[edge_index[0]] - data.compose_pos[edge_index[1]]).unsqueeze(-2)
node_vec = torch.cat([pos.unsqueeze(1),pos.unsqueeze(1)], dim=-2)


In [133]:
alpha_sca = torch.cat([node_sca[edge_index[0]], node_sca[edge_index[1]], edge_dist.unsqueeze(-1)], dim=-1)
alpha_sca = sca_attn_net(alpha)
alpha_vec_hid = vec_attn_net(node_vec)
alpha_vec = (alpha_vec_hid[edge_index[0]] * alpha_vec_hid[edge_index[1]]).sum(-1).sum(-1)

In [137]:
edge_raw = edge_index[0]

In [149]:
node_net = nn.Linear(input_node_sca_dim, out_dim)
edge_net = nn.Linear(input_edge_sca_dim, out_dim)

In [153]:
alpha_sca.shape

torch.Size([7128, 16])

In [159]:
alpha_sca.shape

torch.Size([7128, 16])

In [162]:
scatter_softmax(alpha_sca, edge_raw, dim=0)

torch.Size([7128, 16])

In [174]:
node_sca_feat = node_net(node_sca)[edge_raw] * edge_net(edge_sca) * scatter_softmax(alpha_sca, edge_raw, dim=0)

emb_sca = scatter_sum(node_sca_feat,edge_raw, dim=0)

In [166]:

node_sca_hid = node_sca_net(node_sca)[edge_raw].unsqueeze(-1)
edge_vec_hid = edge_vec_net(edge_vec)
node_vec_hid = node_vec_net(node_vec)[edge_raw]
edge_sca_hid =  edge_sca_net(edge_sca).unsqueeze(-1)
emb_vec = scatter_add((node_sca_hid * edge_vec_hid + node_vec_hid*edge_sca_hid)*alpha_vec.unsqueeze(-1).unsqueeze(-1), edge_raw, dim=0)

In [172]:
emb_sca.shape

torch.Size([7128, 16])

In [173]:
emb_vec.shape

torch.Size([297, 16, 3])

In [175]:
mapper([emb_sca, emb_vec])

(tensor([[-1.3168e+04, -1.3868e+05,  1.7453e+06,  ...,  9.1794e+06,
          -5.0587e+04,  1.2143e+07],
         [-1.5051e+04, -1.4223e+05,  1.5296e+06,  ...,  9.9887e+06,
          -5.3473e+04,  1.2224e+07],
         [-1.1724e+04, -1.3098e+05,  1.7133e+06,  ...,  9.3352e+06,
          -4.8092e+04,  1.1505e+07],
         ...,
         [-1.6278e+04, -1.5842e+05,  2.0223e+06,  ...,  9.8644e+06,
          -5.8754e+04,  1.3841e+07],
         [-1.4820e+04, -1.4958e+05,  1.8701e+06,  ...,  9.1884e+06,
          -5.4672e+04,  1.3077e+07],
         [-1.4477e+04, -1.4683e+05,  1.8290e+06,  ...,  8.9941e+06,
          -5.3496e+04,  1.2857e+07]], grad_fn=<LeakyReluBackward0>),
 tensor([[[ 2.1365e+04,  3.2876e+04,  3.0362e+04],
          [-2.3366e+05, -4.8037e+05, -1.5488e+06],
          [ 2.6071e+06,  2.9411e+06, -3.9659e+06],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-2.9457e+05, -2.0850e+05,  3.8137e+05],
          [ 1.3829e+06,  2.4427e+06,  6.7921e+06]],
 


In [143]:
from torch_scatter import scatter_add, scatter_sum

In [145]:
a = scatter_sum((node_sca_hid * edge_vec_hid + node_vec_hid*edge_sca_hid)*alpha_vec.unsqueeze(-1).unsqueeze(-1), edge_raw, dim=0)

In [146]:
b = scatter_add((node_sca_hid * edge_vec_hid + node_vec_hid*edge_sca_hid)*alpha_vec.unsqueeze(-1).unsqueeze(-1), edge_raw, dim=0)

In [121]:
out.shape

torch.Size([297, 16, 3])

In [116]:
alpha_vec.shape

torch.Size([7128])

In [114]:
scatter_add(out, edge_col, dim=0)

torch.Size([297, 16, 3])

In [103]:
(node_sca_hid * edge_vec_hid).shape

torch.Size([7128, 16, 3])

In [95]:
vec_edge_vec_net(edge_vec).shape

torch.Size([7128, 24, 3])

In [None]:
class AtomEmbedding(Module):
    def __init__(self, in_scalar, in_vector,
                 out_scalar, out_vector, vector_normalizer=20.):
        super().__init__()
        assert in_vector == 1
        self.in_scalar = in_scalar
        self.vector_normalizer = vector_normalizer
        self.emb_sca = Linear(in_scalar, out_scalar)
        self.emb_vec = Linear(in_vector, out_vector)

    def forward(self, scalar_input, vector_input):
        vector_input = vector_input / self.vector_normalizer
        assert vector_input.shape[1:] == (3, ), 'Not support. Only one vector can be input'
        sca_emb = self.emb_sca(scalar_input[:, :self.in_scalar])  # b, f -> b, f'
        vec_emb = vector_input.unsqueeze(-1)  # b, 3 -> b, 3, 1
        vec_emb = self.emb_vec(vec_emb).transpose(1, -1)  # b, 1, 3 -> b, f', 3
        return sca_emb, vec_emb

## atom embedding

In [28]:
emb_vec = nn.Linear(1, 10)

In [29]:
    def forward(self, scalar_input, vector_input):
        vector_input = vector_input / self.vector_normalizer
        assert vector_input.shape[1:] == (3, ), 'Not support. Only one vector can be input'
        sca_emb = self.emb_sca(scalar_input[:, :self.in_scalar])  # b, f -> b, f'
        vec_emb = vector_input.unsqueeze(-1)  # b, 3 -> b, 3, 1
        vec_emb = self.emb_vec(vec_emb).transpose(1, -1)  # b, 1, 3 -> b, f', 3
        return sca_emb, vec_emb

In [30]:
vec_emb = pos.unsqueeze(-1)

In [34]:
emb_vec(vec_emb).transpose(1, -1).shape

torch.Size([276, 10, 3])

In [35]:
test_net = VNLinear(1,10)

In [None]:
class VNLinear(nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super(VNLinear, self).__init__()
        self.map_to_feat = nn.Linear(in_channels, out_channels, *args, **kwargs)
    
    def forward(self, x):
        '''
        x: point features of shape [B, N_samples, N_feat, 3]
        '''
        x_out = self.map_to_feat(x.transpose(-2,-1)).transpose(-2,-1)
        return x_out

In [39]:
vec = pos.unsqueeze(-1)

In [54]:
vec2 = torch.cat([vec,vec], dim=-1)

In [52]:
test_net2 = nn.Linear(2,10)

In [61]:
vec2.transpose(-2,-1).shape

torch.Size([276, 2, 3])

In [62]:
a = vec2.transpose(-2,-1)

In [73]:
class GVP(nn.Module):
    '''
    Geometric Vector Perceptron. See manuscript and README.md
    for more details.
    
    :param in_dims: tuple (n_scalar, n_vector)
    :param out_dims: tuple (n_scalar, n_vector)
    :param h_dim: intermediate number of vector channels, optional
    :param activations: tuple of functions (scalar_act, vector_act)
    :param vector_gate: whether to use vector gating.
                        (vector_act will be used as sigma^+ in vector gating if `True`)
    '''
    def __init__(self, in_dims, out_dims, h_dim=None,
                 activations=(F.relu, torch.sigmoid), vector_gate=False):
        super(GVP, self).__init__()
        self.si, self.vi = in_dims
        self.so, self.vo = out_dims
        self.vector_gate = vector_gate
        if self.vi: 
            self.h_dim = h_dim or max(self.vi, self.vo) 
            self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
            self.ws = nn.Linear(self.h_dim + self.si, self.so)
            if self.vo:
                self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
                if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
        else:
            self.ws = nn.Linear(self.si, self.so)
        
        self.scalar_act, self.vector_act = activations
        self.dummy_param = nn.Parameter(torch.empty(0))
        
    def forward(self, x):
        '''
        :param x: tuple (s, V) of `torch.Tensor`, 
                  or (if vectors_in is 0), a single `torch.Tensor`
        :return: tuple (s, V) of `torch.Tensor`,
                 or (if vectors_out is 0), a single `torch.Tensor`
        '''
        if self.vi:
            s, v = x
            v = torch.transpose(v, -1, -2)
            vh = self.wh(v)    
            vn = _norm_no_nan(vh, axis=-2)
            s = self.ws(torch.cat([s, vn], -1))
            if self.vo: 
                v = self.wv(vh) 
                v = torch.transpose(v, -1, -2)
                if self.vector_gate: 
                    if self.vector_act:
                        gate = self.wsv(self.vector_act(s))
                    else:
                        gate = self.wsv(s)
                    v = v * torch.sigmoid(gate).unsqueeze(-1)
                elif self.vector_act:
                    v = v * self.vector_act(
                        _norm_no_nan(v, axis=-1, keepdims=True))
        else:
            s = self.ws(x)
            if self.vo:
                v = torch.zeros(s.shape[0], self.vo, 3,
                                device=self.dummy_param.device)
        if self.scalar_act:
            s = self.scalar_act(s)
        
        return (s, v) if self.vo else s

In [76]:
sca = pos

In [78]:
test = GVP(in_dims=(3,2),out_dims=(6,4))

In [103]:
vec2.shape

torch.Size([276, 3, 2])

In [88]:
vec2_trans = torch.transpose(vec2, -1, -2)

In [93]:
vec2_trans.shape

torch.Size([276, 2, 3])

In [90]:
wh = nn.Linear(2, 10, bias=False)

In [98]:
vec2.shape

torch.Size([276, 3, 2])

In [99]:
vh = wh(vec2)    

In [100]:
vh.shape

torch.Size([276, 3, 10])

In [13]:
alpha_sca.shape

torch.Size([6624, 16])