In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv1D, GlobalAveragePooling1D, Dense
from tensorflow.keras import layers, Model, Input, regularizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
from tensorflow.keras.losses import Huber

# Set working directory
working_dir = "DL-M"

In [None]:
drug_file = os.path.join(working_dir, "Drugs.csv")
chemberta_output = os.path.join(working_dir, "ChemBERTa_drug_embeddings.csv")
morgan_output = os.path.join(working_dir, "Morgan_fingerprints.csv")

# Load drug data
drugs_df = pd.read_csv(drug_file)

# Initialize ChemBERTa model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer_chemberta = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model_chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device)

# Sliding window tokenization
def tokenize_smiles(smiles, max_length=512, stride=256):
    tokens = tokenizer_chemberta(smiles, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    input_ids = tokens["input_ids"]
    if input_ids.shape[1] > max_length:
        chunks = [input_ids[:, i:i+max_length] for i in range(0, input_ids.shape[1], stride)]
        return chunks
    return [input_ids]

# Extract ChemBERTa embeddings
def get_chemberta_embedding(smiles):
    tokens_list = tokenize_smiles(smiles)
    embeddings = []
    for tokens in tokens_list:
        tokens = tokens.to(device)
        with torch.no_grad():
            outputs = model_chemberta(input_ids=tokens)
        embeddings.append(outputs.last_hidden_state.mean(dim=1).cpu().numpy())
    return np.mean(np.vstack(embeddings), axis=0)

# Compute Morgan fingerprints
def get_morgan_fingerprint(smiles, radius=2, n_bits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return np.zeros(n_bits)
    return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits))

# Process drugs
drug_embeddings = []
drug_fingerprints = []

for _, row in tqdm(drugs_df.iterrows(), total=len(drugs_df)):
    smiles = row['Ligand_SMILES']
    drug_id = row['DrugBank_ID']
    
    chemberta_emb = get_chemberta_embedding(smiles)
    morgan_fp = get_morgan_fingerprint(smiles)
    
    drug_embeddings.append([drug_id] + chemberta_emb.tolist())
    drug_fingerprints.append([drug_id] + morgan_fp.tolist())

# Convert to DataFrames
drug_embeddings_df = pd.DataFrame(drug_embeddings, columns=['DrugBank_ID'] + [f"embedding_dim_{i+1}" for i in range(chemberta_emb.shape[0])])
drug_fingerprints_df = pd.DataFrame(drug_fingerprints, columns=['DrugBank_ID'] + [f"fingerprint_{i+1}" for i in range(morgan_fp.shape[0])])

# Normalize features
scaler = StandardScaler()
drug_embeddings_df.iloc[:, 1:] = scaler.fit_transform(drug_embeddings_df.iloc[:, 1:])
drug_fingerprints_df.iloc[:, 1:] = scaler.fit_transform(drug_fingerprints_df.iloc[:, 1:])

# Save feature files
drug_embeddings_df.to_csv(chemberta_output, index=False)
drug_fingerprints_df.to_csv(morgan_output, index=False)

In [None]:
protein_file = os.path.join(working_dir, "Proteins.csv")
protT5_output = os.path.join(working_dir, "ProtT5_protein_embeddings.csv")
cnn_output = os.path.join(working_dir, "CNN_protein_embeddings.csv")

# Load protein data
proteins_df = pd.read_csv(protein_file)

# Initialize ProtT5 model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer_protT5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model_protT5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50").to(device)

