# Interaction Transformer

In [1]:
!pip install transformers
!pip install pytdc

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Collecting transformers
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m32.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m100.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 kB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.12.1 tokenizers-0.13.2 transformers-4.26.1
Looking in indexes: https://pypi.

## Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel
from tdc.multi_pred import DTI

## Data

In [3]:
data = DTI(name = 'BindingDB_Kd')
data.harmonize_affinities(mode = 'max_affinity')
split = data.get_split()

Found local copy...
Loading...
Done!
The scale is in original affinity scale, so we will take the minimum!
The original data has been updated!


In [4]:
targets = split["train"].loc[0:5,"Target"].tolist()
drugs = split["train"].loc[0:5,"Drug"].tolist()

In [5]:
len(targets), len(drugs)

(6, 6)

## Repeatable Interaction Block

In [6]:
class InteractionTransformerBlock(nn.Module):
    def __init__(self, num_heads, attn_dim):
        """
        In the constructor we instantiate four parameters and assign them as
        member parameters.
        """
        super().__init__()
        self.num_heads = num_heads
        self.alpha_multihead_attn = nn.MultiheadAttention(attn_dim, num_heads, batch_first=True)
        self.beta_multihead_attn = nn.MultiheadAttention(attn_dim, num_heads, batch_first=True)
        self.alpha_layer_norm_attn = nn.LayerNorm(attn_dim)
        self.beta_layer_norm_attn = nn.LayerNorm(attn_dim)
        self.alpha_feedforward = nn.Linear(attn_dim, attn_dim)
        self.beta_feedforward = nn.Linear(attn_dim, attn_dim)
        self.alpha_layer_norm_feedforward = nn.LayerNorm(attn_dim)
        self.beta_layer_norm_feedforward = nn.LayerNorm(attn_dim)
        
    def forward(self, x_alpha, x_beta, mask):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        mask_rep = torch.cat(self.num_heads*[mask])

        alpha_attn_output, alpha_attn_output_weights = self.alpha_multihead_attn(x_alpha, x_beta, x_beta, attn_mask=mask_rep)
        beta_attn_output, beta_attn_output_weights = self.beta_multihead_attn(x_beta, x_alpha, x_alpha, attn_mask=mask_rep.transpose(1,2))
    
        interactions = F.sigmoid(torch.bmm(alpha_attn_output, beta_attn_output.transpose(1,2))) * mask
        # interactions = F.sigmoid(alpha_attn_output_weights * beta_attn_output_weights.transpose(1,2)) * mask
        
        x_alpha = self.alpha_layer_norm_attn(x_alpha + torch.bmm(interactions, beta_attn_output))
        x_beta = self.beta_layer_norm_attn(x_beta + torch.bmm(interactions.transpose(1,2), alpha_attn_output))
        
        x_alpha = self.alpha_layer_norm_feedforward(x_alpha + self.alpha_feedforward(x_alpha))
        x_beta = self.beta_layer_norm_feedforward(x_beta + self.beta_feedforward(x_beta))
        
        return x_alpha, x_beta, interactions

## Full Model Prototype

In [7]:
class InteractionTransformerDTIModel(nn.Module):
    def __init__(self, prot_model_str, mol_model_str, num_blocks, num_heads, attn_dim):
        """
        In the constructor we instantiate four parameters and assign them as
        member parameters.
        """
        super().__init__()
        
        self.prot_tokenizer = AutoTokenizer.from_pretrained(prot_model_str)
        self.prot_model = AutoModel.from_pretrained(prot_model_str)
        self.prot_proj = nn.Linear(self.prot_model.pooler.dense.out_features, attn_dim, bias=False)
        
        self.mol_tokenizer = AutoTokenizer.from_pretrained(mol_model_str)
        self.mol_model = AutoModel.from_pretrained(mol_model_str)
        self.mol_proj = nn.Linear(self.mol_model.pooler.dense.out_features, attn_dim, bias=False)
        
        self.interaction_blocks = nn.ModuleList([InteractionTransformerBlock(num_heads, attn_dim) for _ in range(num_blocks)])
        
    def forward(self, targets, drugs):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        prot_inputs = self.prot_tokenizer(
            targets, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True, return_attention_mask=True
        )
        mol_inputs = self.mol_tokenizer(
            drugs, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True, return_attention_mask=True
        )
        
        x_prot = self.prot_proj(self.prot_model(**prot_inputs).last_hidden_state)
        x_mol = self.mol_proj(self.mol_model(**mol_inputs).last_hidden_state)
        
        mask_prot = prot_inputs["attention_mask"].float()
        mask_mol = mol_inputs["attention_mask"].float()
        mask = torch.bmm(mask_prot.unsqueeze(2), mask_mol.unsqueeze(1))
        
        for interaction_block in self.interaction_blocks:
            x_prot, x_mol, interactions = interaction_block(x_prot, x_mol, mask)
        
        return x_prot, x_mol, interactions

In [9]:
model = InteractionTransformerDTIModel(
    prot_model_str="facebook/esm2_t30_150M_UR50D",
    mol_model_str="DeepChem/ChemBERTa-77M-MLM",
    num_blocks=8,
    num_heads=4,
    attn_dim=512
)

x_targets, x_drugs, interactions = model(targets, drugs)
x_targets.shape, x_drugs.shape, interactions.shape

Downloading (…)okenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/595M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/esm2_t30_150M_UR50D were not used when initializing EsmModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/6.96k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/8.26k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/420 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/13.7M [00:00<?, ?B/s]

Some weights of the model checkpoint at DeepChem/ChemBERTa-77M-MLM were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MLM and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be 

(torch.Size([6, 538, 512]), torch.Size([6, 26, 512]), torch.Size([6, 538, 26]))

In [10]:
sim_labels = torch.randint(low=0, high=2, size=interactions.shape).float()
loss = F.binary_cross_entropy(interactions, sim_labels)
loss

tensor(27.9232, grad_fn=<BinaryCrossEntropyBackward0>)