In [None]:
%cd /fs/pool/pool-marsot/tankbind_philip/TankBind/tankbind

In [None]:
import sys; sys.path.append("/fs/pool/pool-marsot")

In [None]:
from data import TankBindDataSet

In [None]:
dataset = TankBindDataSet("/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset")

In [None]:
from model import TankBindDataLoader

In [None]:
import torch
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.utils import to_dense_batch
from torch import nn
from torch.nn import Linear
import sys
import torch.nn as nn
from gvp import GVP, GVPConvLayer, LayerNorm, tuple_index
from torch.distributions import Categorical
from torch_scatter import scatter_mean
#from GATv2 import GAT
from GINv2 import GIN
import xformers.ops as xops

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x



class GVP_embedding(nn.Module):
    '''
    Modified based on https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py
    GVP-GNN for Model Quality Assessment as described in manuscript.
    
    Takes in protein structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a scalar score for
    each graph in the batch in a `torch.Tensor` of shape [n_nodes]
    
    Should be used with `gvp.data.ProteinGraphDataset`, or with generators
    of `torch_geometric.data.Batch` objects with the same attributes.
    
    :param node_in_dim: node dimensions in input graph, should be
                        (6, 3) if using original features
    :param node_h_dim: node dimensions to use in GVP-GNN layers
    :param node_in_dim: edge dimensions in input graph, should be
                        (32, 1) if using original features
    :param edge_h_dim: edge dimensions to embed to before use
                       in GVP-GNN layers
    :seq_in: if `True`, sequences will also be passed in with
             the forward pass; otherwise, sequence information
             is assumed to be part of input node embeddings
    :param num_layers: number of GVP-GNN layers
    :param drop_rate: rate to use in all dropout layers
    '''
    def __init__(self, node_in_dim, node_h_dim, 
                 edge_in_dim, edge_h_dim,
                 seq_in=False, num_layers=3, drop_rate=0.1):

        super(GVP_embedding, self).__init__()
        
        if seq_in:
            self.W_s = nn.Embedding(20, 20)
            node_in_dim = (node_in_dim[0] + 20, node_in_dim[1])
        
        self.W_v = nn.Sequential(
            LayerNorm(node_in_dim),
            GVP(node_in_dim, node_h_dim, activations=(None, None))
        )
        self.W_e = nn.Sequential(
            LayerNorm(edge_in_dim),
            GVP(edge_in_dim, edge_h_dim, activations=(None, None))
        )

        self.layers = nn.ModuleList(
                GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 
            for _ in range(num_layers))
        
        ns, _ = node_h_dim
        self.W_out = nn.Sequential(
            LayerNorm(node_h_dim),
            GVP(node_h_dim, (ns, 0)))

    def forward(self, h_V, edge_index, h_E, seq):      
        '''
        :param h_V: tuple (s, V) of node embeddings
        :param edge_index: `torch.Tensor` of shape [2, num_edges]
        :param h_E: tuple (s, V) of edge embeddings
        :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes]
                    to be embedded and appended to `h_V`
        '''
        seq = self.W_s(seq)
        h_V = (torch.cat([h_V[0], seq], dim=-1), h_V[1])
        h_V = self.W_v(h_V)
        h_E = self.W_e(h_E)
        for layer in self.layers:
            h_V = layer(h_V, edge_index, h_E)
        out = self.W_out(h_V)

        return out


def get_pair_dis_one_hot(d, bin_size=2, bin_min=-1, bin_max=30):
    # without compute_mode='donot_use_mm_for_euclid_dist' could lead to wrong result.
    pair_dis = torch.cdist(d, d, compute_mode='donot_use_mm_for_euclid_dist')
    pair_dis[pair_dis>bin_max] = bin_max
    pair_dis_bin_index = torch.div(pair_dis - bin_min, bin_size, rounding_mode='floor').long()
    pair_dis_one_hot = torch.nn.functional.one_hot(pair_dis_bin_index, num_classes=16)
    return pair_dis_one_hot

