In [51]:
import torch

if torch.cuda.is_available():
    print("✅ CUDA is available! GPU will be used for training.")
    print(f"PyTorch CUDA Version: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
else:
    print("❌ CUDA not found. Running on CPU.")

✅ CUDA is available! GPU will be used for training.
PyTorch CUDA Version: 12.1
GPU Name: NVIDIA GeForce RTX 4060 Laptop GPU


In [52]:
import pandas as pd
import sqlite3
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import softmax
from rdkit import Chem
from Bio import SeqIO
import random

# --- Configuration (UPDATE THESE PATHS & ID) ---
CHEMPL_DB_PATH = 'DL_ENDSEM__DATASET\chembl_35\chembl_35_sqlite\chembl_35.db'
BLAST_FASTA_PATH = 'DL_ENDSEM__DATASET\chembl_35_blast.fa.gz'
TARGET_UNIPROT_ID = "P00533" # Example: EGFR Kinase - CHANGE THIS TO YOUR TARGET'S UNIPROT ID

# --- Model Hyperparameters ---
Z_DIM = 100         # Latent noise dimension
ATOM_FEAT_DIM = 9   # Atom feature size
BOND_FEAT_DIM = 4   # Bond feature size (Single, Double, Triple, Aromatic)
EMBED_DIM = 128     # Hidden dimension for the Graph Transformer
T_EMBED_DIM = 768   # Target embedding dimension (Mocked)
LAMBDA_GP = 10.0    # Gradient Penalty weight
MAX_NODES = 30      # Max atoms in generated molecules (for Generator tensor shape)
N_CRITIC = 5        # Discriminator training steps per Generator step
EPOCHS = 100        # Total epochs
BATCH_SIZE = 16

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Mock Target Embedding (Replace with actual protein sequence embedding for T_EMBED_DIM)
TARGET_EMBED = torch.randn(T_EMBED_DIM).to(DEVICE)

Using device: cuda


In [53]:
def extract_potent_inhibitors(db_path, uniprot_id, potency_cutoff_nM=100):
    """
    Connects to ChEMBL DB and extracts SMILES for potent inhibitors of a given Uniprot ID.
    This version uses the most robust join by routing through target_components and
    component_sequences, where the accession (UniProt ID) is guaranteed to be found.
    """
    try:
        conn = sqlite3.connect(db_path)
        
        # The FIX: Joining multiple tables to hit the accession column in cseq
        sql_query = f"""
        SELECT DISTINCT
            cs.canonical_smiles
        FROM
            activities acts
        JOIN assays a ON acts.assay_id = a.assay_id
        JOIN target_dictionary td ON a.tid = td.tid
        
        -- Joining to link the target (td) to its components (tc)
        JOIN target_components tc ON td.tid = tc.tid
        -- Joining to get the sequence information where the accession is stored
        JOIN component_sequences cseq ON tc.component_id = cseq.component_id
        
        JOIN compound_structures cs ON acts.molregno = cs.molregno
        WHERE
            cseq.accession = '{uniprot_id}' AND  -- **UniProt ID is stored here**
            acts.standard_type = 'IC50' AND
            acts.standard_units = 'nM' AND
            acts.standard_relation = '=' AND
            acts.standard_value <= {potency_cutoff_nM}
        """
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        
        print(f"Found {len(df)} potent inhibitors for UniProt ID {uniprot_id}.")
        return df['canonical_smiles'].unique().tolist()
        
    except Exception as e:
        print(f"Error during database query. This is likely due to missing tables or a critical file path issue.")
        raise RuntimeError(f"Database Error: {e}. Please ensure the file is the full ChEMBL SQLite dump.") from e

# --- 2.2. Protein Sequence Loading (Just for context) ---
def load_target_sequence(fasta_path, uniprot_id):
    """Loads a protein sequence from a FASTA file."""
    for record in SeqIO.parse(fasta_path, "fasta"):
        if uniprot_id in record.id or uniprot_id in record.description:
            return str(record.seq)
    return None

# --- 2.3. Graph Featurization (Double Bond Adjacency) ---
def get_atom_features(atom):
    """Creates the 9-dimensional atom feature vector."""
    return [
        atom.GetAtomicNum(), atom.GetDegree(), atom.GetFormalCharge(),
        int(atom.GetIsAromatic()),
        int(atom.GetHybridization() == Chem.HybridizationType.SP),
        int(atom.GetHybridization() == Chem.HybridizationType.SP2),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D2)
    ]

