In [None]:
import os
import sys

# This is to force the path to be on the same level as the dl_ba folder
sys.path.append("../..") 

from transformers import AutoTokenizer
import torch
from datasets import load_dataset

import time

from balm import common_utils
from balm.models.utils import load_trained_model, load_pretrained_pkd_bounds
from balm.configs import Configs
from balm.models import LipidBALM

DEVICE = "cpu"  # Change to "cuda" if using GPU

# Load Pretrained LipidBALM model
config_filepath = "../../configs/lipidbalm_peft.yaml"
configs = Configs(**common_utils.load_yaml(config_filepath))

# Load the model
model = LipidBALM(configs.model_configs)
model = load_trained_model(model, configs.model_configs, is_training=False)
model.to(DEVICE)  # Use this to move the model to the specified device (CPU or GPU)

model.eval()
# Pretrained binding affinity lower and upper bounds
pkd_lower_bound, pkd_upper_bound = load_pretrained_pkd_bounds(configs.model_configs.checkpoint_path)

# Load the tokenizers
protein_tokenizer = AutoTokenizer.from_pretrained(
    configs.model_configs.protein_model_name_or_path
)
lipid_tokenizer = AutoTokenizer.from_pretrained(
    configs.model_configs.lipid_model_name_or_path
)

# Custom Data: load your combined_binding_data.csv  
import pandas as pd

df = pd.read_csv("combined_binding_data.csv")

# Examine the first few rows
df.head(5)

# Zero shot predictions with pretrained model
start = time.time()
predictions = []
labels = []
for _, sample in df.iterrows():
    # Prepare input
    protein_inputs = protein_tokenizer(sample["ProteinSequence"], return_tensors="pt").to(DEVICE)
    lipid_inputs = lipid_tokenizer(sample["LipidSMILES"], return_tensors="pt").to(DEVICE)
    inputs = {
        "protein_input_ids": protein_inputs["input_ids"],
        "protein_attention_mask": protein_inputs["attention_mask"],
        "lipid_input_ids": lipid_inputs["input_ids"],
        "lipid_attention_mask": lipid_inputs["attention_mask"],
    }
    prediction = model(inputs)["cosine_similarity"]
    prediction = model.cosine_similarity_to_pkd(prediction, pkd_upper_bound=pkd_upper_bound, pkd_lower_bound=pkd_lower_bound)
    label = torch.tensor([sample["BindingAffinityValue"]])

    print(f"Predicted binding affinity: {prediction.item()} | True binding affinity: {label.item()}")
    predictions.append(prediction.item())
    labels.append(label.item())
print(f"Time taken for {len(df)} protein-lipid pairs: {time.time() - start}")

# Visualize results
from balm.metrics import get_ci, get_pearson, get_rmse, get_spearman
import seaborn as sns

rmse = get_rmse(torch.tensor(labels), torch.tensor(predictions))
pearson = get_pearson(torch.tensor(labels), torch.tensor(predictions))
spearman = get_spearman(torch.tensor(labels), torch.tensor(predictions))
ci = get_ci(torch.tensor(labels), torch.tensor(predictions))

print(f"RMSE: {rmse}")
print(f"Pearson: {pearson}")
print(f"Spearman: {spearman}")
print(f"CI: {ci}")

ax = sns.regplot(x=labels, y=predictions)
ax.set_title(f"Protein-Lipid Binding Affinity Prediction")
ax.set_xlabel(r"Experimental Binding Affinity")
ax.set_ylabel(r"Predicted Binding Affinity")

# Few shot training
from sklearn.model_selection import train_test_split

# Split data
train_data, test_data = train_test_split(df, train_size=0.2, random_state=1234)

# Define a function that applies the cosine similarity conversion to a single example
# This is VERY IMPORTANT since LipidBALM uses cosine similarity
def add_cosine_similarity(example, pkd_upper_bound, pkd_lower_bound):
    example['cosine_similarity'] = (
        (example['BindingAffinityValue'] - pkd_lower_bound)
        / (pkd_upper_bound - pkd_lower_bound)
        * 2
        - 1
    )
    return example