# Sliding window tokenization
def tokenize_sequence(sequence, max_length=512, stride=256):
    tokens = tokenizer_protT5(sequence, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    input_ids = tokens["input_ids"]
    if input_ids.shape[1] > max_length:
        chunks = [input_ids[:, i:i+max_length] for i in range(0, input_ids.shape[1], stride)]
        return chunks
    return [input_ids]

# Extract ProtT5 embeddings
def get_protT5_embedding(sequence):
    tokens_list = tokenize_sequence(sequence)
    embeddings = []
    for tokens in tokens_list:
        tokens = tokens.to(device)
        with torch.no_grad():
            outputs = model_protT5(input_ids=tokens)
        embeddings.append(outputs.last_hidden_state.mean(dim=1).cpu().numpy())
    return np.mean(np.vstack(embeddings), axis=0)

# CNN Model for Non-Transformer Embeddings
def build_cnn_model(input_length=1000, embedding_dim=128):
    model = Sequential([
        Conv1D(64, kernel_size=3, activation='relu', input_shape=(input_length, 1)),
        GlobalAveragePooling1D(),
        Dense(embedding_dim, activation='relu')
    ])
    return model

cnn_model = build_cnn_model()

# Process proteins
protein_embeddings = []
cnn_embeddings = []

for _, row in tqdm(proteins_df.iterrows(), total=len(proteins_df)):
    sequence = row['Sequence']
    protein_id = row['UniProt_ID']
    
    protT5_emb = get_protT5_embedding(sequence)
    sequence_array = np.array([ord(aa) for aa in sequence])[:1000]  # Convert AA to integer representation
    sequence_array = sequence_array.reshape(1, -1, 1)
    cnn_emb = cnn_model.predict(sequence_array, verbose=0).flatten()
    
    protein_embeddings.append([protein_id] + protT5_emb.tolist())
    cnn_embeddings.append([protein_id] + cnn_emb.tolist())

# Convert to DataFrames
protein_embeddings_df = pd.DataFrame(protein_embeddings, columns=['UniProt_ID'] + [f"embedding_dim_{i+1}" for i in range(protT5_emb.shape[0])])
cnn_embeddings_df = pd.DataFrame(cnn_embeddings, columns=['UniProt_ID'] + [f"cnn_dim_{i+1}" for i in range(cnn_emb.shape[0])])

# Normalize features
scaler = StandardScaler()
protein_embeddings_df.iloc[:, 1:] = scaler.fit_transform(protein_embeddings_df.iloc[:, 1:])
cnn_embeddings_df.iloc[:, 1:] = scaler.fit_transform(cnn_embeddings_df.iloc[:, 1:])

# Save feature files
protein_embeddings_df.to_csv(protT5_output, index=False)
cnn_embeddings_df.to_csv(cnn_output, index=False)

In [None]:
drug_chemberta_file = os.path.join(working_dir, "ChemBERTa_drug_embeddings.csv")
drug_morgan_file = os.path.join(working_dir, "Morgan_fingerprints.csv")
protein_prott5_file = os.path.join(working_dir, "ProtT5_protein_embeddings.csv")
protein_cnn_file = os.path.join(working_dir, "CNN_protein_embeddings.csv")
link_file = os.path.join(working_dir, "Link.csv")

train_output = os.path.join(working_dir, "train.csv")
val_output = os.path.join(working_dir, "val.csv")
test_output = os.path.join(working_dir, "test.csv")

# Load embeddings
drug_chemberta_df = pd.read_csv(drug_chemberta_file)
drug_morgan_df = pd.read_csv(drug_morgan_file)
protein_prott5_df = pd.read_csv(protein_prott5_file)
protein_cnn_df = pd.read_csv(protein_cnn_file)
link_df = pd.read_csv(link_file)

# Merge drug embeddings
drug_embeddings = drug_chemberta_df.merge(drug_morgan_df, on='DrugBank_ID', how='inner')

# Merge protein embeddings
protein_embeddings = protein_prott5_df.merge(protein_cnn_df, on='UniProt_ID', how='inner')

# Merge with drug-protein interactions
dataset = link_df.merge(drug_embeddings, on='DrugBank_ID', how='inner')
dataset = dataset.merge(protein_embeddings, on='UniProt_ID', how='inner')

# Apply log-normalization on Kd values
dataset['Kd(nM)'] = np.log10(dataset['Kd(nM)'] + 1)

# Extract feature columns
features = dataset.drop(columns=['DrugBank_ID', 'UniProt_ID', 'Kd(nM)'])
target = dataset['Kd(nM)']

# Normalize features
scaler = StandardScaler()
features = scaler.fit_transform(features)

# Ensure minimum bin size for stratified splitting
num_bins = min(10, len(np.unique(target)))
while num_bins > 1:
    discretizer = KBinsDiscretizer(n_bins=num_bins, encode='ordinal', strategy='quantile')
    binned_target = discretizer.fit_transform(target.values.reshape(-1, 1)).flatten()
    bin_counts = np.bincount(binned_target.astype(int))
    if np.min(bin_counts) >= 5:
        stratify_param = binned_target
        break
    else:
        num_bins -= 1

if num_bins == 1:
    stratify_param = None

# Implement Multi-Head Attention for Feature Fusion
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, batch_first=True)
    
    def forward(self, x):
        x = x.unsqueeze(1)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.squeeze(1)