def smiles_to_graph(smiles, target_embed):
    """Converts SMILES to a PyG Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if not mol: return None
    if mol.GetNumAtoms() > MAX_NODES: return None

    x = torch.tensor([get_atom_features(atom) for atom in mol.GetAtoms()], dtype=torch.float).to(DEVICE)
    
    edge_indices, edge_attrs = [], []
    bond_types = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]
    
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        
        # Bond Feature One-Hot Encoding (BOND_FEAT_DIM = 4)
        bond_type_oh = [int(bond.GetBondType() == t) for t in bond_types]
        
        edge_indices.extend([[i, j], [j, i]])
        edge_attrs.extend([bond_type_oh, bond_type_oh])

    if not edge_indices: return None
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous().to(DEVICE)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float).to(DEVICE)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                target_embed=target_embed.unsqueeze(0))
    return data

# --- Data Pipeline Execution ---
inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID)
real_data_list = [smiles_to_graph(s, TARGET_EMBED) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

if not real_data_list:
    print("Warning: No valid inhibitor data found. Exiting.")
    exit()

real_loader = DataLoader(real_data_list, batch_size=BATCH_SIZE, shuffle=True)
print(f"Prepared {len(real_data_list)} real graph samples for training.")

Found 3989 potent inhibitors for UniProt ID P00533.
Prepared 1334 real graph samples for training.


  real_loader = DataLoader(real_data_list, batch_size=BATCH_SIZE, shuffle=True)


In [54]:
# --- 2.1. ChEMBL Data Extraction (FIXED SQL QUERY) ---
def extract_potent_inhibitors(db_path, uniprot_id, potency_cutoff_nM=100):
    """
    Connects to ChEMBL DB and extracts SMILES for potent inhibitors of a given Uniprot ID.
    Uses the robust join through target_components and component_sequences to find the accession.
    """
    conn = sqlite3.connect(db_path)
    # The fix: Join through target_components and component_sequences tables
    sql_query = f"""
    SELECT DISTINCT
        cs.canonical_smiles
    FROM
        activities acts
    JOIN assays a ON acts.assay_id = a.assay_id
    JOIN target_dictionary td ON a.tid = td.tid
    JOIN target_components tc ON td.tid = tc.tid
    JOIN component_sequences cseq ON tc.component_id = cseq.component_id
    JOIN compound_structures cs ON acts.molregno = cs.molregno
    WHERE
        cseq.accession = '{uniprot_id}' AND  -- The accession is correctly found here
        acts.standard_type = 'IC50' AND
        acts.standard_units = 'nM' AND
        acts.standard_relation = '=' AND
        acts.standard_value <= {potency_cutoff_nM}
    """
    df = pd.read_sql_query(sql_query, conn)
    conn.close()
    
    print(f"Found {len(df)} potent inhibitors for Uniprot ID {uniprot_id}.")
    return df['canonical_smiles'].unique().tolist()

# --- 2.2. Protein Sequence Loading (Just for context) ---
def load_target_sequence(fasta_path, uniprot_id):
    """Loads a protein sequence from a FASTA file."""
    for record in SeqIO.parse(fasta_path, "fasta"):
        if uniprot_id in record.id or uniprot_id in record.description:
            return str(record.seq)
    return None

# --- 2.3. Graph Featurization (Double Bond Adjacency) ---
def get_atom_features(atom):
    """Creates the 9-dimensional atom feature vector."""
    return [
        atom.GetAtomicNum(), atom.GetDegree(), atom.GetFormalCharge(),
        int(atom.GetIsAromatic()),
        int(atom.GetHybridization() == Chem.HybridizationType.SP),
        int(atom.GetHybridization() == Chem.HybridizationType.SP2),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D2)
    ]

def smiles_to_graph(smiles, target_embed):
    """Converts SMILES to a PyG Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if not mol: return None
    if mol.GetNumAtoms() > MAX_NODES: return None

    x = torch.tensor([get_atom_features(atom) for atom in mol.GetAtoms()], dtype=torch.float).to(DEVICE)
    
    edge_indices, edge_attrs = [], []
    bond_types = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]
    
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        
        # Bond Feature One-Hot Encoding (BOND_FEAT_DIM = 4)
        bond_type_oh = [int(bond.GetBondType() == t) for t in bond_types]
        
        edge_indices.extend([[i, j], [j, i]])
        edge_attrs.extend([bond_type_oh, bond_type_oh])

    if not edge_indices: return None
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous().to(DEVICE)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float).to(DEVICE)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                target_embed=target_embed.unsqueeze(0))
    return data

