In [None]:
from utils import *
import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
import pickle

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, input_dim):
        super(AttentionBlock, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.softmax = nn.Softmax(dim=1)  # Apply softmax over the sequence dimension

    def forward(self, x):
        # Compute attention weights
        attention_weights = self.softmax(self.attention(x))
        # Apply weights and maintain embedding dimension
        weighted_output = x * attention_weights
        return weighted_output.sum(dim=1)  # Collapse sequence dimension, retain batch size and embedding dim

In [None]:
# ----------------- Step 5: Deep Learning Model -----------------
class BindingPredictor(nn.Module):
    def __init__(self, embedding_dim):
        super(BindingPredictor, self).__init__()
        self.protein_attention = AttentionBlock(embedding_dim)
        self.chemical_attention = AttentionBlock(embedding_dim)
        self.fc = nn.Sequential(
            nn.Linear(embedding_dim * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 1)#,
            # nn.Sigmoid()
        )

    def forward(self, protein_emb, chemical_emb):
        protein_feat = self.protein_attention(protein_emb)  # Shape: (batch_size, embedding_dim)
        chemical_feat = self.chemical_attention(chemical_emb)  # Shape: (batch_size, embedding_dim)
        combined = torch.cat((protein_feat, chemical_feat), dim=1)  # Concatenate along embedding dimension
        return self.fc(combined)

In [None]:
# Load model checkpoint
model_path = "./classification_model_w_attention.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_dim = 320 # the size of the embeddings of the "facebook/esm2_t6_8M_UR50D" model
model = BindingPredictor(embedding_dim)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

In [None]:
# Prepare scaler (use the one fitted during training)
scaler_path = "./scaler.pkl"
with open(scaler_path, "rb") as f:
    scaler = pickle.load(f)

In [None]:
# Inference pipeline
def predict_binding(uniprot_id, pubchem_cid):
    # Generate embeddings
    try:
        protein_embedding = get_protein_embedding(uniprot_id)
    except Exception as e:
        print(f"Error generating protein embedding: {e}")
        return None

    chemical_embedding = generate_random_projection(pubchem_cid)

    # Combine features and scale
    protein_embedding = protein_embedding.reshape(1, -1, embedding_dim)
    chemical_embedding = chemical_embedding.reshape(1, -1, embedding_dim)

    # Convert to PyTorch tensor
    protein_tensor = torch.tensor(protein_embedding, dtype=torch.float32).to(device)
    chemical_tensor = torch.tensor(chemical_embedding, dtype=torch.float32).to(device)

    # Make prediction
    with torch.no_grad():
        output = model(protein_tensor, chemical_tensor).squeeze()
        predicted_score = output.item()

    return predicted_score

In [None]:
# Example usage
uniprot_id = "P12345"  # an example of UniProt ID
pubchem_cid = "123456"  # an example of PubChem CID
predicted = predict_binding(uniprot_id, pubchem_cid)
if predicted_score is not None:
    print(f"Predicted KIBA score: {predicted:.4f}")
