In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
from transformers import BertModel, BertTokenizer, AutoModelForMaskedLM, AutoTokenizer
import re
import os
import requests
from tqdm.auto import tqdm
import pandas as pd
import numpy as np

In [None]:
protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protein_model = BertModel.from_pretrained("Rostlab/prot_bert")
protein_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
protein_model = protein_model.to(protein_device)
protein_model.eval()

In [None]:
file_path = './all.parquet'
df = pd.read_parquet(file_path)
seed = 42
sampled_df = df.sample(n=100000, random_state=seed)
protein_sequences = sampled_df.iloc[:, 1].tolist()
protein_sequences = [re.sub(r"[UZOB]", "X", sequence) for sequence in protein_sequences]

In [None]:
# Process in batches
batch_size = 32  
protein_features = []

for i in tqdm(range(0, len(protein_sequences), batch_size)):
    batch = protein_sequences[i:i+batch_size]
    
    # Tokenize
    ids = protein_tokenizer.batch_encode_plus(
        batch,
        add_special_tokens=True,
        padding='max_length',
        max_length=3200,
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = ids['input_ids'].to(protein_device)
    attention_mask = ids['attention_mask'].to(protein_device)

    with torch.no_grad():
        outputs = protein_model(input_ids=input_ids, attention_mask=attention_mask)
        embedding = outputs[0]

    for seq_num in range(len(embedding)):
        seq_len = attention_mask[seq_num].sum().item()
        seq_emd = embedding[seq_num][1:seq_len-1].cpu().numpy()
        avg_emd = seq_emd.mean(dim=0).cpu().numpy()
        protein_features.append(seq_emd)

    del input_ids, attention_mask, embedding
    torch.cuda.empty_cache()

In [None]:
protein_features = np.array(protein_features)

In [None]:
ligand_model_name = "DeepChem/ChemBERTa-10M-MLM"
ligand_tokenizer = AutoTokenizer.from_pretrained(ligand_model_name)
ligand_model = AutoModelForMaskedLM.from_pretrained(ligand_model_name)

In [None]:
ligand_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ligand_model = ligand_model.to(ligand_device)
ligand_model.eval()

In [None]:
smiles_tokenizer = AutoTokenizer.from_pretrained("DeepChem/SmilesTokenizer_PubChem_1M")

In [None]:
smiles_list = sampled_df.iloc[:, 2].tolist()

In [None]:
tokenized_smiles = smiles_tokenizer(
    smiles_list,
    padding='max_length',     
    truncation=True,          
    max_length=278,           
    return_tensors='pt'        
)

In [None]:
batch_size = 32
ligand_features = []

for i in tqdm(range(0, len(tokenized_smiles), batch_size), desc="Extracting features"):
    batch = tokenized_smiles[i:i+batch_size]
    
    input_ids = torch.cat([item['input_ids'] for item in batch]).to(ligand_device)
    attention_mask = torch.cat([item['attention_mask'] for item in batch]).to(ligand_device)
    
    with torch.no_grad():
        outputs = ligand_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        
    last_hidden_state = outputs.hidden_states[-1]
    
    batch_features = last_hidden_state.mean(dim=1).cpu().numpy()
    ligand_features.extend(batch_features)
    
    del input_ids, attention_mask, last_hidden_state
    torch.cuda.empty_cache()

In [None]:
ligand_features = np.array(ligand_features)

In [None]:
def feature_transformation(X, W):
    return X @ W

In [None]:
def l2_normalize(X):
    return X / np.linalg.norm(X, axis=1, keepdims=True)

In [None]:
class ProteinLigandModel(nn.Module):
    def __init__(self, input_dim):
        super(ProteinLigandModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc2 = nn.Linear(512, 64)
        self.fc3 = nn.Linear(64, 1)
        self.dropout = nn.Dropout(0.2)
        self.norm = nn.BatchNorm1d(input_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.norm(x)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

In [None]:
dissociation_constants = sampled_df.iloc[:, 3].values

In [None]:
W_protein = np.random.randn(1024, 1024)
W_ligand = np.random.randn(768, 768)

protein_features_transformed = feature_transformation(protein_features, W_protein)
ligand_features_transformed = feature_transformation(ligand_features, W_ligand)

protein_features_normalized = l2_normalize(protein_features_transformed)
ligand_features_normalized = l2_normalize(ligand_features_transformed)

In [None]:
X = np.concatenate((protein_features_normalized, ligand_features_normalized), axis=1)
y = dissociation_constants

In [None]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=103)

In [None]:
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.FloatTensor(y_train).unsqueeze(1)
X_val_tensor = torch.FloatTensor(X_val)
y_val_tensor = torch.FloatTensor(y_val).unsqueeze(1)

In [None]:
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

In [None]:
model = ProteinLigandModel(X.shape[1]) 
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

In [None]:
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
    
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val_tensor)
        val_loss = criterion(val_outputs, y_val_tensor)
    
    scheduler.step(val_loss)
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss.item()}')

In [None]:
model.eval()
with torch.no_grad():
    y_pred = model(X_val_tensor).numpy().flatten()

y_true = y_val

In [None]:
pearson_corr, _ = pearsonr(y_true, y_pred)
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
mae = mean_absolute_error(y_true, y_pred)