In [52]:
import torch
import torch.nn as nn
from torch.distributions import Normal
import numpy as np # Will be used for generating points for plotting
from models.gatedgcn import GatedGCN
data_path = './dataset/pdbbind/demo.pt'
data = torch.load(data_path,weights_only=False)
ligmodel = GatedGCN(in_channels=41, 
                    edge_features=10, 
                    num_hidden_channels=128, 
                    residual=True,
                    dropout_rate=0.15,
                    equivstable_pe=False,
                    num_layers=6
                    )
protmodel = GatedGCN(in_channels=41, 
                    edge_features=5, 
                    num_hidden_channels=128, 
                    residual=True,
                    dropout_rate=0.15,
                    equivstable_pe=False,
                    num_layers=6
                    )
class MLP(nn.Module):
    def __init__(self, in_channels, hidden_dim, dropout_rate):
        super(MLP, self).__init__()
        self.mlp = nn.Sequential(nn.Linear(in_channels*2, hidden_dim), 
								nn.BatchNorm1d(hidden_dim), 
								nn.ELU(), 
								nn.Dropout(p=dropout_rate)) 
    def forward(self, x):
        return self.mlp(x)
class z_pi_layer(nn.Module):
    def __init__(self, hidden_dim=128, n_gaussians=10):
        super(z_pi_layer, self).__init__()
        self.z_pi = nn.Linear(hidden_dim, n_gaussians)
    def forward(self, x):
        return self.z_pi(x)
class z_sigma_layer(nn.Module):
    def __init__(self, hidden_dim=128, n_gaussians=10):
        super(z_sigma_layer, self).__init__()
        self.z_sigma = nn.Linear(hidden_dim, n_gaussians)
    def forward(self, x):
        return self.z_sigma(x)
class z_mu_layer(nn.Module):
    def __init__(self, hidden_dim=128, n_gaussians=10):
        super(z_mu_layer, self).__init__()
        self.z_mu = nn.Linear(hidden_dim, n_gaussians)
    def forward(self, x):
        return self.z_mu(x)
class atom_types_layer(nn.Module):
    def __init__(self, in_channels=128, n=17):
        super(atom_types_layer, self).__init__()
        self.atom_types = nn.Linear(in_channels, n)
    def forward(self, x):
        return self.atom_types(x)
class bond_types_layer(nn.Module):
    def __init__(self, in_channel=128, n=4):
        super(bond_types_layer, self).__init__()
        self.bond_types = nn.Linear(in_channel*2, n)
    def forward(self, x):
        return self.bond_types(x)
mlp = MLP(128, 128, 0.15)
z_pi = z_pi_layer(128, 10)
z_sigma = z_sigma_layer(128, 10)
z_mu = z_mu_layer(128, 10)
atom_types = atom_types_layer(128, 17)
bond_types = bond_types_layer(128, 4)

In [53]:
from torch_geometric.data import Batch
batchsize = 8
prot_batch = []
lig_batch = []
label_batch = []
flag = 0
for key, value in data.items():
    prot_batch.append(value['prot'])
    lig_batch.append(value['lig'])
    label_batch.append(float(value['label']))
    flag += 1
    if flag == batchsize:
        break
prot_batch = Batch.from_data_list(prot_batch)
lig_batch = Batch.from_data_list(lig_batch)
label_batch = torch.tensor(label_batch)
h_l = ligmodel(lig_batch)
h_t = protmodel(prot_batch)

In [54]:
from torch_scatter import scatter_add
from torch_geometric.utils import to_dense_batch
h_l_x, l_mask = to_dense_batch(h_l.x, h_l.batch, fill_value=0)
h_t_x, t_mask = to_dense_batch(h_t.x, h_t.batch, fill_value=0)
h_l_pos, _ = to_dense_batch(h_l.pos, h_l.batch, fill_value=0)
h_t_pos, _ = to_dense_batch(h_t.pos, h_t.batch, fill_value=0)
(B, N_l, C_out), N_t = h_l_x.size(), h_t_x.size(1)

In [55]:
h_l_x_new = h_l_x.unsqueeze(-2)
h_l_x_new2 = h_l_x_new.repeat(1, 1, N_t, 1) # [B, N_l, N_t, C_out]
h_l_x_new2.shape
h_t_x_new = h_t_x.unsqueeze(-3)
h_t_x_new2 = h_t_x_new.repeat(1, N_l, 1, 1) # [B, N_l, N_t, C_out]
h_t_x_new2.shape
C = torch.cat((h_l_x_new2, h_t_x_new2), -1)
C.shape
C_mask = l_mask.view(B, N_l, 1) & t_mask.view(B, 1, N_t)
C_mask.shape
C_new = C[C_mask]
C_new.shape,C.shape,C_mask.shape