# Convert features to PyTorch tensor
features_tensor = torch.tensor(features, dtype=torch.float32)

# Apply Multi-Head Attention Fusion
attention_layer = MultiHeadAttention(input_dim=features.shape[1], num_heads=8)
features_fused = attention_layer(features_tensor).detach().numpy()

# Split into train and test sets
train_features, test_features, train_target, test_target = train_test_split(
    features_fused, target, test_size=0.2, stratify=stratify_param, random_state=42
)

# Re-bin test set to avoid single-member classes in stratification
num_bins_test = min(10, len(np.unique(test_target)))
while num_bins_test > 1:
    discretizer_test = KBinsDiscretizer(n_bins=num_bins_test, encode='ordinal', strategy='quantile')
    binned_test_target = discretizer_test.fit_transform(test_target.values.reshape(-1, 1)).flatten()
    bin_counts_test = np.bincount(binned_test_target.astype(int))
    if np.min(bin_counts_test) >= 2:
        stratify_test = binned_test_target
        break
    else:
        num_bins_test -= 1

if num_bins_test == 1:
    stratify_test = None

# Split into validation and test sets
val_features, test_features, val_target, test_target = train_test_split(
    test_features, test_target, test_size=0.5, stratify=stratify_test, random_state=42
)

# Save datasets
train_df = pd.DataFrame(train_features)
train_df['Kd(nM)'] = train_target.values
train_df.to_csv(train_output, index=False)

val_df = pd.DataFrame(val_features)
val_df['Kd(nM)'] = val_target.values
val_df.to_csv(val_output, index=False)

test_df = pd.DataFrame(test_features)
test_df['Kd(nM)'] = test_target.values
test_df.to_csv(test_output, index=False)

print("Final datasets saved successfully: train.csv, val.csv, test.csv")

In [None]:
############################################
# 1. Global Config & Paths
############################################
train_file    = os.path.join(working_dir, "train.csv")
val_file      = os.path.join(working_dir, "val.csv")
test_file     = os.path.join(working_dir, "test.csv")

# Dimensions (for your data: 2816 + 1152 = 3968)
drug_dim      = 2816  
prot_dim      = 1152  

# Hyperparams
batch_size    = 256
num_epochs    = 60
patience      = 8
learning_rate = 1e-4
l2_reg        = 1e-5
dropout_rate  = 0.2

# Ensemble
ensemble_size = 3

# Output
best_model_basename = "best_model"
final_ensemble_file = os.path.join(working_dir, "ensemble_preds.npy")
final_plots_prefix  = "ensemble"

############################################
# 2. Cosine Decay Callback
############################################
class CosineDecayRestarts(Callback):
    """
    Cosine decay with restarts for the learning rate.
    - T_max: number of epochs per cycle
    - restart_factor: factor to expand T_max after each cycle
    """
    def __init__(self, initial_lr=1e-4, min_lr=1e-6, T_max=10, restart_factor=1.5):
        super().__init__()
        self.initial_lr = initial_lr
        self.min_lr = min_lr
        self.T_max = T_max
        self.restart_factor = restart_factor
        self.epoch_since_restart = 0
        self.cycle_count = 0

    def on_epoch_begin(self, epoch, logs=None):
        cycle_progress = (self.epoch_since_restart % self.T_max) / float(self.T_max)
        cos_inner = np.pi * cycle_progress
        cos_out = np.cos(cos_inner) + 1
        new_lr = self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * cos_out

        if hasattr(self.model.optimizer, "learning_rate") and isinstance(self.model.optimizer.learning_rate, tf.Variable):
            tf.keras.backend.set_value(self.model.optimizer.learning_rate, new_lr)

    def on_epoch_end(self, epoch, logs=None):
        self.epoch_since_restart += 1
        if self.epoch_since_restart >= self.T_max:
            self.epoch_since_restart = 0
            self.cycle_count += 1
            self.T_max = int(self.T_max * self.restart_factor)

