Execute the original TankBind model, and try to find the differences in outputs between the old and the new model

In [1]:
%pwd

'/fs/gpfs41/lv11/fileset01/pool/pool-marsot'

In [1]:
%cd /fs/gpfs41/lv11/fileset01/pool/pool-marsot/tankbind_philip/TankBind/tankbind/

/fs/gpfs41/lv11/fileset01/pool/pool-marsot/tankbind_philip/TankBind/tankbind


In [2]:
import sys; sys.path.append("/fs/pool/pool-marsot")

In [3]:
import tankbind_philip.TankBind.tankbind as tankbind_og

In [4]:
from tankbind_philip.TankBind.tankbind.model import IaBNet_with_affinity

In [5]:
model_og = IaBNet_with_affinity()

Protein embedding

In [14]:
#replace model weights with original model weights
for (name1, param1), (name2, param2) in zip(model.protein_embedding.named_parameters(), model_og.conv_protein.named_parameters()):
    param1.data = param2.data

Compound embedding

In [14]:
import torch

In [20]:
model.compound_embedding

GINE(
  (gnn_node): GNN_node(
    (atom_encoder): AtomEncoder(
      (linear): Linear(in_features=128, out_features=128, bias=True)
    )
    (convs): ModuleList(
      (0-4): 5 x GINConv()
    )
    (batch_norms): ModuleList(
      (0-4): 5 x BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)

In [15]:
torch.__version__

'2.3.1'

Iteratively modifying the model

In [None]:
import torch
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.utils import to_dense_batch
from torch import nn
from torch.nn import Linear
import sys
import torch.nn as nn
from gvp import GVP, GVPConvLayer, LayerNorm, tuple_index
from torch.distributions import Categorical
from torch_scatter import scatter_mean
#from GATv2 import GAT
from GINv2 import GIN

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x



class GVP_embedding(nn.Module):
    '''
    Modified based on https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py
    GVP-GNN for Model Quality Assessment as described in manuscript.
    
    Takes in protein structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a scalar score for
    each graph in the batch in a `torch.Tensor` of shape [n_nodes]
    
    Should be used with `gvp.data.ProteinGraphDataset`, or with generators
    of `torch_geometric.data.Batch` objects with the same attributes.
    
    :param node_in_dim: node dimensions in input graph, should be
                        (6, 3) if using original features
    :param node_h_dim: node dimensions to use in GVP-GNN layers
    :param node_in_dim: edge dimensions in input graph, should be
                        (32, 1) if using original features
    :param edge_h_dim: edge dimensions to embed to before use
                       in GVP-GNN layers
    :seq_in: if `True`, sequences will also be passed in with
             the forward pass; otherwise, sequence information
             is assumed to be part of input node embeddings
    :param num_layers: number of GVP-GNN layers
    :param drop_rate: rate to use in all dropout layers
    '''
    def __init__(self, node_in_dim, node_h_dim, 
                 edge_in_dim, edge_h_dim,
                 seq_in=False, num_layers=3, drop_rate=0.1):

        super(GVP_embedding, self).__init__()
        
        if seq_in:
            self.W_s = nn.Embedding(20, 20)
            node_in_dim = (node_in_dim[0] + 20, node_in_dim[1])
        
        self.W_v = nn.Sequential(
            LayerNorm(node_in_dim),
            GVP(node_in_dim, node_h_dim, activations=(None, None))
        )
        self.W_e = nn.Sequential(
            LayerNorm(edge_in_dim),
            GVP(edge_in_dim, edge_h_dim, activations=(None, None))
        )

        self.layers = nn.ModuleList(
                GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 
            for _ in range(num_layers))
        
        ns, _ = node_h_dim
        self.W_out = nn.Sequential(
            LayerNorm(node_h_dim),
            GVP(node_h_dim, (ns, 0)))

    def forward(self, h_V, edge_index, h_E, seq):      
        '''
        :param h_V: tuple (s, V) of node embeddings
        :param edge_index: `torch.Tensor` of shape [2, num_edges]
        :param h_E: tuple (s, V) of edge embeddings
        :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes]
                    to be embedded and appended to `h_V`
        '''
        seq = self.W_s(seq)
        h_V = (torch.cat([h_V[0], seq], dim=-1), h_V[1])
        h_V = self.W_v(h_V)
        h_E = self.W_e(h_E)
        for layer in self.layers:
            h_V = layer(h_V, edge_index, h_E)
        out = self.W_out(h_V)

        return out


def get_pair_dis_one_hot(d, bin_size=2, bin_min=-1, bin_max=30):
    # without compute_mode='donot_use_mm_for_euclid_dist' could lead to wrong result.
    pair_dis = torch.cdist(d, d, compute_mode='donot_use_mm_for_euclid_dist')
    pair_dis[pair_dis>bin_max] = bin_max
    pair_dis_bin_index = torch.div(pair_dis - bin_min, bin_size, rounding_mode='floor').long()
    pair_dis_one_hot = torch.nn.functional.one_hot(pair_dis_bin_index, num_classes=16)
    return pair_dis_one_hot

class TriangleProteinToCompound(torch.nn.Module):
    def __init__(self, embedding_channels=256, c=128, hasgate=True):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.layernorm_c = torch.nn.LayerNorm(c)
        self.hasgate = hasgate
        if hasgate:
            self.gate_linear = Linear(embedding_channels, c)
        self.linear = Linear(embedding_channels, c)
        self.ending_gate_linear = Linear(embedding_channels, embedding_channels)
        self.linear_after_sum = Linear(c, embedding_channels)
    def forward(self, z, protein_pair, compound_pair, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        # z_mask of shape b, i, j, 1
        z = self.layernorm(z)
        if self.hasgate:
            ab = self.gate_linear(z).sigmoid() * self.linear(z) * z_mask
        else:
            ab = self.linear(z) * z_mask
        g = self.ending_gate_linear(z).sigmoid()
        block1 = torch.einsum("bikc,bkjc->bijc", protein_pair, ab)
        block2 = torch.einsum("bikc,bjkc->bijc", ab, compound_pair)
        z = g * self.linear_after_sum(self.layernorm_c(block1+block2)) * z_mask
        return z

class TriangleProteinToCompound_v2(torch.nn.Module):
    # separate left/right edges (block1/block2).
    def __init__(self, embedding_channels=256, c=128):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.layernorm_c = torch.nn.LayerNorm(c)

        self.gate_linear1 = Linear(embedding_channels, c)
        self.gate_linear2 = Linear(embedding_channels, c)

        self.linear1 = Linear(embedding_channels, c)
        self.linear2 = Linear(embedding_channels, c)

        self.ending_gate_linear = Linear(embedding_channels, embedding_channels)
        self.linear_after_sum = Linear(c, embedding_channels)
    def forward(self, z, protein_pair, compound_pair, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        z = self.layernorm(z)
        protein_pair = self.layernorm(protein_pair)
        compound_pair = self.layernorm(compound_pair)
 
        ab1 = self.gate_linear1(z).sigmoid() * self.linear1(z) * z_mask
        ab2 = self.gate_linear2(z).sigmoid() * self.linear2(z) * z_mask
        protein_pair = self.gate_linear2(protein_pair).sigmoid() * self.linear2(protein_pair)
        compound_pair = self.gate_linear1(compound_pair).sigmoid() * self.linear1(compound_pair)

        g = self.ending_gate_linear(z).sigmoid()
        block1 = torch.einsum("bikc,bkjc->bijc", protein_pair, ab1)
        block2 = torch.einsum("bikc,bjkc->bijc", ab2, compound_pair)
        # print(g.shape, block1.shape, block2.shape)
        z = g * self.linear_after_sum(self.layernorm_c(block1+block2)) * z_mask
        return z

class Self_Attention(nn.Module):
    def __init__(self, hidden_size,num_attention_heads=8,drop_rate=0.5):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.dp = nn.Dropout(drop_rate)
        self.ln = nn.LayerNorm(hidden_size)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self,q,k,v,attention_mask=None,attention_weight=None):
        q = self.transpose_for_scores(q)
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)
        attention_scores = torch.matmul(q, k.transpose(-1, -2))

        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        # attention_probs = self.dp(attention_probs)
        if attention_weight is not None:
            attention_weight_sorted_sorted = torch.argsort(torch.argsort(-attention_weight,axis=-1),axis=-1)
            # if self.training:
            #     top_mask = (attention_weight_sorted_sorted<np.random.randint(28,45))
            # else:
            top_mask = (attention_weight_sorted_sorted<32)
            attention_probs = attention_probs * top_mask
            # attention_probs = attention_probs * attention_weight
            attention_probs = attention_probs / (torch.sum(attention_probs,dim=-1,keepdim=True) + 1e-5)
        # print(attention_probs.shape,v.shape)
        # attention_probs = self.dp(attention_probs)
        outputs = torch.matmul(attention_probs, v)

        outputs = outputs.permute(0, 2, 1, 3).contiguous()
        new_output_shape = outputs.size()[:-2] + (self.all_head_size,)
        outputs = outputs.view(*new_output_shape)
        outputs = self.ln(outputs)
        return outputs

class TriangleSelfAttentionRowWise(torch.nn.Module):
    # use the protein-compound matrix only.
    def __init__(self, embedding_channels=128, c=32, num_attention_heads=4):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = c
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # self.dp = nn.Dropout(drop_rate)
        # self.ln = nn.LayerNorm(hidden_size)

        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        # self.layernorm_c = torch.nn.LayerNorm(c)

        self.linear_q = Linear(embedding_channels, self.all_head_size, bias=False)
        self.linear_k = Linear(embedding_channels, self.all_head_size, bias=False)
        self.linear_v = Linear(embedding_channels, self.all_head_size, bias=False)
        # self.b = Linear(embedding_channels, h, bias=False)
        self.g = Linear(embedding_channels, self.all_head_size)
        self.final_linear = Linear(self.all_head_size, embedding_channels)

    def reshape_last_dim(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x

    def forward(self, z, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        # z_mask of shape b, i, j
        z = self.layernorm(z)
        p_length = z.shape[1]
        batch_n = z.shape[0]
        # new_z = torch.zeros(z.shape, device=z.device)
        z_i = z
        z_mask_i = z_mask.view((batch_n, p_length, 1, 1, -1))
        attention_mask_i = (1e9 * (z_mask_i.float() - 1.))
        # q, k, v of shape b, j, h, c
        q = self.reshape_last_dim(self.linear_q(z_i)) #  * (self.attention_head_size**(-0.5))
        k = self.reshape_last_dim(self.linear_k(z_i))
        v = self.reshape_last_dim(self.linear_v(z_i))
        logits = torch.einsum('biqhc,bikhc->bihqk', q, k) + attention_mask_i
        weights = nn.Softmax(dim=-1)(logits)
        # weights of shape b, h, j, j
        # attention_probs = self.dp(attention_probs)
        weighted_avg = torch.einsum('bihqk,bikhc->biqhc', weights, v)
        g = self.reshape_last_dim(self.g(z_i)).sigmoid()
        output = g * weighted_avg
        new_output_shape = output.size()[:-2] + (self.all_head_size,)
        output = output.view(*new_output_shape)
        # output of shape b, j, embedding.
        # z[:, i] = output
        z = output
        # print(g.shape, block1.shape, block2.shape)
        z = self.final_linear(z) * z_mask.unsqueeze(-1)
        return z


class Transition(torch.nn.Module):
    # separate left/right edges (block1/block2).
    def __init__(self, embedding_channels=256, n=4):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.linear1 = Linear(embedding_channels, n*embedding_channels)
        self.linear2 = Linear(n*embedding_channels, embedding_channels)
    def forward(self, z):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        z = self.layernorm(z)
        z = self.linear2((self.linear1(z)).relu())
        return z



class IaBNet_with_affinity_attention(torch.nn.Module):
    def __init__(self, hidden_channels=128, embedding_channels=128, c=128, mode=0, protein_embed_mode=1, compound_embed_mode=1, n_trigonometry_module_stack=5, protein_bin_max=30, readout_mode=2):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.protein_bin_max = protein_bin_max
        self.mode = mode
        self.protein_embed_mode = protein_embed_mode
        self.compound_embed_mode = compound_embed_mode
        self.n_trigonometry_module_stack = n_trigonometry_module_stack
        self.readout_mode = readout_mode

        if protein_embed_mode == 0:
            self.conv_protein = GNN(hidden_channels, embedding_channels)
            self.conv_compound = GNN(hidden_channels, embedding_channels)
        if protein_embed_mode == 1:
            self.conv_protein = GVP_embedding((6, 3), (embedding_channels, 16), 
                                              (32, 1), (32, 1), seq_in=True)
            

        if compound_embed_mode == 0:
            self.conv_compound = GNN(hidden_channels, embedding_channels)
        elif compound_embed_mode == 1:
            self.conv_compound = GIN(input_dim = 56, hidden_dims = [128,56,embedding_channels], edge_input_dim = 19, concat_hidden = False)

        if mode == 0:
            self.protein_pair_embedding = Linear(16, c)
            self.compound_pair_embedding = Linear(16, c)
            self.protein_to_compound_list = []
            self.protein_to_compound_list = nn.ModuleList([TriangleProteinToCompound_v2(embedding_channels=embedding_channels, c=c) for _ in range(n_trigonometry_module_stack)])
            self.triangle_self_attention_list = nn.ModuleList([TriangleSelfAttentionRowWise(embedding_channels=embedding_channels) for _ in range(n_trigonometry_module_stack)])
            self.tranistion = Transition(embedding_channels=embedding_channels, n=4)

        self.linear = Linear(embedding_channels, 1)
        self.linear_energy = Linear(embedding_channels, 1)
        if readout_mode == 2:
            self.gate_linear = Linear(embedding_channels, 1)
        # self.gate_linear = Linear(embedding_channels, 1)
        self.bias = torch.nn.Parameter(torch.ones(1))
        self.leaky = torch.nn.LeakyReLU()
        self.dropout = nn.Dropout2d(p=0.25)
    def forward(self, data):
        if self.protein_embed_mode == 0:
            x = data['protein'].x.float()
            edge_index = data[("protein", "p2p", "protein")].edge_index
            protein_batch = data['protein'].batch
            protein_out = self.conv_protein(x, edge_index)
        if self.protein_embed_mode == 1:
            nodes = (data['protein']['node_s'], data['protein']['node_v'])
            edges = (data[("protein", "p2p", "protein")]["edge_s"], data[("protein", "p2p", "protein")]["edge_v"])
            protein_batch = data['protein'].batch
            protein_out = self.conv_protein(nodes, data[("protein", "p2p", "protein")]["edge_index"], edges, data.seq)

        if self.compound_embed_mode == 0:
            compound_x = data['compound'].x.float()
            compound_edge_index = data[("compound", "c2c", "compound")].edge_index
            compound_batch = data['compound'].batch
            compound_out = self.conv_compound(compound_x, compound_edge_index)
        elif self.compound_embed_mode == 1:
            compound_x = data['compound'].x.float()
            compound_edge_index = data[("compound", "c2c", "compound")].edge_index.T
            # compound_edge_index = data[("compound", "c2c", "compound")].edge_index
            compound_edge_feature = data[("compound", "c2c", "compound")].edge_attr
            edge_weight = data[("compound", "c2c", "compound")].edge_weight
            compound_batch = data['compound'].batch
            # Enzo : print dimensions
            #print(f"{compound_edge_index.shape=}, {edge_weight.shape=}, {compound_edge_feature.shape=}, {compound_x.shape=}")
            compound_out = self.conv_compound(compound_edge_index,edge_weight,compound_edge_feature,compound_x.shape[0],compound_x)['node_feature']
    
        # protein_batch version could further process b matrix. better than for loop.
        # protein_out_batched of shape b, n, c
        protein_out_batched, protein_out_mask = to_dense_batch(protein_out, protein_batch)
        compound_out_batched, compound_out_mask = to_dense_batch(compound_out, compound_batch)

        node_xyz = data.node_xyz

        p_coords_batched, p_coords_mask = to_dense_batch(node_xyz, protein_batch)
        # c_coords_batched, c_coords_mask = to_dense_batch(coords, compound_batch)

        protein_pair = get_pair_dis_one_hot(p_coords_batched, bin_size=2, bin_min=-1, bin_max=self.protein_bin_max)
        # compound_pair = get_pair_dis_one_hot(c_coords_batched, bin_size=1, bin_min=-0.5, bin_max=15)
        compound_pair_batched, compound_pair_batched_mask = to_dense_batch(data.compound_pair, data.compound_pair_batch)
        batch_n = compound_pair_batched.shape[0]
        max_compound_size_square = compound_pair_batched.shape[1]
        max_compound_size = int(max_compound_size_square**0.5)
        assert (max_compound_size**2 - max_compound_size_square)**2 < 1e-4
        compound_pair = torch.zeros((batch_n, max_compound_size, max_compound_size, 16)).to(data.compound_pair.device)
        for i in range(batch_n):
            one = compound_pair_batched[i]
            compound_size_square = (data.compound_pair_batch == i).sum()
            compound_size = int(compound_size_square**0.5)
            compound_pair[i,:compound_size, :compound_size] = one[:compound_size_square].reshape(
                                                                (compound_size, compound_size, -1))
        protein_pair = self.protein_pair_embedding(protein_pair.float())
        compound_pair = self.compound_pair_embedding(compound_pair.float())
        # b = torch.einsum("bik,bjk->bij", protein_out_batched, compound_out_batched).flatten()

        protein_out_batched = self.layernorm(protein_out_batched)
        compound_out_batched = self.layernorm(compound_out_batched)
        # z of shape, b, protein_length, compound_length, channels.
        z = torch.einsum("bik,bjk->bijk", protein_out_batched, compound_out_batched)
        z_mask = torch.einsum("bi,bj->bij", protein_out_mask, compound_out_mask)
        # z = z * z_mask.unsqueeze(-1)
        # print(protein_pair.shape, compound_pair.shape, b.shape)
        if self.mode == 0:
            for _ in range(1):
                for i_module in range(self.n_trigonometry_module_stack):
                    z = z + self.dropout(self.protein_to_compound_list[i_module](z, protein_pair, compound_pair, z_mask.unsqueeze(-1)))
                    z = z + self.dropout(self.triangle_self_attention_list[i_module](z, z_mask))
                    z = self.tranistion(z)
        # batch_dim = z.shape[0]

        b = self.linear(z).squeeze(-1)
        y_pred = b[z_mask]
        y_pred = y_pred.sigmoid() * 10   # normalize to 0 to 10.
        if self.readout_mode == 0:
            pair_energy = self.linear_energy(z).squeeze(-1) * z_mask
            affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
        if self.readout_mode == 1:
            # valid_interaction_z = (z * z_mask.unsqueeze(-1)).mean(axis=(1, 2))
            valid_interaction_z = (z * z_mask.unsqueeze(-1)).sum(axis=(1, 2)) / z_mask.sum(axis=(1, 2)).unsqueeze(-1)
            affinity_pred = self.linear_energy(valid_interaction_z).squeeze(-1)
            # print("z shape", z.shape, "z_mask shape", z_mask.shape,   "valid_interaction_z shape", valid_interaction_z.shape, "affinity_pred shape", affinity_pred.shape)
        if self.readout_mode == 2:
            pair_energy = (self.gate_linear(z).sigmoid() * self.linear_energy(z)).squeeze(-1) * z_mask
            affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
        return y_pred, affinity_pred

In [26]:
model_og.triangle_self_attention_list[0]

TriangleSelfAttentionRowWise(
  (layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (linear_q): Linear(in_features=128, out_features=128, bias=False)
  (linear_k): Linear(in_features=128, out_features=128, bias=False)
  (linear_v): Linear(in_features=128, out_features=128, bias=False)
  (g): Linear(in_features=128, out_features=128, bias=True)
  (final_linear): Linear(in_features=128, out_features=128, bias=True)
)

In [9]:
from bindbind.models.models.tankbind_layers import *

In [6]:
from tankbind_philip.TankBind.tankbind.data import TankBindDataSet

In [7]:
add_noise_to_com = 0.0
# compoundMode = 1 is for GIN model.
#new_dataset = TankBindDataSet(f"{pre}/apr22_pdbbind_gvp_pocket_radius20", add_noise_to_com=add_noise_to_com)'
new_dataset = TankBindDataSet("/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset", add_noise_to_com=add_noise_to_com)
# modified by Enzo
# load compound features extracted using torchdrug.
# new_dataset.compound_dict = torch.load(f"{pre}/compound_dict.pt")
new_dataset.data = new_dataset.data.query("c_length < 100 and native_num_contact > 5").reset_index(drop=True)
d = new_dataset.data
only_native_train_index = d.query("use_compound_com and group =='train'").index.values
train = new_dataset[only_native_train_index]
train_index = d.query("group =='train'").index.values
train_after_warm_up = new_dataset[train_index]
# train = torch.utils.data.ConcatDataset([train1, train2])
valid_index = d.query("use_compound_com and group =='valid'").index.values
valid = new_dataset[valid_index]
test_index = d.query("use_compound_com and group =='test'").index.values
test = new_dataset[test_index]

#all_pocket_test_fileName = f"{pre}/apr23_testset_pdbbind_gvp_pocket_radius20/"
# added by Enzo
all_pocket_test_fileName = "/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/test_dataset"
all_pocket_test = TankBindDataSet(all_pocket_test_fileName)
#all_pocket_test.compound_dict = torch.load(f"{pre}/compound_dict.pt")
# added by Enzo
all_pocket_test.compound_dict = "/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/compound.pt"
# info is used to evaluate the test set. 
info = None

['/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/data.pt', '/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/protein.pt', '/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/compound.pt']
['/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/test_dataset/processed/data.pt', '/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/test_dataset/processed/protein.pt', '/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/test_dataset/processed/compound.pt']


In [36]:
from torch.utils.data import RandomSampler
from torch_geometric.loader import DataLoader
sampler = RandomSampler(train, replacement=True, num_samples=20000)
train_loader = DataLoader(train, batch_size=4, follow_batch=['x', 'compound_pair'], sampler=sampler, pin_memory=False, num_workers=2)

In [37]:
batch = next(iter(train_loader))

In [38]:
batch

HeteroDataBatch(
  dis_map=[13207],
  node_xyz=[539, 3],
  coords=[108, 3],
  y=[13207],
  seq=[539],
  affinity=[4],
  compound_pair=[4010, 16],
  compound_pair_batch=[4010],
  compound_pair_ptr=[5],
  pdb=[4],
  group=[4],
  real_affinity_mask=[4],
  real_y_mask=[13207],
  is_equivalent_native_pocket=[4],
  equivalent_native_y_mask=[13207],
  protein={
    node_s=[539, 6],
    node_v=[539, 3, 3],
    batch=[539],
    ptr=[5],
  },
  compound={
    x=[108, 56],
    x_batch=[108],
    x_ptr=[5],
    batch=[108],
    ptr=[5],
  },
  (protein, p2p, protein)={
    edge_index=[2, 12801],
    edge_s=[12801, 32],
    edge_v=[12801, 1, 3],
  },
  (compound, c2c, compound)={
    edge_index=[2, 218],
    edge_weight=[218],
    edge_attr=[218, 19],
  }
)

In [39]:
model_og

IaBNet_with_affinity(
  (layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (conv_protein): GVP_embedding(
    (W_s): Embedding(20, 20)
    (W_v): Sequential(
      (0): LayerNorm(
        (scalar_norm): LayerNorm((26,), eps=1e-05, elementwise_affine=True)
      )
      (1): GVP(
        (wh): Linear(in_features=3, out_features=16, bias=False)
        (ws): Linear(in_features=42, out_features=128, bias=True)
        (wv): Linear(in_features=16, out_features=16, bias=False)
      )
    )
    (W_e): Sequential(
      (0): LayerNorm(
        (scalar_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      )
      (1): GVP(
        (wh): Linear(in_features=1, out_features=1, bias=False)
        (ws): Linear(in_features=33, out_features=32, bias=True)
        (wv): Linear(in_features=1, out_features=1, bias=False)
      )
    )
    (layers): ModuleList(
      (0-2): 3 x GVPConvLayer(
        (conv): GVPConv()
        (norm): ModuleList(
          (0-1): 2 x LayerNo

In [40]:
data=batch

In [42]:
from torch_geometric.utils import to_dense_batch

In [44]:
def get_pair_dis_one_hot(d, bin_size=2, bin_min=-1, bin_max=30):
    # without compute_mode='donot_use_mm_for_euclid_dist' could lead to wrong result.
    pair_dis = torch.cdist(d, d, compute_mode='donot_use_mm_for_euclid_dist')
    pair_dis[pair_dis>bin_max] = bin_max
    pair_dis_bin_index = torch.div(pair_dis - bin_min, bin_size, rounding_mode='floor').long()
    pair_dis_one_hot = torch.nn.functional.one_hot(pair_dis_bin_index, num_classes=16)
    return pair_dis_one_hot

In [51]:
model_og.eval()

IaBNet_with_affinity(
  (layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (conv_protein): GVP_embedding(
    (W_s): Embedding(20, 20)
    (W_v): Sequential(
      (0): LayerNorm(
        (scalar_norm): LayerNorm((26,), eps=1e-05, elementwise_affine=True)
      )
      (1): GVP(
        (wh): Linear(in_features=3, out_features=16, bias=False)
        (ws): Linear(in_features=42, out_features=128, bias=True)
        (wv): Linear(in_features=16, out_features=16, bias=False)
      )
    )
    (W_e): Sequential(
      (0): LayerNorm(
        (scalar_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      )
      (1): GVP(
        (wh): Linear(in_features=1, out_features=1, bias=False)
        (ws): Linear(in_features=33, out_features=32, bias=True)
        (wv): Linear(in_features=1, out_features=1, bias=False)
      )
    )
    (layers): ModuleList(
      (0-2): 3 x GVPConvLayer(
        (conv): GVPConv()
        (norm): ModuleList(
          (0-1): 2 x LayerNo

In [46]:

nodes = (data['protein']['node_s'], data['protein']['node_v'])
edges = (data[("protein", "p2p", "protein")]["edge_s"], data[("protein", "p2p", "protein")]["edge_v"])
protein_batch = data['protein'].batch
protein_out = model_og.conv_protein(nodes, data[("protein", "p2p", "protein")]["edge_index"], edges, data.seq)


compound_x = data['compound'].x.float()
compound_edge_index = data[("compound", "c2c", "compound")].edge_index.T
compound_edge_feature = data[("compound", "c2c", "compound")].edge_attr
edge_weight = data[("compound", "c2c", "compound")].edge_weight
compound_batch = data['compound'].batch
# Enzo : print dimensions
#print(f"{compound_edge_index.shape=}, {edge_weight.shape=}, {compound_edge_feature.shape=}, {compound_x.shape=}")
compound_out = model_og.conv_compound(compound_edge_index,edge_weight,compound_edge_feature,compound_x.shape[0],compound_x)['node_feature']

protein_out_batched, protein_out_mask = to_dense_batch(protein_out, protein_batch)
compound_out_batched, compound_out_mask = to_dense_batch(compound_out, compound_batch)

node_xyz = data.node_xyz

p_coords_batched, p_coords_mask = to_dense_batch(node_xyz, protein_batch)

protein_pair = get_pair_dis_one_hot(p_coords_batched, bin_size=2, bin_min=-1, bin_max=model_og.protein_bin_max)
compound_pair_batched, compound_pair_batched_mask = to_dense_batch(data.compound_pair, data.compound_pair_batch)
batch_n = compound_pair_batched.shape[0]
max_compound_size_square = compound_pair_batched.shape[1]
max_compound_size = int(max_compound_size_square**0.5)
assert (max_compound_size**2 - max_compound_size_square)**2 < 1e-4
compound_pair = torch.zeros((batch_n, max_compound_size, max_compound_size, 16)).to(data.compound_pair.device)
for i in range(batch_n):
    one = compound_pair_batched[i]
    compound_size_square = (data.compound_pair_batch == i).sum()
    compound_size = int(compound_size_square**0.5)
    compound_pair[i,:compound_size, :compound_size] = one[:compound_size_square].reshape(
                                                        (compound_size, compound_size, -1))
protein_pair = model_og.protein_pair_embedding(protein_pair.float())
compound_pair = model_og.compound_pair_embedding(compound_pair.float())

protein_out_batched = model_og.layernorm(protein_out_batched)
compound_out_batched = model_og.layernorm(compound_out_batched)
# z of shape, b, protein_length, compound_length, channels.
z = torch.einsum("bik,bjk->bijk", protein_out_batched, compound_out_batched)
z_mask = torch.einsum("bi,bj->bij", protein_out_mask, compound_out_mask)


In [None]:


for i_module in range(self.n_trigonometry_module_stack):
    z = z + self.dropout(self.protein_to_compound_list[i_module](z, protein_pair, compound_pair, z_mask.unsqueeze(-1)))
    z = z + self.dropout(self.triangle_self_attention_list[i_module](z, z_mask))
    z = self.tranistion(z)


b = self.linear(z).squeeze(-1)
y_pred = b[z_mask]
y_pred = y_pred.sigmoid() * 10   # normalize to 0 to 10.
if self.readout_mode == 0:
    pair_energy = self.linear_energy(z).squeeze(-1) * z_mask
    affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
if self.readout_mode == 1:
    # valid_interaction_z = (z * z_mask.unsqueeze(-1)).mean(axis=(1, 2))
    valid_interaction_z = (z * z_mask.unsqueeze(-1)).sum(axis=(1, 2)) / z_mask.sum(axis=(1, 2)).unsqueeze(-1)
    affinity_pred = self.linear_energy(valid_interaction_z).squeeze(-1)
    # print("z shape", z.shape, "z_mask shape", z_mask.shape,   "valid_interaction_z shape", valid_interaction_z.shape, "affinity_pred shape", affinity_pred.shape)
if self.readout_mode == 2:
    pair_energy = (self.gate_linear(z).sigmoid() * self.linear_energy(z)).squeeze(-1) * z_mask
    affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
return y_pred, affinity_pred

In [48]:
y = model_og.triangle_self_attention_list[0](z, z_mask)

In [49]:
z.shape

torch.Size([4, 177, 55, 128])

In [50]:
z_mask.shape

torch.Size([4, 177, 55])

In [56]:
att0 = model_og.triangle_self_attention_list[0]

In [58]:
y = att0.layernorm(z)
p_length = y.shape[1]
batch_n = y.shape[0]
z_i = y
z_mask_i = z_mask.view((batch_n, p_length, 1, 1, -1))
attention_mask_i = (1e9 * (z_mask_i.float() - 1.))
q = att0.reshape_last_dim(att0.linear_q(z_i))
k = att0.reshape_last_dim(att0.linear_k(z_i))
v = att0.reshape_last_dim(att0.linear_v(z_i))
logits = torch.einsum('biqhc,bikhc->bihqk', q, k) + attention_mask_i
weights = nn.Softmax(dim=-1)(logits)
weighted_avg = torch.einsum('bihqk,bikhc->biqhc', weights, v)

In [59]:
q.shape

torch.Size([4, 177, 55, 4, 32])

Etape 1: padding et batching: verifier qu'on obtient le meme resultat

In [15]:
from torch_geometric.loader.dataloader import Collater
from torch_geometric.utils import to_dense_batch
import torch


class TankBindDataLoader(torch.utils.data.DataLoader):
    """Subclass of the torch DataLoader, in order to apply the collate function TankBindCollater."""
    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=False,
                 follow_batch=None,
                 exclude_keys=None,
                 make_divisible_by_8=True,
                 **kwargs):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys
        self.make_divisible_by_8=make_divisible_by_8
        super().__init__(dataset,
                         batch_size,
                         shuffle,
                         collate_fn=TankBindCollater(dataset, follow_batch, exclude_keys, make_divisible_by_8=self.make_divisible_by_8),
                         **kwargs)



class TankBindCollater(Collater):
    """Applies batching operations and computations of masks in place of the model, in order to avoid having to recompute it in the
    forward pass on GPU."""
    def __init__(self, dataset,
                 follow_batch=None,
                 exclude_keys=None,
                 make_divisible_by_8=True):
        super().__init__(dataset, follow_batch, exclude_keys)
        self.make_divisible_by_8 = make_divisible_by_8
    def __call__(self, batch):
        data = super().__call__(batch)
        if self.make_divisible_by_8:
            max_dim_divisible_by_8_protein = 8 * (torch.diff(data["protein"].ptr).max() // 8 + 1)
            max_dim_divisible_by_8_compound = 8 * (torch.diff(data["compound"].ptr).max() // 8 + 1)
        else:
            max_dim_divisible_by_8_protein = torch.diff(data["protein"].ptr).max()
            max_dim_divisible_by_8_compound = torch.diff(data["compound"].ptr).max()
        protein_coordinates_batched, _ = to_dense_batch(
            data.node_xyz, data["protein"].batch,
            max_num_nodes=max_dim_divisible_by_8_protein,
            )
        protein_pairwise_representation = get_pair_dis_index(
            protein_coordinates_batched,
            bin_size=2,
            bin_min=-1,
            bin_max=protein_bin_max,
            ) # shape [batch_n, max_protein_size, max_protein_size, 16]
        _compound_lengths = (data["compound"].ptr[1:] - data["compound"].ptr[:-1]) ** 2
        _total = torch.cumsum(_compound_lengths, 0)
        compound_pairwise_distance_batch = torch.zeros(
                _total[-1], dtype=torch.long
            )
        for i in range(len(_total) - 1):
            compound_pairwise_distance_batch[_total[i] : _total[i + 1]] = i + 1
        compound_pair_batched, compound_pair_batched_mask = to_dense_batch(
            data.compound_pair,
            data.compound_pair_batch,
            )
        compound_pairwise_representation = torch.zeros(
            (len(batch), max_dim_divisible_by_8_compound, max_dim_divisible_by_8_compound, 16),
            dtype=torch.float32,
            )
        for i in range(len(batch)):
            one = compound_pair_batched[i]
            compound_size_square = (compound_pairwise_distance_batch == i).sum()
            compound_size = int(compound_size_square**0.5)
            compound_pairwise_representation[i, :compound_size, :compound_size] = one[
                :compound_size_square
                ].reshape((compound_size, compound_size, -1))
        data.batch_n = len(batch)
        data.max_dim_divisible_by_8_protein = max_dim_divisible_by_8_protein
        data.max_dim_divisible_by_8_compound = max_dim_divisible_by_8_compound
        data["protein", "p2p", "protein"].pairwise_representation = protein_pairwise_representation
        data["compound", "p2p", "compound"].pairwise_representation = compound_pairwise_representation
        data["compound", "p2p", "compound"].pairwise_representation_mask = compound_pair_batched_mask
        return data




def get_pair_dis_index(d, bin_size=2, bin_min=-1, bin_max=30):
    """
    Computing pairwise distances and binning.
    """
    pair_dis = torch.cdist(d, d, compute_mode='donot_use_mm_for_euclid_dist')
    pair_dis[pair_dis>bin_max] = bin_max
    pair_dis_bin_index = torch.div(pair_dis - bin_min, bin_size, rounding_mode='floor').long()
    return pair_dis_bin_index

protein_bin_max = 30

In [16]:
train_loader = DataLoader(train, batch_size=4, follow_batch=['x', 'compound_pair'], pin_memory=False, num_workers=2, shuffle=False)

NameError: name 'DataLoader' is not defined

In [11]:
from bindbind.torch_datasets.tankbind_dataloader import TankBindDataLoader

In [17]:
train_loader_2 = TankBindDataLoader(train, batch_size=4, follow_batch=['x', 'compound_pair'], make_divisible_by_8=False, shuffle=False)

In [18]:
train_loader_3 = TankBindDataLoader(train, batch_size=4, follow_batch=['x', 'compound_pair'], make_divisible_by_8=True, shuffle=False)

In [108]:
batch = next(iter(train_loader))

In [19]:
batch_2 = next(iter(train_loader_2))

In [20]:
batch_3 = next(iter(train_loader_3))

Testing whether masking has an influence on the result

In [34]:
model_og.eval()
model_og.to('cpu')
result_2 = model_og(batch_2.to('cpu'))
result_3 = model_og(batch_3.to('cpu'))

In [35]:
result_3

(tensor([4.5944, 4.6123, 4.6162,  ..., 4.5521, 4.5225, 4.6108],
        grad_fn=<MulBackward0>),
 tensor([-0.2154, -0.8206, -0.2637, -0.1369], grad_fn=<LeakyReluBackward0>))

In [36]:
result_2

(tensor([4.5944, 4.6123, 4.6162,  ..., 4.5521, 4.5225, 4.6108],
        grad_fn=<MulBackward0>),
 tensor([-0.2154, -0.8206, -0.2637, -0.1369], grad_fn=<LeakyReluBackward0>))

In [37]:
result_2[0] == result_3[0]

tensor([True, True, True,  ..., True, True, True])

In [38]:
print(torch.linalg.norm(result_2[0] - result_3[0]))

tensor(0., grad_fn=<LinalgVectorNormBackward0>)


In [39]:
print(torch.linalg.norm(result_2[1] - result_3[1]))

tensor(0., grad_fn=<LinalgVectorNormBackward0>)


In [118]:
result_1[0].eq(result_2[0]).all()

tensor(True)

In [119]:
result_1[1].eq(result_2[1]).all()

tensor(True)

In [109]:
batch

HeteroDataBatch(
  dis_map=[9498],
  node_xyz=[675, 3],
  coords=[55, 3],
  y=[9498],
  seq=[675],
  affinity=[4],
  compound_pair=[853, 16],
  compound_pair_batch=[853],
  compound_pair_ptr=[5],
  pdb=[4],
  group=[4],
  real_affinity_mask=[4],
  real_y_mask=[9498],
  is_equivalent_native_pocket=[4],
  equivalent_native_y_mask=[9498],
  protein={
    node_s=[675, 6],
    node_v=[675, 3, 3],
    batch=[675],
    ptr=[5],
  },
  compound={
    x=[55, 56],
    x_batch=[55],
    x_ptr=[5],
    batch=[55],
    ptr=[5],
  },
  (protein, p2p, protein)={
    edge_index=[2, 16185],
    edge_s=[16185, 32],
    edge_v=[16185, 1, 3],
  },
  (compound, c2c, compound)={
    edge_index=[2, 110],
    edge_weight=[110],
    edge_attr=[110, 19],
  }
)

In [88]:
batch_2 = next(iter(train_loader_2))

In [89]:
batch_2

HeteroDataBatch(
  dis_map=[9498],
  node_xyz=[675, 3],
  coords=[55, 3],
  y=[9498],
  seq=[675],
  affinity=[4],
  compound_pair=[853, 16],
  compound_pair_batch=[853],
  compound_pair_ptr=[5],
  pdb=[4],
  group=[4],
  real_affinity_mask=[4],
  real_y_mask=[9498],
  is_equivalent_native_pocket=[4],
  equivalent_native_y_mask=[9498],
  batch_n=4,
  max_dim_divisible_by_8_protein=228,
  max_dim_divisible_by_8_compound=22,
  protein={
    node_s=[675, 6],
    node_v=[675, 3, 3],
    batch=[675],
    ptr=[5],
  },
  compound={
    x=[55, 56],
    x_batch=[55],
    x_ptr=[5],
    batch=[55],
    ptr=[5],
  },
  (protein, p2p, protein)={
    edge_index=[2, 16185],
    edge_s=[16185, 32],
    edge_v=[16185, 1, 3],
    pairwise_representation=[4, 228, 228],
  },
  (compound, c2c, compound)={
    edge_index=[2, 110],
    edge_weight=[110],
    edge_attr=[110, 19],
  },
  (compound, p2p, compound)={
    pairwise_representation=[4, 22, 22, 16],
    pairwise_representation_mask=[4, 484],
  }
)

In [74]:
batch_2

HeteroDataBatch(
  dis_map=[9498],
  node_xyz=[675, 3],
  coords=[55, 3],
  y=[9498],
  seq=[675],
  affinity=[4],
  compound_pair=[853, 16],
  compound_pair_batch=[853],
  compound_pair_ptr=[5],
  pdb=[4],
  group=[4],
  real_affinity_mask=[4],
  real_y_mask=[9498],
  is_equivalent_native_pocket=[4],
  equivalent_native_y_mask=[9498],
  batch_n=4,
  max_dim_divisible_by_8_protein=228,
  max_dim_divisible_by_8_compound=22,
  protein={
    node_s=[675, 6],
    node_v=[675, 3, 3],
    batch=[675],
    ptr=[5],
  },
  compound={
    x=[55, 56],
    x_batch=[55],
    x_ptr=[5],
    batch=[55],
    ptr=[5],
  },
  (protein, p2p, protein)={
    edge_index=[2, 16185],
    edge_s=[16185, 32],
    edge_v=[16185, 1, 3],
    pairwise_representation=[4, 228, 228],
  },
  (compound, c2c, compound)={
    edge_index=[2, 110],
    edge_weight=[110],
    edge_attr=[110, 19],
  },
  (compound, p2p, compound)={
    pairwise_representation=[4, 22, 22, 16],
    pairwise_representation_mask=[4, 484],
  }
)

In [40]:
data = batch_3

In [41]:
max_dim_divisible_by_8_protein = data.max_dim_divisible_by_8_protein
max_dim_divisible_by_8_compound = data.max_dim_divisible_by_8_compound

In [42]:
protein = model_og.conv_protein(
    h_V=(
        data["protein"]["node_s"],
        data["protein"]["node_v"],
    ),
    edge_index=data[("protein", "p2p", "protein")]["edge_index"],
    h_E=(
        data[("protein", "p2p", "protein")]["edge_s"],
        data[("protein", "p2p", "protein")]["edge_v"],
    ),
    seq=data.seq,
)

In [43]:
compound_node_features = data["compound"].x.float()
compound_edge_features = data["compound", "c2c", "compound"].edge_attr
compound_embedding_tensor = model_og.conv_compound(
    data["compound", "c2c", "compound"].edge_index.T,
    torch.ones(data["compound", "c2c", "compound"].edge_index.shape[1], dtype=torch.float32),
    compound_edge_features,
    compound_node_features.shape[0],
    compound_node_features,
)['node_feature']

In [44]:
protein_batched, protein_mask = to_dense_batch(protein, data["protein"].batch, max_num_nodes=max_dim_divisible_by_8_protein)
compound_batched, compound_mask = to_dense_batch(
    compound_embedding_tensor, data["compound"].batch,
    max_num_nodes=max_dim_divisible_by_8_compound,
)
protein_batched = model_og.layernorm(protein_batched)
compound_batched = model_og.layernorm(compound_batched)

In [45]:
protein_pairwise_representation = data["protein", "p2p", "protein"].pairwise_representation # shape [batch_n, max_protein_size, max_protein_size, 16]
compound_pairwise_representation = data["compound", "p2p", "compound"].pairwise_representation # shape [batch_n, max_compound_size, max_compound_size, 16]

In [46]:
new_layer = torch.nn.Embedding(16, 128)

In [48]:
protein_pairwise_representation_one_hot = torch.nn.functional.one_hot(protein_pairwise_representation, num_classes=16).float()

In [49]:
new_layer.weight = torch.nn.Parameter(model_og.protein_pair_embedding.weight.T+model_og.protein_pair_embedding.bias)

In [50]:
protein_pair_og = model_og.protein_pair_embedding(protein_pairwise_representation_one_hot)

In [52]:
protein_pair = new_layer(protein_pairwise_representation)
compound_pair = model_og.compound_pair_embedding(compound_pairwise_representation.float())

In [53]:
protein_pair.shape

torch.Size([4, 232, 232, 128])

In [54]:
######## MASKING ########
batch_n = data.batch_n
z_mask = torch.einsum("bi,bj->bij", protein_mask, compound_mask)
# BUG: hardcoded 4 heads...
z_mask_attention = torch.einsum("bik, bq-> biqk", z_mask, compound_mask).reshape(batch_n*protein_batched.shape[1], max_dim_divisible_by_8_compound, max_dim_divisible_by_8_compound).unsqueeze(1).expand(-1, 4, -1, -1).contiguous()
z_mask_attention = torch.where(z_mask_attention, 0.0, -10.0**6)
z_mask_flat = torch.arange(
    start=0, end=z_mask.numel(), device=model_og.device
).view(z_mask.shape)[z_mask]
protein_square_mask = torch.einsum("bi,bj->bij", protein_mask, protein_mask)

In [55]:
z = torch.einsum("bik,bjk->bijk", protein_batched, compound_batched)

In [61]:
z.shape

torch.Size([4, 232, 24, 128])

In [283]:
from torch.nn import Linear
class TriangleSelfAttentionRowWise(torch.nn.Module):
    # use the protein-compound matrix only.
    def __init__(self, embedding_channels=128, c=32, num_attention_heads=4):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = c
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # self.dp = nn.Dropout(drop_rate)
        # self.ln = nn.LayerNorm(hidden_size)

        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        # self.layernorm_c = torch.nn.LayerNorm(c)

        self.linear_q = Linear(embedding_channels, self.all_head_size, bias=False)
        self.linear_k = Linear(embedding_channels, self.all_head_size, bias=False)
        self.linear_v = Linear(embedding_channels, self.all_head_size, bias=False)
        # self.b = Linear(embedding_channels, h, bias=False)
        self.g = Linear(embedding_channels, self.all_head_size)
        self.final_linear = Linear(self.all_head_size, embedding_channels)

    def reshape_last_dim(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x

    def forward(self, z, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        # z_mask of shape b, i, j
        z = self.layernorm(z)
        p_length = z.shape[1]
        batch_n = z.shape[0]
        # new_z = torch.zeros(z.shape, device=z.device)
        z_i = z
        z_mask_i = z_mask.view((batch_n, p_length, 1, 1, -1))
        attention_mask_i = (1e9 * (z_mask_i.float() - 1.))
        # q, k, v of shape b, j, h, c
        q = self.reshape_last_dim(self.linear_q(z_i))  * (self.attention_head_size**(-0.5))
        k = self.reshape_last_dim(self.linear_k(z_i))
        v = self.reshape_last_dim(self.linear_v(z_i))
        logits = torch.einsum('biqhc,bikhc->bihqk', q, k) + attention_mask_i
        weights = nn.Softmax(dim=-1)(logits)
        # weights of shape b, h, j, j
        # attention_probs = self.dp(attention_probs)
        weighted_avg = torch.einsum('bihqk,bikhc->biqhc', weights, v)
        g = self.reshape_last_dim(self.g(z_i)).sigmoid()
        output = g * weighted_avg
        new_output_shape = output.size()[:-2] + (self.all_head_size,)
        output = output.view(*new_output_shape)
        # output of shape b, j, embedding.
        # z[:, i] = output
        z = output
        # print(g.shape, block1.shape, block2.shape)
        z = self.final_linear(z) * z_mask.unsqueeze(-1)
        return z
    
class FastTriangleSelfAttention(nn.Module):
    def __init__(self, embedding_channels, num_attention_heads):
        super().__init__()
        self.layernorm = nn.LayerNorm(embedding_channels, bias=False)
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = embedding_channels // num_attention_heads
        self.linear_qkv = nn.Linear(embedding_channels, 3*embedding_channels, bias=False)
        self.output_linear = nn.Linear(embedding_channels, embedding_channels)
        self.g = nn.Linear(embedding_channels, embedding_channels)
    def forward(self, z, z_mask_attention_float, z_mask):
        """
        Parameters
        ----------
        z: torch.Tensor of shape [batch, n_protein, n_compound, embedding_channels]
        z_mask: torch.Tensor of shape [batch*n_protein*num_attention_heads, n_compound, n_compound] saying which coefficients
            correspond to actual data. (we take this weird shape because scaled_dot_product_attention
            requires it). We take it to be float("-inf") where we want to mask.
        Returns
        -------
        """
        z = self.layernorm(z)
        batch_size, n_protein, n_compound, embedding_channels = z.shape
        z = z.reshape(batch_size*n_protein, n_compound, embedding_channels)
        q, k, v = self.linear_qkv(z).chunk(3, dim=-1)
        q = q.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        k = k.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        v = v.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        attention_coefficients = xops.memory_efficient_attention(query=q,
                                                key=k,
                                                value=v,
                                                attn_bias=z_mask_attention_float.to("cuda:0")) # shape [batch*protein_nodes, compound_nodes, n_heads, embedding//n_heads]        

        attention_output = attention_coefficients.view(batch_size, n_protein, n_compound, embedding_channels)
        g = self.g(z).sigmoid()
        output = g * attention_output.view(batch_size*n_protein, n_compound, embedding_channels)

        output = self.output_linear(output.view(batch_size, n_protein, n_compound, embedding_channels))*z_mask.unsqueeze(-1).to('cuda:0')
        return output

In [321]:
fast_attention = FastTriangleSelfAttention(embedding_channels=128, num_attention_heads=4)
attention = TriangleSelfAttentionRowWise(embedding_channels=128, c=32, num_attention_heads=4)
fast_attention.layernorm.weight = attention.layernorm.weight
fast_attention.layernorm.bias = attention.layernorm.bias
fast_attention.linear_qkv.weight = torch.nn.Parameter(torch.cat([attention.linear_q.weight, attention.linear_k.weight, attention.linear_v.weight], dim=0))
fast_attention.output_linear.weight = attention.final_linear.weight
fast_attention.output_linear.bias = attention.final_linear.bias
fast_attention.g.weight = attention.g.weight
fast_attention.g.bias = attention.g.bias
attention.to('cuda:0')
fast_attention.to('cuda:0')

FastTriangleSelfAttention(
  (layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (linear_qkv): Linear(in_features=128, out_features=384, bias=False)
  (output_linear): Linear(in_features=128, out_features=128, bias=True)
  (g): Linear(in_features=128, out_features=128, bias=True)
)

In [322]:
z = z_2.clone()
z = z.to("cuda:0")
z = attention.layernorm(z)
p_length = z.shape[1]
batch_n = z.shape[0]
# new_z = torch.zeros(z.shape, device=z.device)
z_i = z
z_mask_i = z_mask.view((batch_n, p_length, 1, 1, -1))
attention_mask_i = (1e9 * (z_mask_i.float() - 1.))
# q, k, v of shape b, j, h, c
q = attention.reshape_last_dim(attention.linear_q(z_i))  * (attention.attention_head_size**(-0.5))
k = attention.reshape_last_dim(attention.linear_k(z_i))
v = attention.reshape_last_dim(attention.linear_v(z_i))
logits = torch.einsum('biqhc,bikhc->bihqk', q, k) + attention_mask_i.to('cuda:0')
weights = nn.Softmax(dim=-1)(logits)
# weights of shape b, h, j, j
# attention_probs = self.dp(attention_probs)
weighted_avg = torch.einsum('bihqk,bikhc->biqhc', weights, v)
g_1 = attention.reshape_last_dim(attention.g(z_i)).sigmoid()
output_1 = g_1 * weighted_avg
new_output_shape = output_1.size()[:-2] + (attention.all_head_size,)
output_1 = output_1.view(*new_output_shape)
# output of shape b, j, embedding.
# z[:, i] = output
z = output_1
# print(g.shape, block1.shape, block2.shape)
z_out = attention.final_linear(z) * z_mask.unsqueeze(-1).to("cuda:0")

In [320]:
fast_attention.g.weight == attention.g.weight

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')

In [323]:
z = z_2.clone().to("cuda:0")
z = fast_attention.layernorm(z)
batch_size, n_protein, n_compound, embedding_channels = z.shape
z = z.reshape(batch_size*n_protein, n_compound, embedding_channels)
q, k, v = fast_attention.linear_qkv(z).chunk(3, dim=-1)
q = q.view(batch_size*n_protein, n_compound, fast_attention.num_attention_heads, fast_attention.attention_head_size).contiguous()
k = k.view(batch_size*n_protein, n_compound, fast_attention.num_attention_heads, fast_attention.attention_head_size).contiguous()
v = v.view(batch_size*n_protein, n_compound, fast_attention.num_attention_heads, fast_attention.attention_head_size).contiguous()
attention_coefficients = xops.memory_efficient_attention(query=q,
                                        key=k,
                                        value=v,
                                        attn_bias=z_mask_attention_float.to("cuda:0")) # shape [batch*protein_nodes, compound_nodes, n_heads, embedding//n_heads]        

attention_output = attention_coefficients.view(batch_size, n_protein, n_compound, embedding_channels)
g = fast_attention.g(z).sigmoid()
output = g * attention_output.view(batch_size*n_protein, n_compound, embedding_channels)

output = fast_attention.output_linear(output.view(batch_size, n_protein, n_compound, embedding_channels))*z_mask.unsqueeze(-1).to('cuda:0')

In [316]:
z_mask.shape

torch.Size([4, 232, 24])

In [318]:
output.shape

torch.Size([4, 232, 24, 128])

In [324]:
output

tensor([[[[ 0.1219,  0.0946, -0.1399,  ...,  0.0886,  0.1805,  0.2867],
          [ 0.1213,  0.0935, -0.1294,  ...,  0.0775,  0.2057,  0.2919],
          [ 0.1147,  0.0873, -0.1234,  ...,  0.0750,  0.1941,  0.3020],
          ...,
          [ 0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0202,  0.1372,  0.1826,  ..., -0.2824, -0.0042,  0.0900],
          [ 0.0239,  0.1361,  0.1826,  ..., -0.2821, -0.0052,  0.1046],
          [ 0.0249,  0.1438,  0.1881,  ..., -0.2812, -0.0038,  0.0968],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000,  0.0000]],

         [[ 0.1441,  0.1901,  0.0691,  ..., -0.2150,  0.0659,  0.2961],
          [ 0.1295,  0.1872,  

In [328]:
(z_out - output).abs().max()

tensor(2.6822e-07, device='cuda:0', grad_fn=<MaxBackward1>)

In [301]:
(weighted_avg - attention_output.view(4, 232, 24, 4, 32)).norm()/attention_output.numel()

tensor(9.6849e-05, device='cuda:0', grad_fn=<DivBackward0>)

In [303]:
output.shape

torch.Size([4, 232, 24, 128])

In [280]:
z = z_2.clone()
fast_attention.to("cuda:0")
fast_attention(z.to("cuda:0"), z_mask_attention.to("cuda:0"), z_mask.to("cuda:0"))

tensor([[[[-4.1723e-01,  1.7940e-01, -3.4570e-01,  ..., -1.7574e-01,
            3.7142e-01, -2.3193e-02],
          [-4.1757e-01,  1.8025e-01, -3.4542e-01,  ..., -1.7599e-01,
            3.7168e-01, -2.3248e-02],
          [-4.1740e-01,  1.7964e-01, -3.4540e-01,  ..., -1.7634e-01,
            3.7122e-01, -2.3353e-02],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -0.0000e+00]],

         [[-9.5555e-02,  1.9026e-01,  6.7507e-02,  ..., -3.1940e-01,
            3.8911e-02, -2.0688e-02],
          [-9.4462e-02,  1.9031e-01,  6.7647e-02,  ..., -3.1976e-01,
            3.8055e-02, -2.1225e-02],
          [-9.5457e-02,  1.9091e-01,  6.7285e-02,  ..., -3.1949e-01,
            3.7964e-02, -2.0498e-02],
          ...,
     

In [281]:
z = z_2.clone()
attention.to("cuda:0")
attention(z.to("cuda:0"), z_mask.to("cuda:0"))

tensor([[[[-8.0932e-02,  4.2761e-01,  7.3466e-02,  ...,  1.0252e-01,
            3.5413e-01, -3.0100e-01],
          [-7.9785e-02,  4.2358e-01,  7.5949e-02,  ...,  7.9677e-02,
            3.4278e-01, -2.9386e-01],
          [-8.1421e-02,  4.4335e-01,  7.5089e-02,  ...,  8.1080e-02,
            3.4571e-01, -3.0940e-01],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -0.0000e+00]],

         [[-1.3750e-01,  1.9741e-01,  1.9737e-01,  ...,  8.0988e-02,
            3.8124e-01, -3.2311e-01],
          [-1.1666e-01,  2.1328e-01,  1.8591e-01,  ...,  8.9159e-02,
            3.5562e-01, -3.1623e-01],
          [-1.2569e-01,  2.1042e-01,  1.9157e-01,  ...,  9.5573e-02,
            3.6962e-01, -3.2571e-01],
          ...,
     

In [69]:
attention = model_og.triangle_self_attention_list[0]
y = z.clone()
y = attention.layernorm(y)
p_length = y.shape[1]
batch_n = y.shape[0]
z_i = y
q = attention.reshape_last_dim(attention.linear_q(z_i))
k = attention.reshape_last_dim(attention.linear_k(z_i))
v = attention.reshape_last_dim(attention.linear_v(z_i))
logits = torch.einsum('biqhc,bikhc->bihqk', q, k)

In [129]:
fast_attention.to('cuda:0')

FastTriangleSelfAttention(
  (layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (linear_qkv): Linear(in_features=128, out_features=384, bias=False)
  (output_linear): Linear(in_features=128, out_features=128, bias=True)
  (g): Linear(in_features=128, out_features=128, bias=True)
)

In [174]:
z = z_2.clone().to('cuda:0')
z_mask_attention = z_mask_attention.to('cuda:0')
z = fast_attention.layernorm(z)
batch_size, n_protein, n_compound, embedding_channels = z.shape
z = z.reshape(batch_size*n_protein, n_compound, embedding_channels)


In [181]:
(z[~z_mask.view(928, 24)]==0).all()

tensor(True, device='cuda:0')

In [182]:
q, k, v = fast_attention.linear_qkv(z).chunk(3, dim=-1)
q = q.view(batch_size*n_protein, n_compound, fast_attention.num_attention_heads, fast_attention.attention_head_size).contiguous()
k = k.view(batch_size*n_protein, n_compound, fast_attention.num_attention_heads, fast_attention.attention_head_size).contiguous()
v = v.view(batch_size*n_protein, n_compound, fast_attention.num_attention_heads, fast_attention.attention_head_size).contiguous()
attention_coefficients = xops.memory_efficient_attention(query=q,
                                        key=k,
                                        value=v,
                                        attn_bias=z_mask_attention) # shape [batch*protein_nodes, compound_nodes, n_heads, embedding//n_heads]        



In [190]:
attention_output = attention_coefficients.permute(0, 2, 1, 3).contiguous().view(batch_size, n_protein, n_compound, embedding_channels)

In [191]:
g = fast_attention.g(z).sigmoid()
output = g * attention_output.view(batch_size*n_protein, n_compound, embedding_channels)

output = fast_attention.output_linear(attention_output)*z_mask.unsqueeze(-1).to('cuda:0')

In [173]:
(output[~z_mask]==0).all()

tensor(True, device='cuda:0')

In [80]:
print(z[~z_mask].norm())

tensor(0., grad_fn=<LinalgVectorNormBackward0>)

In [81]:
z_2 = z.clone()

In [87]:
z = attention.layernorm(z)
p_length = z.shape[1]
batch_n = z.shape[0]
# new_z = torch.zeros(z.shape, device=z.device)
z_i = z
z_mask_i = z_mask.view((batch_n, p_length, 1, 1, -1))
attention_mask_i = (1e9 * (z_mask_i.float() - 1.))
# q, k, v of shape b, j, h, c


In [122]:
attention.layernorm.bias = torch.nn.Parameter(10*torch.ones(128))

In [123]:
print(z[~z_mask].norm())

tensor(1278.6993, grad_fn=<LinalgVectorNormBackward0>)


In [124]:
q = attention.reshape_last_dim(attention.linear_q(z_i)) #  * (self.attention_head_size**(-0.5))
k = attention.reshape_last_dim(attention.linear_k(z_i))
v = attention.reshape_last_dim(attention.linear_v(z_i))
logits = torch.einsum('biqhc,bikhc->bihqk', q, k) + attention_mask_i
weights = nn.Softmax(dim=-1)(logits) + attention_mask_i

In [119]:
_attention_mask_i = attention_mask_i.expand(4, 232, 4, 24, 24)

In [125]:
(logits[_attention_mask_i < 0]>-10).sum()

tensor(0)

In [113]:
weights.view(4*232, 4, 24, 24).shape

torch.Size([928, 4, 24, 24])

In [116]:
print(weights.view(4*232, 4, 24, 24)[(z_mask_attention>-1)].relu().norm())

tensor(62.8368, grad_fn=<LinalgVectorNormBackward0>)


In [None]:

# weights of shape b, h, j, j
# attention_probs = self.dp(attention_probs)
weighted_avg = torch.einsum('bihqk,bikhc->biqhc', weights, v)
g = attention.reshape_last_dim(attention.g(z_i)).sigmoid()
output = g * weighted_avg
new_output_shape = output.size()[:-2] + (attention.all_head_size,)
output = output.view(*new_output_shape)
# output of shape b, j, embedding.
# z[:, i] = output
z = output
# print(g.shape, block1.shape, block2.shape)
z = attention.final_linear(z) * z_mask.unsqueeze(-1)

In [76]:
z_cuda = z.to('cuda:0')
z_mask_cuda = z_mask.to('cuda:0')
z_mask_attention_float = z_mask_attention.to('cuda:0')

In [77]:
fast_attention(z_cuda, z_mask_attention_float, z_mask_cuda)

tensor([[[[-0.1252, -0.1378, -0.4370,  ...,  0.3072, -0.4999,  0.3409],
          [-0.1248, -0.1379, -0.4369,  ...,  0.3072, -0.4995,  0.3411],
          [-0.1254, -0.1377, -0.4361,  ...,  0.3080, -0.4998,  0.3407],
          ...,
          [-0.2171, -0.3801,  0.2941,  ..., -0.1717,  0.0874, -0.1296],
          [-0.1613, -0.3047,  0.2157,  ..., -0.1519,  0.1093, -0.1257],
          [-0.1613, -0.3047,  0.2157,  ..., -0.1519,  0.1093, -0.1257]],

         [[-0.2119, -0.3509, -0.4920,  ...,  0.1195,  0.3097,  0.1126],
          [-0.2116, -0.3504, -0.4922,  ...,  0.1198,  0.3095,  0.1128],
          [-0.2119, -0.3506, -0.4921,  ...,  0.1192,  0.3093,  0.1127],
          ...,
          [-0.0260,  0.2772,  0.1034,  ...,  0.0422, -0.0248, -0.1400],
          [-0.0063,  0.3141,  0.0646,  ...,  0.0381, -0.0311, -0.0800],
          [-0.0063,  0.3141,  0.0646,  ...,  0.0381, -0.0311, -0.0800]],

         [[-0.0747,  0.0261, -0.4693,  ...,  0.2404, -0.2333,  0.0390],
          [-0.0748,  0.0263, -

In [74]:
fast_attention(z, z_mask_attention_float=z_mask_attention, z_mask=z_mask)

NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(928, 24, 4, 32) (torch.float32)
     key         : shape=(928, 24, 4, 32) (torch.float32)
     value       : shape=(928, 24, 4, 32) (torch.float32)
     attn_bias   : <class 'torch.Tensor'>
     p           : 0.0
`flshattF@v2.5.6` is not supported because:
    device=cpu (supported: {'cuda'})
    dtype=torch.float32 (supported: {torch.float16, torch.bfloat16})
    attn_bias type is <class 'torch.Tensor'>
`cutlassF` is not supported because:
    device=cpu (supported: {'cuda'})
`smallkF` is not supported because:
    device=cpu (supported: {'cuda'})
    bias with non-zero stride not supported

In [67]:
z.shape

torch.Size([4, 232, 24, 128])

In [65]:
q.shape

torch.Size([4, 232, 24, 4, 32])

In [280]:
y = z.clone()
for i_module in range(model_og.n_trigonometry_module_stack):
    y = y + model_og.dropout(model_og.protein_to_compound_list[i_module](y, protein_pair, compound_pair, z_mask.unsqueeze(-1)))
    y = y + model_og.dropout(model_og.triangle_self_attention_list[i_module](y, z_mask))
    y = model_og.tranistion(y)

In [281]:
b = model_og.linear(y).squeeze(-1)

In [282]:
y_pred = b[z_mask]
y_pred = y_pred.sigmoid() * 10   # normalize to 0 to 10.
pair_energy = (model_og.gate_linear(y).sigmoid() * model_og.linear_energy(y)).squeeze(-1) * z_mask
affinity_pred = model_og.leaky(model_og.bias + ((pair_energy).sum(axis=(-1, -2))))

In [283]:
torch.abs(y_pred_2-y_pred).sum() # montre qu'on doit faire attention aux biais

tensor(0.0011, grad_fn=<SumBackward0>)

In [284]:
torch.abs(affinity_pred_2-affinity_pred).sum()

tensor(2.0981e-05, grad_fn=<SumBackward0>)

In [230]:
model_og.to("cpu")
y_pred_2, affinity_pred_2 = model_og(batch.to("cpu"))

In [228]:
model_og.to("cuda:1")
y_pred_2, affinity_pred_2 = model_og(batch.to("cuda:1"))
model_og.to("cpu")


OutOfMemoryError: CUDA out of memory. Tried to allocate 40.00 MiB. GPU  has a total capacity of 15.74 GiB of which 27.44 MiB is free. Process 23209 has 1.10 GiB memory in use. Including non-PyTorch memory, this process has 14.61 GiB memory in use. Of the allocated memory 14.34 GiB is allocated by PyTorch, and 129.86 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [197]:
# Original TankBind forward pass

In [243]:
nodes_2 = (data['protein']['node_s'], data['protein']['node_v'])
edges_2 = (data[("protein", "p2p", "protein")]["edge_s"], data[("protein", "p2p", "protein")]["edge_v"])
protein_batch_2 = data['protein'].batch
protein_out_2 = model_og.conv_protein(nodes_2, data[("protein", "p2p", "protein")]["edge_index"], edges_2, data.seq)

compound_x_2 = data['compound'].x.float()
compound_edge_index_2 = data[("compound", "c2c", "compound")].edge_index.T
compound_edge_feature_2 = data[("compound", "c2c", "compound")].edge_attr
edge_weight_2 = data[("compound", "c2c", "compound")].edge_weight
compound_batch_2 = data['compound'].batch
compound_out_2 = model_og.conv_compound(compound_edge_index_2, edge_weight_2, compound_edge_feature_2, compound_x_2.shape[0], compound_x_2)['node_feature']

protein_out_batched_2, protein_out_mask_2 = to_dense_batch(protein_out_2, protein_batch_2)
compound_out_batched_2, compound_out_mask_2 = to_dense_batch(compound_out_2, compound_batch_2)

node_xyz_2 = data.node_xyz

p_coords_batched_2, p_coords_mask_2 = to_dense_batch(node_xyz_2, protein_batch_2)

protein_pair_2 = get_pair_dis_one_hot(p_coords_batched_2, bin_size=2, bin_min=-1, bin_max=model_og.protein_bin_max)
compound_pair_batched_2, compound_pair_batched_mask_2 = to_dense_batch(data.compound_pair, data.compound_pair_batch)
batch_n_2 = compound_pair_batched_2.shape[0]
max_compound_size_square_2 = compound_pair_batched_2.shape[1]
max_compound_size_2 = int(max_compound_size_square_2**0.5)
assert (max_compound_size_2**2 - max_compound_size_square_2)**2 < 1e-4
compound_pair_2 = torch.zeros((batch_n_2, max_compound_size_2, max_compound_size_2, 16)).to(data.compound_pair.device)
for i in range(batch_n_2):
    one_2 = compound_pair_batched_2[i]
    compound_size_square_2 = (data.compound_pair_batch == i).sum()
    compound_size_2 = int(compound_size_square_2**0.5)
    compound_pair_2[i, :compound_size_2, :compound_size_2] = one_2[:compound_size_square_2].reshape(
                                                            (compound_size_2, compound_size_2, -1))

protein_pair_2 = model_og.protein_pair_embedding(protein_pair_2.float())
compound_pair_2 = model_og.compound_pair_embedding(compound_pair_2.float())

protein_out_batched_2 = model_og.layernorm(protein_out_batched_2)
compound_out_batched_2 = model_og.layernorm(compound_out_batched_2)

z_2 = torch.einsum("bik,bjk->bijk", protein_out_batched_2, compound_out_batched_2)
z_mask_2 = torch.einsum("bi,bj->bij", protein_out_mask_2, compound_out_mask_2)


In [216]:
torch.abs(compound_pair-compound_pair_2).sum()

tensor(0., grad_fn=<SumBackward0>)

In [252]:
torch.abs(protein_pair-protein_pair_2).sum()

tensor(0., grad_fn=<SumBackward0>)

In [232]:
torch.abs(protein_batched-protein_out_batched_2).sum()

tensor(0., grad_fn=<SumBackward0>)

In [233]:
torch.abs(compound_batched-compound_out_batched).sum()

tensor(0., grad_fn=<SumBackward0>)

In [234]:
torch.abs(z_2-z).sum()

tensor(0., grad_fn=<SumBackward0>)

In [236]:
z_mask.eq(z_mask_2).all()

tensor(True)

In [244]:
torch.abs(z-z_2).sum()

tensor(0., grad_fn=<SumBackward0>)

In [237]:
for i_module_2 in range(model_og.n_trigonometry_module_stack):
    z_2 = z_2 + model_og.dropout(model_og.protein_to_compound_list[i_module_2](z_2, protein_pair_2, compound_pair_2, z_mask_2.unsqueeze(-1)))
    z_2 = z_2 + model_og.dropout(model_og.triangle_self_attention_list[i_module_2](z_2, z_mask_2))
    z_2 = model_og.tranistion(z_2)


In [246]:
import torch

# Assuming the relevant variables are already defined and initialized:
# y, protein_pair, compound_pair, z_mask, z_2, protein_pair_2, compound_pair_2, z_mask_2

# Define a function to check if two variables are equal
def check_equality(var1, var2, var_name):
    if torch.equal(var1, var2):
        print(f"{var_name} are equal.")
    else:
        print(f"{var_name} are NOT equal.")

# Check the input variables
y = z.clone()
check_equality(y, z_2, "y and z_2")
check_equality(protein_pair, protein_pair_2, "protein_pair and protein_pair_2")
check_equality(compound_pair, compound_pair_2, "compound_pair and compound_pair_2")
check_equality(z_mask, z_mask_2, "z_mask and z_mask_2")

# Run the first code block
for i_module in range(model_og.n_trigonometry_module_stack):
    y = y + model_og.dropout(model_og.protein_to_compound_list[i_module](y, protein_pair, compound_pair, z_mask.unsqueeze(-1)))
    y = y + model_og.dropout(model_og.triangle_self_attention_list[i_module](y, z_mask))
    y = model_og.tranistion(y)

# Save the result of the first code block
output_y = y.clone()

# Run the second code block
for i_module_2 in range(model_og.n_trigonometry_module_stack):
    z_2 = z_2 + model_og.dropout(model_og.protein_to_compound_list[i_module_2](z_2, protein_pair_2, compound_pair_2, z_mask_2.unsqueeze(-1)))
    z_2 = z_2 + model_og.dropout(model_og.triangle_self_attention_list[i_module_2](z_2, z_mask_2))
    z_2 = model_og.tranistion(z_2)

# Save the result of the second code block
output_z_2 = z_2.clone()

# Check the outputs
check_equality(output_y, output_z_2, "Output y and Output z_2")


y and z_2 are equal.
protein_pair and protein_pair_2 are NOT equal.
compound_pair and compound_pair_2 are equal.
z_mask and z_mask_2 are equal.
Output y and Output z_2 are NOT equal.


In [253]:
torch.abs(protein_pair-protein_pair_2).sum()

tensor(0., grad_fn=<SumBackward0>)

In [242]:
torch.abs(z_2-y).sum()

tensor(231392.3594, grad_fn=<SumBackward0>)

In [None]:

b_2 = model_og.linear(z_2).squeeze(-1)
y_pred_2 = b_2[z_mask_2]
y_pred_2 = y_pred_2.sigmoid() * 10   # normalize to 0 to 10.

pair_energy_2 = (model_og.gate_linear(z_2).sigmoid() * model_og.linear_energy(z_2)).squeeze(-1) * z_mask_2
affinity_pred_2 = model_og.leaky(model_og.bias + ((pair_energy_2).sum(axis=(-1, -2))))
return y_pred_2, affinity_pred_2


8 July
- Test fast attention net
- Modify lightning training
- Test original model with 3000 examples

In [14]:
%cd /fs/pool/pool-marsot/tankbind_philip

/fs/gpfs41/lv11/fileset01/pool/pool-marsot/tankbind_philip


In [15]:
from model import IaBNet_with_affinity
from torch_geometric.loader.dataloader import Collater
from torch_geometric.utils import to_dense_batch
import torch


class TankBindDataLoader(torch.utils.data.DataLoader):
    """Subclass of the torch DataLoader, in order to apply the collate function TankBindCollater."""
    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=False,
                 follow_batch=None,
                 exclude_keys=None,
                 make_divisible_by_8=True,
                 **kwargs):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys
        self.make_divisible_by_8=make_divisible_by_8
        super().__init__(dataset,
                         batch_size,
                         shuffle,
                         collate_fn=TankBindCollater(dataset, follow_batch, exclude_keys, make_divisible_by_8=self.make_divisible_by_8),
                         **kwargs)



class TankBindCollater(Collater):
    """Applies batching operations and computations of masks in place of the model, in order to avoid having to recompute it in the
    forward pass on GPU."""
    def __init__(self, dataset,
                 follow_batch=None,
                 exclude_keys=None,
                 make_divisible_by_8=True):
        super().__init__(dataset, follow_batch, exclude_keys)
        self.make_divisible_by_8 = make_divisible_by_8
    def __call__(self, batch):
        data = super().__call__(batch)
        if self.make_divisible_by_8:
            max_dim_divisible_by_8_protein = 8 * (torch.diff(data["protein"].ptr).max() // 8 + 1)
            max_dim_divisible_by_8_compound = 8 * (torch.diff(data["compound"].ptr).max() // 8 + 1)
        else:
            max_dim_divisible_by_8_protein = torch.diff(data["protein"].ptr).max()
            max_dim_divisible_by_8_compound = torch.diff(data["compound"].ptr).max()
        protein_coordinates_batched, _ = to_dense_batch(
            data.node_xyz, data["protein"].batch,
            max_num_nodes=max_dim_divisible_by_8_protein,
            )
        protein_pairwise_representation = get_pair_dis_index(
            protein_coordinates_batched,
            bin_size=2,
            bin_min=-1,
            bin_max=protein_bin_max,
            ) # shape [batch_n, max_protein_size, max_protein_size, 16]
        _compound_lengths = (data["compound"].ptr[1:] - data["compound"].ptr[:-1]) ** 2
        _total = torch.cumsum(_compound_lengths, 0)
        compound_pairwise_distance_batch = torch.zeros(
                _total[-1], dtype=torch.long
            )
        for i in range(len(_total) - 1):
            compound_pairwise_distance_batch[_total[i] : _total[i + 1]] = i + 1
        compound_pair_batched, compound_pair_batched_mask = to_dense_batch(
            data.compound_pair,
            data.compound_pair_batch,
            )
        compound_pairwise_representation = torch.zeros(
            (len(batch), max_dim_divisible_by_8_compound, max_dim_divisible_by_8_compound, 16),
            dtype=torch.float32,
            )
        for i in range(len(batch)):
            one = compound_pair_batched[i]
            compound_size_square = (compound_pairwise_distance_batch == i).sum()
            compound_size = int(compound_size_square**0.5)
            compound_pairwise_representation[i, :compound_size, :compound_size] = one[
                :compound_size_square
                ].reshape((compound_size, compound_size, -1))
        data.batch_n = len(batch)
        data.max_dim_divisible_by_8_protein = max_dim_divisible_by_8_protein
        data.max_dim_divisible_by_8_compound = max_dim_divisible_by_8_compound
        data["protein", "p2p", "protein"].pairwise_representation = protein_pairwise_representation
        data["compound", "p2p", "compound"].pairwise_representation = compound_pairwise_representation
        data["compound", "p2p", "compound"].pairwise_representation_mask = compound_pair_batched_mask
        return data




def get_pair_dis_index(d, bin_size=2, bin_min=-1, bin_max=30):
    """
    Computing pairwise distances and binning.
    """
    pair_dis = torch.cdist(d, d, compute_mode='donot_use_mm_for_euclid_dist')
    pair_dis[pair_dis>bin_max] = bin_max
    pair_dis_bin_index = torch.div(pair_dis - bin_min, bin_size, rounding_mode='floor').long()
    return pair_dis_bin_index

protein_bin_max = 30

train_loader_3 = TankBindDataLoader(train, batch_size=4, follow_batch=['x', 'compound_pair'], make_divisible_by_8=True, shuffle=False)

In [None]:
import torch
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.utils import to_dense_batch
from torch import nn
from torch.nn import Linear
import sys
import torch.nn as nn
from gvp import GVP, GVPConvLayer, LayerNorm, tuple_index
from torch.distributions import Categorical
from torch_scatter import scatter_mean
#from GATv2 import GAT
from GINv2 import GIN

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x



class GVP_embedding(nn.Module):
    '''
    Modified based on https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py
    GVP-GNN for Model Quality Assessment as described in manuscript.
    
    Takes in protein structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a scalar score for
    each graph in the batch in a `torch.Tensor` of shape [n_nodes]
    
    Should be used with `gvp.data.ProteinGraphDataset`, or with generators
    of `torch_geometric.data.Batch` objects with the same attributes.
    
    :param node_in_dim: node dimensions in input graph, should be
                        (6, 3) if using original features
    :param node_h_dim: node dimensions to use in GVP-GNN layers
    :param node_in_dim: edge dimensions in input graph, should be
                        (32, 1) if using original features
    :param edge_h_dim: edge dimensions to embed to before use
                       in GVP-GNN layers
    :seq_in: if `True`, sequences will also be passed in with
             the forward pass; otherwise, sequence information
             is assumed to be part of input node embeddings
    :param num_layers: number of GVP-GNN layers
    :param drop_rate: rate to use in all dropout layers
    '''
    def __init__(self, node_in_dim, node_h_dim, 
                 edge_in_dim, edge_h_dim,
                 seq_in=False, num_layers=3, drop_rate=0.1):

        super(GVP_embedding, self).__init__()
        
        if seq_in:
            self.W_s = nn.Embedding(20, 20)
            node_in_dim = (node_in_dim[0] + 20, node_in_dim[1])
        
        self.W_v = nn.Sequential(
            LayerNorm(node_in_dim),
            GVP(node_in_dim, node_h_dim, activations=(None, None))
        )
        self.W_e = nn.Sequential(
            LayerNorm(edge_in_dim),
            GVP(edge_in_dim, edge_h_dim, activations=(None, None))
        )

        self.layers = nn.ModuleList(
                GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 
            for _ in range(num_layers))
        
        ns, _ = node_h_dim
        self.W_out = nn.Sequential(
            LayerNorm(node_h_dim),
            GVP(node_h_dim, (ns, 0)))

    def forward(self, h_V, edge_index, h_E, seq):      
        '''
        :param h_V: tuple (s, V) of node embeddings
        :param edge_index: `torch.Tensor` of shape [2, num_edges]
        :param h_E: tuple (s, V) of edge embeddings
        :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes]
                    to be embedded and appended to `h_V`
        '''
        seq = self.W_s(seq)
        h_V = (torch.cat([h_V[0], seq], dim=-1), h_V[1])
        h_V = self.W_v(h_V)
        h_E = self.W_e(h_E)
        for layer in self.layers:
            h_V = layer(h_V, edge_index, h_E)
        out = self.W_out(h_V)

        return out


def get_pair_dis_one_hot(d, bin_size=2, bin_min=-1, bin_max=30):
    # without compute_mode='donot_use_mm_for_euclid_dist' could lead to wrong result.
    pair_dis = torch.cdist(d, d, compute_mode='donot_use_mm_for_euclid_dist')
    pair_dis[pair_dis>bin_max] = bin_max
    pair_dis_bin_index = torch.div(pair_dis - bin_min, bin_size, rounding_mode='floor').long()
    pair_dis_one_hot = torch.nn.functional.one_hot(pair_dis_bin_index, num_classes=16)
    return pair_dis_one_hot

class TriangleProteinToCompound(torch.nn.Module):
    def __init__(self, embedding_channels=256, c=128, hasgate=True):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.layernorm_c = torch.nn.LayerNorm(c)
        self.hasgate = hasgate
        if hasgate:
            self.gate_linear = Linear(embedding_channels, c)
        self.linear = Linear(embedding_channels, c)
        self.ending_gate_linear = Linear(embedding_channels, embedding_channels)
        self.linear_after_sum = Linear(c, embedding_channels)
    def forward(self, z, protein_pair, compound_pair, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        # z_mask of shape b, i, j, 1
        z = self.layernorm(z)
        if self.hasgate:
            ab = self.gate_linear(z).sigmoid() * self.linear(z) * z_mask
        else:
            ab = self.linear(z) * z_mask
        g = self.ending_gate_linear(z).sigmoid()
        block1 = torch.einsum("bikc,bkjc->bijc", protein_pair, ab)
        block2 = torch.einsum("bikc,bjkc->bijc", ab, compound_pair)
        z = g * self.linear_after_sum(self.layernorm_c(block1+block2)) * z_mask
        return z

class TriangleProteinToCompound_v2(torch.nn.Module):
    # separate left/right edges (block1/block2).
    def __init__(self, embedding_channels=256, c=128):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels, bias=False)
        self.layernorm_c = torch.nn.LayerNorm(c, bias=False)

        # self.gate_linear1 = Linear(embedding_channels, c)
        # self.gate_linear2 = Linear(embedding_channels, c)
        # modification by Enzo to remove biases. (hypothesis: biases make the outputs dependent on padding)
        self.gate_linear1 = Linear(embedding_channels, c, bias=False)
        self.gate_linear2 = Linear(embedding_channels, c, bias=False)

        self.linear1 = Linear(embedding_channels, c)
        self.linear2 = Linear(embedding_channels, c)

        self.ending_gate_linear = Linear(embedding_channels, embedding_channels)
        self.linear_after_sum = Linear(c, embedding_channels)
    def forward(self, z, protein_pair, compound_pair, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        z = self.layernorm(z)
        protein_pair = self.layernorm(protein_pair)
        compound_pair = self.layernorm(compound_pair)
 
        ab1 = self.gate_linear1(z).sigmoid() * self.linear1(z) * z_mask
        ab2 = self.gate_linear2(z).sigmoid() * self.linear2(z) * z_mask
        protein_pair = self.gate_linear2(protein_pair).sigmoid() * self.linear2(protein_pair)
        compound_pair = self.gate_linear1(compound_pair).sigmoid() * self.linear1(compound_pair)

        g = self.ending_gate_linear(z).sigmoid()
        block1 = torch.einsum("bikc,bkjc->bijc", protein_pair, ab1)
        block2 = torch.einsum("bikc,bjkc->bijc", ab2, compound_pair)
        # print(g.shape, block1.shape, block2.shape)
        z = g * self.linear_after_sum(self.layernorm_c(block1+block2)) * z_mask
        return z

# class Self_Attention(nn.Module):
#     def __init__(self, hidden_size,num_attention_heads=8,drop_rate=0.5):
#         super().__init__()
#         self.num_attention_heads = num_attention_heads
#         self.attention_head_size = int(hidden_size / num_attention_heads)
#         self.all_head_size = self.num_attention_heads * self.attention_head_size
#         self.dp = nn.Dropout(drop_rate)
#         self.ln = nn.LayerNorm(hidden_size)

#     def transpose_for_scores(self, x):
#         new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
#         x = x.view(*new_x_shape)
#         return x.permute(0, 2, 1, 3)

#     def forward(self,q,k,v,attention_mask=None,attention_weight=None):
#         q = self.transpose_for_scores(q)
#         k = self.transpose_for_scores(k)
#         v = self.transpose_for_scores(v)
#         attention_scores = torch.matmul(q, k.transpose(-1, -2))

#         attention_probs = nn.Softmax(dim=-1)(attention_scores)
#         # attention_probs = self.dp(attention_probs)
#         if attention_weight is not None:
#             attention_weight_sorted_sorted = torch.argsort(torch.argsort(-attention_weight,axis=-1),axis=-1)
#             # if self.training:
#             #     top_mask = (attention_weight_sorted_sorted<np.random.randint(28,45))
#             # else:
#             top_mask = (attention_weight_sorted_sorted<32)
#             attention_probs = attention_probs * top_mask
#             # attention_probs = attention_probs * attention_weight
#             attention_probs = attention_probs / (torch.sum(attention_probs,dim=-1,keepdim=True) + 1e-5)
#         # print(attention_probs.shape,v.shape)
#         # attention_probs = self.dp(attention_probs)
#         outputs = torch.matmul(attention_probs, v)

#         outputs = outputs.permute(0, 2, 1, 3).contiguous()
#         new_output_shape = outputs.size()[:-2] + (self.all_head_size,)
#         outputs = outputs.view(*new_output_shape)
#         outputs = self.ln(outputs)
#         return outputs


class FastTriangleSelfAttention(nn.Module):
    def __init__(self, embedding_channels, num_attention_heads):
        super().__init__()
        self.layernorm = nn.LayerNorm(embedding_channels, bias=False)
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = embedding_channels // num_attention_heads
        self.linear_qkv = nn.Linear(embedding_channels, 3*embedding_channels, bias=False)
        self.output_linear = nn.Linear(embedding_channels, embedding_channels)
        self.g = nn.Linear(embedding_channels, embedding_channels)
    def forward(self, z, z_mask_attention_float, z_mask):
        """
        Parameters
        ----------
        z: torch.Tensor of shape [batch, n_protein, n_compound, embedding_channels]
        z_mask: torch.Tensor of shape [batch*n_protein*num_attention_heads, n_compound, n_compound] saying which coefficients
            correspond to actual data. (we take this weird shape because scaled_dot_product_attention
            requires it). We take it to be float("-inf") where we want to mask.
        Returns
        -------
        """
        z = self.layernorm(z)
        batch_size, n_protein, n_compound, embedding_channels = z.shape
        z = z.reshape(batch_size*n_protein, n_compound, embedding_channels)
        q, k, v = self.linear_qkv(z).chunk(3, dim=-1)
        q = q.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        k = k.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        v = v.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        attention_coefficients = xops.memory_efficient_attention(query=q,
                                                key=k,
                                                value=v,
                                                attn_bias=z_mask_attention_float.to("cuda:0")) # shape [batch*protein_nodes, compound_nodes, n_heads, embedding//n_heads]        

        attention_output = attention_coefficients.view(batch_size, n_protein, n_compound, embedding_channels)
        g = self.g(z).sigmoid()
        output = g * attention_output.view(batch_size*n_protein, n_compound, embedding_channels)

        output = self.output_linear(output.view(batch_size, n_protein, n_compound, embedding_channels))*z_mask.unsqueeze(-1).to('cuda:0')
        return output

class TriangleSelfAttentionRowWise(torch.nn.Module):
    # use the protein-compound matrix only.
    def __init__(self, embedding_channels=128, c=32, num_attention_heads=4):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = c
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # self.dp = nn.Dropout(drop_rate)
        # self.ln = nn.LayerNorm(hidden_size)

        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        # self.layernorm_c = torch.nn.LayerNorm(c)

        self.linear_q = Linear(embedding_channels, self.all_head_size, bias=False)
        self.linear_k = Linear(embedding_channels, self.all_head_size, bias=False)
        self.linear_v = Linear(embedding_channels, self.all_head_size, bias=False)
        # self.b = Linear(embedding_channels, h, bias=False)
        self.g = Linear(embedding_channels, self.all_head_size)
        self.final_linear = Linear(self.all_head_size, embedding_channels)

    def reshape_last_dim(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x

    def forward(self, z, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        # z_mask of shape b, i, j
        z = self.layernorm(z)
        p_length = z.shape[1]
        batch_n = z.shape[0]
        # new_z = torch.zeros(z.shape, device=z.device)
        z_i = z
        z_mask_i = z_mask.view((batch_n, p_length, 1, 1, -1))
        attention_mask_i = (1e9 * (z_mask_i.float() - 1.))
        # q, k, v of shape b, j, h, c
        q = self.reshape_last_dim(self.linear_q(z_i)) #  * (self.attention_head_size**(-0.5))
        k = self.reshape_last_dim(self.linear_k(z_i))
        v = self.reshape_last_dim(self.linear_v(z_i))
        logits = torch.einsum('biqhc,bikhc->bihqk', q, k) + attention_mask_i
        weights = nn.Softmax(dim=-1)(logits)
        # weights of shape b, h, j, j
        # attention_probs = self.dp(attention_probs)
        weighted_avg = torch.einsum('bihqk,bikhc->biqhc', weights, v)
        g = self.reshape_last_dim(self.g(z_i)).sigmoid()
        output = g * weighted_avg
        new_output_shape = output.size()[:-2] + (self.all_head_size,)
        output = output.view(*new_output_shape)
        # output of shape b, j, embedding.
        # z[:, i] = output
        z = output
        # print(g.shape, block1.shape, block2.shape)
        z = self.final_linear(z) * z_mask.unsqueeze(-1)
        return z


class Transition(torch.nn.Module):
    # separate left/right edges (block1/block2).
    def __init__(self, embedding_channels=256, n=4):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.linear1 = Linear(embedding_channels, n*embedding_channels)
        self.linear2 = Linear(n*embedding_channels, embedding_channels)
    def forward(self, z):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        z = self.layernorm(z)
        z = self.linear2((self.linear1(z)).relu())
        return z



class IaBNet_with_affinity(torch.nn.Module):
    def __init__(self, hidden_channels=128, embedding_channels=128, c=128, fast_attention=True, mode=0, protein_embed_mode=1, compound_embed_mode=1, n_trigonometry_module_stack=5, protein_bin_max=30, readout_mode=2):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.protein_bin_max = protein_bin_max
        self.mode = mode
        self.protein_embed_mode = protein_embed_mode
        self.compound_embed_mode = compound_embed_mode
        self.n_trigonometry_module_stack = n_trigonometry_module_stack
        self.readout_mode = readout_mode
        # Added by Enzo
        self.fast_attention = fast_attention
        if protein_embed_mode == 0:
            self.conv_protein = GNN(hidden_channels, embedding_channels)
            self.conv_compound = GNN(hidden_channels, embedding_channels)
            # self.conv_protein = SAGEConv((-1, -1), embedding_channels)
            # self.conv_compound = SAGEConv((-1, -1), embedding_channels)
        if protein_embed_mode == 1:
            self.conv_protein = GVP_embedding((6, 3), (embedding_channels, 16), 
                                              (32, 1), (32, 1), seq_in=True)
            

        if compound_embed_mode == 0:
            self.conv_compound = GNN(hidden_channels, embedding_channels)
        elif compound_embed_mode == 1:
            self.conv_compound = GIN(input_dim = 56, hidden_dims = [128,56,embedding_channels], edge_input_dim = 19, concat_hidden = False)

        if mode == 0:
            self.protein_pair_embedding = nn.Embedding(16, c)
            self.compound_pair_embedding = Linear(16, c)
            self.protein_to_compound_list = []
            self.protein_to_compound_list = nn.ModuleList([TriangleProteinToCompound_v2(embedding_channels=embedding_channels, c=c) for _ in range(n_trigonometry_module_stack)])
            if fast_attention:
                self.triangle_self_attention_list = nn.ModuleList([FastTriangleSelfAttention(embedding_channels=embedding_channels, num_attention_heads=4) for _ in range(n_trigonometry_module_stack)])
            else:
                self.triangle_self_attention_list = nn.ModuleList([TriangleSelfAttentionRowWise(embedding_channels=embedding_channels) for _ in range(n_trigonometry_module_stack)])
            self.tranistion = Transition(embedding_channels=embedding_channels, n=4)

        self.linear = Linear(embedding_channels, 1)
        self.linear_energy = Linear(embedding_channels, 1)
        if readout_mode == 2:
            self.gate_linear = Linear(embedding_channels, 1)
        # self.gate_linear = Linear(embedding_channels, 1)
        self.bias = torch.nn.Parameter(torch.ones(1))
        self.leaky = torch.nn.LeakyReLU()
        self.dropout = nn.Dropout2d(p=0.25)
    def forward(self, data):
        # Added by Enzo
        max_dim_divisible_by_8_protein = data.max_dim_divisible_by_8_protein
        max_dim_divisible_by_8_compound = data.max_dim_divisible_by_8_compound
        if self.protein_embed_mode == 0:
            x = data['protein'].x.float()
            edge_index = data[("protein", "p2p", "protein")].edge_index
            protein_batch = data['protein'].batch
            protein_out = self.conv_protein(x, edge_index)
        if self.protein_embed_mode == 1:
            nodes = (data['protein']['node_s'], data['protein']['node_v'])
            edges = (data[("protein", "p2p", "protein")]["edge_s"], data[("protein", "p2p", "protein")]["edge_v"])
            protein_batch = data['protein'].batch
            protein_out = self.conv_protein(nodes, data[("protein", "p2p", "protein")]["edge_index"], edges, data.seq)

        if self.compound_embed_mode == 0:
            compound_x = data['compound'].x.float()
            compound_edge_index = data[("compound", "c2c", "compound")].edge_index
            compound_batch = data['compound'].batch
            compound_out = self.conv_compound(compound_x, compound_edge_index)
        elif self.compound_embed_mode == 1:
            compound_x = data['compound'].x.float()
            compound_edge_index = data[("compound", "c2c", "compound")].edge_index.T
            # compound_edge_index = data[("compound", "c2c", "compound")].edge_index
            compound_edge_feature = data[("compound", "c2c", "compound")].edge_attr
            edge_weight = data[("compound", "c2c", "compound")].edge_weight
            compound_batch = data['compound'].batch
            # Enzo : print dimensions
            #print(f"{compound_edge_index.shape=}, {edge_weight.shape=}, {compound_edge_feature.shape=}, {compound_x.shape=}")
            compound_out = self.conv_compound(compound_edge_index,edge_weight,compound_edge_feature,compound_x.shape[0],compound_x)['node_feature']
    
        # protein_batch version could further process b matrix. better than for loop.
        # protein_out_batched of shape b, n, c
        protein_out_batched, protein_out_mask = to_dense_batch(protein_out, protein_batch, max_num_nodes=max_dim_divisible_by_8_protein)
        compound_out_batched, compound_out_mask = to_dense_batch(compound_out, compound_batch, max_num_nodes=max_dim_divisible_by_8_compound)
        batch_n = data.batch_n
        z_mask = torch.einsum("bi,bj->bij", protein_out_mask, compound_out_mask)
        z_mask_attention = torch.einsum("bik, bq-> biqk", z_mask, compound_out_mask).reshape(batch_n*protein_out_batched.shape[1], max_dim_divisible_by_8_compound, max_dim_divisible_by_8_compound).unsqueeze(1).expand(-1, self.n_heads, -1, -1).contiguous()
        z_mask_attention = torch.where(z_mask_attention, 0.0, -10.0**6)
        z_mask_flat = torch.arange(
            start=0, end=z_mask.numel(), device=self.device
        ).view(z_mask.shape)[z_mask]
        protein_square_mask = torch.einsum("bi,bj->bij", protein_out_mask, protein_out_mask)
        node_xyz = data.node_xyz

        p_coords_batched, p_coords_mask = to_dense_batch(node_xyz, protein_batch)
        # c_coords_batched, c_coords_mask = to_dense_batch(coords, compound_batch)

        protein_pair = data["protein", "p2p", "protein"].pairwise_representation
        
        # compound_pair = get_pair_dis_one_hot(c_coords_batched, bin_size=1, bin_min=-0.5, bin_max=15)
        compound_pair_batched, compound_pair_batched_mask = data["compound", "p2p", "compound"].pairwise_representation, data["compound", "p2p", "compound"].pairwise_representation_mask

        batch_n = compound_pair_batched.shape[0]
        # max_compound_size_square = compound_pair_batched.shape[1]
        # max_compound_size = int(max_compound_size_square**0.5)
        # assert (max_compound_size**2 - max_compound_size_square)**2 < 1e-4
        # compound_pair = torch.zeros((batch_n, max_compound_size, max_compound_size, 16)).to(data.compound_pair.device)
        # for i in range(batch_n):
        #     one = compound_pair_batched[i]
        #     compound_size_square = (data.compound_pair_batch == i).sum()
        #     compound_size = int(compound_size_square**0.5)
        #     compound_pair[i,:compound_size, :compound_size] = one[:compound_size_square].reshape(
        #                                                         (compound_size, compound_size, -1))
        protein_pair = self.protein_pair_embedding(protein_pair)
        compound_pair = self.compound_pair_embedding(data["compound", "p2p", "compound"].pairwise_representation.float())
        # b = torch.einsum("bik,bjk->bij", protein_out_batched, compound_out_batched).flatten()

        protein_out_batched = self.layernorm(protein_out_batched)
        compound_out_batched = self.layernorm(compound_out_batched)
        # z of shape, b, protein_length, compound_length, channels.
        z = torch.einsum("bik,bjk->bijk", protein_out_batched, compound_out_batched)
        # z_mask = torch.einsum("bi,bj->bij", protein_out_mask, compound_out_mask)
        # z = z * z_mask.unsqueeze(-1)
        # print(protein_pair.shape, compound_pair.shape, b.shape)
        if self.mode == 0:
            for _ in range(1):
                for i_module in range(self.n_trigonometry_module_stack):
                    z = z + self.dropout(self.protein_to_compound_list[i_module](z, protein_pair, compound_pair, z_mask.unsqueeze(-1)))
                    if self.fast_attention:
                        z = z + self.dropout(self.triangle_self_attention_list[i_module](z, z_mask_attention, z_mask))
                    else:
                        z = z + self.dropout(self.triangle_self_attention_list[i_module](z, z_mask))
                    z = self.tranistion(z)
        # batch_dim = z.shape[0]

        b = self.linear(z).squeeze(-1)
        y_pred = b.flatten()[z_mask_flat]
        y_pred = y_pred.sigmoid() * 10   # normalize to 0 to 10.
        if self.readout_mode == 0:
            pair_energy = self.linear_energy(z).squeeze(-1) * z_mask
            affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
        if self.readout_mode == 1:
            # valid_interaction_z = (z * z_mask.unsqueeze(-1)).mean(axis=(1, 2))
            valid_interaction_z = (z * z_mask.unsqueeze(-1)).sum(axis=(1, 2)) / z_mask.sum(axis=(1, 2)).unsqueeze(-1)
            affinity_pred = self.linear_energy(valid_interaction_z).squeeze(-1)
            # print("z shape", z.shape, "z_mask shape", z_mask.shape,   "valid_interaction_z shape", valid_interaction_z.shape, "affinity_pred shape", affinity_pred.shape)
        if self.readout_mode == 2:
            pair_energy = (self.gate_linear(z).sigmoid() * self.linear_energy(z)).squeeze(-1) * z_mask
            affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
        return y_pred, affinity_pred