(torch.Size([18681, 256]),
 torch.Size([8, 49, 101, 256]),
 torch.Size([8, 49, 101]))

In [56]:
C_new2 = mlp(C_new)
C_new2.shape
C_batch = torch.tensor(range(B)).unsqueeze(-1).unsqueeze(-1)
C_batch.shape
C_batch_new = C_batch.repeat(1, N_l, N_t)[C_mask]
C_batch_new.shape
C_batch_new

tensor([0, 0, 0,  ..., 7, 7, 7])

In [58]:
import torch.nn.functional as F
pi = F.softmax(z_pi(C_new2), -1)
sigma = F.elu(z_sigma(C_new2))+1.1
mu = F.elu(z_mu(C_new2))+1
pi.shape,sigma.shape,mu.shape

(torch.Size([18681, 10]), torch.Size([18681, 10]), torch.Size([18681, 10]))

In [59]:
atom_types = atom_types(h_l.x)
bond_types = bond_types(torch.cat([h_l.x[h_l.edge_index[0]], h_l.x[h_l.edge_index[1]]], axis=1))  
atom_types.shape,bond_types.shape

(torch.Size([224, 17]), torch.Size([476, 4]))

In [61]:
atom_types,bond_types

(tensor([[ 2.0631, -1.6908, -0.0284,  ..., -1.4472,  1.8577,  3.6292],
         [ 0.7991, -0.5083, -0.3321,  ..., -0.0709,  1.5080,  2.6271],
         [ 2.8874,  0.1504,  1.0486,  ..., -2.1823,  1.6224,  3.5007],
         ...,
         [ 2.8269, -0.3445,  1.3712,  ...,  1.9708,  0.6048,  5.2988],
         [ 1.6809, -1.6907, -0.4642,  ..., -1.2399,  2.0318,  2.0042],
         [ 3.5060, -0.2919,  1.8847,  ...,  0.5114, -0.5795,  3.9267]],
        grad_fn=<AddmmBackward0>),
 tensor([[-0.6892, -1.5002, -1.1214, -0.0823],
         [-1.2390, -1.5962, -4.0807,  0.2000],
         [-1.8250, -2.2734, -2.5978, -0.0464],
         ...,
         [ 0.2831, -1.8221,  2.1551, -1.0913],
         [-2.8723, -0.3263, -3.0241,  1.0681],
         [ 0.8589, -0.7306,  2.0615, -1.3952]], grad_fn=<AddmmBackward0>))

In [68]:
def compute_euclidean_distances_matrix(X, Y):
    X = X.double()
    Y = Y.double()
    dists = -2 * torch.bmm(X, Y.permute(0, 2, 1)) + torch.sum(Y**2,axis=-1).unsqueeze(1) + torch.sum(X**2, axis=-1).unsqueeze(-1)	
    return torch.nan_to_num((dists**0.5).view(B, N_l,-1,24),10000).min(axis=-1)#[0]
dist = compute_euclidean_distances_matrix(h_l_pos, h_t_pos.view(B,-1,3))#[C_mask]

In [69]:
h_l_pos.shape,h_t_pos.shape

(torch.Size([8, 49, 3]), torch.Size([8, 101, 24, 3]))

In [None]:
h_t_pos.view(B,-1,3).shape

torch.Size([8, 2424, 3])

In [None]:
h_t_pos.view(B,-1,3).shape

torch.Size([8, 2424, 3])

In [70]:
h_t_pos.view(B,-1,3).shape

torch.Size([8, 2424, 3])

In [None]:
h_t_pos.view(B,-1,3).shape

torch.Size([8, 2424, 3])

In [None]:
h_t_pos.view(B,-1,3).shape

torch.Size([8, 2424, 3])

In [None]:
def mdn_loss_fn(pi, sigma, mu, y, eps1=1e-10, eps2=1e-10):
    """
    Calculates the Mixture Density Network loss.
    Args:
        pi (torch.Tensor): Mixture coefficients (batch_size, num_components).
        sigma (torch.Tensor): Standard deviations of Gaussian components (batch_size, num_components).
                           Must be positive.
        mu (torch.Tensor): Means of Gaussian components (batch_size, num_components).
        y (torch.Tensor): Target values (batch_size, 1) or (batch_size,).
        eps1 (float): Small epsilon for pi log stability.
        eps2 (float): Small epsilon for final probability log stability.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - loss: The negative log-likelihood loss for each sample in the batch (batch_size,).
            - prob: The probability of y for each sample (batch_size,).
    """
    normal = Normal(mu, sigma)
    loglik = normal.log_prob(y.expand_as(normal.loc))
    prob = (torch.log(pi + eps1) + loglik).exp().sum(dim=1)
    loss = -torch.log(prob + eps2)
    return loss, prob

In [64]:
C_mask.shape

torch.Size([8, 49, 101])