In [1]:
# !pip install PyTDC

### Imports & Hyperprameters

In [2]:
import torch
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from transformers import BertModel,BertTokenizer
import time
from datetime import datetime
import os
import sys
import random
import string
import requests
from tqdm.auto import tqdm
import wandb

### Load encoders

In [3]:
device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
freeze_prot_encoder = False
freeze_disease_encoder = False
batch_size = 6

prot_encoder_path = "Rostlab/prot_bert"
text_encoder_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

prot_tokenizer = BertTokenizer.from_pretrained(prot_encoder_path, do_lower_case=False )
prot_model = BertModel.from_pretrained(prot_encoder_path)
prot_model = prot_model.to(device)

text_tokenizer = BertTokenizer.from_pretrained(text_encoder_path)
text_model = BertModel.from_pretrained(text_encoder_path)
text_model = text_model.to(device)

### Model Class

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GDANet(torch.nn.Module):
    def __init__(self, prot_encoder, disease_encoder, prot_out_dim=1024, disease_out_dim=768, drop_out=0, freeze_prot_encoder=True,  freeze_disease_encoder=True):
        super(GDANet, self).__init__()
        self.prot_encoder = prot_encoder
        self.disease_encoder = disease_encoder
        self.freeze_encoders(freeze_prot_encoder,freeze_disease_encoder)
        
        self.reg = nn.Sequential(
            nn.Linear(prot_out_dim + disease_out_dim, 1024),
            nn.ReLU(),
            nn.Dropout(drop_out),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(drop_out),
            nn.Linear(512, 1)
        )
    
    def freeze_encoders(self,freeze_prot_encoder,freeze_disease_encoder):
        if freeze_prot_encoder:
            print(f"freeze_prot_encoder:{freeze_prot_encoder}")
            for param in self.prot_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.disease_encoder.parameters():
                param.requires_grad = True
                
        if freeze_disease_encoder:
            print(f"freeze_disease_encoder:{freeze_disease_encoder}")
            for param in self.disease_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.disease_encoder.parameters():
                param.requires_grad = True
                
    def forward(self, x1, x2):
        x1 = self.prot_encoder(x1)[0][:,0]
        x2 = self.disease_encoder(x2)[0][:,0]
        x = torch.cat((x1, x2), 1)
        x = self.reg(x)
        return x

### Data Processor

In [5]:
class DisGeNETProcessor():
    def __init__(self, data_dir="./data"):
        from tdc.multi_pred import GDA
        data = GDA(name = 'DisGeNET',path=data_dir)
        self.datasets = data.get_split()

    def get_train_examples(self):
        return self.datasets["train"]["Gene"].values,self.datasets["train"]["Disease"].values,self.datasets["train"]["Y"].values

    def get_dev_examples(self):
        return self.datasets["valid"]["Gene"].values,self.datasets["valid"]["Disease"].values,self.datasets["valid"]["Y"].values

    def get_test_examples(self):
        return self.datasets["test"]["Gene"].values,self.datasets["test"]["Disease"].values,self.datasets["test"]["Y"].values
    
def convert_examples_to_tokens(examples, prot_tokenizer, text_tokenizer, max_seq_length=512, test=False):

    first_sentences = []
    second_sentences = []
    labels = []
    for gene, disease,label in zip(*examples):
        first_sentences.append([" ".join(gene)])
        second_sentences.append([disease])
        labels.append([label])

    # Flatten out
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])
    if test:
        first_sentences = first_sentences[:1024]
        second_sentences = second_sentences[:1024]
        labels = labels[:1024]
    
    print("start tokenizing ...")
    # Tokenize
    prot_tokens = prot_tokenizer(
        first_sentences,
        truncation=True,
        max_length=max_seq_length,
        padding="max_length",
    )["input_ids"]
    # Tokenize
    text_tokens = text_tokenizer(
        second_sentences,
        truncation=True,
        max_length=max_seq_length,
        padding="max_length",
    )["input_ids"]
    
    print("finish tokenizing ...")
    
    inputs = {
    }
    inputs["prot_tokens"] = prot_tokens
    inputs["text_tokens"] = text_tokens
    inputs["labels"] = labels
    return inputs

def convert_tokens_to_tensors(tokens, device='cpu'):
    input_dict = {}
    prot_inputs = torch.tensor(tokens["prot_tokens"], dtype=torch.long, device=device)
    text_inputs = torch.tensor(tokens["text_tokens"], dtype=torch.long, device=device)
    labels_inputs = torch.tensor(tokens["labels"], dtype=torch.float, device=device)
    input_dict["prot_input"] = prot_inputs
    input_dict["text_inputs"] = text_inputs
    input_dict["label_inputs"] = labels_inputs
    return input_dict

# Test
# disGeNET = DisGeNETProcessor()
# examples = disGeNET.get_train_examples() 
# tokens = convert_examples_to_tokens(examples, prot_tokenizer, text_tokenizer, max_seq_length=5,test=True)
# inputs = convert_tokens_to_tensors(tokens, device)

### Model Trainer

In [6]:
def train_an_epoch(model,train_dataloader,optimizer, loss):
    model.train()
    t_loss = 0 
    for step, batch in enumerate(train_dataloader):
        prot_input, text_inputs, label_inputs = batch
        optimizer.zero_grad()
        out = model(prot_input, text_inputs)
        output = loss(out, label_inputs)
        t_loss += output.item()
        output.backward()
        optimizer.step()
    return t_loss

