In [1]:
!pip install transformers
!pip install accelerate peft
!pip install tqdm
!pip install rdkit
!pip install PyTDC



### Model Wrapper Classes

In [51]:
import torch.nn as nn
import torch

class MoELoraModel(nn.Module):
  def __init__(self, model, routing_network, config, num_experts):
    super().__init__()
    self.num_experts = num_experts
    self.experts_model = LoraModel(model, config, "default")
    self.experts_list = []

    for i in range(num_experts):
      self.experts_model.add_weighted_adapter(["default"],[1],f"expert_{i}")
      self.experts_list.append(f"expert_{i}")

    self.routing_network = routing_network

  def forward(self, inputs_dict):
    model_inputs = inputs_dict["model_inputs"]
    routing_network_inputs = inputs_dict["routing_network_inputs"]
    self.route(**routing_network_inputs)
    return self.experts_model(**model_inputs)

  def route(self, **inputs):
    logits = self.routing_network(**inputs)  # logits should be shape (num_experts, )
    expert_probabilities = self.topk_with_softmax(logits)
    chosen_expert = torch.multinomial(expert_probabilities, 1)
    self.choose_expert(chosen_expert.item())

  def topk_with_softmax(self, logits):
    values, indices = torch.topk(logits,2)
    ret = torch.zeros(self.num_experts,device=next(self.parameters()).device)
    values = values/torch.norm(values)  # todo: probably remove
    expert_weights = torch.softmax(values.float(), dim=1)
    ret[indices] = expert_weights
    return ret

  def choose_expert(self,expert_num):
    self.experts_model.disable_adapter_layers()
    self.experts_model.set_adapter(f"expert_{expert_num}")

In [176]:
class RoutingNetworkFromTransformer(nn.Module):
  def __init__(self, model, num_experts, embedding_dim=384):
    super().__init__()
    self.num_experts = num_experts
    self.last_layer = nn.Sequential(nn.Linear(embedding_dim, num_experts), nn.Softmax())
    self.model = model

  def forward(self, **inputs):
    outputs = self.model(**inputs)

    # Extract the hidden states
    if hasattr(outputs, "last_hidden_state"): # Depends on the pretrained model backbone
      hidden_states = outputs.last_hidden_state
    else:
      hidden_states = outputs.hidden_states[-1]

    # Aggregate hidden states to get a single vector representation (e.g., mean pooling)
    embeddings = torch.mean(hidden_states, dim=1)
    return self.last_layer(embeddings)

### Load Pretrained Drug Model

In [177]:
from transformers import RobertaModel, RobertaTokenizer
import torch

drug_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/ChemBERTaLM")
pretrained_drug_model = RobertaModel.from_pretrained("gokceuludogan/ChemBERTaLM", output_hidden_states=True)

Some weights of RobertaModel were not initialized from the model checkpoint at gokceuludogan/ChemBERTaLM and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Load Pre-trained Target Model

In [178]:
from transformers import AutoConfig,AutoModelForSequenceClassification, AutoModel, AutoTokenizer

config = AutoConfig.from_pretrained("facebook/esm2_t33_650M_UR50D", output_hidden_states=True)
target_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
pretrained_target_model = AutoModelForSequenceClassification.from_config(config)
num_lora_experts = 8

### Load Pre-Trained Routing Network Backbone

In [179]:
routing_network_backbone = AutoModel.from_pretrained("DeepChem/ChemBERTa-10M-MLM")
routing_network_tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-10M-MLM")

routing_network = RoutingNetworkFromTransformer(routing_network_backbone, num_lora_experts)

Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-10M-MLM and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Create Proposed Model

In [180]:
from peft import LoraModel, LoraConfig
lora_config = LoraConfig(
    task_type="SEQ_CLS",
    target_modules=["query","value"],
    r=8,
    lora_alpha=32,
    lora_dropout=0.01,
)

In [181]:
drug_router = RoutingNetworkFromTransformer(pretrained_drug_model, num_lora_experts, embedding_dim=768)
target_router = RoutingNetworkFromTransformer(pretrained_target_model, num_lora_experts, embedding_dim=1280)

