In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!sudo apt install build-essential

In [None]:
%pip install -r "/content/drive/MyDrive/Protein-binding/requirements.txt"
%pip install datasets mdtraj dssp
%pip install numpy --no-cache-dir
%pip install pandas==2.2.0
%pip install pykan

In [None]:
import torch
import ast
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.nn.functional as F
import warnings
import mdtraj as md
import esm
import gc
import pickle
import torch.optim as optim

from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_add_pool
from Bio import SeqIO
from Bio import PDB
from Bio.PDB.DSSP import DSSP
from Bio.PDB import Selection

from Bio.PDB.Polypeptide import is_aa
from Bio.SeqUtils import seq1
from torch_geometric.data import Data
from transformers import (AutoModelForTokenClassification, AutoTokenizer,
                          AutoModelForMaskedLM, DataCollatorForTokenClassification,
                           EsmForMaskedLM, EsmTokenizer, EsmModel, EsmForTokenClassification,
                           TrainingArguments, Trainer, TrainerCallback
                        )
from kan import KAN
from transformers.trainer_callback import ProgressCallback
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             matthews_corrcoef, roc_auc_score)
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from pprint import pprint
from datasets import Dataset
from datetime import datetime
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
from glob import glob
from loguru import logger

warnings.filterwarnings("ignore", message="Ignoring unrecognized record 'END'")

### Preparing train-test dataset

In [None]:
initial_train_df = pd.read_csv("/content/drive/MyDrive/Protein-binding/data/development_set/full_grouped_train_binding_sites_df.csv")
initial_train_df['binding_sites'] = initial_train_df['binding_sites'].apply(ast.literal_eval)
initial_train_df['any_ligand_binding_sites'] = initial_train_df['any_ligand_binding_sites'].apply(ast.literal_eval)
initial_train_df['metal_binding'] = initial_train_df['metal_binding'].apply(ast.literal_eval)
initial_train_df['small_binding'] = initial_train_df['small_binding'].apply(ast.literal_eval)
initial_train_df['nuclear_binding'] = initial_train_df['nuclear_binding'].apply(ast.literal_eval)

In [None]:
np.random.seed(42)
excluded_protein_id = ['Q9NZV6']

train_df = initial_train_df[~initial_train_df['prot_id'].isin(excluded_protein_id)]

In [None]:
test_df = pd.read_csv("/content/drive/MyDrive/Protein-binding/data/development_set/full_grouped_test_binding_sites_df.csv")
# test_df = pd.read_csv("/content/drive/MyDrive/Protein-binding/data/independent_set/grouped_test_46_new_binding_sites.csv")

test_df['binding_sites'] = test_df['binding_sites'].apply(ast.literal_eval)
test_df['any_ligand_binding_sites'] = test_df['any_ligand_binding_sites'].apply(ast.literal_eval)
test_df['metal_binding'] = test_df['metal_binding'].apply(ast.literal_eval)
test_df['small_binding'] = test_df['small_binding'].apply(ast.literal_eval)
test_df['nuclear_binding'] = test_df['nuclear_binding'].apply(ast.literal_eval)

In [None]:
# Initial sequences
test_seq = test_df['sequence'].tolist()
test_labels = test_df['metal_binding'].tolist() # Edit labels

train_seq = train_df['sequence'].tolist()
train_labels = train_df['metal_binding'].tolist() # Edit labels

saved_model_name = '/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_metal_binding_20Apr_model.pth'
device = "cuda:0" if torch.cuda.is_available() else "cpu"

### Tokenization and get embeddings from ESM-2 language model

In [None]:
pretrained_model = "facebook/esm2_t33_650M_UR50D"

tokenizer = EsmTokenizer.from_pretrained(pretrained_model)
max_sequence_length = 1000