class TriangleProteinToCompound(torch.nn.Module):
    def __init__(self, embedding_channels=256, c=128, hasgate=True):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.layernorm_c = torch.nn.LayerNorm(c)
        self.hasgate = hasgate
        if hasgate:
            self.gate_linear = Linear(embedding_channels, c)
        self.linear = Linear(embedding_channels, c)
        self.ending_gate_linear = Linear(embedding_channels, embedding_channels)
        self.linear_after_sum = Linear(c, embedding_channels)
    def forward(self, z, protein_pair, compound_pair, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        # z_mask of shape b, i, j, 1
        z = self.layernorm(z)
        if self.hasgate:
            ab = self.gate_linear(z).sigmoid() * self.linear(z) * z_mask
        else:
            ab = self.linear(z) * z_mask
        g = self.ending_gate_linear(z).sigmoid()
        block1 = torch.einsum("bikc,bkjc->bijc", protein_pair, ab)
        block2 = torch.einsum("bikc,bjkc->bijc", ab, compound_pair)
        z = g * self.linear_after_sum(self.layernorm_c(block1+block2)) * z_mask
        return z

class TriangleProteinToCompound_v2(torch.nn.Module):
    # separate left/right edges (block1/block2).
    def __init__(self, embedding_channels=256, c=128):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels, bias=False)
        self.layernorm_c = torch.nn.LayerNorm(c, bias=False)

        # self.gate_linear1 = Linear(embedding_channels, c)
        # self.gate_linear2 = Linear(embedding_channels, c)
        # modification by Enzo to remove biases. (hypothesis: biases make the outputs dependent on padding)
        self.gate_linear1 = Linear(embedding_channels, c, bias=False)
        self.gate_linear2 = Linear(embedding_channels, c, bias=False)

        self.linear1 = Linear(embedding_channels, c)
        self.linear2 = Linear(embedding_channels, c)

        self.ending_gate_linear = Linear(embedding_channels, embedding_channels)
        self.linear_after_sum = Linear(c, embedding_channels)
    def forward(self, z, protein_pair, compound_pair, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        z = self.layernorm(z)
        protein_pair = self.layernorm(protein_pair)
        compound_pair = self.layernorm(compound_pair)
 
        ab1 = self.gate_linear1(z).sigmoid() * self.linear1(z) * z_mask
        ab2 = self.gate_linear2(z).sigmoid() * self.linear2(z) * z_mask
        protein_pair = self.gate_linear2(protein_pair).sigmoid() * self.linear2(protein_pair)
        compound_pair = self.gate_linear1(compound_pair).sigmoid() * self.linear1(compound_pair)

        g = self.ending_gate_linear(z).sigmoid()
        block1 = torch.einsum("bikc,bkjc->bijc", protein_pair, ab1)
        block2 = torch.einsum("bikc,bjkc->bijc", ab2, compound_pair)
        # print(g.shape, block1.shape, block2.shape)
        z = g * self.linear_after_sum(self.layernorm_c(block1+block2)) * z_mask
        return z

# class Self_Attention(nn.Module):
#     def __init__(self, hidden_size,num_attention_heads=8,drop_rate=0.5):
#         super().__init__()
#         self.num_attention_heads = num_attention_heads
#         self.attention_head_size = int(hidden_size / num_attention_heads)
#         self.all_head_size = self.num_attention_heads * self.attention_head_size
#         self.dp = nn.Dropout(drop_rate)
#         self.ln = nn.LayerNorm(hidden_size)

#     def transpose_for_scores(self, x):
#         new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
#         x = x.view(*new_x_shape)
#         return x.permute(0, 2, 1, 3)

#     def forward(self,q,k,v,attention_mask=None,attention_weight=None):
#         q = self.transpose_for_scores(q)
#         k = self.transpose_for_scores(k)
#         v = self.transpose_for_scores(v)
#         attention_scores = torch.matmul(q, k.transpose(-1, -2))

#         attention_probs = nn.Softmax(dim=-1)(attention_scores)
#         # attention_probs = self.dp(attention_probs)
#         if attention_weight is not None:
#             attention_weight_sorted_sorted = torch.argsort(torch.argsort(-attention_weight,axis=-1),axis=-1)
#             # if self.training:
#             #     top_mask = (attention_weight_sorted_sorted<np.random.randint(28,45))
#             # else:
#             top_mask = (attention_weight_sorted_sorted<32)
#             attention_probs = attention_probs * top_mask
#             # attention_probs = attention_probs * attention_weight
#             attention_probs = attention_probs / (torch.sum(attention_probs,dim=-1,keepdim=True) + 1e-5)
#         # print(attention_probs.shape,v.shape)
#         # attention_probs = self.dp(attention_probs)
#         outputs = torch.matmul(attention_probs, v)

#         outputs = outputs.permute(0, 2, 1, 3).contiguous()
#         new_output_shape = outputs.size()[:-2] + (self.all_head_size,)
#         outputs = outputs.view(*new_output_shape)
#         outputs = self.ln(outputs)
#         return outputs


class FastTriangleSelfAttention(nn.Module):
    def __init__(self, embedding_channels, num_attention_heads):
        super().__init__()
        self.layernorm = nn.LayerNorm(embedding_channels, bias=False)
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = embedding_channels // num_attention_heads
        self.linear_qkv = nn.Linear(embedding_channels, 3*embedding_channels, bias=False)
        self.output_linear = nn.Linear(embedding_channels, embedding_channels)
        self.g = nn.Linear(embedding_channels, embedding_channels)
    def forward(self, z, z_mask_attention_float, z_mask):
        """
        Parameters
        ----------
        z: torch.Tensor of shape [batch, n_protein, n_compound, embedding_channels]
        z_mask: torch.Tensor of shape [batch*n_protein*num_attention_heads, n_compound, n_compound] saying which coefficients
            correspond to actual data. (we take this weird shape because scaled_dot_product_attention
            requires it). We take it to be float("-inf") where we want to mask.
        Returns
        -------
        """
        z = self.layernorm(z)
        batch_size, n_protein, n_compound, embedding_channels = z.shape
        z = z.reshape(batch_size*n_protein, n_compound, embedding_channels)
        q, k, v = self.linear_qkv(z).chunk(3, dim=-1)
        q = q.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        k = k.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        v = v.view(batch_size*n_protein, n_compound, self.num_attention_heads, self.attention_head_size).contiguous()
        attention_coefficients = xops.memory_efficient_attention(query=q,
                                                key=k,
                                                value=v,
                                                attn_bias=z_mask_attention_float.to("cuda:0")) # shape [batch*protein_nodes, compound_nodes, n_heads, embedding//n_heads]        

        attention_output = attention_coefficients.view(batch_size, n_protein, n_compound, embedding_channels)
        g = self.g(z).sigmoid()
        output = g * attention_output.view(batch_size*n_protein, n_compound, embedding_channels)

        output = self.output_linear(output.view(batch_size, n_protein, n_compound, embedding_channels))*z_mask.unsqueeze(-1).to('cuda:0')
        return output

class TriangleSelfAttentionRowWise(torch.nn.Module):
    # use the protein-compound matrix only.
    def __init__(self, embedding_channels=128, c=32, num_attention_heads=4):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = c
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # self.dp = nn.Dropout(drop_rate)
        # self.ln = nn.LayerNorm(hidden_size)

        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        # self.layernorm_c = torch.nn.LayerNorm(c)

        self.linear_q = Linear(embedding_channels, self.all_head_size, bias=False)
        self.linear_k = Linear(embedding_channels, self.all_head_size, bias=False)
        self.linear_v = Linear(embedding_channels, self.all_head_size, bias=False)
        # self.b = Linear(embedding_channels, h, bias=False)
        self.g = Linear(embedding_channels, self.all_head_size)
        self.final_linear = Linear(self.all_head_size, embedding_channels)

    def reshape_last_dim(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x

    def forward(self, z, z_mask):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        # z_mask of shape b, i, j
        z = self.layernorm(z)
        p_length = z.shape[1]
        batch_n = z.shape[0]
        # new_z = torch.zeros(z.shape, device=z.device)
        z_i = z
        z_mask_i = z_mask.view((batch_n, p_length, 1, 1, -1))
        attention_mask_i = (1e9 * (z_mask_i.float() - 1.))
        # q, k, v of shape b, j, h, c
        q = self.reshape_last_dim(self.linear_q(z_i)) #  * (self.attention_head_size**(-0.5))
        k = self.reshape_last_dim(self.linear_k(z_i))
        v = self.reshape_last_dim(self.linear_v(z_i))
        logits = torch.einsum('biqhc,bikhc->bihqk', q, k) + attention_mask_i
        weights = nn.Softmax(dim=-1)(logits)
        # weights of shape b, h, j, j
        # attention_probs = self.dp(attention_probs)
        weighted_avg = torch.einsum('bihqk,bikhc->biqhc', weights, v)
        g = self.reshape_last_dim(self.g(z_i)).sigmoid()
        output = g * weighted_avg
        new_output_shape = output.size()[:-2] + (self.all_head_size,)
        output = output.view(*new_output_shape)
        # output of shape b, j, embedding.
        # z[:, i] = output
        z = output
        # print(g.shape, block1.shape, block2.shape)
        z = self.final_linear(z) * z_mask.unsqueeze(-1)
        return z


class Transition(torch.nn.Module):
    # separate left/right edges (block1/block2).
    def __init__(self, embedding_channels=256, n=4):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.linear1 = Linear(embedding_channels, n*embedding_channels)
        self.linear2 = Linear(n*embedding_channels, embedding_channels)
    def forward(self, z):
        # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim.
        z = self.layernorm(z)
        z = self.linear2((self.linear1(z)).relu())
        return z



class IaBNet_with_affinity(torch.nn.Module):
    def __init__(self, hidden_channels=128, embedding_channels=128, c=128, fast_attention=True, mode=0, protein_embed_mode=1, compound_embed_mode=1, n_trigonometry_module_stack=5, protein_bin_max=30, readout_mode=2):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.protein_bin_max = protein_bin_max
        self.mode = mode
        self.protein_embed_mode = protein_embed_mode
        self.compound_embed_mode = compound_embed_mode
        self.n_trigonometry_module_stack = n_trigonometry_module_stack
        self.readout_mode = readout_mode
        self.n_heads=4
        # Added by Enzo
        self.fast_attention = fast_attention
        if protein_embed_mode == 0:
            self.conv_protein = GNN(hidden_channels, embedding_channels)
            self.conv_compound = GNN(hidden_channels, embedding_channels)
            # self.conv_protein = SAGEConv((-1, -1), embedding_channels)
            # self.conv_compound = SAGEConv((-1, -1), embedding_channels)
        if protein_embed_mode == 1:
            self.conv_protein = GVP_embedding((6, 3), (embedding_channels, 16), 
                                              (32, 1), (32, 1), seq_in=True)
            

        if compound_embed_mode == 0:
            self.conv_compound = GNN(hidden_channels, embedding_channels)
        elif compound_embed_mode == 1:
            self.conv_compound = GIN(input_dim = 56, hidden_dims = [128,56,embedding_channels], edge_input_dim = 19, concat_hidden = False)

        if mode == 0:
            self.protein_pair_embedding = nn.Embedding(16, c)
            self.compound_pair_embedding = nn.Linear(16, c)
            self.protein_to_compound_list = []
            self.protein_to_compound_list = nn.ModuleList([TriangleProteinToCompound_v2(embedding_channels=embedding_channels, c=c) for _ in range(n_trigonometry_module_stack)])
            if fast_attention:
                self.triangle_self_attention_list = nn.ModuleList([FastTriangleSelfAttention(embedding_channels=embedding_channels, num_attention_heads=4) for _ in range(n_trigonometry_module_stack)])
            else:
                self.triangle_self_attention_list = nn.ModuleList([TriangleSelfAttentionRowWise(embedding_channels=embedding_channels) for _ in range(n_trigonometry_module_stack)])
            self.tranistion = Transition(embedding_channels=embedding_channels, n=4)

        self.linear = Linear(embedding_channels, 1)
        self.linear_energy = Linear(embedding_channels, 1)
        if readout_mode == 2:
            self.gate_linear = Linear(embedding_channels, 1)
        # self.gate_linear = Linear(embedding_channels, 1)
        self.bias = torch.nn.Parameter(torch.ones(1))
        self.leaky = torch.nn.LeakyReLU()
        self.dropout = nn.Dropout2d(p=0.25)
    def forward(self, data):
        # Added by Enzo
        max_dim_divisible_by_8_protein = data.max_dim_divisible_by_8_protein
        max_dim_divisible_by_8_compound = data.max_dim_divisible_by_8_compound
        if self.protein_embed_mode == 0:
            x = data['protein'].x.float()
            edge_index = data[("protein", "p2p", "protein")].edge_index
            protein_batch = data['protein'].batch
            protein_out = self.conv_protein(x, edge_index)
        if self.protein_embed_mode == 1:
            nodes = (data['protein']['node_s'], data['protein']['node_v'])
            edges = (data[("protein", "p2p", "protein")]["edge_s"], data[("protein", "p2p", "protein")]["edge_v"])
            protein_batch = data['protein'].batch
            protein_out = self.conv_protein(nodes, data[("protein", "p2p", "protein")]["edge_index"], edges, data.seq)

        if self.compound_embed_mode == 0:
            compound_x = data['compound'].x.float()
            compound_edge_index = data[("compound", "c2c", "compound")].edge_index
            compound_batch = data['compound'].batch
            compound_out = self.conv_compound(compound_x, compound_edge_index)
        elif self.compound_embed_mode == 1:
            compound_x = data['compound'].x.float()
            compound_edge_index = data[("compound", "c2c", "compound")].edge_index.T
            # compound_edge_index = data[("compound", "c2c", "compound")].edge_index
            compound_edge_feature = data[("compound", "c2c", "compound")].edge_attr
            edge_weight = data[("compound", "c2c", "compound")].edge_weight
            compound_batch = data['compound'].batch
            # Enzo : print dimensions
            #print(f"{compound_edge_index.shape=}, {edge_weight.shape=}, {compound_edge_feature.shape=}, {compound_x.shape=}")
            compound_out = self.conv_compound(compound_edge_index,edge_weight,compound_edge_feature,compound_x.shape[0],compound_x)['node_feature']
    
        # protein_batch version could further process b matrix. better than for loop.
        # protein_out_batched of shape b, n, c
        protein_out_batched, protein_out_mask = to_dense_batch(protein_out, protein_batch, max_num_nodes=max_dim_divisible_by_8_protein)
        compound_out_batched, compound_out_mask = to_dense_batch(compound_out, compound_batch, max_num_nodes=max_dim_divisible_by_8_compound)
        batch_n = data.batch_n
        z_mask = torch.einsum("bi,bj->bij", protein_out_mask, compound_out_mask)
        z_mask_attention = torch.einsum("bik, bq-> biqk", z_mask, compound_out_mask).reshape(batch_n*protein_out_batched.shape[1], max_dim_divisible_by_8_compound, max_dim_divisible_by_8_compound).unsqueeze(1).expand(-1, self.n_heads, -1, -1).contiguous()
        z_mask_attention = torch.where(z_mask_attention, 0.0, -10.0**6)
        z_mask_flat = torch.arange(
            start=0, end=z_mask.numel(), device=self.device
        ).view(z_mask.shape)[z_mask]
        protein_square_mask = torch.einsum("bi,bj->bij", protein_out_mask, protein_out_mask)
        node_xyz = data.node_xyz

        p_coords_batched, p_coords_mask = to_dense_batch(node_xyz, protein_batch)
        # c_coords_batched, c_coords_mask = to_dense_batch(coords, compound_batch)

        protein_pair = data["protein", "p2p", "protein"].pairwise_representation
        
        # compound_pair = get_pair_dis_one_hot(c_coords_batched, bin_size=1, bin_min=-0.5, bin_max=15)
        compound_pair_batched, compound_pair_batched_mask = data["compound", "p2p", "compound"].pairwise_representation, data["compound", "p2p", "compound"].pairwise_representation_mask

        batch_n = compound_pair_batched.shape[0]
        # max_compound_size_square = compound_pair_batched.shape[1]
        # max_compound_size = int(max_compound_size_square**0.5)
        # assert (max_compound_size**2 - max_compound_size_square)**2 < 1e-4
        # compound_pair = torch.zeros((batch_n, max_compound_size, max_compound_size, 16)).to(data.compound_pair.device)
        # for i in range(batch_n):
        #     one = compound_pair_batched[i]
        #     compound_size_square = (data.compound_pair_batch == i).sum()
        #     compound_size = int(compound_size_square**0.5)
        #     compound_pair[i,:compound_size, :compound_size] = one[:compound_size_square].reshape(
        #                                                         (compound_size, compound_size, -1))
        protein_pair = self.protein_pair_embedding(protein_pair)
        compound_pair = self.compound_pair_embedding(data["compound", "p2p", "compound"].pairwise_representation.float())
        # b = torch.einsum("bik,bjk->bij", protein_out_batched, compound_out_batched).flatten()

        protein_out_batched = self.layernorm(protein_out_batched)
        compound_out_batched = self.layernorm(compound_out_batched)
        # z of shape, b, protein_length, compound_length, channels.
        z = torch.einsum("bik,bjk->bijk", protein_out_batched, compound_out_batched)
        # z_mask = torch.einsum("bi,bj->bij", protein_out_mask, compound_out_mask)
        # z = z * z_mask.unsqueeze(-1)
        # print(protein_pair.shape, compound_pair.shape, b.shape)
        if self.mode == 0:
            for _ in range(1):
                for i_module in range(self.n_trigonometry_module_stack):
                    z = z + self.dropout(self.protein_to_compound_list[i_module](z, protein_pair, compound_pair, z_mask.unsqueeze(-1)))
                    if self.fast_attention:
                        z = z + self.dropout(self.triangle_self_attention_list[i_module](z, z_mask_attention, z_mask))
                    else:
                        z = z + self.dropout(self.triangle_self_attention_list[i_module](z, z_mask))
                    z = self.tranistion(z)
        # batch_dim = z.shape[0]

        b = self.linear(z).squeeze(-1)
        y_pred = b.flatten()[z_mask_flat]
        y_pred = y_pred.sigmoid() * 10   # normalize to 0 to 10.
        if self.readout_mode == 0:
            pair_energy = self.linear_energy(z).squeeze(-1) * z_mask
            affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
        if self.readout_mode == 1:
            # valid_interaction_z = (z * z_mask.unsqueeze(-1)).mean(axis=(1, 2))
            valid_interaction_z = (z * z_mask.unsqueeze(-1)).sum(axis=(1, 2)) / z_mask.sum(axis=(1, 2)).unsqueeze(-1)
            affinity_pred = self.linear_energy(valid_interaction_z).squeeze(-1)
            # print("z shape", z.shape, "z_mask shape", z_mask.shape,   "valid_interaction_z shape", valid_interaction_z.shape, "affinity_pred shape", affinity_pred.shape)
        if self.readout_mode == 2:
            pair_energy = (self.gate_linear(z).sigmoid() * self.linear_energy(z)).squeeze(-1) * z_mask
            affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
        return y_pred, affinity_pred
# Added by Enzo
from torch_geometric.loader.dataloader import Collater
from torch_geometric.utils import to_dense_batch
import torch


class TankBindDataLoader(torch.utils.data.DataLoader):
    """Subclass of the torch DataLoader, in order to apply the collate function TankBindCollater."""
    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=False,
                 follow_batch=None,
                 exclude_keys=None,
                 make_divisible_by_8=True,
                 **kwargs):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys
        self.make_divisible_by_8=make_divisible_by_8
        super().__init__(dataset,
                         batch_size,
                         shuffle,
                         collate_fn=TankBindCollater(dataset, follow_batch, exclude_keys, make_divisible_by_8=self.make_divisible_by_8),
                         **kwargs)



class TankBindCollater(Collater):
    """Applies batching operations and computations of masks in place of the model, in order to avoid having to recompute it in the
    forward pass on GPU."""
    def __init__(self, dataset,
                 follow_batch=None,
                 exclude_keys=None,
                 make_divisible_by_8=True):
        super().__init__(dataset, follow_batch, exclude_keys)
        self.make_divisible_by_8 = make_divisible_by_8
    def __call__(self, batch):
        data = super().__call__(batch)
        if self.make_divisible_by_8:
            max_dim_divisible_by_8_protein = 8 * (torch.diff(data["protein"].ptr).max() // 8 + 1)
            max_dim_divisible_by_8_compound = 8 * (torch.diff(data["compound"].ptr).max() // 8 + 1)
        else:
            max_dim_divisible_by_8_protein = torch.diff(data["protein"].ptr).max()
            max_dim_divisible_by_8_compound = torch.diff(data["compound"].ptr).max()
        protein_coordinates_batched, _ = to_dense_batch(
            data.node_xyz, data["protein"].batch,
            max_num_nodes=max_dim_divisible_by_8_protein,
            )
        protein_pairwise_representation = get_pair_dis_index(
            protein_coordinates_batched,
            bin_size=2,
            bin_min=-1,
            bin_max=protein_bin_max,
            ) # shape [batch_n, max_protein_size, max_protein_size, 16]
        _compound_lengths = (data["compound"].ptr[1:] - data["compound"].ptr[:-1]) ** 2
        _total = torch.cumsum(_compound_lengths, 0)
        compound_pairwise_distance_batch = torch.zeros(
                _total[-1], dtype=torch.long
            )
        for i in range(len(_total) - 1):
            compound_pairwise_distance_batch[_total[i] : _total[i + 1]] = i + 1
        compound_pair_batched, compound_pair_batched_mask = to_dense_batch(
            data.compound_pair,
            data.compound_pair_batch,
            )
        compound_pairwise_representation = torch.zeros(
            (len(batch), max_dim_divisible_by_8_compound, max_dim_divisible_by_8_compound, 16),
            dtype=torch.float32,
            )
        for i in range(len(batch)):
            one = compound_pair_batched[i]
            compound_size_square = (compound_pairwise_distance_batch == i).sum()
            compound_size = int(compound_size_square**0.5)
            compound_pairwise_representation[i, :compound_size, :compound_size] = one[
                :compound_size_square
                ].reshape((compound_size, compound_size, -1))
        data.batch_n = len(batch)
        data.max_dim_divisible_by_8_protein = max_dim_divisible_by_8_protein
        data.max_dim_divisible_by_8_compound = max_dim_divisible_by_8_compound
        data["protein", "p2p", "protein"].pairwise_representation = protein_pairwise_representation
        data["compound", "p2p", "compound"].pairwise_representation = compound_pairwise_representation
        data["compound", "p2p", "compound"].pairwise_representation_mask = compound_pair_batched_mask
        return data




def get_pair_dis_index(d, bin_size=2, bin_min=-1, bin_max=30):
    """
    Computing pairwise distances and binning.
    """
    pair_dis = torch.cdist(d, d, compute_mode='donot_use_mm_for_euclid_dist')
    pair_dis[pair_dis>bin_max] = bin_max
    pair_dis_bin_index = torch.div(pair_dis - bin_min, bin_size, rounding_mode='floor').long()
    return pair_dis_bin_index

protein_bin_max = 30


def get_model(mode, logging, device):
    if mode == 0:
        logging.info("5 stack, readout2, pred dis map add self attention and GVP embed, compound model GIN")
        model = IaBNet_with_affinity().to(device)
    return model


In [None]:
class IaBNet_with_affinity(torch.nn.Module):
    def __init__(self, hidden_channels=128, embedding_channels=128, c=128, fast_attention=True, mode=0, protein_embed_mode=1, compound_embed_mode=1, n_trigonometry_module_stack=5, protein_bin_max=30, readout_mode=2):
        super().__init__()
        self.layernorm = torch.nn.LayerNorm(embedding_channels)
        self.protein_bin_max = protein_bin_max
        self.mode = mode
        self.protein_embed_mode = protein_embed_mode
        self.compound_embed_mode = compound_embed_mode
        self.n_trigonometry_module_stack = n_trigonometry_module_stack
        self.readout_mode = readout_mode
        self.n_heads=4
        # Added by Enzo
        self.fast_attention = fast_attention
        if protein_embed_mode == 0:
            self.conv_protein = GNN(hidden_channels, embedding_channels)
            self.conv_compound = GNN(hidden_channels, embedding_channels)
            # self.conv_protein = SAGEConv((-1, -1), embedding_channels)
            # self.conv_compound = SAGEConv((-1, -1), embedding_channels)
        if protein_embed_mode == 1:
            self.conv_protein = GVP_embedding((6, 3), (embedding_channels, 16), 
                                              (32, 1), (32, 1), seq_in=True)
            

        if compound_embed_mode == 0:
            self.conv_compound = GNN(hidden_channels, embedding_channels)
        elif compound_embed_mode == 1:
            self.conv_compound = GIN(input_dim = 56, hidden_dims = [128,56,embedding_channels], edge_input_dim = 19, concat_hidden = False)

        if mode == 0:
            self.protein_pair_embedding = nn.Embedding(16, c)
            self.compound_pair_embedding = Linear(16, c)
            self.protein_to_compound_list = []
            self.protein_to_compound_list = nn.ModuleList([TriangleProteinToCompound_v2(embedding_channels=embedding_channels, c=c) for _ in range(n_trigonometry_module_stack)])
            if fast_attention:
                self.triangle_self_attention_list = nn.ModuleList([FastTriangleSelfAttention(embedding_channels=embedding_channels, num_attention_heads=4) for _ in range(n_trigonometry_module_stack)])
            else:
                self.triangle_self_attention_list = nn.ModuleList([TriangleSelfAttentionRowWise(embedding_channels=embedding_channels) for _ in range(n_trigonometry_module_stack)])
            self.tranistion = Transition(embedding_channels=embedding_channels, n=4)

        self.linear = Linear(embedding_channels, 1)
        self.linear_energy = Linear(embedding_channels, 1)
        if readout_mode == 2:
            self.gate_linear = Linear(embedding_channels, 1)
        # self.gate_linear = Linear(embedding_channels, 1)
        self.bias = torch.nn.Parameter(torch.ones(1))
        self.leaky = torch.nn.LeakyReLU()
        self.dropout = nn.Dropout2d(p=0.25)
    def forward(self, data):
        # Added by Enzo
        max_dim_divisible_by_8_protein = data.max_dim_divisible_by_8_protein
        max_dim_divisible_by_8_compound = data.max_dim_divisible_by_8_compound
        if self.protein_embed_mode == 0:
            x = data['protein'].x.float()
            edge_index = data[("protein", "p2p", "protein")].edge_index
            protein_batch = data['protein'].batch
            protein_out = self.conv_protein(x, edge_index)
        if self.protein_embed_mode == 1:
            nodes = (data['protein']['node_s'], data['protein']['node_v'])
            edges = (data[("protein", "p2p", "protein")]["edge_s"], data[("protein", "p2p", "protein")]["edge_v"])
            protein_batch = data['protein'].batch
            protein_out = self.conv_protein(nodes, data[("protein", "p2p", "protein")]["edge_index"], edges, data.seq)

        if self.compound_embed_mode == 0:
            compound_x = data['compound'].x.float()
            compound_edge_index = data[("compound", "c2c", "compound")].edge_index
            compound_batch = data['compound'].batch
            compound_out = self.conv_compound(compound_x, compound_edge_index)
        elif self.compound_embed_mode == 1:
            compound_x = data['compound'].x.float()
            compound_edge_index = data[("compound", "c2c", "compound")].edge_index.T
            # compound_edge_index = data[("compound", "c2c", "compound")].edge_index
            compound_edge_feature = data[("compound", "c2c", "compound")].edge_attr
            edge_weight = data[("compound", "c2c", "compound")].edge_weight
            compound_batch = data['compound'].batch
            # Enzo : print dimensions
            #print(f"{compound_edge_index.shape=}, {edge_weight.shape=}, {compound_edge_feature.shape=}, {compound_x.shape=}")
            compound_out = self.conv_compound(compound_edge_index,edge_weight,compound_edge_feature,compound_x.shape[0],compound_x)['node_feature']
    
        # protein_batch version could further process b matrix. better than for loop.
        # protein_out_batched of shape b, n, c
        protein_out_batched, protein_out_mask = to_dense_batch(protein_out, protein_batch, max_num_nodes=max_dim_divisible_by_8_protein)
        compound_out_batched, compound_out_mask = to_dense_batch(compound_out, compound_batch, max_num_nodes=max_dim_divisible_by_8_compound)
        batch_n = data.batch_n
        z_mask = torch.einsum("bi,bj->bij", protein_out_mask, compound_out_mask)
        z_mask_attention = torch.einsum("bik, bq-> biqk", z_mask, compound_out_mask).reshape(batch_n*protein_out_batched.shape[1], max_dim_divisible_by_8_compound, max_dim_divisible_by_8_compound).unsqueeze(1).expand(-1, self.n_heads, -1, -1).contiguous()
        z_mask_attention = torch.where(z_mask_attention, 0.0, -10.0**6)
        z_mask_flat = torch.arange(
            start=0, end=z_mask.numel(), device=self.device
        ).view(z_mask.shape)[z_mask]
        protein_square_mask = torch.einsum("bi,bj->bij", protein_out_mask, protein_out_mask)
        node_xyz = data.node_xyz

        p_coords_batched, p_coords_mask = to_dense_batch(node_xyz, protein_batch)
        # c_coords_batched, c_coords_mask = to_dense_batch(coords, compound_batch)

        protein_pair = data["protein", "p2p", "protein"].pairwise_representation
        
        # compound_pair = get_pair_dis_one_hot(c_coords_batched, bin_size=1, bin_min=-0.5, bin_max=15)
        compound_pair_batched, compound_pair_batched_mask = data["compound", "p2p", "compound"].pairwise_representation, data["compound", "p2p", "compound"].pairwise_representation_mask

        batch_n = compound_pair_batched.shape[0]
        # max_compound_size_square = compound_pair_batched.shape[1]
        # max_compound_size = int(max_compound_size_square**0.5)
        # assert (max_compound_size**2 - max_compound_size_square)**2 < 1e-4
        # compound_pair = torch.zeros((batch_n, max_compound_size, max_compound_size, 16)).to(data.compound_pair.device)
        # for i in range(batch_n):
        #     one = compound_pair_batched[i]
        #     compound_size_square = (data.compound_pair_batch == i).sum()
        #     compound_size = int(compound_size_square**0.5)
        #     compound_pair[i,:compound_size, :compound_size] = one[:compound_size_square].reshape(
        #                                                         (compound_size, compound_size, -1))
        protein_pair = self.protein_pair_embedding(protein_pair)
        compound_pair = self.compound_pair_embedding(data["compound", "p2p", "compound"].pairwise_representation.float())        # b = torch.einsum("bik,bjk->bij", protein_out_batched, compound_out_batched).flatten()

        protein_out_batched = self.layernorm(protein_out_batched)
        compound_out_batched = self.layernorm(compound_out_batched)
        # z of shape, b, protein_length, compound_length, channels.
        z = torch.einsum("bik,bjk->bijk", protein_out_batched, compound_out_batched)
        # z_mask = torch.einsum("bi,bj->bij", protein_out_mask, compound_out_mask)

        # print(protein_pair.shape, compound_pair.shape, b.shape)
        if self.mode == 0:
            for _ in range(1):
                for i_module in range(self.n_trigonometry_module_stack):
                    z = z + self.dropout(self.protein_to_compound_list[i_module](z, protein_pair, compound_pair, z_mask.unsqueeze(-1)))
                    if self.fast_attention:
                        z = z + self.dropout(self.triangle_self_attention_list[i_module](z, z_mask_attention, z_mask))
                    else:
                        z = z + self.dropout(self.triangle_self_attention_list[i_module](z, z_mask))
                    z = self.tranistion(z)
        # batch_dim = z.shape[0]

        b = self.linear(z).squeeze(-1)
        y_pred = b.flatten()[z_mask_flat]
        y_pred = y_pred.sigmoid() * 10   # normalize to 0 to 10.
        if self.readout_mode == 0:
            pair_energy = self.linear_energy(z).squeeze(-1) * z_mask
            affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
        if self.readout_mode == 1:
            # valid_interaction_z = (z * z_mask.unsqueeze(-1)).mean(axis=(1, 2))
            valid_interaction_z = (z * z_mask.unsqueeze(-1)).sum(axis=(1, 2)) / z_mask.sum(axis=(1, 2)).unsqueeze(-1)
            affinity_pred = self.linear_energy(valid_interaction_z).squeeze(-1)
            # print("z shape", z.shape, "z_mask shape", z_mask.shape,   "valid_interaction_z shape", valid_interaction_z.shape, "affinity_pred shape", affinity_pred.shape)
        if self.readout_mode == 2:
            pair_energy = (self.gate_linear(z).sigmoid() * self.linear_energy(z)).squeeze(-1) * z_mask
            affinity_pred = self.leaky(self.bias + ((pair_energy).sum(axis=(-1, -2))))
        return y_pred, affinity_pred

In [None]:
from torch.utils.data import RandomSampler

In [None]:
sampler = RandomSampler(dataset, replacement=True, num_samples=1000)

In [None]:
dataloader = TankBindDataLoader(dataset, batch_size=4, follow_batch=['x', 'compound_pair'], sampler = sampler)

In [None]:
batch = next(iter(dataloader))

In [None]:
model = IaBNet_with_affinity()

In [None]:
model.to("cuda:0")
batch_cuda = batch.to("cuda:0")

In [None]:
model(batch_cuda)

Etape importante: 
- Refaire script lightning+wandb
- Verifier que l'evaluation du modele fonctionne.

In [None]:
def get_data(addNoise=None):
    pre = "./"
    add_noise_to_com = float(addNoise) if addNoise else None

    new_dataset = TankBindDataSet("/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset", add_noise_to_com=add_noise_to_com)
    new_dataset.data = new_dataset.data.query("c_length < 100 and native_num_contact > 5").reset_index(drop=True)
    d = new_dataset.data
    only_native_train_index = d.query("use_compound_com and group =='train'").index.values
    train = new_dataset[only_native_train_index]
    train_index = d.query("group =='train'").index.values
    train_after_warm_up = new_dataset[train_index]
    valid_index = d.query("use_compound_com and group =='valid'").index.values
    valid = new_dataset[valid_index]
    test_index = d.query("use_compound_com and group =='test'").index.values
    test = new_dataset[test_index]

    all_pocket_test_fileName = "/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/test_dataset"
    all_pocket_test = TankBindDataSet(all_pocket_test_fileName)
    all_pocket_test.compound_dict = "/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/compound.pt"
    info = None
    return train, train_after_warm_up, valid, test, all_pocket_test, info

In [None]:
batch

In [None]:
%cd /fs/pool/pool-marsot/tankbind_philip/TankBind/tankbind

In [None]:
import torch
import sys
sys.path.append("/fs/pool/pool-marsot/")
import wandb
from datetime import datetime
from utils import *
from tqdm import tqdm
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M")
run = wandb.init(project="TankBind", name=f"{timestamp}")


device = torch.device("cuda:0")
model = IaBNet_with_affinity()
train, train_after_warm_up, valid, test, all_pocket_test, info = get_data(addNoise=5)
sampler = RandomSampler(train, replacement=True, num_samples=20000)
train_loader = TankBindDataLoader(train, batch_size=5, follow_batch=['x', 'compound_pair'], sampler = sampler)
sampler_2 = RandomSampler(train_after_warm_up, replacement=True, num_samples=20000)
train_after_warmup_loader = TankBindDataLoader(train_after_warm_up, batch_size=5, follow_batch=['x', 'compound_pair'], sampler = sampler_2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.MSELoss()
affinity_criterion = nn.MSELoss()


In [None]:

for epoch in range(200):
    model.train()
    model.to(device)
    y_list = []

    y_pred_list = []
    affinity_list = []
    affinity_pred_list = []
    batch_loss = 0.0
    affinity_batch_loss = 0.0
    data_it = tqdm(train_after_warmup_loader)

    for data in data_it:
        data = data.to(device)
        optimizer.zero_grad()
        y_pred, affinity_pred = model(data)
        y = data.y
        affinity = data.affinity
        dis_map = data.dis_map
        y_pred = y_pred[data.equivalent_native_y_mask]
        y = y[data.equivalent_native_y_mask]
        dis_map = dis_map[data.equivalent_native_y_mask]


        contact_loss = criterion(y_pred, dis_map) if len(dis_map) > 0 else torch.tensor([0]).to(dis_map.device)
        y_pred = y_pred.sigmoid()


        relative_k = 0.01


        native_pocket_mask = data.is_equivalent_native_pocket
        affinity_loss =  relative_k * my_affinity_criterion(affinity_pred,
                                                            affinity, 
                                                            native_pocket_mask, decoy_gap=1.0)

        loss = contact_loss + affinity_loss
        wandb.log({"contact_loss":contact_loss.detach(), "affinity_loss":affinity_loss.detach(), "loss":loss.detach()})
        loss.backward()
        optimizer.step()
        batch_loss += len(y_pred)*contact_loss.item()
        affinity_batch_loss += len(affinity_pred)*affinity_loss.item()
        # print(f"{loss.item():.3}")
        y_list.append(y)
        y_pred_list.append(y_pred.detach())
        affinity_list.append(data.affinity)
        affinity_pred_list.append(affinity_pred.detach())
        # torch.cuda.empty_cache()

    y = torch.cat(y_list)
    y_pred = torch.cat(y_pred_list)
    if args.pred_dis:
        y_pred = torch.clip(1 - (y_pred / 10.0), min=1e-6, max=0.99999)
        contact_threshold = 0.2
    else:
        contact_threshold = 0.5

    affinity = torch.cat(affinity_list)
    affinity_pred = torch.cat(affinity_pred_list)
    metrics = {"loss":batch_loss/len(y_pred) + affinity_batch_loss/len(affinity_pred)}

In [None]:
%cd /fs/pool/pool-marsot/tankbind_philip_base/tankbind_philip/tankbind

In [None]:
import sys; sys.path.append("/fs/pool/pool-marsot")

In [None]:
val_dataset[0]

In [None]:
mask = ~(val_dataset.data)["pdb"].str.endswith('_c')

In [None]:
mask

In [None]:
val_dataset.data["pdb"]

In [None]:
val_dataset.data.reset_index(drop=True)

In [None]:
import numpy as np
import torch
from tqdm import tqdm
import os
import rdkit.Chem as Chem
from bindbind.datasets.processing.ligand_features.tankbind_ligand_features import read_molecule, create_tankbind_ligand_features, get_LAS_distance_constraint_mask
from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
import torch_geometric
from bindbind.datasets.processing.ligand_features.tankbind_ligand_features import read_molecule, create_tankbind_ligand_features, get_LAS_distance_constraint_mask
from bindbind.experiments.ablations.regular.metrics.helper import compute_RMSD, write_with_new_coords, generate_sdf_from_smiles_using_rdkit, get_info_pred_distance, simple_custom_description, distribute_function
import pandas as pd
from tankbind_philip_base.tankbind_philip.tankbind.data import TankBindDataSet
from tankbind_philip.TankBind.tankbind.data import TankBindDataLoader
def evaluate_model_val(model,
                       batch_size=8,
                       num_workers=8,
                       val_dataset_path="/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/val_dataset",
                       full_dataset_path="/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/dataset",
                       rdkit_folder="/fs/pool/pool-marsot/tankbind_philip/TankBind/data/rdkit_folder",
                       renumbered_ligands_folder="/fs/pool/pool-marsot/tankbind_philip/TankBind/data/renumber_atom_index_same_as_smiles",
                       recompute=False,
                       ):
    print("hi")

val_dataset_path="/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/val_dataset"
full_dataset_path="/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/dataset"
recompute=False
rdkit_folder="/fs/pool/pool-marsot/tankbind_philip/TankBind/data/rdkit_folder"
renumbered_ligands_folder="/fs/pool/pool-marsot/tankbind_philip/TankBind/data/renumber_atom_index_same_as_smiles"
num_workers=8
batch_size=8

if not os.path.exists(f"{val_dataset_path}/processed"):
    dataset = TankBindDataSet(full_dataset_path)
    val_data = dataset.data[(dataset.data["group"]=="valid") & (~(val_dataset.data)["pdb"].str.endswith('_c'))]
    val_names = val_data["protein_name"].unique().tolist()
    val_compound_dict = {name:item for (name, item) in dataset.compound_dict.items() if name in val_names}
    val_protein_dict = {name:item for (name, item) in dataset.protein_dict.items() if name in val_names}
    val_dataset = TankBindDataSet(val_dataset_path, data=val_data,
                                    compound_dict=val_compound_dict,
                                    protein_dict=val_protein_dict)
else:
    val_dataset = TankBindDataSet(val_dataset_path)
device = model.device
model.eval()
val = val_dataset.data["protein_name"].unique().tolist()
if recompute or not os.path.exists(f"{rdkit_folder}/compound_dict_based_on_rdkit.pt"):
    compound_dict = {}
    print("generating compound dictionary")
    for protein_name in tqdm(val):
        mol, _ = read_molecule(f"{renumbered_ligands_folder}/{protein_name}.sdf", None)
        smiles = Chem.MolToSmiles(mol)
        rdkit_mol_path = f"{rdkit_folder}/{protein_name}.sdf"
        generate_sdf_from_smiles_using_rdkit(smiles, rdkit_mol_path, shift_dis=0.0)
        mol, _ = read_molecule(rdkit_mol_path, None)
        compound_dict[protein_name] = create_tankbind_ligand_features(rdkit_mol_path, None, has_LAS_mask=True)
    torch.save(compound_dict, f"{rdkit_folder}/compound_dict_based_on_rdkit.pt")
else:
    compound_dict = torch.load(f"{rdkit_folder}/compound_dict_based_on_rdkit.pt")

data_loader = TankBindDataLoader(val_dataset,follow_batch=["protein_nodes_xyz", "coords", "y_pred", "y", "LAS_distance_constraint_mask"], batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
affinity_pred_list = []
y_pred_list = []
for data in tqdm(data_loader):
    data = data.to(device)
    previous_index_start=0
    this_index_start=0
    protein_sizes = torch.diff(data["protein"].ptr)
    compound_sizes = torch.diff(data["compound"].ptr)

    with torch.no_grad():
        y_pred, affinity_pred = model(data)
    affinity_pred_list.append(affinity_pred.detach().cpu())
    for i in range(data.batch_n):
        this_index_start += protein_sizes[i] * compound_sizes[i]
        y_pred_list.append((y_pred[previous_index_start:this_index_start]).detach().cpu())
        previous_index_start = this_index_start.clone()
affinity_pred_list = torch.cat(affinity_pred_list)
output_info_chosen = val_dataset.data
output_info_chosen["affinity"] = affinity_pred_list
output_info_chosen['dataset_index'] = range(len(output_info_chosen))

chosen = output_info_chosen.loc[output_info_chosen.groupby(['protein_name'], sort=False)['affinity'].agg('idxmax')].reset_index()


In [None]:
device = "cpu"
compound_coordinates_dict = {}
protein_coordinates_dict = {}
for name in val:
    compound_coordinates_dict[name] = compound_dict[name]["tankbind_ligand_atom_coordinates"]
    
    protein_coordinates_dict[name] = val_dataset.protein_dict[name][0]
max_compound_nodes = 0
max_protein_nodes = 0
list_mols = []
list_complexes = []
for idx, line in tqdm(chosen.iterrows(), total=chosen.shape[0]):
    protein_name = line['protein_name']
    dataset_index = line['dataset_index']

    coords = val_dataset[dataset_index].coords
    protein_node_coordinates = val_dataset[dataset_index].node_xyz
    # if denormalize:
    #    protein_node_coordinates = denormalize_feature(protein_node_coordinates, "protein_node_coordinates")
    n_compound = coords.shape[0]
    n_protein = protein_node_coordinates.shape[0]
    y_pred = y_pred_list[dataset_index]
    y = val_dataset[dataset_index].dis_map
    rdkit_mol_path = f"{rdkit_folder}/{protein_name}.sdf"
    mol, _ = read_molecule(rdkit_mol_path, None)
    LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool().flatten()
    max_compound_nodes = max(max_compound_nodes, n_compound)
    max_protein_nodes = max(max_protein_nodes, n_protein)
    cplx = HeteroData()
    cplx.protein_name = protein_name
    cplx.protein_nodes_xyz = protein_node_coordinates
    cplx.coords = coords
    cplx.y_pred = y_pred
    cplx.y = y
    cplx.LAS_distance_constraint_mask = LAS_distance_constraint_mask
    list_complexes.append(cplx)
    list_mols.append(mol)

dataloader = DataLoader(list_complexes, batch_size=chosen.shape[0], shuffle=False,
                        follow_batch=["protein_nodes_xyz", "coords", "y_pred", "y", "LAS_distance_constraint_mask"])

batch = next(iter(dataloader))
coords_batched, coords_mask = torch_geometric.utils.to_dense_batch(batch.coords, batch.coords_batch)
coords_pair_mask = torch.einsum("ij,ik->ijk", coords_mask, coords_mask)
compound_pair_dis_constraint = torch.cdist(coords_batched, coords_batched)[coords_pair_mask]
batch.compound_pair_dis_constraint = compound_pair_dis_constraint

In [None]:
from tqdm.notebook import tqdm
def distance_loss_function(epoch, x, batch):
    protein_nodes_xyz_batched, protein_nodes_xyz_mask = torch_geometric.utils.to_dense_batch(batch.protein_nodes_xyz, batch.protein_nodes_xyz_batch)
    x_batched, x_mask = torch_geometric.utils.to_dense_batch(x, batch.coords_batch)
    protein_compound_mask = torch.einsum("ij,ik->ijk", protein_nodes_xyz_mask, x_mask)
    dis = torch.cdist(protein_nodes_xyz_batched, x_batched)
    dis_clamp = torch.clamp(dis, max=10)
    dis_flat = dis_clamp[protein_compound_mask]
    interaction_loss = torch_geometric.utils.segment((dis_flat - batch.y_pred).abs(), batch.y_pred_ptr, reduce="mean")
    xx_mask = torch.einsum("ij,ik->ijk", x_mask, x_mask)
    config_dis = torch.cdist(x_batched, x_batched)[xx_mask]

    configuration_loss = 1 * (((config_dis-batch.compound_pair_dis_constraint).abs()))
    configuration_loss += 2 * ((1.22 - config_dis).relu())
    configuration_loss = torch_geometric.utils.segment(configuration_loss, batch.LAS_distance_constraint_mask_ptr, reduce="mean")
    # if epoch < 500:
    #     loss = interaction_loss.sum()
    # else:
    #     loss = 1 * (interaction_loss.sum() + 5e-3 * (epoch - 500) * configuration_loss.sum())
    # added by Enzo
    interaction_loss_sum = interaction_loss.sum()
    configuration_loss_sum = configuration_loss.sum() 
    # modification by Enzo: achieves 20 percent
    loss = 1 * (interaction_loss_sum + 5e-2 * configuration_loss_sum)
    return loss, (interaction_loss_sum.detach(), configuration_loss_sum.detach())



def distance_optimize_compound_coords(batch, total_epoch=5000, loss_function=distance_loss_function, LAS_distance_constraint_mask=None, mode=0, show_progress=False):
    # random initialization. center at the protein center.
    # coords: shape n_compound_nodes, 3
    # y_pred: shape n_protein_nodes, n_compound_nodes
    # protein_nodes_xyz: shape n_protein_nodes, 3
    # compound_pair_dis_constraint: shape n_compound_nodes, n_compound_nodes
    # LAS_distance_constraint_mask: boolean tensor shape n_compound_nodes, n_compound_nodes
    batch = batch.to("cuda:0")

    # TODO: c_pred est le centre de la protéine. On obtient la valeur avec torch_scatter
    c_pred = torch_geometric.utils.segment(batch.protein_nodes_xyz, batch.protein_nodes_xyz_ptr, reduce="mean")
    c_pred = c_pred[batch.coords_batch]
    x = (5 * (2 * torch.randn_like(batch.coords) - 1) + c_pred.reshape(-1, 3)).detach().clone().requires_grad_(True)
    # modification by Enzo: achieves 20 percent
    optimizer = torch.optim.Adam([x], lr=1)
    # optimizer = torch.optim.Adam([x], lr=0.01)
    # optimizer = torch.optim.SGD([x], lr=1, momentum=0.9)
    loss_list = []
    rmsd_list = []
    progress_bar = tqdm(range(total_epoch))
    for epoch in progress_bar:
        optimizer.zero_grad()
        loss, (interaction_loss, configuration_loss) = loss_function(epoch, x, batch)
        print(f"loss: {loss.item()} interaction loss: {interaction_loss.item()} configuration loss: {configuration_loss.item()}")
        loss.backward()
        
        optimizer.step()
        
        # Append the loss to the list
        loss_list.append(loss.item())
        
        # Compute RMSD
        rmsd = compute_RMSD_batch(batch.coords, x.detach(), batch.coords_ptr)
        rmsd_list += rmsd.detach().cpu().tolist()
        
        # Update the progress bar with the loss
        progress_bar.set_postfix(loss=loss.item())
        
    return x, loss_list, rmsd_list

def get_info_pred_distance(batch, n_repeat=1, total_epoch=5000, mode=0, show_progress=False):
    info = []
    if show_progress:
        it = tqdm(range(n_repeat))
    else:
        it = range(n_repeat)
    for repeat in it:
        # random initialization.
        # x = torch.rand(coords.shape, requires_grad=True)
        x, loss_list, rmsd_list = distance_optimize_compound_coords(batch, mode=mode, total_epoch=total_epoch, show_progress=False)
        rmsd = rmsd_list[-1]
        for i in range(len(batch.coords_ptr)-1):
            try:
                info.append([repeat, rmsd_list[batch.coords_ptr[i]:batch.coords_ptr[i+1]], float(loss_list[-1]), x[batch.coords_ptr[i]:batch.coords_ptr[i+1]].detach().cpu().numpy()])
            except:
                info.append([repeat, rmsd_list[batch.coords_ptr[i]:batch.coords_ptr[i+1]], 0, x[batch.coords_ptr[i]:batch.coords_ptr[i+1]].detach().cpu().numpy()])
    info = pd.DataFrame(info, columns=['repeat', 'rmsd', 'loss', 'coords'])
    return info

In [None]:
def compute_RMSD_batch(a, b, ptr):
    # correct rmsd calculation.

    distances=((a-b)**2).sum(dim=-1) # (compound_nodes_batch, 3) -> (compound_nodes_batch)
    sum_distances = torch_geometric.utils.segment(distances, ptr, reduce="mean")
    return torch.sqrt(sum_distances)

In [None]:

pred_dist_info = get_info_pred_distance(batch,
                            n_repeat=1, show_progress=False)

for idx, line in tqdm(chosen.iterrows(), total=chosen.shape[0]):
    protein_name = line['protein_name']
    toFile = f'{rdkit_folder}/{protein_name}_tankbind_chosen.sdf'
    new_coords = pred_dist_info['coords'].iloc[idx].astype(np.double)
    write_with_new_coords(list_mols[idx], new_coords, toFile)






In [None]:
(chosen["num_contact"]==chosen["native_num_contact"]).values.mean()

In [None]:
total_epoch = 5000
loss_function=distance_loss_function

In [None]:

ligand_metrics = []
for idx, line in tqdm(chosen.iterrows(), total=chosen.shape[0]):
    protein_name = line['protein_name']
    mol, _ = read_molecule(f"{renumbered_ligands_folder}/{protein_name}.sdf", None)
    mol_pred, _ = read_molecule(f"{rdkit_folder}/{protein_name}_tankbind_chosen.sdf", None) # tankbind_chosen is the compound with predicted coordinates assigned by write_with_new_coords

    sm = Chem.MolToSmiles(mol)
    mol_order = list(mol.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
    mol = Chem.RenumberAtoms(mol, mol_order)
    mol = Chem.RemoveHs(mol)
    true_ligand_pos = np.array(mol.GetConformer().GetPositions())

    sm = Chem.MolToSmiles(mol_pred)
    mol_order = list(mol_pred.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
    mol_pred = Chem.RenumberAtoms(mol_pred, mol_order)
    mol_pred = Chem.RemoveHs(mol_pred)
    mol_pred_pos = np.array(mol_pred.GetConformer().GetPositions())

    rmsd = np.sqrt(((true_ligand_pos - mol_pred_pos) ** 2).sum(axis=1).mean(axis=0))
    com_dist = compute_RMSD(mol_pred_pos.mean(axis=0), true_ligand_pos.mean(axis=0))
    ligand_metrics.append([protein_name, rmsd, com_dist,])

d = pd.DataFrame(ligand_metrics, columns=['pdb', 'TankBind_RMSD', 'TankBind_COM_DIST',])


In [None]:
simple_custom_description(d)

In [None]:
model.to("cuda:0")

In [None]:
val_dataset.protein_dict['1zsb']

In [None]:
val_dataset.compound_dict.keys()

In [None]:
next(iter(data_loader))

In [None]:
import torch
torch.save(compound_dict, f"{rdkit_folder}/compound_dict_based_on_rdkit.pt")

In [None]:
full_dataset_path

In [None]:
full_dataset_path="/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/dataset"
dataset = TankBindDataSet(full_dataset_path)
val_data = dataset.data[(dataset.data["group"]=="valid") & (~(dataset.data["use_compound_com"]))]

In [None]:
evaluate_model_val(model)

In [None]:
dataset.protein_dict.keys()

In [None]:
dataset.protein_dict['3zzf'][0]

In [None]:
for item in dataset.protein_dict['3zzf']:
    print(item.shape)

In [None]:
dataset.data

In [None]:
import pickle as pkl
with open("/fs/pool/pool-marsot/bindbind/datasets/data/equibind_dataset/compound_coordinates_dict.pkl", "rb") as f:
    compound_coordinates_dict = pkl.load(f)

In [None]:
dataset.compound_dict.keys()

In [None]:
val_names = dataset.data[(dataset.data["group"]=="valid")]["protein_name"].unique()

In [None]:
compound_coordinates_dict = {}
for name in val_names:
    compound_coordinates_dict[name] = dataset.compound_dict[name][0]

In [None]:
compound_coordinates_dict.keys()

In [None]:
from bindbind.experiments.ablations.regular.metrics.metrics_fast import evaluate_model_val    

In [1]:
%cd /fs/pool/pool-marsot/tankbind_philip_base/tankbind_philip/tankbind/

/fs/gpfs41/lv11/fileset01/pool/pool-marsot/tankbind_philip_base/tankbind_philip/tankbind


In [2]:
import sys; sys.path.append("/fs/pool/pool-marsot/")

In [3]:
from tankbind_philip_base.tankbind_philip.tankbind.model import IaBNet_with_affinity as old_model

In [4]:
import torch
model = old_model().to("cuda:0")
model.load_state_dict(torch.load(f"/fs/pool/pool-marsot/tankbind_philip/TankBind/tankbind/result/2024_07_03_21_19/model_119.pt"))

<All keys matched successfully>

In [5]:
model.to("cuda:0")

IaBNet_with_affinity(
  (layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (conv_protein): GVP_embedding(
    (W_s): Embedding(20, 20)
    (W_v): Sequential(
      (0): LayerNorm(
        (scalar_norm): LayerNorm((26,), eps=1e-05, elementwise_affine=True)
      )
      (1): GVP(
        (wh): Linear(in_features=3, out_features=16, bias=False)
        (ws): Linear(in_features=42, out_features=128, bias=True)
        (wv): Linear(in_features=16, out_features=16, bias=False)
      )
    )
    (W_e): Sequential(
      (0): LayerNorm(
        (scalar_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      )
      (1): GVP(
        (wh): Linear(in_features=1, out_features=1, bias=False)
        (ws): Linear(in_features=33, out_features=32, bias=True)
        (wv): Linear(in_features=1, out_features=1, bias=False)
      )
    )
    (layers): ModuleList(
      (0-2): 3 x GVPConvLayer(
        (conv): GVPConv()
        (norm): ModuleList(
          (0-1): 2 x LayerNo

In [6]:
from tankbind_philip.TankBind.tankbind.evaluation_fast import evaluate_model_val
df=evaluate_model_val(model, batch_size=4,)

['/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/val_dataset/processed/data.pt', '/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/val_dataset/processed/protein.pt', '/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/val_dataset/processed/compound.pt']


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 790/790 [01:42<00:00,  7.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 959/959 [00:16<00:00, 58.70it/s]
144.33 466.43 167.65 16.88: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [03:13<00:00, 25.81it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 959/959 [00:01<00:00, 811.83it/s]
 22%|███████████████████████████████████▏                                                                                                                           | 212/959 [00:00<00:02, 288.13it/s][

In [8]:
from tankbind_philip.TankBind.tankbind.helper import simple_custom_description

In [9]:
simple_custom_description(df)

Unnamed: 0,index,mean,25%,50%,75%,5A,2A,median
0,TankBind_RMSD,16.88944,2.402756,13.864611,30.82978,36.704901,20.333681,13.864611
1,TankBind_COM_DIST,15.480724,1.190346,11.036407,30.165105,42.857143,34.410845,11.036407


In [13]:
from tankbind_philip.TankBind.tankbind.model import IaBNet_with_affinity as old_model_2

In [14]:
model_2 = old_model_2().to("cuda:0")

In [1]:
1+1

2

In [5]:
import torch
info = torch.load("/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/dataset/processed/data.pt")

In [14]:
info["protein_name"].unique()

array(['3zzf', '3gww', '1w8l', ..., '1avd', '2xui', '2avi'], dtype=object)

In [15]:
proteins = torch.load("/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/dataset/processed/protein.pt")

In [16]:
len(proteins)

19420