# --- Data Pipeline Execution ---
inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID)
real_data_list = [smiles_to_graph(s, TARGET_EMBED) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

if not real_data_list:
    print("Warning: No valid inhibitor data found. Exiting.")
    exit()

real_loader = DataLoader(real_data_list, batch_size=BATCH_SIZE, shuffle=True)
print(f"Prepared {len(real_data_list)} real graph samples for training.")

Found 3989 potent inhibitors for Uniprot ID P00533.
Prepared 1334 real graph samples for training.


  real_loader = DataLoader(real_data_list, batch_size=BATCH_SIZE, shuffle=True)


In [55]:
# --- 3.1. Relational Graph Transformer Layer (Core GNN Component) ---
class RelationalGraphTransformerLayer(MessagePassing):
    """Graph Transformer Layer with explicit edge/bond feature integration."""
    def __init__(self, in_channels, out_channels, edge_dim, heads=4, dropout=0.1, **kwargs):
        super().__init__(aggr='add', node_dim=0, **kwargs)
        self.out_channels = out_channels
        self.heads = heads
        
        self.lin_q = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_k = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_v = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) 
        self.att_coeff = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels)) 
        self.lin_out = nn.Linear(heads * out_channels, out_channels)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, edge_attr):
        Q = self.lin_q(x).view(-1, self.heads, self.out_channels)
        K = self.lin_k(x).view(-1, self.heads, self.out_channels)
        V = self.lin_v(x).view(-1, self.heads, self.out_channels)
        E = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels)

        out = self.propagate(edge_index, Q=Q, K=K, V=V, E=E)
        out = out.view(-1, self.heads * self.out_channels)
        out = self.lin_out(out)
        return out

    def message(self, Q_i, K_j, V_j, E_k, index):
        
        # 1. The FINAL Robust Guard Clause (using try-except to catch the '_empty' object)
        try:
            # Attempt to use the tensor method (.size) to check for emptiness
            is_empty = (E_k.size(0) == 0)
        except AttributeError:
            # If AttributeError is raised, E_k is the problematic '_empty' object.
            is_empty = True
        
        if is_empty:
            # If empty (no messages/edges for this step), set the edge bias contribution to zero.
            # We must infer the correct shape for the zero tensor using Q_i and heads.
            E_bias = torch.zeros(Q_i.size(0), self.heads, 1, device=self.lin_q.weight.device)
        else:
            # If not empty, calculate the bias contribution as intended.
            E_bias = E_k.mean(dim=-1, keepdim=True) 

        # Compute raw attention score e_ij: 
        QK_cat = torch.cat([Q_i, K_j], dim=-1) # (E, H, 2*out)
        e_ij = (QK_cat * self.att_coeff).sum(dim=-1, keepdim=True) # (E, H, 1)

        # 2. Add Bond Feature Bias (Use the safe E_bias)
        e_ij = e_ij + E_bias
        
        e_ij = torch.nn.functional.leaky_relu(e_ij)
        alpha = softmax(e_ij, index)
        self.dropout(alpha)
        
        return V_j * alpha.view(-1, self.heads, 1)