train_tokenized = tokenizer(train_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
# test_tokenized = tokenizer(test_seq, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

In [None]:
def get_embeddings_list(tokenized_dataset, batch_size, model_name="facebook/esm2_t33_650M_UR50D", device="cuda",
                       embedding_mode="multi_layer", num_layers=4, return_hidden_states=True, return_attentions=False):
    """
    Extract ESM-2 embeddings with multi-layer aggregation or attention-guided pooling.

    Args:
        tokenized_dataset: Dict with 'input_ids', 'attention_mask', and optionally 'sequence_id'.
        batch_size: Number of sequences per batch.
        model_name: Pretrained ESM-2 model name.
        device: Device to run model on ('cuda' or 'cpu').
        embedding_mode: 'multi_layer' for layer aggregation, 'attention_guided' for attention pooling.
        num_layers: Number of layers to aggregate (for multi_layer mode).
        return_hidden_states: Whether to return hidden states.
        return_attentions: Whether to return attention weights (required for attention_guided).

    Returns:
        List of dicts with sequence_id and per-residue embeddings.
    """
    # Validate inputs
    if embedding_mode == "attention_guided" and not return_attentions:
        raise ValueError("Attention-guided pooling requires return_attentions=True")

    # Initialize model
    model = EsmModel.from_pretrained(model_name).to(device)
    model.eval()

    # Extract input tensors
    ids_list = tokenized_dataset['input_ids'].to(device)  # Shape: [num_sequences, max_length]
    attention_mask_list = tokenized_dataset['attention_mask'].to(device)
    sequence_ids = tokenized_dataset.get('sequence_id', list(range(len(ids_list))))

    num_batches = (len(ids_list) + batch_size - 1) // batch_size
    embeddings_list = []

    for i in tqdm(range(num_batches), total=num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(ids_list))

        batch_ids = ids_list[start_idx:end_idx]  # Shape: [batch_size, max_length]
        batch_attention_mask = attention_mask_list[start_idx:end_idx]
        batch_seq_ids = sequence_ids[start_idx:end_idx]

        with torch.no_grad():
            outputs = model(
                input_ids=batch_ids,
                attention_mask=batch_attention_mask,
                output_hidden_states=return_hidden_states,
                output_attentions=(embedding_mode == "attention_guided")
            )

            if embedding_mode == "multi_layer":
                # Aggregate last N layers
                hidden_states = outputs.hidden_states[-num_layers:]  # List of (batch_size, seq_len, embedding_dim)
                aggregated_embeddings = torch.stack(hidden_states, dim=0).mean(dim=0)  # Average across layers
                # Alternative: Concatenate layers (uncomment to use)
                # aggregated_embeddings = torch.cat(hidden_states, dim=-1)  # Shape: (batch_size, seq_len, num_layers * embedding_dim)

                # Process each sequence
                for j in range(aggregated_embeddings.shape[0]):
                    mask = batch_attention_mask[j].bool()
                    seq_embeddings = aggregated_embeddings[j][mask][1:-1]  # Remove <cls>, <eos>
                    embeddings_list.append({
                        'sequence_id': batch_seq_ids[j],
                        'embeddings': seq_embeddings.cpu().numpy()
                    })

            elif embedding_mode == "attention_guided":
                # Attention-guided pooling
                hidden_states = outputs.hidden_states[-1]  # Last layer: (batch_size, seq_len, embedding_dim)
                attentions = outputs.attentions[-1]  # Last layer attention: (batch_size, num_heads, seq_len, seq_len)

                # Process each sequence individually to handle variable lengths
                for j in range(hidden_states.shape[0]):
                    mask = batch_attention_mask[j].bool()
                    seq_hidden = hidden_states[j][mask]  # (valid_len, embedding_dim)
                    seq_attention = attentions[j][:, mask][:, :, mask]  # (num_heads, valid_len, valid_len)
                    # Average attention across heads
                    attention_weights = seq_attention.mean(dim=0)  # (valid_len, valid_len)
                    attention_weights = attention_weights / (attention_weights.sum(dim=-1, keepdim=True) + 1e-8)
                    # Weighted embedding
                    weighted_embedding = torch.matmul(attention_weights, seq_hidden)  # (valid_len, embedding_dim)
                    seq_embeddings = weighted_embedding[1:-1]  # Remove <cls>, <eos>
                    embeddings_list.append({
                        'sequence_id': batch_seq_ids[j],
                        'embeddings': seq_embeddings.cpu().numpy()
                    })

            else:
                raise ValueError("embedding_mode must be 'multi_layer' or 'attention_guided'")

    return embeddings_list

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
train_embeddings = get_embeddings_list(train_tokenized, batch_size = 8, model_name = pretrained_model, device= "cuda",
                                       embedding_mode = "multi_layer", num_layers=2, return_hidden_states = True, return_attentions = False)

# test_embeddings = get_embeddings_list(test_tokenized, batch_size = 8, model_name = pretrained_model, device = "cuda",
#                                        embedding_mode = "multi_layer", num_layers=2, return_hidden_states = True, return_attentions = False)

In [None]:
# with open('/content/drive/MyDrive/Protein-binding/metal_train_embeddings.pkl', 'wb') as f:
#        pickle.dump(train_embeddings, f)

# with open('/content/drive/MyDrive/Protein-binding/test_embeddings.pkl', 'wb') as f:
#     pickle.dump(test_embeddings, f)

In [None]:
# with open('/content/drive/MyDrive/Protein-binding/train_embeddings.pkl', 'rb') as f:
#        train_embeddings = pickle.load(f)

# with open('/content/drive/MyDrive/Protein-binding/test_embeddings.pkl', 'rb') as f:
#     test_embeddings = pickle.load(f)

### Features extraction

In [None]:
def get_structure(prot_id, pdb_file):
    parser = PDB.PDBParser()
    structure = parser.get_structure(prot_id, pdb_file)
    return structure

def extract_coordinates(structure):
    # Extract Cα coordinates (central carbon atom)
    coordinates = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if "CA" in residue:  # Get Cα atom
                    ca_atom = residue["CA"]
                    coord = ca_atom.get_coord()  # Returns numpy array [x, y, z]
                    coordinates.append(coord)

    return coordinates

def get_amino_acid_types(structure):
    amino_acids = []
    for model in structure:
        for chain in model:
            for residue in chain:
                # Get residue name (3-letter code)
                if not is_aa(residue):
                    continue
                res_name = residue.get_resname()
                # Convert to 1-letter code if needed
                one_letter = seq1(res_name)
                amino_acids.append(one_letter)
    return amino_acids

def get_secondary_structure_mdtraj(pdb_file, sequence_length):
    """
    Extract secondary structure features from a .PDB file using mdtraj.

    Args:
        pdb_file (str): Path to the .PDB file.
        sequence_length (int): Length of the protein sequence.

    Returns:
        dict: A dictionary with:
            - 'raw': List of raw secondary structure codes.
            - 'one_hot': Tensor of one-hot encoded secondary structure.
    """
    # Load the .PDB file with mdtraj
    traj = md.load(pdb_file)

    # Compute secondary structure
    ss = md.compute_dssp(traj)[0]  # Returns codes like 'H', 'E', 'C', '-'('NA')

    # One-hot encode
    ss_onehot = []
    for code in ss:
        ss_onehot.append([
            1 if code == 'H' else 0,
            1 if code == 'E' else 0,
            1 if code == 'T' else 0,
            1 if code == 'C' or code == 'NA' else 0
        ])

    ss_onehot = torch.tensor(ss_onehot, dtype=torch.float32)

    # Pad with zeros to match sequence length
    if ss_onehot.shape[0] < sequence_length:
        padding = torch.zeros((sequence_length - ss_onehot.shape[0], ss_onehot.shape[1]), dtype=torch.float32)
        ss_onehot = torch.cat([ss_onehot, padding], dim=0)

    return {
        "raw": ss.tolist(),
        "one_hot": ss_onehot
    }

def calculate_residue_distances(coordinates):
    """
    Calculate pairwise distances between residues in a protein structure.
    Args: Coordinates (list): List of residue atom's coordinates.
    Returns:
        np.ndarray: 2D array of pairwise distances.
    """
    num_residues = len(coordinates)
    distances = np.zeros((num_residues, num_residues))

    for i in range(num_residues):
        for j in range(i + 1, num_residues):
            dist = np.linalg.norm(coordinates[i] - coordinates[j])
            distances[i, j] = distances[j, i] = dist

    return distances

def get_dihedral_angles(pdb_file, sequence_length):
    traj = md.load(pdb_file)
    # Compute phi and psi angles
    phi_indices, phi_angles = md.compute_phi(traj)
    psi_indices, psi_angles = md.compute_psi(traj)

    # Convert to degrees and create tensors
    phi_angles = torch.tensor(np.degrees(phi_angles[0]), dtype=torch.float32).unsqueeze(1)  # Shape: [num_residues-1, 1]
    psi_angles = torch.tensor(np.degrees(psi_angles[0]), dtype=torch.float32).unsqueeze(1)  # Shape: [num_residues-1, 1]

    # Pad with zeros to match sequence length
    while phi_angles.shape[0] < sequence_length:
        phi_angles = torch.cat([phi_angles, torch.zeros(1, 1, dtype=torch.float32)], dim=0)
    while psi_angles.shape[0] < sequence_length:
        psi_angles = torch.cat([psi_angles, torch.zeros(1, 1, dtype=torch.float32)], dim=0)

    return phi_angles, psi_angles


def get_b_factors(structure, sequence_length):
    b_factors = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if "CA" in residue:
                    ca_atom = residue["CA"]
                    b_factor = ca_atom.get_bfactor()
                    b_factors.append(b_factor)

    # Pad with zeros if b_factors length is less than sequence length
    while len(b_factors) < sequence_length:
        b_factors.append(0.0)

    return torch.tensor(b_factors, dtype=torch.float32).unsqueeze(1)



In [None]:
# Compute RSA values for amino acids
MAX_SASA = {
    'ALA': 129.0, 'ARG': 274.0, 'ASN': 195.0, 'ASP': 193.0, 'CYS': 167.0,
    'GLU': 223.0, 'GLN': 225.0, 'GLY': 104.0, 'HIS': 224.0, 'ILE': 197.0,
    'LEU': 201.0, 'LYS': 236.0, 'MET': 224.0, 'PHE': 240.0, 'PRO': 159.0,
    'SER': 155.0, 'THR': 172.0, 'TRP': 285.0, 'TYR': 263.0, 'VAL': 174.0
}

def compute_rsa(structure_file, chain_id=None):
    """
    Compute RSA for each residue in a protein structure.

    Args:
        structure_file (str): Path to the protein structure file (e.g., PDB file).
        chain_id (str, optional): Chain ID to analyze (if None, uses the first chain).

    Returns:
        list: List of (residue_name, residue_id, rsa) tuples.
    """
    traj = md.load(structure_file)

    if chain_id:
        chain = next(c for c in traj.topology.chains if c.chain_id == chain_id)
        traj = traj.atom_slice([atom.index for atom in traj.topology.atoms if atom.residue.chain == chain])

    sasa = md.shrake_rupley(traj, mode='residue')[0]  # Shape: [n_residues]
    rsa_values = []
    for i, residue in enumerate(traj.topology.residues):
        if not residue.is_protein:
            continue
        res_name = residue.name
        res_id = residue.resSeq
        residue_sasa = sasa[i]
        max_sasa = MAX_SASA.get(res_name, 0.0)
        if max_sasa == 0.0:
            print(f"Warning: No max SASA value for residue {res_name}. Skipping.")
            continue
        rsa = residue_sasa / max_sasa if max_sasa > 0 else 0.0
        rsa = min(rsa, 1.0)
        rsa_values.append((res_name, res_id, rsa))
    # return rsa_values
    return torch.tensor([rsa for _, _, rsa in rsa_values], dtype=torch.float32).unsqueeze(1)


def calculate_depth(structure, sequence_length):
    depth_per_residue = {}
    atoms = Selection.unfold_entities(structure, "A")  # Get all atoms
    for residue in Selection.unfold_entities(structure, "R"):
        res_id = (residue.get_parent().id, residue.id[1])
        min_dist = float("inf")
        for atom in residue:
            for other_atom in atoms:
                if other_atom.get_parent() != residue:
                    dist = np.linalg.norm(atom.coord - other_atom.coord)
                    min_dist = min(min_dist, dist)
        depth_per_residue[res_id] = min_dist

    # Assign default value for missing residues
    depth_values = [0.0] * sequence_length  # Or use average depth if available
    for i, residue in enumerate(structure.get_residues()):
        res_id = (residue.get_parent().id, residue.id[1])
        if res_id in depth_per_residue:
            depth_values[i] = depth_per_residue[res_id]

    return torch.tensor(depth_values, dtype=torch.float32).unsqueeze(1)
    # return depth_values



In [None]:
def fuse_features(esm2_embeddings, ss_onehot, phi_angles,
                  psi_angles, b_factors):
    """
    Fuse ESM-2 embeddings with structural features.

    Args:
        esm2_embeddings (torch.Tensor): Shape [num_residues, 1280]
        ss_onehot (torch.Tensor): Shape [num_residues, 4]
        phi_angles (torch.Tensor): Shape [num_residues, 1]
        psi_angles (torch.Tensor): Shape [num_residues, 1]
        b_factors (torch.Tensor): Shape [num_residues, 1]

    Returns:
        torch.Tensor: Fused node features, shape [num_residues, 1287] (1280 + 4 + 1 + 1 + 1)
    """
    # Ensure all features have the same length
    num_residues = esm2_embeddings.shape[0]

    # Print shapes for debugging
    print(f"esm2_embeddings shape: {esm2_embeddings.shape}")
    print(f"ss_onehot shape: {ss_onehot.shape}")
    print(f"phi_angles shape: {phi_angles.shape}")
    print(f"psi_angles shape: {psi_angles.shape}")
    print(f"b_factors shape: {b_factors.shape}")
    # print(f"residue_depths shape: {residue_depths.shape}")

    # Adjust ss_onehot, phi_angles, psi_angles, b_factors if needed
    min_length = min(num_residues, ss_onehot.shape[0], phi_angles.shape[0],
                     psi_angles.shape[0], b_factors.shape[0])

    esm2_embeddings = esm2_embeddings[:min_length]
    ss_onehot = ss_onehot[:min_length]
    phi_angles = phi_angles[:min_length]
    psi_angles = psi_angles[:min_length]
    b_factors = b_factors[:min_length]

    # Concatenate all features
    node_features = torch.cat([
        esm2_embeddings,  # [num_residues, 1280]
        ss_onehot,        # [num_residues, 4]
        phi_angles,       # [num_residues, 1]
        psi_angles,       # [num_residues, 1]
        b_factors        # [num_residues, 1]
    ], dim=1)  # Shape: [num_residues, 1287]

    return node_features

In [None]:
def create_edge_features(distances, threshold=8.0):
    num_residues = distances.shape[0]
    contact_map = (distances < threshold) & (distances > 0)

    # Restrict indices to be within the valid node range (0 to num_residues - 1)
    edge_index = torch.nonzero(torch.tensor(contact_map, dtype=torch.bool), as_tuple=False).t()
    edge_index = edge_index[:, :num_residues]  # Ensure indices are within bounds

    # Vectorized edge attribute computation
    src, dst = edge_index[0], edge_index[1]
    dists = torch.tensor(distances[src, dst], dtype=torch.float32)
    seq_seps = torch.abs(src - dst).float()

    # Normalize features
    dists = dists / threshold  # Scale to 0–1
    seq_seps = seq_seps / num_residues  # Scale to 0–1

    edge_attr = torch.stack([dists, seq_seps], dim=1)  # Shape: [num_edges, 2]

    return edge_index, edge_attr


In [None]:
def create_graph_data(node_features, edge_index, edge_attr, labels):
    """
    Create a PyTorch Geometric Data object for the protein graph.

    Args:
        node_features (torch.Tensor): Shape [num_residues, 1288], fused node features.
        edge_index (torch.Tensor): Shape [2, num_edges], indices of connected nodes.
        edge_attr (torch.Tensor): Shape [num_edges, 2], edge features.
        labels (torch.Tensor): Shape [num_residues], binary labels (0 or 1) for binding sites.

    Returns:
        Data: PyTorch Geometric Data object.
    """
    return Data(
        x=node_features,      # Node features
        edge_index=edge_index,  # Edge indices
        edge_attr=edge_attr,  # Edge features
        y=labels              # Labels for binding site prediction
    )

### Test feature extraction

In [None]:
sample_idx = 540
sample_prot_id = train_df.iloc[sample_idx]['prot_id']
sample_sequence = train_df.iloc[sample_idx]['sequence']
sample_sequence_len = len(train_df.iloc[sample_idx]['sequence'])
sample_labels = train_df.iloc[sample_idx]['any_ligand_binding_sites']
sample_train_embeddings = train_embeddings[sample_idx]

sample_structure_file = f"/content/drive/MyDrive/Protein-binding/esmFold_pdb_files/{sample_prot_id}.pdb"

sample_structure = get_structure(sample_prot_id, sample_structure_file)
sample_coordinates = extract_coordinates(sample_structure)
sample_distances = calculate_residue_distances(sample_coordinates)
sample_ss_one_hot = get_secondary_structure_mdtraj(sample_structure_file, sample_sequence_len)['one_hot'].to(device)
sample_phi_angles, sample_psi_angles = get_dihedral_angles(sample_structure_file, sample_sequence_len)
sample_b_factors = get_b_factors(sample_structure, sample_sequence_len).to(device)
sample_rsa_values = compute_rsa(sample_structure_file)
min_length = 1

# print(sample_rsa_values)
# rsa_dict = {res_id: rsa for _, res_id, rsa in sample_rsa_values}
# sample_rsa_tensor = torch.zeros((min_length, 1), dtype=torch.float)
# print(sample_rsa_tensor)

# for i, res_id in enumerate(residue_ids[:min_length]):
#     sample_rsa_tensor[i, 0] = rsa_dict.get(res_id, 0.0)

# print(sample_prot_id)
# print(sample_sequence)
# print(len(sample_sequence))
# print(sample_b_factors.shape)
# print(sample_ss_one_hot.shape)
# print(sample_phi_angles.shape)
# print(sample_psi_angles.shape)
# print(sample_rsa_tensor.shape)

In [None]:
# edge_index, edge_attr = create_edge_features(sample_distances)
# node_features = fuse_features(sample_train_embeddings.to(device), sample_ss_one_hot,
#                               sample_phi_angles.to(device), sample_psi_angles.to(device), sample_b_factors)
# labels = torch.tensor(sample_labels, dtype = torch.long).to(device)

In [None]:
# graph_data = create_graph_data(node_features, edge_index, edge_attr, labels)

In [None]:
# print(len(train_df))
# print(len(train_labels))
# print(len(train_seq))

### Prepare data tensors

In [None]:
def get_graph_data(df, embeddings, device):
    graph_data_list = []

    for idx in tqdm(range(len(df))):
        prot_id = df.iloc[idx]['prot_id']
        sequence = df.iloc[idx]['sequence']
        sequence_len = len(df.iloc[idx]['sequence'])
        labels = df.iloc[idx]['any_ligand_binding_sites']
        embedding = torch.tensor(embeddings[idx]['embeddings'])

        try:
            structure_file = f"/content/drive/MyDrive/Protein-binding/esmFold_pdb_files/{prot_id}.pdb"
            structure = get_structure(prot_id, structure_file)
            coordinates = extract_coordinates(structure)
            distances = calculate_residue_distances(coordinates)
            # depths = calculate_depth(structure, sequence_len)
            # rsa_values = compute_rsa(structure_file)
            ss_one_hot = get_secondary_structure_mdtraj(structure_file, sequence_len)['one_hot'].to(device)
            phi_angles, psi_angles = get_dihedral_angles(structure_file, sequence_len)
            b_factors = get_b_factors(structure, sequence_len).to(device)
        except Exception as e:
            # print(f"PDB file of {prot_id} ID cannot be found")
            print(f"Error: {e}")

        edge_index, edge_attr = create_edge_features(distances)
        node_features = fuse_features(embedding.to(device), ss_one_hot,
                                      phi_angles.to(device), psi_angles.to(device),
                                      b_factors)
        labels = torch.tensor(labels, dtype = torch.long).to(device)
        print(f"Shape of labels: {labels.shape}")
        graph_data = create_graph_data(node_features, edge_index, edge_attr, labels)

        graph_data_list.append(graph_data)

    return graph_data_list

In [None]:
train_graphs_data = get_graph_data(train_df, train_embeddings, device)
# test_graphs_data = get_graph_data(test_df, test_embeddings, device)

### Create baseline GNN

In [None]:
class BindingSiteGCN(nn.Module):
    def __init__(self, node_dim, edge_dim=2, hidden_dim=512):
        """
        A baseline GCN for binding site prediction.

        Args:
            node_dim (int): Dimension of node features (1288 in your case).
            edge_dim (int): Dimension of edge features (2 in your case).
            hidden_dim (int): Hidden dimension for GCN layers.
        """
        super(BindingSiteGCN, self).__init__()
        # GCN layers with edge features
        self.conv1 = GCNConv(node_dim, hidden_dim)  # Remove edge_dim
        self.conv2 = GCNConv(hidden_dim, 256)  # Remove edge_dim
        self.conv3 = GCNConv(256, 128)  # Remove edge_dim

        # Linear layer to process edge features
        self.edge_lin = nn.Linear(edge_dim, hidden_dim)

        # Final classifier
        # self.fc = nn.Linear(128, 2)  # 2 classes: 0 (non-binding), 1 (binding)
        self.pre_fc = nn.Linear(128, 16)  # Reduce dimension before KAN
        self.fc = KAN([16, 8, 2], grid=2, k=2)

        # Activation and dropout
        self.activation_func = nn.LeakyReLU(negative_slope=0.1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        # Get the number of nodes in the current batch

        # Process edge features separately
        edge_attr = self.edge_lin(edge_attr)

        # GCN layers
        # Pass num_nodes to gcn_norm
        x = self.conv1(x, edge_index)  # Removed edge_attr here, Added num_nodes
        x = self.activation_func(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index) # Removed edge_attr here, Added num_nodes
        x = self.activation_func(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index) # Removed edge_attr here, Added num_nodes
        x = self.activation_func(x)

        # Final classification
        x = self.fc(x)  # Shape: [num_residues, 2]

        return x  # Logits for each residue

In [None]:
class BindingSiteGAT(nn.Module):
    def __init__(self, node_dim=1287, hidden_dim=512, heads=4):
        """
        A GAT-based model for binding site prediction with a KAN classifier.

        Args:
            node_dim (int): Dimension of node features (e.g., 1287 for ESM-2 + 7 features).
            hidden_dim (int): Hidden dimension for GAT layers.
            heads (int): Number of attention heads in GAT layers.
        """
        super(BindingSiteGAT, self).__init__()
        # GAT layers with multi-head attention
        self.conv1 = GATConv(node_dim, hidden_dim, heads=heads, concat=True)
        self.conv2 = GATConv(hidden_dim * heads, 256, heads=heads, concat=True)
        self.conv3 = GATConv(256 * heads, 128, heads=1, concat=False)  # Single head for final layer

        # KAN layer for classification
        self.pre_fc = nn.Linear(128, 16)  # Reduce dimension before KAN
        self.fc = KAN([16, 8, 2], grid=2, k=2)  # Smaller KAN with reduced spline complexity

        # Activation and dropout
        self.activation_func = nn.LeakyReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  # edge_attr is unused in GATConv

        # GAT layers
        x = self.conv1(x, edge_index)
        x = self.activation_func(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index)
        x = self.activation_func(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index)
        x = self.activation_func(x)

        # KAN classifier
        x = self.fc(x)  # Shape: [num_residues, 2]

        return x  # Logits for each residue

In [None]:
train_graphs, val_graphs = train_test_split(
    train_graphs_data,
    test_size=0.05,
    random_state=42,
    shuffle=True
)

print(f"Number of training graphs: {len(train_graphs)}")
print(f"Number of validation graphs: {len(val_graphs)}")
# print(f"Number of testing graphs: {len(test_graphs_data)}")

In [None]:
for idx, train_graph in enumerate(train_graphs):
    if train_graph.x.shape[0] != train_graph.y.shape[0]:
        print(f"Abnormal at index: {idx}")
        print(f"Train graph's input shape: {train_graph['x'].shape[0]} and labels shape: {train_graph['y'].shape[0]}")

In [None]:
for idx, val_graph in enumerate(val_graphs):
    if val_graph.x.shape[0] != val_graph.y.shape[0]:
        print(f"Abnormal at index: {idx}")
        print(f"Train graph's input shape: {val_graph['x'].shape[0]} and labels shape: {val_graph['y'].shape[0]}")

### Model Training

In [None]:
# Create DataLoaders
batch_size = 16

train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)
# test_loader = DataLoader(test_graphs_data, batch_size=batch_size, shuffle=False)

### Customized loss function

In [None]:
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, pos_weight):
        super().__init__()
        self.pos_weight = pos_weight

    def forward(self, logits, labels):
        # Handle class imbalance with weighted loss
        weight = torch.tensor([1.0, self.pos_weight]).to(logits.device)  # [weight for class 0, weight for class 1]
        loss_fct = nn.CrossEntropyLoss(weight=weight)
        return loss_fct(logits, labels)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha  # Weight for the positive class
        self.gamma = gamma  # Focusing parameter
        self.reduction = reduction

    def forward(self, logits, labels):
        # Compute the cross-entropy loss (without reduction)
        ce_loss = nn.functional.cross_entropy(logits, labels, reduction='none')

        # Compute the probability of the true class
        probs = torch.softmax(logits, dim=-1)
        true_probs = probs[torch.arange(probs.size(0), device=probs.device), labels]

        # Compute the focal loss term: (1 - p_t)^gamma
        focal_term = (1 - true_probs) ** self.gamma

        # Apply the alpha weighting
        alpha_weight = torch.where(labels == 1, self.alpha, 1.0 - self.alpha).to(logits.device)

        # Compute the focal loss
        loss = alpha_weight * focal_term * ce_loss

        # Apply reduction
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class PositionAwareLoss(nn.Module):
    def __init__(self, pos_weight, position_weight, alpha=0.25, gamma=2.0):
        super().__init__()
        self.weighted_ce = WeightedCrossEntropyLoss(pos_weight)
        self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma, reduction='mean')
        self.position_weight = position_weight

    def forward(self, logits, labels, batch=None):
        # Compute the base loss (weighted cross-entropy + focal loss)
        ce_loss = self.weighted_ce(logits, labels)
        focal_loss = self.focal_loss(logits, labels)
        base_loss = ce_loss + focal_loss

        # Position-aware component
        probs = torch.softmax(logits, dim=-1)[:, 1]  # Get binding probabilities [total_num_nodes]
        position_loss = torch.tensor(0.0).to(logits.device)

        # Since we're using a GNN, we need to account for the graph structure
        # batch.batch indicates which nodes belong to which graph
        if batch is not None and batch.num_graphs == 1:  # Single graph per batch
            num_nodes = batch.num_nodes
            # Penalize offset predictions by checking neighboring nodes
            for i in range(1, num_nodes - 1):
                # Encourage predictions to match true binding site positions
                if labels[i] == 1 or labels[i-1] == 1 or labels[i+1] == 1:
                    position_loss += torch.abs(probs[i] - (labels[i] == 1).float())

        return base_loss + self.position_weight * position_loss