In [182]:
drug_model = MoELoraModel(pretrained_drug_model, drug_router, lora_config, num_lora_experts)
target_model = MoELoraModel(pretrained_target_model, target_router, lora_config, num_lora_experts)
regressor = nn.Sequential(nn.Linear(2048,128),nn.ReLU(),nn.Linear(128,1))

In [183]:
# for param in model.experts_model.classifier.parameters():
#   param.requires_grad = True

### Setup Optimizers

In [184]:
combined_parameters = list(drug_model.parameters()) + list(target_model.parameters()) + list(regressor.parameters())
optimizer = torch.optim.AdamW(combined_parameters, lr=5e-5)

### Embedding Helper Functions

In [185]:
def get_embeddings(tokenizer,model, input,device):
    # Encode the SMILES sequence
    encoded_input = tokenizer(input, return_tensors="pt").to(device)
    # routing_network_inputs = router_tokenizer(input, return_tensors="pt").to(device)

    inputs_dict = {
        "model_inputs" : encoded_input,
        "routing_network_inputs": encoded_input
    }

    # Get model outputs
    with torch.no_grad():
        outputs = model(inputs_dict)

    # Extract the hidden states
    if hasattr(outputs, "last_hidden_state"): # Depends on the pretrained model backbone
      hidden_states = outputs.last_hidden_state
    else:
      hidden_states = outputs.hidden_states[-1]

    # Aggregate hidden states to get a single vector representation (e.g., mean pooling)
    embeddings = torch.mean(hidden_states, dim=1)
    return embeddings

In [186]:
mse_loss_fn = nn.MSELoss()

### Trainer

In [187]:
from tqdm import tqdm

def train(
    drug_model,
    target_model,
    regressor,
    drug_tokenizer,
    target_tokenizer,
    router_tokenizer,
    optimizer,
    data_loader,
    get_embeddings_fn,
    loss_fn,
    device,
    num_epochs
):
  for epoch in range(num_epochs):
    for input in tqdm(data_loader, desc=f"Training epoch {epoch}"):
      drug_smiles = input['drug']
      target_seq = input['target']
      target_affinity = torch.tensor(input['affinity'],dtype=torch.float).to(device)

      drug_embeddings = get_embeddings_fn(drug_tokenizer,drug_model,drug_smiles,device)
      target_embeddings = get_embeddings_fn(target_tokenizer,target_model,target_seq,device)

      all_embeds = torch.cat((drug_embeddings, target_embeddings), dim=1)

      predicted_affinity = regressor(all_embeds)

      loss = loss_fn(predicted_affinity,target_affinity)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

In [188]:
from tdc.multi_pred import DTI
data = DTI(name = 'DAVIS')
split = data.get_split()

Found local copy...
Loading...
Done!


In [189]:
from torch.utils.data import Dataset, DataLoader
class DavisDataset(Dataset):
    def __init__(self, data):
        # Assuming data is a dictionary with 'train', 'valid', 'test' splits
        # Concatenate training, validation, and test sets if needed
        # Or you can adjust the code to use only one of the splits
        self.data = data['train']
        self.drug = self.data['Drug']
        self.target = self.data['Target']
        self.affinity = self.data['Y']

        # Here, additional preprocessing can be done (e.g., tokenization)

    def __len__(self):
        return len(self.affinity)

    def __getitem__(self, idx):
        return {
            "drug": self.drug[idx],
            "target": self.target[idx],
            "affinity": torch.tensor(self.affinity[idx], dtype=torch.float)
        }

In [190]:
davis_dataset = DavisDataset(split)
dataloader = DataLoader(davis_dataset, batch_size=1, shuffle=True)

In [191]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

drug_model.to(device)
target_model.to(device)
regressor.to(device)

Sequential(
  (0): Linear(in_features=2048, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=1, bias=True)
)

In [None]:
train(
    drug_model,
    target_model,
    regressor,
    drug_tokenizer,
    target_tokenizer,
    routing_network_tokenizer,
    optimizer,
    dataloader,
    get_embeddings,
    mse_loss_fn,
    device,
    5
)

Training epoch 0:   1%|          | 206/18041 [00:59<1:18:27,  3.79it/s]

In [None]:
# Check device of a parameter after calling .to(device)
print(next(drug_model.parameters()).device)  # Should output 'cuda' if a GPU is available, otherwise 'cpu'