############################################
# 3. Data Loading
############################################
def load_dataset(csv_path):
    df = pd.read_csv(csv_path)
    expected_cols = drug_dim + prot_dim + 1
    if df.shape[1] != expected_cols:
        raise ValueError(f"{csv_path} has {df.shape[1]} columns, expected {expected_cols}.")
    if "Kd(nM)" not in df.columns:
        raise ValueError(f"Column 'Kd(nM)' not found in {csv_path}.")
    
    X_drug = df.iloc[:, :drug_dim].values
    X_prot = df.iloc[:, drug_dim:drug_dim+prot_dim].values
    y      = df["Kd(nM)"].values.astype(np.float32)
    return X_drug, X_prot, y

print("Loading train data...")
X_drug_train, X_prot_train, y_train = load_dataset(train_file)
print("Loading validation data...")
X_drug_val, X_prot_val, y_val = load_dataset(val_file)
print("Loading test data...")
X_drug_test, X_prot_test, y_test = load_dataset(test_file)

############################################
# 4. Build Model
############################################
def build_model(
    drug_dim, 
    prot_dim, 
    hidden_drug=[1024, 512, 256],
    hidden_prot=[1024, 512, 256],
    hidden_merged=[1024, 512, 128, 64],
    dropout=0.2,
    l2_rate=1e-5,
    lr=1e-4
):
    reg = regularizers.l2(l2_rate)

    # Drug sub-network
    drug_input = Input(shape=(drug_dim,), name="drug_input")
    xA = drug_input
    for size in hidden_drug:
        xA = layers.Dense(size, kernel_regularizer=reg)(xA)
        xA = layers.BatchNormalization()(xA)
        xA = layers.ReLU()(xA)
        xA = layers.Dropout(dropout)(xA)

    # Protein sub-network
    prot_input = Input(shape=(prot_dim,), name="protein_input")
    xB = prot_input
    for size in hidden_prot:
        xB = layers.Dense(size, kernel_regularizer=reg)(xB)
        xB = layers.BatchNormalization()(xB)
        xB = layers.ReLU()(xB)
        xB = layers.Dropout(dropout)(xB)

    # Merge sub-networks
    merged = layers.Concatenate(axis=1)([xA, xB])

    # Merged branch with a residual block
    x = layers.Dense(1024, kernel_regularizer=reg)(merged)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    x = layers.Dropout(0.1)(x)

    # Residual block (keeping dimensions at 1024)
    y = layers.Dense(1024, kernel_regularizer=reg)(x)
    y = layers.BatchNormalization()(y)
    y = layers.LeakyReLU()(y)
    y = layers.Dropout(0.1)(y)
    x = layers.add([x, y])

    # Further layers to reduce dimensions
    for size in [512, 128, 64]:
        x = layers.Dense(size, kernel_regularizer=reg)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.1)(x)

    # Output layer
    output = layers.Dense(1, activation='linear', name="Kd_output")(x)

    model = Model(inputs=[drug_input, prot_input], outputs=output)
    model.compile(
        optimizer=Adam(learning_rate=lr),
        loss=Huber(delta=1.0),
        metrics=["mae"]
    )
    return model

############################################
# 5. Training a Single Model
############################################
def train_single_model(seed=0):
    np.random.seed(seed)
    tf.random.set_seed(seed)

    model = build_model(
        drug_dim=drug_dim,
        prot_dim=prot_dim,
        dropout=dropout_rate,
        l2_rate=l2_reg,
        lr=learning_rate
    )

    cos_callback = CosineDecayRestarts(
        initial_lr=learning_rate,
        min_lr=1e-6,
        T_max=10,
        restart_factor=1.5
    )
    early_stopping = EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True)
    model_ckpt_path = os.path.join(working_dir, f"{best_model_basename}_seed{seed}.keras")
    checkpoint = ModelCheckpoint(
        model_ckpt_path,
        monitor='val_loss',
        save_best_only=True
    )

    history = model.fit(
        x={"drug_input": X_drug_train, "protein_input": X_prot_train},
        y=y_train,
        validation_data=(
            {"drug_input": X_drug_val, "protein_input": X_prot_val}, 
            y_val
        ),
        epochs=num_epochs,
        batch_size=batch_size,
        callbacks=[cos_callback, early_stopping, checkpoint],
        verbose=1
    )

    model.load_weights(model_ckpt_path)
    return model, history

############################################
# 6. Ensemble Training
############################################
all_histories = []
models = []

for seed_idx in range(ensemble_size):
    print(f"\n=== Training Ensemble Model #{seed_idx+1}/{ensemble_size} (seed={seed_idx}) ===")
    trained_model, history = train_single_model(seed=seed_idx)
    all_histories.append(history)
    models.append(trained_model)