In [None]:
input_train_embedding_dim = torch.tensor(train_embeddings[0]['embeddings']).shape[1] + 7

model = BindingSiteGCN(node_dim = input_train_embedding_dim, edge_dim=2, hidden_dim=512).to(device)
# model = BindingSiteGAT(node_dim = input_train_embedding_dim, hidden_dim=512, heads=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# class_weights = torch.tensor([0.2, 0.8]).to(device)
pos_weight = 5.0
position_weight = 1
focal_loss = FocalLoss(alpha=0.25, gamma=2.0)
# criterion = WeightedCrossEntropyLoss(pos_weight=pos_weight)
criterion = PositionAwareLoss(pos_weight=pos_weight, position_weight=position_weight, alpha=0.25, gamma=2.0)

In [None]:
def evaluate(model, data_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            out = model(batch)

            # Debug: Print shapes
            # print(f"Batch num_graphs: {batch.num_graphs}")
            # print(f"Batch num_nodes: {batch.num_nodes}")
            # print(f"Out shape: {out.shape}")
            # print(f"Labels shape: {batch.y.shape}")

            probs = torch.softmax(out, dim=1)[:, 1]
            preds = torch.argmax(out, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    auc = roc_auc_score(all_labels, all_probs)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc": auc,
        "mcc": mcc
    }

In [None]:
gc.collect()  # Force garbage collection to potentially free up memory
# torch.cuda.empty_cache()  # Empty the CUDA cache
# model = model.half()  # Convert model parameters to half-precision
# torch.cuda.synchronize()

In [None]:
num_epochs = 50
best_val_f1 = 0
best_model_state = None

model.train()
for epoch in range(num_epochs):
    total_train_loss = 0
    model.train()
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        # print(f"Batch shape: {batch}")
        out = model(batch)
        # print(f"Out shape: {out.shape}")
        # print(f"Labels shape: {batch.y.shape}")
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    val_metrics = evaluate(model, eval_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Training Loss: {avg_train_loss:.4f}")
    print(f"Validation Precision: {val_metrics['precision']:.4f}")
    print(f"Validation Recall: {val_metrics['recall']:.4f}")
    print(f"Validation F1-Score: {val_metrics['f1']:.4f}")
    print(f"Validation AUC-ROC: {val_metrics['auc']:.4f}")

    if val_metrics['f1'] > best_val_f1:
        best_val_f1 = val_metrics['f1']
        best_model_state = model.state_dict()
        print("Best validation F1-score improved! Saving model state.")

In [None]:
if best_model_state is not None:
    model.load_state_dict(best_model_state)

with open('/content/drive/MyDrive/Protein-binding/data/test300_loader.pkl', 'rb') as f:
    saved_test_loader = pickle.load(f)

model.eval()
test_metrics = evaluate(model, saved_test_loader, device)

In [None]:
print(test_metrics)

In [None]:
print(saved_model_name)

In [None]:
torch.save(best_model_state, saved_model_name)

In [None]:
# with open('/content/drive/MyDrive/Protein-binding/data/test300_loader.pkl', 'wb') as f:
#     pickle.dump(test_loader, f)

### Save and load model weights

In [None]:
def create_labels(data_loader):
    all_labels = []
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            all_labels.extend(batch.y.cpu().numpy())
    return all_labels

In [None]:
def inference(model_name, data_loader):
    all_preds = []
    model_instance = BindingSiteGCN(node_dim=input_train_embedding_dim, edge_dim=2, hidden_dim=512).to(device)
    model_instance.load_state_dict(torch.load(model_name))
    model_instance.eval()
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            out = model_instance(batch)

            probs = torch.softmax(out, dim=1)[:, 1]
            preds = torch.argmax(out, dim=1)

            all_preds.extend(preds.cpu().numpy())

    return all_preds

nuclear_model_name = "/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_nuclear_binding_20Apr_model.pth"
nuclear_preds = inference(nuclear_model_name, saved_test_loader)

metal_model_name = "/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_metal_binding_20Apr_model.pth"
metal_preds = inference(metal_model_name, saved_test_loader)

small_model_name = "/content/drive/MyDrive/Protein-binding/trained_models/GCN_KAN_for_small_binding_20Apr_model.pth"
small_preds = inference(small_model_name, saved_test_loader)

merged_preds_list = [max(a, b, c) for a, b, c in zip(nuclear_preds, metal_preds, small_preds)]
print(merged_preds_list)

In [None]:
# print(len(merged_preds_list))
all_labels = create_labels(saved_test_loader)
print(len(all_labels))

In [None]:
metal_precision, metal_recall, metal_f1, _ = precision_recall_fscore_support(all_labels, metal_preds, average='binary')
metal_mcc = matthews_corrcoef(all_labels, metal_preds)

nuclear_precision, nuclear_recall, nuclear_f1, _ = precision_recall_fscore_support(all_labels, nuclear_preds, average='binary')
nuclear_mcc = matthews_corrcoef(all_labels, nuclear_preds)

small_precision, small_recall, small_f1, _ = precision_recall_fscore_support(all_labels, small_preds, average='binary')
small_mcc = matthews_corrcoef(all_labels, small_preds)

overall_precision, overall_recall, overall_f1, _ = precision_recall_fscore_support(all_labels, merged_preds_list, average='binary')
overall_mcc = matthews_corrcoef(all_labels, merged_preds_list)


print(f"Metal Precision: {metal_precision:.4f}")
print(f"Metal Recall: {metal_recall:.4f}")
print(f"Metal F1-Score: {metal_f1:.4f}")
print(f"Metal MCC score: {metal_mcc:.4f}")
print("\n")

print(f"Nuclear Precision: {nuclear_precision:.4f}")
print(f"Nuclear Recall: {nuclear_recall:.4f}")
print(f"Nuclear F1-Score: {nuclear_f1:.4f}")
print(f"Nuclear MCC score: {nuclear_mcc:.4f}")
print("\n")

print(f"Small Precision: {small_precision:.4f}")
print(f"Small Recall: {small_recall:.4f}")
print(f"Small F1-Score: {small_f1:.4f}")
print(f"Small MCC score: {small_mcc:.4f}")
print("\n")

print(f"Overall Precision: {overall_precision:.4f}")
print(f"Overall Recoverall: {overall_recall:.4f}")
print(f"Overall F1-Score: {overall_f1:.4f}")
print(f"Overall MCC score: {overall_mcc:.4f}")

### Error analysis

In [None]:
# all_preds, all_labels, all_probs = [], [], []

# with torch.no_grad():
#     for test_batch in test_loader:
#         test_batch = test_batch.to(device)
#         out = model(test_batch)

#         probs = torch.softmax(out, dim=1)[:, 1]
#         preds = torch.argmax(out, dim=1)

#         all_preds.extend(preds.cpu().numpy())
#         all_labels.extend(batch.y.cpu().numpy())
#         all_probs.extend(probs.cpu().numpy())

In [None]:
# all_test_prot_sequences = test_df['sequence'].tolist()
# all_sequences = []
# for prot_seq in all_test_prot_sequences:
#     all_sequences.extend(list(prot_seq))

In [None]:
# from collections import Counter
# from copy import deepcopy

# test_aa_counter = Counter(all_sequences)
# false_negatives_dict = {}

# for prob, label, pred, acid in zip(all_probs, all_labels, all_preds, all_sequences):
#     if label == 1 and pred == 0:
#         print(f"Acid: {acid} and its probability: {prob}")

#         if acid not in false_negatives_dict:
#             false_negatives_dict[acid] = 1
#         else:
#             false_negatives_dict[acid] += 1

In [None]:
# false_negatives_percentage_dict = deepcopy(false_negatives_dict)

# for acid in false_negatives_percentage_dict:
#     false_negatives_percentage_dict[acid] /= test_aa_counter[acid]

In [None]:
# # print(false_negatives_percentage_dict)
# sorted_false_negatives_percentage_dict = dict(sorted(false_negatives_percentage_dict.items(), key=lambda item: item[1], reverse=True))
# print(sorted_false_negatives_percentage_dict)

In [None]:
# sorted_false_negatives_percentage_dict