def evaluate(model,test_dataloader, metric):
    model.eval()
    metric_val = 0 
    with torch.no_grad():
        for step, batch in enumerate(test_dataloader):
            prot_input, text_inputs, label_inputs = batch
            optimizer.zero_grad()
            out = model(prot_input, text_inputs)
            output = metric(out, label_inputs)
            metric_val += output.item()
    return metric_val/(step+1)

disGeNET = DisGeNETProcessor()
# Train data_loader
examples = disGeNET.get_train_examples() 
# tokens = convert_examples_to_tokens(examples, prot_tokenizer, text_tokenizer, test=True)  ## Trun the test off if doing the real training
tokens = convert_examples_to_tokens(examples, prot_tokenizer, text_tokenizer)

inputs = convert_tokens_to_tensors(tokens, device)
train_data = TensorDataset(
    inputs["prot_input"], inputs["text_inputs"], inputs["label_inputs"]
)
train_sampler = RandomSampler(train_data)

# Validation data_loader
valid_examples = disGeNET.get_dev_examples() 
# valid_tokens = convert_examples_to_tokens(valid_examples, prot_tokenizer, text_tokenizer, test=True)  ## Trun the test off if doing the real training
valid_tokens = convert_examples_to_tokens(valid_examples, prot_tokenizer, text_tokenizer)

valid_inputs = convert_tokens_to_tensors(valid_tokens, device)
valid_data = TensorDataset(
    valid_inputs["prot_input"], valid_inputs["text_inputs"], valid_inputs["label_inputs"]
)
valid_sampler = RandomSampler(valid_data)

# Test data_loader
test_examples = disGeNET.get_test_examples() 
# test_tokens = convert_examples_to_tokens(test_examples, prot_tokenizer, text_tokenizer, test=True)  ## Trun the test off if doing the real training
test_tokens = convert_examples_to_tokens(test_examples, prot_tokenizer, text_tokenizer)

test_inputs = convert_tokens_to_tensors(test_tokens, device)
test_data = TensorDataset(
    test_inputs["prot_input"], test_inputs["text_inputs"], test_inputs["label_inputs"]
)
test_sampler = RandomSampler(test_data)

Found local copy...
Loading...
Done!


start tokenizing ...
finish tokenizing ...
start tokenizing ...
finish tokenizing ...
start tokenizing ...
finish tokenizing ...


In [None]:
for lr in [1e-4]:
    # lr = 1e-4
    patience = 5
    dropout = 0.0
    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    random_str = "".join([random.choice(string.ascii_lowercase) for n in range(6)])
    best_model_dir = f"../model/model_{timestamp_str}_{random_str}/"

    if not os.path.exists(best_model_dir):
        os.makedirs(best_model_dir, exist_ok = False) 

    wandb.init(project="protbert")
    args = {}
    args["batch_size"] = batch_size
    args["lr"] = lr
    args["dropout"] = dropout
    args["best_model_dir"] = best_model_dir
    args["prot_encoder_path"] = prot_encoder_path
    args["batch_size"] = batch_size
    args["patience"] = patience
    args["best_model_dir"] = best_model_dir
    args["freeze_prot_encoder"] = freeze_prot_encoder
    args["freeze_disease_encoder"] = freeze_disease_encoder
    args["device"] = device


    train_dataloader = DataLoader(
        train_data, sampler=train_sampler, batch_size=batch_size
    )
    valid_dataloader = DataLoader(
        valid_data, sampler=valid_sampler, batch_size=batch_size
    )
    test_dataloader = DataLoader(
        test_data, sampler=test_sampler, batch_size=batch_size
    )
    wandb.config.update(args)
    model = GDANet(prot_model, text_model, freeze_prot_encoder=True, freeze_disease_encoder=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    loss = nn.MSELoss()
    metric = nn.MSELoss(reduction='sum')
    best_dev_mse = sys.maxsize
    unimproved_iters = 0
    for epoch in tqdm(range(200), desc="Training"):
        epoch_loss = train_an_epoch(model, train_dataloader, optimizer, loss)
        print("epoch_loss:\t",epoch_loss)
        valid_mse = evaluate(model,valid_dataloader,metric)
        test_mse = evaluate(model,test_dataloader,metric)
        print("valid_mse:\t",valid_mse)
        print("test_mse:\t",test_mse)
        wandb.log({'epoch':epoch,'epoch_loss': epoch_loss, 'valid_mse': valid_mse, 'test_mse': test_mse})
        # Update validation results
        if valid_mse < best_dev_mse:
            unimproved_iters = 0
            best_dev_mse = valid_mse
            torch.save(model, best_model_dir + "model.bin")
        else:
            unimproved_iters += 1
            if unimproved_iters >= patience:
                early_stop = True
                tqdm.write(f"Early Stopped on Epoch: {epoch}, Best Dev MSE: {best_dev_mse}")
                break
    wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmengzaiqiao[0m (use `wandb login --relogin` to force relogin)


freeze_prot_encoder:True
freeze_disease_encoder:True


Training:   0%|          | 0/200 [00:00<?, ?it/s]