In [1]:
# !pip install PyTDC

### Imports & Hyperprameters

In [1]:
import torch
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from transformers import BertModel,BertTokenizer
import time
from datetime import timedelta
import os
import requests
from tqdm.auto import tqdm

batch_size = 24
lr = 1e-4

### Load encoders

In [2]:
# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
prot_model = BertModel.from_pretrained("Rostlab/prot_bert")
prot_model = prot_model.to(device)

# text_tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
# text_model = BertModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
# text_model = text_model.to(device)

In [3]:
prot_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30, 1024, padding_idx=0)
    (position_embeddings): Embedding(40000, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.0, inplace=False

### Model Class

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GDANet(torch.nn.Module):
    def __init__(self, prot_encoder, disease_encoder, prot_out_dim=1024, disease_out_dim=768, drop_out=0, freeze_prot_encoder=True,  freeze_disease_encoder=True):
        super(GDANet, self).__init__()
        self.prot_encoder = prot_encoder
        self.disease_encoder = disease_encoder
        self.freeze_encoders(freeze_prot_encoder,freeze_disease_encoder)
        
        self.reg = nn.Sequential(
            nn.Linear(prot_out_dim + disease_out_dim, 1024),
            nn.ReLU(),
            nn.Dropout(drop_out),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(drop_out),
            nn.Linear(512, 1)
        )
    
    def freeze_encoders(self,freeze_prot_encoder,freeze_disease_encoder):
        if freeze_prot_encoder:
            print(f"freeze_prot_encoder:{freeze_prot_encoder}")
            for param in self.prot_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.disease_encoder.parameters():
                param.requires_grad = True
                
        if freeze_disease_encoder:
            print(f"freeze_disease_encoder:{freeze_disease_encoder}")
            for param in self.disease_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.disease_encoder.parameters():
                param.requires_grad = True
                
    def forward(self, x1, x2):
        x1 = self.prot_encoder(x1)[0][:,0]
        x2 = self.disease_encoder(x2)[0][:,0]
        x = torch.cat((x1, x2), 1)
        x = self.reg(x)
        return x

In [5]:
class DisGeNETProcessor():
    def __init__(self, data_dir="./data"):
        from tdc.multi_pred import GDA
        data = GDA(name = 'DisGeNET',path=data_dir)
        self.datasets = data.get_split()

    def get_train_examples(self):
        return self.datasets["train"]["Gene"].values,self.datasets["train"]["Disease"].values,self.datasets["train"]["Y"].values

    def get_dev_examples(self):
        return self.datasets["valid"]["Gene"].values,self.datasets["valid"]["Disease"].values,self.datasets["valid"]["Y"].values

    def get_test_examples(self):
        return self.datasets["test"]["Gene"].values,self.datasets["test"]["Disease"].values,self.datasets["test"]["Y"].values
    
def convert_examples_to_tokens(examples, prot_tokenizer, text_tokenizer, max_seq_length=512, test=False):

    first_sentences = []
    second_sentences = []
    labels = []
    for gene, disease,label in zip(*examples):
        first_sentences.append([" ".join(gene)])
        second_sentences.append([disease])
        labels.append([label])

    # Flatten out
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])
    if test:
        first_sentences = first_sentences[:2048]
        second_sentences = second_sentences[:2048]
        labels = labels[:2048]
    
    print("start tokenizing ...")
    # Tokenize
    prot_tokens = prot_tokenizer(
        first_sentences,
        truncation=True,
        max_length=max_seq_length,
        padding="max_length",
    )["input_ids"]
    # Tokenize
    text_tokens = text_tokenizer(
        second_sentences,
        truncation=True,
        max_length=max_seq_length,
        padding="max_length",
    )["input_ids"]
    
    print("finish tokenizing ...")
    
    inputs = {
    }
    inputs["prot_tokens"] = prot_tokens
    inputs["text_tokens"] = text_tokens
    inputs["labels"] = labels
    return inputs

def convert_tokens_to_tensors(tokens, device='cpu'):
    input_dict = {}
    prot_inputs = torch.tensor(tokens["prot_tokens"], dtype=torch.long, device=device)
    text_inputs = torch.tensor(tokens["text_tokens"], dtype=torch.long, device=device)
    labels_inputs = torch.tensor(tokens["labels"], dtype=torch.float, device=device)
    input_dict["prot_input"] = prot_inputs
    input_dict["text_inputs"] = text_inputs
    input_dict["label_inputs"] = labels_inputs
    return input_dict

# Test
# disGeNET = DisGeNETProcessor()
# examples = disGeNET.get_train_examples() 
# tokens = convert_examples_to_tokens(examples, prot_tokenizer, text_tokenizer, max_seq_length=5,test=True)
# inputs = convert_tokens_to_tensors(tokens, device)

In [6]:
def train_an_epoch(model,train_dataloader,optimizer, loss):
    t_loss = 0 
    for step, batch in enumerate(train_dataloader):
        prot_input, text_inputs, label_inputs = batch
        optimizer.zero_grad()
        out = model(prot_input, text_inputs)
        output = loss(out, label_inputs)
        t_loss += output.item()
        output.backward()
        optimizer.step()
    return t_loss

def evaluate(model,test_dataloader):
    model.eval()
    t_loss = 0 
    for step, batch in enumerate(train_dataloader):
        prot_input, text_inputs, label_inputs = batch
        optimizer.zero_grad()
        out = model(prot_input, text_inputs)
        output = loss(out, label_inputs)
        t_loss += output.item()
        output.backward()
        optimizer.step()
    return t_loss

disGeNET = DisGeNETProcessor()
examples = disGeNET.get_train_examples() 
tokens = convert_examples_to_tokens(examples, prot_tokenizer, text_tokenizer, test=True)
inputs = convert_tokens_to_tensors(tokens, device)
train_data = TensorDataset(
    inputs["prot_input"], inputs["text_inputs"], inputs["label_inputs"]
)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(
    train_data, sampler=train_sampler, batch_size=batch_size
)

model = GDANet(prot_model, text_model, freeze_prot_encoder=True, freeze_disease_encoder=False).to(device)
# model = GDANet(prot_model, text_model, freeze_prot_encoder=True, freeze_disease_encoder=True).to(device)


Found local copy...
Loading...
Done!


start tokenizing ...
finish tokenizing ...
freeze_prot_encoder:True


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
loss = nn.MSELoss()
for epoch in tqdm(range(200), desc="Training"):
    model.train()
    # epoch_loss = train_an_epoch(model, train_dataloader, optimizer, loss)
    t_loss = 0 
    for step, batch in enumerate(train_dataloader):
        prot_input, text_inputs, label_inputs = batch
        optimizer.zero_grad()
        out = model(prot_input, text_inputs)
        output = loss(out, label_inputs)
        t_loss += output.item()
        output.backward()
        optimizer.step()
    print(t_loss)