############################################
# 7. Ensemble Prediction on Test Set
############################################
print("\n=== Ensemble Prediction on Test Set ===")
pred_list = []
for m in models:
    preds = m.predict({"drug_input": X_drug_test, "protein_input": X_prot_test}).flatten()
    pred_list.append(preds)

ensemble_preds = np.mean(pred_list, axis=0)
np.save(final_ensemble_file, ensemble_preds)

mse  = mean_squared_error(y_test, ensemble_preds)
mae  = mean_absolute_error(y_test, ensemble_preds)
rmse = np.sqrt(mse)
r2   = r2_score(y_test, ensemble_preds)

print("\n=== Final Ensemble Test Results ===")
print(f"MSE:  {mse:.4f}")
print(f"MAE:  {mae:.4f}")
print(f"RMSE: {rmse:.4f}")
print(f"R²:   {r2:.4f}")

############################################
# 8. Plotting
############################################
max_epochs = max(len(h.history["loss"]) for h in all_histories)
avg_train_loss = np.zeros(max_epochs)
avg_val_loss   = np.zeros(max_epochs)
count_epochs   = np.zeros(max_epochs)

for hist in all_histories:
    train_l = hist.history["loss"]
    val_l   = hist.history["val_loss"]
    for e in range(len(train_l)):
        avg_train_loss[e] += train_l[e]
        avg_val_loss[e]   += val_l[e]
        count_epochs[e]   += 1

for e in range(max_epochs):
    if count_epochs[e] > 0:
        avg_train_loss[e] /= count_epochs[e]
        avg_val_loss[e]   /= count_epochs[e]

plt.figure(figsize=(8,5))
epochs_range = range(1, max_epochs+1)
plt.plot(epochs_range, avg_train_loss[:max_epochs], label="Train Loss (avg)", linewidth=2)
plt.plot(epochs_range, avg_val_loss[:max_epochs],   label="Val Loss (avg)", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Huber Loss")
plt.title("Ensemble: Training vs Validation Loss (Averaged)")
plt.legend()
loss_curve_path = os.path.join(working_dir, f"{final_plots_prefix}_loss_curve.png")
plt.savefig(loss_curve_path, dpi=1200)
plt.show()

plt.figure(figsize=(6,6))
plt.scatter(y_test, ensemble_preds, alpha=0.5)
plt.xlabel("Actual Kd(nM)")
plt.ylabel("Predicted Kd(nM)")
plt.title("Ensemble: Predicted vs Actual Kd")
pred_vs_act_path = os.path.join(working_dir, f"{final_plots_prefix}_predicted_vs_actual.png")
plt.savefig(pred_vs_act_path, dpi=1200)
plt.show()

residuals = y_test - ensemble_preds
plt.figure(figsize=(7,5))
plt.hist(residuals, bins=30, alpha=0.7)
plt.xlabel("Residual (Actual - Predicted)")
plt.ylabel("Frequency")
plt.title("Ensemble: Residual Distribution")
residual_plot_path = os.path.join(working_dir, f"{final_plots_prefix}_residual_plot.png")
plt.savefig(residual_plot_path, dpi=1200)
plt.show()

print("\nAll ensemble plots saved:")
print(loss_curve_path)
print(pred_vs_act_path)
print(residual_plot_path)
print("\nDone!")

In [None]:
######################################################
# 1. Configuration
######################################################
# Input CSVs
new_drugs_file    = os.path.join(working_dir, "New_drugs.csv")
new_protein_file  = os.path.join(working_dir, "New_Proteins.csv")

# auto-generate all pairs
new_pairs_file    = os.path.join(working_dir, "New_Pairs.csv")

# Output for intermediate embeddings
chemberta_drug_csv  = os.path.join(working_dir, "ChemBERTa_new_drug_embeddings.csv")
morgan_drug_csv     = os.path.join(working_dir, "Morgan_new_fingerprints.csv")
protT5_csv          = os.path.join(working_dir, "ProtT5_new_protein_embeddings.csv")
cnn_csv             = os.path.join(working_dir, "CNN_new_protein_embeddings.csv")

# Final fused features after multi-head attention
final_fused_csv     = os.path.join(working_dir, "New_Features_Fused.csv")

# Drug/protein dimension assumptions after merging
drug_dim = 2816
prot_dim = 1152

######################################################
# 2. Step A: Create New_Pairs.csv (Cartesian product)
######################################################
print("=== Creating all possible drug–protein pairs ===")
df_drug = pd.read_csv(new_drugs_file)     # expects columns: [DrugBank_ID, ...]
df_prot = pd.read_csv(new_protein_file)   # expects columns: [UniProt_ID, ...]

pairs_list = []
for d_id in df_drug["DrugBank_ID"]:
    for p_id in df_prot["UniProt_ID"]:
        pairs_list.append([d_id, p_id])

pairs_df = pd.DataFrame(pairs_list, columns=["DrugBank_ID", "UniProt_ID"])
pairs_df.to_csv(new_pairs_file, index=False)
print(f"New_Pairs.csv created with {len(pairs_df)} rows.")

######################################################
# 3. Step B: Extract Drug Embeddings
######################################################
print("=== Extracting drug embeddings (ChemBERTa + Morgan) ===")
# Force device to CPU explicitly for PyTorch
device = torch.device("cpu")

tokenizer_chem = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model_chem     = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device)