# --- 3.2. Discriminator (Graph Encoder) ---
class Discriminator(nn.Module):
    def __init__(self, node_features, edge_dim, t_embed_dim, embed_dim, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_d = node_features if i == 0 else embed_dim
            self.layers.append(RelationalGraphTransformerLayer(in_d, embed_dim, edge_dim))
        self.lin_final = nn.Linear(embed_dim + t_embed_dim, 1)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        t_embed = data.target_embed
        
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
            x = torch.relu(x)
        
        graph_embed = global_mean_pool(x, batch)
        final_input = torch.cat([graph_embed, t_embed.squeeze(1)], dim=1)
        
        return self.lin_final(final_input).squeeze(1)

# --- 3.3. Generator (Graph Decoder - MOCK) ---
class Generator(nn.Module):
    def __init__(self, z_dim, t_embed_dim, node_features, bond_features, max_nodes=MAX_NODES):
        super().__init__()
        self.max_nodes = max_nodes
        self.node_features = node_features
        self.bond_features = bond_features
        
        self.lin_x = nn.Sequential(
            nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(),
            nn.Linear(256, max_nodes * node_features)
        )
        self.lin_adj = nn.Sequential(
            nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(),
            nn.Linear(256, max_nodes * max_nodes * bond_features) 
        )

    def forward(self, z, t_embed):
        zt = torch.cat([z, t_embed], dim=1)
        x_fake = self.lin_x(zt).view(-1, self.max_nodes, self.node_features)
        adj_fake = self.lin_adj(zt).view(-1, self.max_nodes, self.max_nodes, self.bond_features)
        adj_fake = torch.softmax(adj_fake, dim=-1)
        
        return x_fake, adj_fake

# --- Model Initialization ---
generator = Generator(Z_DIM, T_EMBED_DIM, ATOM_FEAT_DIM, BOND_FEAT_DIM).to(DEVICE)
discriminator = Discriminator(ATOM_FEAT_DIM, BOND_FEAT_DIM, T_EMBED_DIM, EMBED_DIM).to(DEVICE)

optimizer_G = optim.Adam(generator.parameters(), lr=1e-4)
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4)

In [56]:
# --- 4.1. Mock Graph Conversion for Discriminator ---
def convert_fake_to_data(x_fake_tensor, adj_fake_tensor, t_embed_batch, device):
    """Mocks the conversion of Generator tensor output into a PyG Data list for the Discriminator."""
    batch_size = x_fake_tensor.size(0)
    data_list = []
    
    for i in range(batch_size):
        num_nodes = MAX_NODES
        x_i = x_fake_tensor[i, :, :]
        
        # Create full graph edge index (for simplicity in GP calculation)
        edge_indices = [[r, c] for r in range(num_nodes) for c in range(num_nodes) if r != c]
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous().to(device)
        
        # Use the most probable bond type for the edge feature
        adj_i = adj_fake_tensor[i, :, :, :]
        edge_attr_indices = [adj_i[r, c].argmax().item() for r, c in edge_indices]
        edge_attr = nn.functional.one_hot(torch.tensor(edge_attr_indices), 
                                          num_classes=BOND_FEAT_DIM).float().to(device)

        data_list.append(Data(x=x_i, edge_index=edge_index, edge_attr=edge_attr, 
                              target_embed=t_embed_batch[i].unsqueeze(0)))
    
    return DataLoader(data_list, batch_size=batch_size).dataset

# --- 4.2. WGAN-GP Gradient Penalty Calculation (Simplified) ---
def calculate_gradient_penalty(discriminator, real_data, fake_data, lambda_gp, device):
    """Calculates the Gradient Penalty on interpolated node features (X)."""
    
    # --- FIX: Match fake_data.x size to real_data.x size for interpolation ---
    real_x_size = real_data.x.size(0)
    fake_x = fake_data.x.detach()
    
    if fake_x.size(0) > real_x_size:
        # Truncate fake nodes if the generated batch is larger
        fake_x = fake_x[:real_x_size]
    elif fake_x.size(0) < real_x_size:
        # Pad fake nodes if the generated batch is smaller
        # Pad with zeros to match the number of nodes in the real batch
        padding = torch.zeros(real_x_size - fake_x.size(0), fake_x.size(1), device=device)
        fake_x = torch.cat([fake_x, padding], dim=0)

    # 1. Linear Interpolation (Now guaranteed to work as both tensors are the same size)
    alpha = torch.rand(real_x_size, 1).to(device) 
    interpolated_x = (alpha * real_data.x.detach()) + ((1 - alpha) * fake_x)
    interpolated_x.requires_grad_(True)
    # ------------------------------------------------------------------------

    # Create interpolated Data object (using real graph structure for batch context)
    interpolated_data = Data(x=interpolated_x, 
                             edge_index=real_data.edge_index, 
                             edge_attr=real_data.edge_attr, 
                             batch=real_data.batch, 
                             target_embed=real_data.target_embed)

    disc_interpolates = discriminator(interpolated_data)
    
    # Compute Gradients
    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolated_x,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True, retain_graph=True
    )[0]

    # Calculate Penalty
    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