# Use map to apply the function across the entire dataset
train_data = train_data.apply(lambda x: add_cosine_similarity(x, pkd_upper_bound, pkd_lower_bound), axis=1)
test_data = test_data.apply(lambda x: add_cosine_similarity(x, pkd_upper_bound, pkd_lower_bound), axis=1)

print(f"Number of train data: {len(train_data)}")
print(f"Number of test data: {len(test_data)}")

# Initialize model for fine-tuning
model = LipidBALM(configs.model_configs)
model = load_trained_model(model, configs.model_configs, is_training=True)
model.to(DEVICE)

# Training loop
from torch.optim import AdamW

NUM_EPOCHS = 10
optimizer = AdamW(
    params=[
        param
        for name, param in model.named_parameters()
        if param.requires_grad
    ],
    lr=configs.model_configs.model_hyperparameters.learning_rate,
)

start = time.time()
for epoch in range(NUM_EPOCHS):
    model.train()  # Set the model to training mode
    total_loss = 0.0  # To track the loss for each epoch

    for _, sample in train_data.iterrows():
        # Prepare input
        protein_inputs = protein_tokenizer(sample["ProteinSequence"], return_tensors="pt").to(DEVICE)
        lipid_inputs = lipid_tokenizer(sample["LipidSMILES"], return_tensors="pt").to(DEVICE)
        # Move labels to the appropriate device and ensure it's a tensor
        labels = torch.tensor([sample["cosine_similarity"]], dtype=torch.float32).to(DEVICE)

        inputs = {
            "protein_input_ids": protein_inputs["input_ids"],
            "protein_attention_mask": protein_inputs["attention_mask"],
            "lipid_input_ids": lipid_inputs["input_ids"],
            "lipid_attention_mask": lipid_inputs["attention_mask"],
            "labels": labels,  # Add labels for training
        }

        # Forward pass
        outputs = model(inputs)

        # Get loss
        loss = outputs["loss"]

        # Backpropagation
        optimizer.zero_grad()  # Zero out the gradients to avoid accumulation
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model parameters

        total_loss += loss.item()
    
    # Log the loss for this epoch
    avg_loss = total_loss / len(train_data)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {avg_loss:.4f}")

print("Training complete! Time taken: ", time.time() - start)

# Test the fine-tuned model
model = model.eval()

predictions = []
labels = []
for _, sample in test_data.iterrows():
    # Prepare input
    protein_inputs = protein_tokenizer(sample["ProteinSequence"], return_tensors="pt").to(DEVICE)
    lipid_inputs = lipid_tokenizer(sample["LipidSMILES"], return_tensors="pt").to(DEVICE)
    inputs = {
        "protein_input_ids": protein_inputs["input_ids"],
        "protein_attention_mask": protein_inputs["attention_mask"],
        "lipid_input_ids": lipid_inputs["input_ids"],
        "lipid_attention_mask": lipid_inputs["attention_mask"],
    }
    prediction = model(inputs)["cosine_similarity"]
    prediction = model.cosine_similarity_to_pkd(prediction, pkd_upper_bound=pkd_upper_bound, pkd_lower_bound=pkd_lower_bound)
    label = torch.tensor([sample["BindingAffinityValue"]])

    print(f"Predicted binding affinity: {prediction.item()} | True binding affinity: {label.item()}")
    predictions.append(prediction.item())
    labels.append(label.item())
print(f"Time taken for {len(test_data)} protein-lipid pairs: {time.time() - start}")

# Visualize results after fine-tuning
rmse = get_rmse(torch.tensor(labels), torch.tensor(predictions))
pearson = get_pearson(torch.tensor(labels), torch.tensor(predictions))
spearman = get_spearman(torch.tensor(labels), torch.tensor(predictions))
ci = get_ci(torch.tensor(labels), torch.tensor(predictions))

print(f"RMSE: {rmse}")
print(f"Pearson: {pearson}")
print(f"Spearman: {spearman}")
print(f"CI: {ci}")

ax = sns.regplot(x=labels, y=predictions)
ax.set_title(f"Fine-tuned Protein-Lipid Model")
ax.set_xlabel(r"Experimental Binding Affinity")
ax.set_ylabel(r"Predicted Binding Affinity")