def tokenize_smiles(smiles, max_len=512, stride=256):
    tokens = tokenizer_chem(smiles, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
    if tokens["input_ids"].shape[1] > max_len:
        input_ids = tokens["input_ids"]
        chunks = []
        for i in range(0, input_ids.shape[1], stride):
            chunks.append(input_ids[:, i:i+max_len])
        return chunks
    return [tokens["input_ids"]]

def get_chemberta_embedding(smiles):
    chunk_list = tokenize_smiles(smiles)
    emb_list = []
    for chunk in chunk_list:
        chunk = chunk.to(device)
        with torch.no_grad():
            out = model_chem(input_ids=chunk)
        emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
        emb_list.append(emb)
    return np.mean(np.vstack(emb_list), axis=0)

def get_morgan_fingerprint(smiles, radius=2, n_bits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return np.zeros(n_bits)
    return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits))

drugs_df = pd.read_csv(new_drugs_file)
chemberta_list = []
morgan_list    = []

for _, row in tqdm(drugs_df.iterrows(), total=len(drugs_df)):
    d_id   = row["DrugBank_ID"]
    smiles = row["Ligand_SMILES"]

    chem_emb = get_chemberta_embedding(smiles)
    morgan   = get_morgan_fingerprint(smiles)

    chemberta_list.append([d_id] + chem_emb.tolist())
    morgan_list.append([d_id] + morgan.tolist())

chem_dim    = len(chemberta_list[0]) - 1
morgan_dim  = len(morgan_list[0])    - 1

chem_cols    = ["DrugBank_ID"] + [f"chem_{i}"   for i in range(chem_dim)]
morgan_cols  = ["DrugBank_ID"] + [f"morgan_{i}" for i in range(morgan_dim)]

chem_df   = pd.DataFrame(chemberta_list, columns=chem_cols)
morgan_df = pd.DataFrame(morgan_list,    columns=morgan_cols)

# Scale each
scalerA = StandardScaler()
chem_df.iloc[:, 1:] = scalerA.fit_transform(chem_df.iloc[:, 1:])

scalerB = StandardScaler()
morgan_df.iloc[:, 1:] = scalerB.fit_transform(morgan_df.iloc[:, 1:])

chem_df.to_csv(chemberta_drug_csv, index=False)
morgan_df.to_csv(morgan_drug_csv,   index=False)

######################################################
# 4. Step C: Extract Protein Embeddings
######################################################
print("=== Extracting protein embeddings (ProtT5 + CNN) ===")
tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model_t5     = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50").to(device)