# --- 4.3. Main Training Loop ---
def run_wgan_gp_training(generator, discriminator, data_loader, epochs, n_critic):
    
    generator.train()
    discriminator.train()
    
    for epoch in range(1, epochs + 1):
        g_loss_sum, d_loss_sum = 0, 0
        
        for real_data in data_loader:
            real_data.to(DEVICE)
            batch_size = real_data.num_graphs
            target_embed_batch = real_data.target_embed.squeeze(1)

            # 1. Train Discriminator (n_critic steps)
            for _ in range(n_critic):
                optimizer_D.zero_grad()
                
                # Real Loss
                d_real = discriminator(real_data).mean()
                
                # Fake Loss
                z = torch.randn(batch_size, Z_DIM).to(DEVICE)
                x_fake, adj_fake = generator(z, target_embed_batch)
                
                # Convert generated tensors to PyG Data
                fake_data_list = convert_fake_to_data(x_fake.detach(), adj_fake.detach(), target_embed_batch, DEVICE)
                fake_data = next(iter(DataLoader(fake_data_list, batch_size=batch_size))).to(DEVICE)
                
                d_fake = discriminator(fake_data).mean()
                
                # Gradient Penalty
                gp = calculate_gradient_penalty(discriminator, real_data, fake_data, LAMBDA_GP, DEVICE)
                
                # Discriminator Loss
                d_loss = - (d_real - d_fake) + gp
                d_loss.backward()
                optimizer_D.step()
                d_loss_sum += d_loss.item()
            
            # 2. Train Generator (1 step)
            optimizer_G.zero_grad()
            
            z = torch.randn(batch_size, Z_DIM).to(DEVICE)
            x_fake, adj_fake = generator(z, target_embed_batch)
            
            fake_data_list = convert_fake_to_data(x_fake, adj_fake, target_embed_batch, DEVICE)
            fake_data = next(iter(DataLoader(fake_data_list, batch_size=batch_size))).to(DEVICE)
            
            # Generator Loss
            g_loss = - discriminator(fake_data).mean()
            g_loss.backward()
            optimizer_G.step()
            g_loss_sum += g_loss.item()

        if epoch % 1 == 0:
            avg_d_loss = d_loss_sum / len(data_loader) / n_critic
            avg_g_loss = g_loss_sum / len(data_loader)
            print(f"Epoch {epoch}/{EPOCHS} | D Loss: {avg_d_loss:.4f} | G Loss: {avg_g_loss:.4f}")

# --- Execute Training ---
# UNCOMMENT THE LINE BELOW TO START TRAINING
# run_wgan_gp_training(generator, discriminator, real_loader, EPOCHS, N_CRITIC) 
print("\nCode is complete and ready to run. Ensure your file paths are correct.")


Code is complete and ready to run. Ensure your file paths are correct.


In [None]:
# --- Execute Training ---
print("\n--- Starting WGAN-GP Training ---")
run_wgan_gp_training(generator, discriminator, real_loader, EPOCHS, N_CRITIC) 
print("\nTraining completed.")
"There was a mode collapse we need to fix that"


--- Starting WGAN-GP Training ---


  return DataLoader(data_list, batch_size=batch_size).dataset
  fake_data = next(iter(DataLoader(fake_data_list, batch_size=batch_size))).to(DEVICE)
  fake_data = next(iter(DataLoader(fake_data_list, batch_size=batch_size))).to(DEVICE)


Epoch 1/100 | D Loss: -52.3648 | G Loss: 78.6632
Epoch 2/100 | D Loss: -121.5836 | G Loss: 7.7575
Epoch 3/100 | D Loss: -156.0709 | G Loss: -3.7328


KeyboardInterrupt: 

In [58]:
# --- Data Pipeline Execution ---
inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID)
real_data_list = [smiles_to_graph(s, TARGET_EMBED) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

# This is the line that reports the dataset size:
print(f"Prepared {len(real_data_list)} real graph samples for training.")

# Your dataset size is:
dataset_size = len(real_data_list)
print(f"The final dataset size is: {dataset_size} molecules.")

Found 3989 potent inhibitors for Uniprot ID P00533.
Prepared 1334 real graph samples for training.
The final dataset size is: 1334 molecules.