def tokenize_protein(seq, max_len=512, stride=256):
    tokens = tokenizer_t5(seq, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
    if tokens["input_ids"].shape[1] > max_len:
        input_ids = tokens["input_ids"]
        chunks = []
        for i in range(0, input_ids.shape[1], stride):
            chunks.append(input_ids[:, i:i+max_len])
        return chunks
    return [tokens["input_ids"]]

def get_protT5_embedding(seq):
    chunk_list = tokenize_protein(seq)
    emb_list = []
    for chunk in chunk_list:
        chunk = chunk.to(device)
        with torch.no_grad():
            out = model_t5(input_ids=chunk)
        emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
        emb_list.append(emb)
    return np.mean(np.vstack(emb_list), axis=0)

# CNN for ASCII
cnn_model = Sequential([
    Conv1D(64, kernel_size=3, activation='relu', input_shape=(1000, 1)),
    GlobalAveragePooling1D(),
    Dense(128, activation='relu')
])

def get_cnn_embedding(seq):
    arr = np.array([ord(a) for a in seq[:1000]], dtype=np.float32)
    if len(arr) < 1000:
        arr = np.pad(arr, (0, 1000 - len(arr)), 'constant', constant_values=0)
    arr = arr.reshape(1, 1000, 1)
    emb = cnn_model.predict(arr, verbose=0)[0]
    return emb

prot_df = pd.read_csv(new_protein_file)
t5_list  = []
cnn_list = []

for _, row in tqdm(prot_df.iterrows(), total=len(prot_df)):
    p_id  = row["UniProt_ID"]
    seq   = row["Sequence"]

    t5_emb  = get_protT5_embedding(seq)
    c_emb   = get_cnn_embedding(seq)

    t5_list.append([p_id] + t5_emb.tolist())
    cnn_list.append([p_id] + c_emb.tolist())

t5_dim   = len(t5_list[0])  - 1
cnn_dim  = len(cnn_list[0]) - 1

t5_cols  = ["UniProt_ID"] + [f"t5_{i}"  for i in range(t5_dim)]
cnn_cols = ["UniProt_ID"] + [f"cnn_{i}" for i in range(cnn_dim)]

t5_new_df  = pd.DataFrame(t5_list,  columns=t5_cols)
cnn_new_df = pd.DataFrame(cnn_list, columns=cnn_cols)

scalerC = StandardScaler()
t5_new_df.iloc[:, 1:]  = scalerC.fit_transform(t5_new_df.iloc[:, 1:])

scalerD = StandardScaler()
cnn_new_df.iloc[:, 1:] = scalerD.fit_transform(cnn_new_df.iloc[:, 1:])

t5_new_df.to_csv(protT5_csv, index=False)
cnn_new_df.to_csv(cnn_csv,   index=False)

######################################################
# 5. Step D: Merge All + Create Final Feature Matrix
######################################################
print("=== Merging drug + protein embeddings with New_Pairs.csv ===")
pairs_df     = pd.read_csv(new_pairs_file)
drug_chem_df = pd.read_csv(chemberta_drug_csv)
drug_morg_df = pd.read_csv(morgan_drug_csv)

prot_t5_df  = pd.read_csv(protT5_csv)
prot_cnn_df = pd.read_csv(cnn_csv)

# Merge drug embeddings => shape (Ndrugs, 1 + dims)
drug_all = drug_chem_df.merge(drug_morg_df, on="DrugBank_ID", how="inner")
# Merge protein embeddings => shape (Nprots, 1 + dims)
prot_all = prot_t5_df.merge(prot_cnn_df, on="UniProt_ID", how="inner")

# Now merge with pairs => final shape: [DrugBank_ID, UniProt_ID, ~2816 drug cols, ~1152 prot cols]
merge1 = pairs_df.merge(drug_all, on="DrugBank_ID", how="inner")
merge2 = merge1.merge(prot_all, on="UniProt_ID", how="inner")
print("Final merged shape:", merge2.shape)

# Feature matrix
feat_cols = [c for c in merge2.columns if c not in ["DrugBank_ID", "UniProt_ID"]]
X_merged  = merge2[feat_cols].values  # shape (N, 3968)

######################################################
# 6. Step E: Multi-Head Attention Fusion
######################################################
print("=== Applying Multi-Head Attention across columns ===")

class MyMHA(nn.Module):
    def __init__(self, input_dim, num_heads=8):
        super(MyMHA, self).__init__()
        self.mha = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, batch_first=True)
    def forward(self, x):
        # x shape: (batch, input_dim)
        x = x.unsqueeze(1)  # (batch, 1, input_dim)
        out, _ = self.mha(x, x, x)
        return out.squeeze(1)

attn_layer = MyMHA(input_dim=X_merged.shape[1], num_heads=8)
with torch.no_grad():
    fused_tensor = torch.tensor(X_merged, dtype=torch.float32)
    fused_out    = attn_layer(fused_tensor).numpy()

print("Fused feature shape:", fused_out.shape)

# Save final to CSV
fused_cols = [f"feature_{i}" for i in range(fused_out.shape[1])]
fused_df   = pd.DataFrame(fused_out, columns=fused_cols)
# Keep IDs for reference
fused_df["DrugBank_ID"] = merge2["DrugBank_ID"].values
fused_df["UniProt_ID"]  = merge2["UniProt_ID"].values

fused_df.to_csv(final_fused_csv, index=False)
print(f"Saved final fused features to {final_fused_csv}")
print("Feature extraction + attention done!")

In [None]:
############################################################
# 1. Configuration
############################################################

# Fused features from the first script
final_fused_csv = os.path.join(working_dir, "New_Features_Fused.csv")

# Ensemble of 3 seeds
ensemble_models = [
    os.path.join(working_dir, "best_model_seed0.keras"),
    os.path.join(working_dir, "best_model_seed1.keras"),
    os.path.join(working_dir, "best_model_seed2.keras"),
]

# Final outputs
pred_csv        = os.path.join(working_dir, "New_Pairs_PredictedKd.csv")
top20_csv       = os.path.join(working_dir, "Top20_LowestKd.csv")
hist_plot       = os.path.join(working_dir, "Kd_Distribution_NewData.png")
bar_plot        = os.path.join(working_dir, "Top20_NewData.png")

# Drug/protein dimensions (these must match what the model expects)
drug_dim = 2816
prot_dim = 1152

############################################################
# 2. Load Fused Features
############################################################
fused_df = pd.read_csv(final_fused_csv)
feat_cols = [c for c in fused_df.columns if c not in ["DrugBank_ID", "UniProt_ID"]]
X_fused = fused_df[feat_cols].values

# Keep IDs for reference
id_drug = fused_df["DrugBank_ID"].values
id_prot = fused_df["UniProt_ID"].values

print("Loaded fused features shape:", X_fused.shape)

# Split into drug and protein inputs for the model
X_drug_fused = X_fused[:, :drug_dim]   # First 2816 columns → Drug features
X_prot_fused = X_fused[:, drug_dim:]   # Last 1152 columns → Protein features

############################################################
# 3. Ensemble Prediction
############################################################
pred_list = []
for model_path in ensemble_models:
    if not os.path.isfile(model_path):
        print(f"Warning: {model_path} not found. Skipping!")
        continue
    print(f"Loading {model_path}")
    model = load_model(model_path, compile=False)

    # Pass them as two separate inputs
    preds = model.predict({"drug_input": X_drug_fused, "protein_input": X_prot_fused}, verbose=0).flatten()
    pred_list.append(preds)

if len(pred_list) == 0:
    print("No models loaded; exiting.")
    exit()

ensemble_preds = np.mean(pred_list, axis=0)

# If you trained with log10(Kd+1), invert:
kd_preds = 10**ensemble_preds - 1

############################################################
# 4. Save Full Predictions
############################################################
results_df = pd.DataFrame({
    "DrugBank_ID": id_drug,
    "UniProt_ID":  id_prot,
    "Predicted_Kd": kd_preds
})
results_df.to_csv(pred_csv, index=False)
print(f"Saved predictions to {pred_csv}")

############################################################
# 5. Top 20 + Plots
############################################################
top20_df = results_df.sort_values("Predicted_Kd", ascending=True).head(20)
top20_df.to_csv(top20_csv, index=False)
print(f"Saved top-20 to {top20_csv}")

# Distribution plot
plt.figure(figsize=(7,5))
plt.hist(kd_preds, bins=30, alpha=0.7)
plt.xlabel("Predicted Kd (nM)")
plt.ylabel("Frequency")
plt.title("Predicted Kd Distribution (New Data)")
plt.savefig(hist_plot, dpi=1200)
plt.show()

# Top 20 bar chart
plt.figure(figsize=(10,6))
plt.barh(np.arange(len(top20_df)), top20_df["Predicted_Kd"], align="center")
labels = [f"{r.DrugBank_ID}-{r.UniProt_ID}" for _, r in top20_df.iterrows()]
plt.yticks(np.arange(len(top20_df)), labels)
plt.xlabel("Predicted Kd (nM)")
plt.title("Top 20 Lowest Predicted Kd (New Data)")
plt.gca().invert_yaxis()
plt.savefig(bar_plot, dpi=1200)
plt.show()

print("Plots saved:")
print(" ", hist_plot)
print(" ", bar_plot)
print("\nPrediction complete!")