In [None]:
import torch
import pandas as pd
import sqlite3
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import softmax
from rdkit import Chem
from Bio import SeqIO
import random
import re
from transformers import T5EncoderModel, T5Tokenizer
import gzip  # <-- FIX 1: ADDED THIS IMPORT

# --- 1. System & Configuration ---

# --- Configuration (UPDATE THESE PATHS & ID) ---
CHEMPL_DB_PATH = 'DL_ENDSEM__DATASET/chembl_35/chembl_35_sqlite/chembl_35.db'
BLAST_FASTA_PATH = 'DL_ENDSEM__DATASET/chembl_35_blast.fa.gz'
TARGET_UNIPROT_ID = "P00533" # Example: EGFR Kinase

# --- Model Hyperparameters ---
Z_DIM = 100         # Latent noise dimension
ATOM_FEAT_DIM = 9   # Atom feature size
BOND_FEAT_DIM = 4   # Bond feature size (Single, Double, Triple, Aromatic)
EMBED_DIM = 128     # Hidden dimension for the Graph Transformer
T_EMBED_DIM = 1024  # Target embedding dimension (from ProtT5)
LAMBDA_GP = 10.0    # Gradient Penalty weight
MAX_NODES = 30      # Max atoms in generated molecules (for Generator tensor shape)
N_CRITIC = 5        # Discriminator training steps per Generator step
EPOCHS = 150       # Total epochs
BATCH_SIZE = 128

# --- CUDA Check ---
if torch.cuda.is_available():
    print("‚úÖ CUDA is available! GPU will be used for training.")
    print(f"PyTorch CUDA Version: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    print("‚ùå CUDA not found. Running on CPU.")
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")


# --- 2. Real Protein Embedding Generation ---

# --- FIX 2: REPLACED THIS FUNCTION ---
def load_target_sequence(fasta_path, uniprot_id):
    """Loads a protein sequence from a gzipped FASTA file."""
    print(f"Opening gzipped FASTA file: {fasta_path}")
    try:
        # 'rt' = read in text mode
        with gzip.open(fasta_path, "rt") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                if uniprot_id in record.id or uniprot_id in record.description:
                    return str(record.seq)
        print(f"Warning: Could not find sequence for {uniprot_id} in {fasta_path}")
        return None
    except FileNotFoundError:
        print(f"FATAL ERROR: FASTA file not found at {fasta_path}")
        raise
    except Exception as e:
        print(f"FATAL ERROR: Could not read FASTA file. Error: {e}")
        raise

def get_protein_embedding(sequence, device):
    """
    Generates a protein embedding using the pre-trained ProtT5 model.
    This replaces the "mock" random embedding.
    """
    print("Loading ProtT5 model... (This may take a moment)")
    # Load model and tokenizer
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc').to(device)
    model.eval() # Set to evaluation mode

    # Pre-process sequence: add spaces between AAs and handle rare AAs
    sequence_preprocessed = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
    
    inputs = tokenizer(sequence_preprocessed, return_tensors="pt", max_length=1024, truncation=True).to(device)
    
    with torch.no_grad():
        embedding = model(**inputs).last_hidden_state
    
    # Average pool the embedding to get a single vector [1, 1024]
    protein_vec = embedding.mean(dim=1).squeeze(0) # Squeeze to [1024]
    print(f"Generated protein embedding of shape: {protein_vec.shape}")
    return protein_vec

# --- Generate the REAL Target Embedding ---
target_seq = load_target_sequence(BLAST_FASTA_PATH, TARGET_UNIPROT_ID)
if target_seq is None:
    raise ValueError(f"Target sequence for {TARGET_UNIPROT_ID} not found. Exiting.")

TARGET_EMBED = get_protein_embedding(target_seq, DEVICE)


# --- 3. Data Pipeline (Molecules -> Graphs) ---

def extract_potent_inhibitors(db_path, uniprot_id, potency_cutoff_nM=100):
    """
    Connects to ChEMBL DB and extracts SMILES for potent inhibitors of a given Uniprot ID.
    (Using your robust SQL query)
    """
    try:
        conn = sqlite3.connect(db_path)
        sql_query = f"""
        SELECT DISTINCT
            cs.canonical_smiles
        FROM
            activities acts
        JOIN assays a ON acts.assay_id = a.assay_id
        JOIN target_dictionary td ON a.tid = td.tid
        JOIN target_components tc ON td.tid = tc.tid
        JOIN component_sequences cseq ON tc.component_id = cseq.component_id
        JOIN compound_structures cs ON acts.molregno = cs.molregno
        WHERE
            cseq.accession = '{uniprot_id}' AND
            acts.standard_type = 'IC50' AND
            acts.standard_units = 'nM' AND
            acts.standard_relation = '=' AND
            acts.standard_value <= {potency_cutoff_nM}
        """
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        
        print(f"Found {len(df)} potent inhibitors for UniProt ID {uniprot_id}.")
        return df['canonical_smiles'].unique().tolist()
        
    except Exception as e:
        print(f"Error during database query. This is likely due to missing tables or a critical file path issue.")
        raise RuntimeError(f"Database Error: {e}. Please ensure the file is the full ChEMBL SQLite dump.") from e

def get_atom_features(atom):
    """Creates the 9-dimensional atom feature vector."""
    return [
        atom.GetAtomicNum(), atom.GetDegree(), atom.GetFormalCharge(),
        int(atom.GetIsAromatic()),
        int(atom.GetHybridization() == Chem.HybridizationType.SP),
        int(atom.GetHybridization() == Chem.HybridizationType.SP2),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D2)
    ]

def smiles_to_graph(smiles, target_embed):
    """Converts SMILES to a PyG Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if not mol: return None
    if mol.GetNumAtoms() > MAX_NODES: return None

    x = torch.tensor([get_atom_features(atom) for atom in mol.GetAtoms()], dtype=torch.float).to(DEVICE)
    
    edge_indices, edge_attrs = [], []
    bond_types = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]
    
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bond_type_oh = [int(bond.GetBondType() == t) for t in bond_types]
        edge_indices.extend([[i, j], [j, i]])
        edge_attrs.extend([bond_type_oh, bond_type_oh])

    if not edge_indices: return None
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous().to(DEVICE)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float).to(DEVICE)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                target_embed=target_embed.unsqueeze(0))
    return data

# --- Data Pipeline Execution ---
inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID)
real_data_list = [smiles_to_graph(s, TARGET_EMBED) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

if not real_data_list:
    print("FATAL: No valid inhibitor data found. The model cannot be trained.")
    exit()

real_loader = DataLoader(real_data_list, batch_size=BATCH_SIZE, shuffle=True)
print(f"Prepared {len(real_data_list)} real graph samples for training.")


# --- 4. Model Architecture ---

# --- 4.1. Relational Graph Transformer Layer (No Changes) ---
class RelationalGraphTransformerLayer(MessagePassing):
    """Graph Transformer Layer with explicit edge/bond feature integration."""
    def __init__(self, in_channels, out_channels, edge_dim, heads=4, dropout=0.1, **kwargs):
        super().__init__(aggr='add', node_dim=0, **kwargs)
        self.out_channels = out_channels
        self.heads = heads
        
        self.lin_q = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_k = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_v = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) 
        self.att_coeff = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels)) 
        self.lin_out = nn.Linear(heads * out_channels, out_channels)
        self.dropout = nn.Dropout(dropout)
        
        # Initialize attention coefficients (Xavier initialization)
        nn.init.xavier_uniform_(self.att_coeff)

    def forward(self, x, edge_index, edge_attr):
        Q = self.lin_q(x).view(-1, self.heads, self.out_channels)
        K = self.lin_k(x).view(-1, self.heads, self.out_channels)
        V = self.lin_v(x).view(-1, self.heads, self.out_channels)
        E = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels)

        out = self.propagate(edge_index, Q=Q, K=K, V=V, E=E)
        out = out.view(-1, self.heads * self.out_channels)
        out = self.lin_out(out)
        return out

    def message(self, Q_i, K_j, V_j, E_k, index):
        try:
            is_empty = (E_k.size(0) == 0)
        except AttributeError:
            is_empty = True
        
        if is_empty:
            E_bias = torch.zeros(Q_i.size(0), self.heads, 1, device=self.lin_q.weight.device)
        else:
            E_bias = E_k.mean(dim=-1, keepdim=True) 

        QK_cat = torch.cat([Q_i, K_j], dim=-1)
        e_ij = (QK_cat * self.att_coeff).sum(dim=-1, keepdim=True)
        e_ij = e_ij + E_bias
        
        e_ij = F.leaky_relu(e_ij)
        alpha = softmax(e_ij, index)
        alpha = self.dropout(alpha)
        
        return V_j * alpha.view(-1, self.heads, 1)

# --- 4.2. Discriminator (No Changes) ---
class Discriminator(nn.Module):
    def __init__(self, node_features, edge_dim, t_embed_dim, embed_dim, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_d = node_features if i == 0 else embed_dim
            self.layers.append(RelationalGraphTransformerLayer(in_d, embed_dim, edge_dim))
        self.lin_final = nn.Linear(embed_dim + t_embed_dim, 1)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        t_embed = data.target_embed
        
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
        
        graph_embed = global_mean_pool(x, batch)
        
        # Ensure t_embed is [batch_size, t_embed_dim]
        if t_embed.dim() > 2:
            t_embed = t_embed.squeeze(1) 
            
        final_input = torch.cat([graph_embed, t_embed], dim=1)
        
        return self.lin_final(final_input).squeeze(1)

# --- 4.3. Generator (FIXED: Outputs Logits) ---
class Generator(nn.Module):
    def __init__(self, z_dim, t_embed_dim, node_features, bond_features, max_nodes=MAX_NODES):
        super().__init__()
        self.max_nodes = max_nodes
        self.node_features = node_features
        self.bond_features = bond_features
        
        self.lin_x = nn.Sequential(
            nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(),
            nn.Linear(256, max_nodes * node_features)
        )
        self.lin_adj = nn.Sequential(
            nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(),
            nn.Linear(256, max_nodes * max_nodes * bond_features) 
        )

    def forward(self, z, t_embed):
        zt = torch.cat([z, t_embed], dim=1)
        
        x_fake = self.lin_x(zt).view(-1, self.max_nodes, self.node_features)
        
        # Output raw logits for the adjacency matrix (NO SOFTMAX)
        adj_fake_logits = self.lin_adj(zt).view(-1, self.max_nodes, self.max_nodes, self.bond_features)
        
        # Note: We don't softmax x_fake as it contains continuous features (e.g., atomic num)
        
        return x_fake, adj_fake_logits

# --- Model Initialization ---
print("Initializing models...")
generator = Generator(Z_DIM, T_EMBED_DIM, ATOM_FEAT_DIM, BOND_FEAT_DIM).to(DEVICE)
discriminator = Discriminator(ATOM_FEAT_DIM, BOND_FEAT_DIM, T_EMBED_DIM, EMBED_DIM).to(DEVICE)

optimizer_G = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))


# --- 5. Training Utilities (FIXED) ---

# --- 5.1. Differentiable Graph Conversion (NEW) ---
def convert_fake_to_data_differentiable(x_fake_tensor, adj_fake_logits, t_embed_batch, device, temperature=0.5):
    """
    Converts Generator output to PyG Data using Gumbel-Softmax for the G-step.
    This IS differentiable, allowing the generator to learn bond formation.
    """
    batch_size = x_fake_tensor.size(0)
    data_list = []
    
    for i in range(batch_size):
        num_nodes = MAX_NODES
        x_i = x_fake_tensor[i, :, :]
        adj_i_logits = adj_fake_logits[i, :, :, :]
        
        edge_indices, edge_attrs_gumbel = [], []
        
        for r in range(num_nodes):
            for c in range(num_nodes):
                if r == c: continue
                
                # Gumbel-Softmax Sampling (Differentiable "argmax")
                bond_probs = F.gumbel_softmax(
                    adj_i_logits[r, c], 
                    tau=temperature, 
                    hard=True
                )
                
                edge_indices.append([r, c])
                edge_attrs_gumbel.append(bond_probs)

        if not edge_indices: continue

        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous().to(device)
        edge_attr = torch.stack(edge_attrs_gumbel).to(device)

        data_list.append(Data(x=x_i, edge_index=edge_index, edge_attr=edge_attr, 
                              target_embed=t_embed_batch[i].unsqueeze(0)))
    
    if not data_list: return None
    # We return the .dataset attribute for direct use in GP calculation
    return DataLoader(data_list, batch_size=batch_size).dataset

# --- 5.2. Discrete Graph Conversion (NEW) ---
def convert_fake_to_data_discrete(x_fake_tensor, adj_fake_logits, t_embed_batch, device):
    """
    Converts Generator output to PyG Data using .argmax() for the D-step.
    This is NOT differentiable and is used when we don't need grads (faster).
    """
    batch_size = x_fake_tensor.size(0)
    data_list = []
    
    for i in range(batch_size):
        num_nodes = MAX_NODES
        x_i = x_fake_tensor[i, :, :].detach() # Detach all inputs
        adj_i_logits = adj_fake_logits[i, :, :, :].detach()
        
        edge_indices, edge_attrs_list = [], []
        
        for r in range(num_nodes):
            for c in range(num_nodes):
                if r == c: continue
                
                # Hard .argmax() sampling
                bond_type_index = adj_i_logits[r, c].argmax().item()
                bond_one_hot = F.one_hot(
                    torch.tensor(bond_type_index, device=device), 
                    num_classes=BOND_FEAT_DIM
                ).float()
                
                edge_indices.append([r, c])
                edge_attrs_list.append(bond_one_hot)

        if not edge_indices: continue
        
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous().to(device)
        edge_attr = torch.stack(edge_attrs_list).to(device)

        data_list.append(Data(x=x_i, edge_index=edge_index, edge_attr=edge_attr, 
                              target_embed=t_embed_batch[i].unsqueeze(0)))
    
    if not data_list: return None
    return DataLoader(data_list, batch_size=batch_size).dataset

# --- 5.3. WGAN-GP Gradient Penalty (Preserved User's Fix) ---
def calculate_gradient_penalty(discriminator, real_data, fake_data, lambda_gp, device):
    """Calculates the Gradient Penalty on interpolated node features (X)."""
    
    real_x = real_data.x.detach()
    fake_x = fake_data.x.detach()
    real_x_size = real_x.size(0)
    
    # --- FIX: Match fake_data.x size to real_data.x size for interpolation ---
    if fake_x.size(0) > real_x_size:
        fake_x = fake_x[:real_x_size]
    elif fake_x.size(0) < real_x_size:
        padding = torch.zeros(real_x_size - fake_x.size(0), fake_x.size(1), device=device)
        fake_x = torch.cat([fake_x, padding], dim=0)

    # 1. Linear Interpolation
    alpha = torch.rand(real_x_size, 1).to(device) 
    interpolated_x = (alpha * real_x) + ((1 - alpha) * fake_x)
    interpolated_x.requires_grad_(True)

    # 2. Create interpolated Data object
    # We use the real data's structure (edges, batch) as a template
    interpolated_data = Data(x=interpolated_x, 
                             edge_index=real_data.edge_index, 
                             edge_attr=real_data.edge_attr, 
                             batch=real_data.batch, 
                             target_embed=real_data.target_embed)

    disc_interpolates = discriminator(interpolated_data)
    
    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolated_x,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True, retain_graph=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

# --- 6. Main Training Loop (FIXED) ---
def run_wgan_gp_training(generator, discriminator, data_loader, epochs, n_critic):
    
    generator.train()
    discriminator.train()
    
    for epoch in range(1, epochs + 1):
        g_loss_sum, d_loss_sum = 0, 0
        
        for batch_idx, real_data in enumerate(data_loader):
            # --- THIS IS THE LINE YOU REQUESTED ---
            print(f"--- Epoch {epoch}, Processing Batch {batch_idx+1}/{len(data_loader)} ---")

            real_data = real_data.to(DEVICE)
            batch_size = real_data.num_graphs
            
            # Ensure target_embed_batch is [batch_size, T_EMBED_DIM]
            target_embed_batch = real_data.target_embed
            if target_embed_batch.dim() > 2:
                target_embed_batch = target_embed_batch.view(batch_size, T_EMBED_DIM)

            # 1. Train Discriminator (n_critic steps)
            for _ in range(n_critic):
                optimizer_D.zero_grad()
                
                # Real Loss
                d_real = discriminator(real_data).mean()
                
                # Fake Loss
                z = torch.randn(batch_size, Z_DIM).to(DEVICE)
                x_fake, adj_fake_logits = generator(z, target_embed_batch)
                
                # --- D-step uses DISCRETE (non-differentiable) sampling ---
                fake_data_list = convert_fake_to_data_discrete(
                    x_fake.detach(), adj_fake_logits.detach(), target_embed_batch, DEVICE
                )
                if fake_data_list is None: continue
                # Convert list of Data objects back into a batched Data object
                fake_data_loader = DataLoader(fake_data_list, batch_size=batch_size)
                fake_data = next(iter(fake_data_loader)).to(DEVICE)

                d_fake = discriminator(fake_data).mean()
                
                # Gradient Penalty
                gp = calculate_gradient_penalty(discriminator, real_data, fake_data, LAMBDA_GP, DEVICE)
                
                d_loss = - (d_real - d_fake) + gp
                d_loss.backward()
                optimizer_D.step()
                d_loss_sum += d_loss.item()
            
            # 2. Train Generator (1 step)
            optimizer_G.zero_grad()
            
            z = torch.randn(batch_size, Z_DIM).to(DEVICE)
            x_fake, adj_fake_logits = generator(z, target_embed_batch)
            
            # --- G-step uses DIFFERENTIABLE Gumbel-Softmax sampling ---
            fake_data_list = convert_fake_to_data_differentiable(
                x_fake, adj_fake_logits, target_embed_batch, DEVICE
            )
            if fake_data_list is None: continue
            fake_data_loader = DataLoader(fake_data_list, batch_size=batch_size)
            fake_data = next(iter(fake_data_loader)).to(DEVICE)
            
            # Generator Loss
            g_loss = - discriminator(fake_data).mean()
            g_loss.backward()
            optimizer_G.step()
            g_loss_sum += g_loss.item()
            
        avg_d_loss = d_loss_sum / len(data_loader) / n_critic
        avg_g_loss = g_loss_sum / len(data_loader)
        print(f"Epoch {epoch}/{EPOCHS} | D Loss: {avg_d_loss:.4f} | G Loss: {avg_g_loss:.4f}")

# --- Execute Training ---
print("\n--- Starting WGAN-GP Training ---")
run_wgan_gp_training(generator, discriminator, real_loader, EPOCHS, N_CRITIC) 
print("\nTraining completed.")

‚úÖ CUDA is available! GPU will be used for training.
PyTorch CUDA Version: 12.1
GPU Name: NVIDIA GeForce RTX 4060 Laptop GPU
Using device: cuda
Opening gzipped FASTA file: DL_ENDSEM__DATASET/chembl_35_blast.fa.gz
Loading ProtT5 model... (This may take a moment)
Generated protein embedding of shape: torch.Size([1024])


In [1]:
import torch
import pandas as pd
import sqlite3
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import softmax
from rdkit import Chem
from Bio import SeqIO
import random
import re
from transformers import T5EncoderModel, T5Tokenizer
import gzip

# --- 1. System & Configuration ---

# --- Configuration (UPDATE THESE PATHS & ID) ---
CHEMPL_DB_PATH = 'DL_ENDSEM__DATASET/chembl_35/chembl_35_sqlite/chembl_35.db'
BLAST_FASTA_PATH = 'DL_ENDSEM__DATASET/chembl_35_blast.fa.gz'
TARGET_UNIPROT_ID = "P00533" # Example: EGFR Kinase

# --- Model Hyperparameters ---
Z_DIM = 100         # Latent noise dimension
ATOM_FEAT_DIM = 9   # Atom feature size
BOND_FEAT_DIM = 4   # Bond feature size (Single, Double, Triple, Aromatic)
EMBED_DIM = 128     # Hidden dimension for the Graph Transformer
T_EMBED_DIM = 1024  # Target embedding dimension (from ProtT5)
LAMBDA_GP = 10.0    # Gradient Penalty weight
MAX_NODES = 30      # Max atoms in generated molecules (for Generator tensor shape)
N_CRITIC = 5        # Discriminator training steps per Generator step
EPOCHS = 10         # Total epochs
BATCH_SIZE = 64    # Your increased batch size

# --- OPTIMIZATION 1: Set num_workers based on your 16-core CPU ---
# Use ~half your cores to pre-fetch real data
CPU_WORKERS = 4  

# --- CUDA Check ---
if torch.cuda.is_available():
    print("‚úÖ CUDA is available! GPU will be used for training.")
    print(f"PyTorch CUDA Version: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    print("‚ùå CUDA not found. Running on CPU.")
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")


# --- 2. Real Protein Embedding Generation (No changes) ---

def load_target_sequence(fasta_path, uniprot_id):
    """Loads a protein sequence from a gzipped FASTA file."""
    print(f"Opening gzipped FASTA file: {fasta_path}")
    try:
        with gzip.open(fasta_path, "rt") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                if uniprot_id in record.id or uniprot_id in record.description:
                    return str(record.seq)
        print(f"Warning: Could not find sequence for {uniprot_id} in {fasta_path}")
        return None
    except FileNotFoundError:
        print(f"FATAL ERROR: FASTA file not found at {fasta_path}")
        raise
    except Exception as e:
        print(f"FATAL ERROR: Could not read FASTA file. Error: {e}")
        raise

def get_protein_embedding(sequence, device):
    """Generates a protein embedding using the pre-trained ProtT5 model."""
    print("Loading ProtT5 model... (This may take a moment)")
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc').to(device)
    model.eval() 

    sequence_preprocessed = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
    inputs = tokenizer(sequence_preprocessed, return_tensors="pt", max_length=1024, truncation=True).to(device)
    
    with torch.no_grad():
        embedding = model(**inputs).last_hidden_state
    
    protein_vec = embedding.mean(dim=1).squeeze(0)
    print(f"Generated protein embedding of shape: {protein_vec.shape}")
    return protein_vec

# --- Generate the REAL Target Embedding ---
target_seq = load_target_sequence(BLAST_FASTA_PATH, TARGET_UNIPROT_ID)
if target_seq is None:
    raise ValueError(f"Target sequence for {TARGET_UNIPROT_ID} not found. Exiting.")
TARGET_EMBED = get_protein_embedding(target_seq, DEVICE)


# --- 3. Data Pipeline (Molecules -> Graphs) (No changes) ---

def extract_potent_inhibitors(db_path, uniprot_id, potency_cutoff_nM=100):
    """Connects to ChEMBL DB and extracts SMILES for potent inhibitors."""
    try:
        conn = sqlite3.connect(db_path)
        sql_query = f"""
        SELECT DISTINCT cs.canonical_smiles
        FROM activities acts
        JOIN assays a ON acts.assay_id = a.assay_id
        JOIN target_dictionary td ON a.tid = td.tid
        JOIN target_components tc ON td.tid = tc.tid
        JOIN component_sequences cseq ON tc.component_id = cseq.component_id
        JOIN compound_structures cs ON acts.molregno = cs.molregno
        WHERE
            cseq.accession = '{uniprot_id}' AND
            acts.standard_type = 'IC50' AND
            acts.standard_units = 'nM' AND
            acts.standard_relation = '=' AND
            acts.standard_value <= {potency_cutoff_nM}
        """
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        print(f"Found {len(df)} potent inhibitors for UniProt ID {uniprot_id}.")
        return df['canonical_smiles'].unique().tolist()
    except Exception as e:
        print(f"Error during database query: {e}")
        raise

def get_atom_features(atom):
    """Creates the 9-dimensional atom feature vector."""
    return [
        atom.GetAtomicNum(), atom.GetDegree(), atom.GetFormalCharge(),
        int(atom.GetIsAromatic()),
        int(atom.GetHybridization() == Chem.HybridizationType.SP),
        int(atom.GetHybridization() == Chem.HybridizationType.SP2),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D2)
    ]

def smiles_to_graph(smiles, target_embed):
    """Converts SMILES to a PyG Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if not mol: return None
    if mol.GetNumAtoms() > MAX_NODES: return None

    x = torch.tensor([get_atom_features(atom) for atom in mol.GetAtoms()], dtype=torch.float)
    
    edge_indices, edge_attrs = [], []
    bond_types = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]
    
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bond_type_oh = [int(bond.GetBondType() == t) for t in bond_types]
        edge_indices.extend([[i, j], [j, i]])
        edge_attrs.extend([bond_type_oh, bond_type_oh])

    if not edge_indices: return None
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float)

    # Note: We send to DEVICE in the data loader, not here.
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                target_embed=target_embed.unsqueeze(0))
    return data

# --- Data Pipeline Execution ---
inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID)
# Note: We keep data on CPU first, for num_workers to be efficient.
real_data_list = [smiles_to_graph(s, TARGET_EMBED.cpu()) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

if not real_data_list:
    print("FATAL: No valid inhibitor data found. The model cannot be trained.")
    exit()

# --- OPTIMIZATION 2: Added num_workers and pin_memory ---
real_loader = DataLoader(
    real_data_list, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=CPU_WORKERS, 
    pin_memory=True  # Speeds up CPU-to-GPU data transfer
)
print(f"Prepared {len(real_data_list)} real graph samples for training.")


# --- 4. Model Architecture (No changes) ---

# --- 4.1. Relational Graph Transformer Layer ---
class RelationalGraphTransformerLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim, heads=4, dropout=0.1, **kwargs):
        super().__init__(aggr='add', node_dim=0, **kwargs)
        self.out_channels = out_channels
        self.heads = heads
        self.lin_q = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_k = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_v = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) 
        self.att_coeff = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels)) 
        self.lin_out = nn.Linear(heads * out_channels, out_channels)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.att_coeff)

    def forward(self, x, edge_index, edge_attr):
        Q = self.lin_q(x).view(-1, self.heads, self.out_channels)
        K = self.lin_k(x).view(-1, self.heads, self.out_channels)
        V = self.lin_v(x).view(-1, self.heads, self.out_channels)
        E = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels)
        out = self.propagate(edge_index, Q=Q, K=K, V=V, E=E)
        out = out.view(-1, self.heads * self.out_channels)
        out = self.lin_out(out)
        return out

    def message(self, Q_i, K_j, V_j, E_k, index):
        try: is_empty = (E_k.size(0) == 0)
        except AttributeError: is_empty = True
        
        if is_empty: E_bias = torch.zeros(Q_i.size(0), self.heads, 1, device=self.lin_q.weight.device)
        else: E_bias = E_k.mean(dim=-1, keepdim=True) 

        QK_cat = torch.cat([Q_i, K_j], dim=-1)
        e_ij = (QK_cat * self.att_coeff).sum(dim=-1, keepdim=True)
        e_ij = e_ij + E_bias
        e_ij = F.leaky_relu(e_ij)
        alpha = softmax(e_ij, index)
        alpha = self.dropout(alpha)
        return V_j * alpha.view(-1, self.heads, 1)

# --- 4.2. Discriminator ---
class Discriminator(nn.Module):
    def __init__(self, node_features, edge_dim, t_embed_dim, embed_dim, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_d = node_features if i == 0 else embed_dim
            self.layers.append(RelationalGraphTransformerLayer(in_d, embed_dim, edge_dim))
        self.lin_final = nn.Linear(embed_dim + t_embed_dim, 1)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        t_embed = data.target_embed
        
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
        
        graph_embed = global_mean_pool(x, batch)
        if t_embed.dim() > 2: t_embed = t_embed.squeeze(1) 
        final_input = torch.cat([graph_embed, t_embed], dim=1)
        return self.lin_final(final_input).squeeze(1)

# --- 4.3. Generator ---
class Generator(nn.Module):
    def __init__(self, z_dim, t_embed_dim, node_features, bond_features, max_nodes=MAX_NODES):
        super().__init__()
        self.max_nodes = max_nodes
        self.node_features = node_features
        self.bond_features = bond_features
        self.lin_x = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * node_features))
        self.lin_adj = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * max_nodes * bond_features))

    def forward(self, z, t_embed):
        zt = torch.cat([z, t_embed], dim=1)
        x_fake = self.lin_x(zt).view(-1, self.max_nodes, self.node_features)
        adj_fake_logits = self.lin_adj(zt).view(-1, self.max_nodes, self.max_nodes, self.bond_features)
        return x_fake, adj_fake_logits

# --- Model Initialization ---
print("Initializing models...")
generator = Generator(Z_DIM, T_EMBED_DIM, ATOM_FEAT_DIM, BOND_FEAT_DIM).to(DEVICE)
discriminator = Discriminator(ATOM_FEAT_DIM, BOND_FEAT_DIM, T_EMBED_DIM, EMBED_DIM).to(DEVICE)

optimizer_G = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))


# --- 5. Training Utilities (FIXED & VECTORIZED) ---

# --- OPTIMIZATION 3: Create a helper for the dense edge_index template ---
# We compute this once on CPU and move it to GPU,
# so it's not recomputed every batch.
N = MAX_NODES
# Create a [N, N] adjacency matrix with 1s everywhere except the diagonal
adj_template = (torch.ones(N, N) - torch.eye(N)).bool()
# Convert to sparse edge_index format [2, N*(N-1)]
EDGE_INDEX_TEMPLATE = adj_template.nonzero(as_tuple=False).t().contiguous().to(DEVICE)
# This template has N*(N-1) edges
NUM_DENSE_EDGES = EDGE_INDEX_TEMPLATE.size(1)

# --- OPTIMIZATION 4: Vectorized Fake Graph Generation ---
# This one function replaces both old convert_fake_to_data functions
# It's fully vectorized and runs on the GPU. No Python loops!

def convert_fake_to_data_vectorized(x_fake_tensor, adj_fake_logits, t_embed_batch, device, gumbel=False, temperature=0.5):
    """
    Converts Generator output to a single batched PyG Data object *on the GPU*.
    This is the core optimization, replacing the slow Python loops.
    """
    batch_size, num_nodes, _ = x_fake_tensor.shape
    
    # 1. Create Batched Node Features (x)
    # Reshape [B, N, F] -> [B*N, F]
    x_batched = x_fake_tensor.reshape(batch_size * num_nodes, -1)
    
    # 2. Create Batched Batch Index (batch)
    # Create [0, 0, ..., 1, 1, ..., B-1, B-1]
    batch_vec = torch.arange(batch_size, device=device).repeat_interleave(num_nodes)
    
    # 3. Create Batched Edge Index (edge_index)
    # Repeat the [2, N*(N-1)] template B times
    edge_index_batched = EDGE_INDEX_TEMPLATE.repeat(1, batch_size)
    # Create offsets: [0, 0, ..., N, N, ..., 2N, 2N, ...]
    offset = torch.arange(0, batch_size * num_nodes, num_nodes, device=device).repeat_interleave(NUM_DENSE_EDGES)
    # Add offsets to create the full [2, B*N*(N-1)] edge_index
    edge_index_batched = edge_index_batched + offset
    
    # 4. Create Batched Edge Attributes (edge_attr)
    # We need to gather the logits corresponding to our new edge_index
    # Get batch, row, and column indices from the batched edge_index
    batch_indices = edge_index_batched[0] // num_nodes
    row_indices = edge_index_batched[0] % num_nodes
    col_indices = edge_index_batched[1] % num_nodes
    
    # Gather the [B*N*(N-1), Bonds] logits
    adj_logits_flat = adj_fake_logits[batch_indices, row_indices, col_indices]
    
    # 5. Sample edge attributes
    if gumbel:
        # Differentiable Gumbel-Softmax for Generator step
        edge_attr_batched = F.gumbel_softmax(adj_logits_flat, tau=temperature, hard=True)
    else:
        # Non-differentiable argmax for Discriminator step
        bond_indices = torch.argmax(adj_logits_flat, dim=-1)
        edge_attr_batched = F.one_hot(bond_indices, num_classes=BOND_FEAT_DIM).float()
        
    # 6. Create the single, batched Data object
    fake_data = Data(
        x=x_batched,
        edge_index=edge_index_batched,
        edge_attr=edge_attr_batched,
        batch=batch_vec,
        target_embed=t_embed_batch  # Already [B, T_EMBED_DIM]
    )
    return fake_data


# --- 5.3. WGAN-GP Gradient Penalty (No changes) ---
def calculate_gradient_penalty(discriminator, real_data, fake_data, lambda_gp, device):
    """Calculates the Gradient Penalty on interpolated node features (X)."""
    real_x = real_data.x.detach()
    fake_x = fake_data.x.detach()
    real_x_size = real_x.size(0)
    
    if fake_x.size(0) > real_x_size:
        fake_x = fake_x[:real_x_size]
    elif fake_x.size(0) < real_x_size:
        padding = torch.zeros(real_x_size - fake_x.size(0), fake_x.size(1), device=device)
        fake_x = torch.cat([fake_x, padding], dim=0)

    alpha = torch.rand(real_x_size, 1).to(device) 
    interpolated_x = (alpha * real_x) + ((1 - alpha) * fake_x)
    interpolated_x.requires_grad_(True)

    interpolated_data = Data(x=interpolated_x, 
                             edge_index=real_data.edge_index, 
                             edge_attr=real_data.edge_attr, 
                             batch=real_data.batch, 
                             target_embed=real_data.target_embed)

    disc_interpolates = discriminator(interpolated_data)
    
    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolated_x,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True, retain_graph=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

# --- 6. Main Training Loop (FIXED & OPTIMIZED) ---
def run_wgan_gp_training(generator, discriminator, data_loader, epochs, n_critic):
    
    generator.train()
    discriminator.train()
    
    for epoch in range(1, epochs + 1):
        g_loss_sum, d_loss_sum = 0, 0
        
        # Use a for loop that automatically prints progress
        from tqdm import tqdm
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        
        for batch_idx, real_data in enumerate(progress_bar):
            # OPTIMIZATION: Move real_data to GPU here
            real_data = real_data.to(DEVICE)
            batch_size = real_data.num_graphs
            
            # target_embed is already [B, T_EMBED_DIM] from the loader
            target_embed_batch = real_data.target_embed

            # 1. Train Discriminator (n_critic steps)
            for _ in range(n_critic):
                optimizer_D.zero_grad()
                
                # Real Loss
                d_real = discriminator(real_data).mean()
                
                # Fake Loss
                z = torch.randn(batch_size, Z_DIM).to(DEVICE)
                x_fake, adj_fake_logits = generator(z, target_embed_batch)
                
                # --- D-step: Use VECTORIZED function (non-differentiable) ---
                with torch.no_grad(): # Ensure no grads are computed here
                    fake_data = convert_fake_to_data_vectorized(
                        x_fake.detach(), adj_fake_logits.detach(), target_embed_batch, DEVICE, gumbel=False
                    )
                
                d_fake = discriminator(fake_data).mean()
                
                # Gradient Penalty
                gp = calculate_gradient_penalty(discriminator, real_data, fake_data, LAMBDA_GP, DEVICE)
                
                d_loss = - (d_real - d_fake) + gp
                d_loss.backward()
                optimizer_D.step()
                d_loss_sum += d_loss.item()
            
            # 2. Train Generator (1 step)
            optimizer_G.zero_grad()
            
            z = torch.randn(batch_size, Z_DIM).to(DEVICE)
            x_fake, adj_fake_logits = generator(z, target_embed_batch)
            
            # --- G-step: Use VECTORIZED function (DIFFERENTIABLE) ---
            fake_data = convert_fake_to_data_vectorized(
                x_fake, adj_fake_logits, target_embed_batch, DEVICE, gumbel=True
            )
            
            # Generator Loss
            g_loss = - discriminator(fake_data).mean()
            g_loss.backward()
            optimizer_G.step()
            g_loss_sum += g_loss.item()
            
            # Update progress bar
            progress_bar.set_postfix(
                D_Loss=f"{(d_loss_sum / (batch_idx+1) / n_critic):.4f}", 
                G_Loss=f"{(g_loss_sum / (batch_idx+1)):.4f}"
            )
            
        avg_d_loss = d_loss_sum / len(data_loader) / n_critic
        avg_g_loss = g_loss_sum / len(data_loader)
        # TQDM handles the epoch printout
        # print(f"Epoch {epoch}/{EPOCHS} | D Loss: {avg_d_loss:.4f} | G Loss: {avg_g_loss:.4f}")

# --- Execute Training ---
print("\n--- Starting WGAN-GP Training ---")
# Add 'tqdm' to your environment: pip install tqdm
run_wgan_gp_training(generator, discriminator, real_loader, EPOCHS, N_CRITIC) 
print("\nTraining completed.")

‚úÖ CUDA is available! GPU will be used for training.
PyTorch CUDA Version: 12.1
GPU Name: NVIDIA GeForce RTX 4060 Laptop GPU
Using device: cuda
Opening gzipped FASTA file: DL_ENDSEM__DATASET/chembl_35_blast.fa.gz
Loading ProtT5 model... (This may take a moment)


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Generated protein embedding of shape: torch.Size([1024])
Found 3989 potent inhibitors for UniProt ID P00533.




Prepared 1334 real graph samples for training.
Initializing models...

--- Starting WGAN-GP Training ---


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:21<00:00,  1.04s/it, D_Loss=-202.8050, G_Loss=-3.8393]
Epoch 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:19<00:00,  1.06it/s, D_Loss=-761.2867, G_Loss=-47.0787]
Epoch 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:27<00:00,  1.32s/it, D_Loss=-836.8307, G_Loss=-103.6194]
Epoch 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:28<00:00,  1.37s/it, D_Loss=-772.8533, G_Loss=-158.7001]
Epoch 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:28<00:00,  1.37s/it, D_Loss=-700.1403, G_Loss=-210.0796]
Epoch 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:20<00:00,  1.01it/s, D_Loss=-626.3723, G_Loss=-258.0512]
Epoch 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:26<00:00,  1.26s/it, D_Loss=-543.0943, G_Loss=-306.2921]
Epoch 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:28<00:00,  1.37s/it, D_Loss=


Training completed.





In [4]:
import torch
import pandas as pd
import sqlite3
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import softmax
from rdkit import Chem, DataStructs
from rdkit.Chem import Descriptors, QED, AllChem
from Bio import SeqIO
import random
import re
from transformers import T5EncoderModel, T5Tokenizer
import gzip
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings

# --- 1. System & Configuration ---

# --- Configuration (UPDATE THESE PATHS & ID) ---
CHEMPL_DB_PATH = 'DL_ENDSEM__DATASET/chembl_35/chembl_35_sqlite/chembl_35.db'
BLAST_FASTA_PATH = 'DL_ENDSEM__DATASET/chembl_35_blast.fa.gz'
TARGET_UNIPROT_ID = "P00533" # Example: EGFR Kinase

# --- CHANGED: Define Atom Classes ---
# We will now classify atoms instead of regressing 9 features
# These are the atoms our model is allowed to generate.
ATOM_CLASSES = [6, 7, 8, 9, 15, 16, 17, 35, 53] # C, N, O, F, P, S, Cl, Br, I
ATOM_CLASSES_MAP = {num: i for i, num in enumerate(ATOM_CLASSES)} # Helper map

# --- Model Hyperparameters ---
Z_DIM = 100         # Latent noise dimension
ATOM_FEAT_DIM = len(ATOM_CLASSES) # --- CHANGED: Now 9 (for 9 classes) ---
BOND_FEAT_DIM = 4   # Bond feature size (Single, Double, Triple, Aromatic)
EMBED_DIM = 128     # Hidden dimension for the Graph Transformer
T_EMBED_DIM = 1024  # Target embedding dimension (from ProtT5)
LAMBDA_GP = 10.0    # Gradient Penalty weight
MAX_NODES = 30      # Max atoms in generated molecules (for Generator tensor shape)
N_CRITIC = 5        # Discriminator training steps per Generator step
EPOCHS = 100        # Total epochs
BATCH_SIZE = 64     # Your increased batch size

# --- OPTIMIZATION 1: Set num_workers based on your 16-core CPU ---
CPU_WORKERS = 4     

# --- NEW: Constants for Generation ---
BOND_TYPES_RDKIT = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]

# --- CUDA Check ---
if torch.cuda.is_available():
    print("‚úÖ CUDA is available! GPU will be used for training.")
    print(f"PyTorch CUDA Version: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    print("‚ùå CUDA not found. Running on CPU.")
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")


# --- 2. Real Protein Embedding Generation (No changes) ---

def load_target_sequence(fasta_path, uniprot_id):
    """Loads a protein sequence from a gzipped FASTA file."""
    print(f"Opening gzipped FASTA file: {fasta_path}")
    try:
        with gzip.open(fasta_path, "rt") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                if uniprot_id in record.id or uniprot_id in record.description:
                    return str(record.seq)
            print(f"Warning: Could not find sequence for {uniprot_id} in {fasta_path}")
            return None
    except FileNotFoundError:
        print(f"FATAL ERROR: FASTA file not found at {fasta_path}")
        raise
    except Exception as e:
        print(f"FATAL ERROR: Could not read FASTA file. Error: {e}")
        raise

def get_protein_embedding(sequence, device):
    """Generates a protein embedding using the pre-trained ProtT5 model."""
    print("Loading ProtT5 model... (This may take a moment)")
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc').to(device)
    model.eval() 

    sequence_preprocessed = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
    inputs = tokenizer(sequence_preprocessed, return_tensors="pt", max_length=1024, truncation=True).to(device)
    
    with torch.no_grad():
        embedding = model(**inputs).last_hidden_state
    
    protein_vec = embedding.mean(dim=1).squeeze(0)
    print(f"Generated protein embedding of shape: {protein_vec.shape}")
    return protein_vec

# --- Generate the REAL Target Embedding ---
target_seq = load_target_sequence(BLAST_FASTA_PATH, TARGET_UNIPROT_ID)
if target_seq is None:
    raise ValueError(f"Target sequence for {TARGET_UNIPROT_ID} not found. Exiting.")
TARGET_EMBED = get_protein_embedding(target_seq, DEVICE)


# --- 3. Data Pipeline (Molecules -> Graphs) (CHANGES) ---

def extract_potent_inhibitors(db_path, uniprot_id, potency_cutoff_nM=100):
    """Connects to ChEMBL DB and extracts SMILES for potent inhibitors."""
    try:
        conn = sqlite3.connect(db_path)
        sql_query = f"""
        SELECT DISTINCT cs.canonical_smiles
        FROM activities acts
        JOIN assays a ON acts.assay_id = a.assay_id
        JOIN target_dictionary td ON a.tid = td.tid
        JOIN target_components tc ON td.tid = tc.tid
        JOIN component_sequences cseq ON tc.component_id = cseq.component_id
        JOIN compound_structures cs ON acts.molregno = cs.molregno
        WHERE
            cseq.accession = '{uniprot_id}' AND
            acts.standard_type = 'IC50' AND
            acts.standard_units = 'nM' AND
            acts.standard_relation = '=' AND
            acts.standard_value <= {potency_cutoff_nM}
        """
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        print(f"Found {len(df)} potent inhibitors for UniProt ID {uniprot_id}.")
        return df['canonical_smiles'].unique().tolist()
    except Exception as e:
        print(f"Error during database query: {e}")
        raise

# --- CHANGED: get_atom_features ---
def get_atom_features(atom):
    """Creates a one-hot vector for the atom type."""
    atom_num = atom.GetAtomicNum()
    if atom_num not in ATOM_CLASSES_MAP:
        return None # Atom is not in our allowed list
        
    atom_index = ATOM_CLASSES_MAP[atom_num]
    # Create a one-hot vector
    atom_one_hot = torch.zeros(ATOM_FEAT_DIM, dtype=torch.float)
    atom_one_hot[atom_index] = 1.0
    return atom_one_hot

# --- CHANGED: smiles_to_graph ---
def smiles_to_graph(smiles, target_embed):
    """Converts SMILES to a PyG Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if not mol: return None
    if mol.GetNumAtoms() > MAX_NODES: return None

    atom_features_list = []
    for atom in mol.GetAtoms():
        features = get_atom_features(atom)
        if features is None: # Skip molecule if it contains an invalid atom
            return None
        atom_features_list.append(features)

    # --- CHANGED: Stack the one-hot vectors ---
    x = torch.stack(atom_features_list)
    
    edge_indices, edge_attrs = [], []
    bond_types = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]
    
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bond_type_oh = [int(bond.GetBondType() == t) for t in bond_types]
        edge_indices.extend([[i, j], [j, i]])
        edge_attrs.extend([bond_type_oh, bond_type_oh])

    if not edge_indices: return None
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float)

    # Note: We send to DEVICE in the data loader, not here.
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                target_embed=target_embed.unsqueeze(0))
    return data

# --- Data Pipeline Execution ---
inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID)
# Note: We keep data on CPU first, for num_workers to be efficient.
real_data_list = [smiles_to_graph(s, TARGET_EMBED.cpu()) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

if not real_data_list:
    print(f"FATAL: No valid inhibitor data found (or all were filtered out). Check ATOM_CLASSES.")
    exit()

# --- OPTIMIZATION 2: Added num_workers and pin_memory ---
real_loader = DataLoader(
    real_data_list, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=CPU_WORKERS, 
    pin_memory=True  # Speeds up CPU-to-GPU data transfer
)
print(f"Prepared {len(real_data_list)} real graph samples for training.")


# --- 4. Model Architecture (No changes) ---
# The models are already correctly parameterized to accept ATOM_FEAT_DIM.
# The *meaning* of ATOM_FEAT_DIM has changed, but the *shape* is the same.

# --- 4.1. Relational Graph Transformer Layer ---
class RelationalGraphTransformerLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim, heads=4, dropout=0.1, **kwargs):
        super().__init__(aggr='add', node_dim=0, **kwargs)
        self.out_channels = out_channels
        self.heads = heads
        self.lin_q = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_k = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_v = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) 
        self.att_coeff = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels)) 
        self.lin_out = nn.Linear(heads * out_channels, out_channels)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.att_coeff)

    def forward(self, x, edge_index, edge_attr):
        Q = self.lin_q(x).view(-1, self.heads, self.out_channels)
        K = self.lin_k(x).view(-1, self.heads, self.out_channels)
        V = self.lin_v(x).view(-1, self.heads, self.out_channels)
        E = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels)
        out = self.propagate(edge_index, Q=Q, K=K, V=V, E=E)
        out = out.view(-1, self.heads * self.out_channels)
        out = self.lin_out(out)
        return out

    def message(self, Q_i, K_j, V_j, E_k, index):
        try: is_empty = (E_k.size(0) == 0)
        except AttributeError: is_empty = True
        
        if is_empty: E_bias = torch.zeros(Q_i.size(0), self.heads, 1, device=self.lin_q.weight.device)
        else: E_bias = E_k.mean(dim=-1, keepdim=True) 

        QK_cat = torch.cat([Q_i, K_j], dim=-1)
        e_ij = (QK_cat * self.att_coeff).sum(dim=-1, keepdim=True)
        e_ij = e_ij + E_bias
        e_ij = F.leaky_relu(e_ij)
        alpha = softmax(e_ij, index)
        alpha = self.dropout(alpha)
        return V_j * alpha.view(-1, self.heads, 1)

# --- 4.2. Discriminator ---
class Discriminator(nn.Module):
    def __init__(self, node_features, edge_dim, t_embed_dim, embed_dim, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_d = node_features if i == 0 else embed_dim
            self.layers.append(RelationalGraphTransformerLayer(in_d, embed_dim, edge_dim))
        self.lin_final = nn.Linear(embed_dim + t_embed_dim, 1)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        t_embed = data.target_embed
        
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
        
        graph_embed = global_mean_pool(x, batch)
        if t_embed.dim() > 2: t_embed = t_embed.squeeze(1) 
        final_input = torch.cat([graph_embed, t_embed], dim=1)
        return self.lin_final(final_input).squeeze(1)

# --- 4.3. Generator ---
class Generator(nn.Module):
    def __init__(self, z_dim, t_embed_dim, node_features, bond_features, max_nodes=MAX_NODES):
        super().__init__()
        self.max_nodes = max_nodes
        self.node_features = node_features
        self.bond_features = bond_features
        # --- CHANGED: lin_x now outputs logits for atom classes ---
        self.lin_x = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * node_features))
        self.lin_adj = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * max_nodes * bond_features))

    def forward(self, z, t_embed):
        zt = torch.cat([z, t_embed], dim=1)
        # --- CHANGED: x_fake is now x_fake_logits ---
        x_fake_logits = self.lin_x(zt).view(-1, self.max_nodes, self.node_features)
        adj_fake_logits = self.lin_adj(zt).view(-1, self.max_nodes, self.max_nodes, self.bond_features)
        return x_fake_logits, adj_fake_logits

# --- Model Initialization ---
print("Initializing models...")
generator = Generator(Z_DIM, T_EMBED_DIM, ATOM_FEAT_DIM, BOND_FEAT_DIM).to(DEVICE)
discriminator = Discriminator(ATOM_FEAT_DIM, BOND_FEAT_DIM, T_EMBED_DIM, EMBED_DIM).to(DEVICE)

optimizer_G = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))


# --- 5. Training Utilities (CHANGES) ---

# --- OPTIMIZATION 3: Create a helper for the dense edge_index template ---
N = MAX_NODES
adj_template = (torch.ones(N, N) - torch.eye(N)).bool()
EDGE_INDEX_TEMPLATE = adj_template.nonzero(as_tuple=False).t().contiguous().to(DEVICE)
NUM_DENSE_EDGES = EDGE_INDEX_TEMPLATE.size(1)

# --- OPTIMIZATION 4: Vectorized Fake Graph Generation ---
# --- CHANGED: Now applies Gumbel-Softmax to node features (x) as well ---
def convert_fake_to_data_vectorized(x_fake_logits, adj_fake_logits, t_embed_batch, device, gumbel=False, temperature=0.5):
    """
    Converts Generator output (logits) to a single batched PyG Data object.
    Applies Gumbel-Softmax to both nodes and edges.
    """
    batch_size, num_nodes, _ = x_fake_logits.shape
    
    # 1. Create Batched Node Features (x)
    # --- CHANGED: Apply Gumbel-Softmax to node logits ---
    if gumbel:
        x_fake_tensor = F.gumbel_softmax(x_fake_logits, tau=temperature, hard=True)
    else:
        # For discriminator, use discrete argmax
        x_indices = torch.argmax(x_fake_logits, dim=-1)
        x_fake_tensor = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()
    
    # Reshape [B, N, F] -> [B*N, F]
    x_batched = x_fake_tensor.reshape(batch_size * num_nodes, -1)
    
    # 2. Create Batched Batch Index (batch)
    batch_vec = torch.arange(batch_size, device=device).repeat_interleave(num_nodes)
    
    # 3. Create Batched Edge Index (edge_index)
    edge_index_batched = EDGE_INDEX_TEMPLATE.repeat(1, batch_size)
    offset = torch.arange(0, batch_size * num_nodes, num_nodes, device=device).repeat_interleave(NUM_DENSE_EDGES)
    edge_index_batched = edge_index_batched + offset
    
    # 4. Create Batched Edge Attributes (edge_attr)
    batch_indices = edge_index_batched[0] // num_nodes
    row_indices = edge_index_batched[0] % num_nodes
    col_indices = edge_index_batched[1] % num_nodes
    
    adj_logits_flat = adj_fake_logits[batch_indices, row_indices, col_indices]
    
    # 5. Sample edge attributes
    if gumbel:
        # Differentiable Gumbel-Softmax for Generator step
        edge_attr_batched = F.gumbel_softmax(adj_logits_flat, tau=temperature, hard=True)
    else:
        # Non-differentiable argmax for Discriminator step
        bond_indices = torch.argmax(adj_logits_flat, dim=-1)
        edge_attr_batched = F.one_hot(bond_indices, num_classes=BOND_FEAT_DIM).float()
        
    # 6. Create the single, batched Data object
    fake_data = Data(
        x=x_batched,
        edge_index=edge_index_batched,
        edge_attr=edge_attr_batched,
        batch=batch_vec,
        target_embed=t_embed_batch 
    )
    return fake_data


# --- 5.3. WGAN-GP Gradient Penalty (No changes) ---
def calculate_gradient_penalty(discriminator, real_data, fake_data, lambda_gp, device):
    """Calculates the Gradient Penalty on interpolated node features (X)."""
    real_x = real_data.x.detach()
    fake_x = fake_data.x.detach()
    real_x_size = real_x.size(0)
    
    if fake_x.size(0) > real_x_size:
        fake_x = fake_x[:real_x_size]
    elif fake_x.size(0) < real_x_size:
        padding = torch.zeros(real_x_size - fake_x.size(0), fake_x.size(1), device=device)
        fake_x = torch.cat([fake_x, padding], dim=0)

    # --- CHANGED: Alpha must now match the dimensions of x ---
    alpha = torch.rand(real_x_size, 1).to(device) 
    # Ensure alpha broadcasts correctly: [real_x_size, 1]
    
    interpolated_x = (alpha * real_x) + ((1 - alpha) * fake_x)
    interpolated_x.requires_grad_(True)

    interpolated_data = Data(x=interpolated_x, 
                             edge_index=real_data.edge_index, 
                             edge_attr=real_data.edge_attr, 
                             batch=real_data.batch, 
                             target_embed=real_data.target_embed)

    disc_interpolates = discriminator(interpolated_data)
    
    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolated_x,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True, retain_graph=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

# --- 6. Main Training Loop (FIXED & OPTIMIZED) ---
def run_wgan_gp_training(generator, discriminator, data_loader, epochs, n_critic):
    
    generator.train()
    discriminator.train()
    
    for epoch in range(1, epochs + 1):
        g_loss_sum, d_loss_sum = 0, 0
        
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        
        for batch_idx, real_data in enumerate(progress_bar):
            real_data = real_data.to(DEVICE)
            batch_size = real_data.num_graphs
            
            target_embed_batch = real_data.target_embed

            # 1. Train Discriminator (n_critic steps)
            for _ in range(n_critic):
                optimizer_D.zero_grad()
                
                # Real Loss
                d_real = discriminator(real_data).mean()
                
                # Fake Loss
                z = torch.randn(batch_size, Z_DIM).to(DEVICE)
                # --- CHANGED: Generator now outputs logits ---
                x_fake_logits, adj_fake_logits = generator(z, target_embed_batch)
                
                with torch.no_grad():
                    fake_data = convert_fake_to_data_vectorized(
                        x_fake_logits.detach(), adj_fake_logits.detach(), 
                        target_embed_batch, DEVICE, gumbel=False
                    )
                
                d_fake = discriminator(fake_data).mean()
                
                gp = calculate_gradient_penalty(discriminator, real_data, fake_data, LAMBDA_GP, DEVICE)
                
                d_loss = - (d_real - d_fake) + gp
                d_loss.backward()
                optimizer_D.step()
                d_loss_sum += d_loss.item()
            
            # 2. Train Generator (1 step)
            optimizer_G.zero_grad()
            
            z = torch.randn(batch_size, Z_DIM).to(DEVICE)
            # --- CHANGED: Generator now outputs logits ---
            x_fake_logits, adj_fake_logits = generator(z, target_embed_batch)
            
            # --- G-step: Use VECTORIZED function (DIFFERENTIABLE) ---
            fake_data = convert_fake_to_data_vectorized(
                x_fake_logits, adj_fake_logits, 
                target_embed_batch, DEVICE, gumbel=True
            )
            
            g_loss = - discriminator(fake_data).mean()
            g_loss.backward()
            optimizer_G.step()
            g_loss_sum += g_loss.item()
            
            progress_bar.set_postfix(
                D_Loss=f"{(d_loss_sum / (batch_idx+1) / n_critic):.4f}", 
                G_Loss=f"{(g_loss_sum / (batch_idx+1)):.4f}"
            )

# --- 7. --- CHANGED: Generation & SMILES Conversion ---

def tensors_to_smiles(x_fake_one_hot, adj_fake_logits, bond_threshold=0.5):
    """
    Converts raw generator tensor output (one-hot nodes) into SMILES strings.
    """
    # --- CHANGED: x_fake is now one-hot, find the class index ---
    x_fake_indices = torch.argmax(x_fake_one_hot, dim=-1).cpu().detach()
    adj_fake_logits = adj_fake_logits.cpu().detach()
    
    adj_probs = F.softmax(adj_fake_logits, dim=-1)
    adj_bond_probs_max, adj_bond_type_idx = torch.max(adj_probs, dim=-1)
    
    batch_size = x_fake_indices.size(0)
    generated_smiles = []
    generated_mols = []
    
    for i in range(batch_size):
        mol = Chem.RWMol()
        atom_map = {} # Map from tensor index (0..MAX_NODES-1) to RDKit atom index
        
        # 1. Add atoms
        for j in range(MAX_NODES):
            # --- CHANGED: Get atom type from ATOM_CLASSES list ---
            atom_idx = x_fake_indices[i, j].item()
            atom_num = ATOM_CLASSES[atom_idx]
            
            # --- CHANGED: No longer need to check validity, but
            # we can use atom_num=6 (Carbon) as a "padding" atom
            # and only add non-Carbon atoms to avoid tiny fragments.
            # This is a heuristic. A better way is to learn a "stop" token.
            # For now, let's just add all atoms.
            
            atom = Chem.Atom(atom_num)
            rdkit_idx = mol.AddAtom(atom)
            atom_map[j] = rdkit_idx
                
        # 2. Add bonds
        for j in range(MAX_NODES):
            for k in range(j + 1, MAX_NODES):
                prob = adj_bond_probs_max[i, j, k].item()
                
                # Apply threshold to create sparsity
                if prob > bond_threshold:
                    bond_type_idx = adj_bond_type_idx[i, j, k].item()
                    bond_type = BOND_TYPES_RDKIT[bond_type_idx]
                    
                    mol.AddBond(atom_map[j], atom_map[k], bond_type)
        
        # 3. Sanitize and Convert
        try:
            Chem.SanitizeMol(mol)
            smi = Chem.MolToSmiles(mol)
            # --- CHANGED: Filter out disconnected fragments ---
            if '.' in smi:
                generated_smiles.append(None) # Invalid fragment
                generated_mols.append(None)
            else:
                generated_smiles.append(smi)
                generated_mols.append(mol)
        except Exception as e:
            generated_smiles.append(None) # Invalid molecule
            generated_mols.append(None)

    valid_smiles = [s for s in generated_smiles if s is not None]
    valid_mols = [m for m in generated_mols if m is not None]
    
    return valid_smiles, valid_mols, generated_smiles

# --- 8. --- NEW: Performance Metrics & Plotting ---

def calculate_and_plot_metrics(generator, target_embed, real_smiles_list, num_to_generate, device):
    """
    Generates molecules and calculates Validity, Uniqueness, Novelty,
    and plots property distributions.
    """
    print("\n--- Starting Generation & Evaluation ---")
    warnings.filterwarnings('ignore', '.*Implicit valence.*') # Suppress RDKit warnings
    
    generator.eval() # Set generator to evaluation mode
    
    real_mols = [Chem.MolFromSmiles(s) for s in real_smiles_list]
    real_mols = [m for m in real_mols if m is not None]
    real_smiles_set = set(real_smiles_list)
    
    all_valid_smiles = []
    all_valid_mols = []
    num_generated = 0
    total_attempts = 0 # Track total attempts

    print(f"Generating {num_to_generate} *valid* molecules for evaluation...")
    with torch.no_grad():
        # --- CHANGED: Loop until we have enough *valid* molecules ---
        while len(all_valid_smiles) < num_to_generate:
            batch_size = BATCH_SIZE
            total_attempts += batch_size

            z = torch.randn(batch_size, Z_DIM).to(device)
            t_embed_batch = target_embed.unsqueeze(0).repeat(batch_size, 1)
            
            # --- CHANGED: Generator outputs logits ---
            x_fake_logits, adj_fake_logits = generator(z, t_embed_batch)
            
            # --- CHANGED: Use argmax (not Gumbel) for final generation ---
            x_indices = torch.argmax(x_fake_logits, dim=-1)
            x_fake_one_hot = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()
            
            valid_smiles, valid_mols, _ = tensors_to_smiles(x_fake_one_hot, adj_fake_logits)
            
            all_valid_smiles.extend(valid_smiles)
            all_valid_mols.extend(valid_mols)
            
            print(f"Generated: {len(all_valid_smiles)}/{num_to_generate} valid molecules...", end='\r')
            
            if total_attempts > num_to_generate * 50 and not all_valid_smiles:
                 print("\nError: Generated too many molecules with 0 validity. Stopping.")
                 break
            if total_attempts > num_to_generate * 10: # Safety break
                 print(f"\nWarning: Low validity. Stopping generation at {len(all_valid_smiles)} molecules.")
                 break


    print("\nGeneration complete. Calculating metrics...")
    
    # --- 1. Calculate Metrics ---
    
    # Validity
    if total_attempts == 0: total_attempts = 1 # avoid divide by zero
    validity = len(all_valid_smiles) / total_attempts
    
    # Uniqueness
    if len(all_valid_smiles) > 0:
        uniqueness = len(set(all_valid_smiles)) / len(all_valid_smiles)
    else:
        uniqueness = 0.0
        
    # Novelty
    if len(all_valid_smiles) > 0:
        unique_valid_smiles = set(all_valid_smiles)
        novel_smiles = unique_valid_smiles - real_smiles_set
        novelty = len(novel_smiles) / len(unique_valid_smiles)
    else:
        novelty = 0.0

    print("\n--- Generative Performance Metrics ---")
    print(f"Total Attempts: {total_attempts}")
    print(f"Total Valid Generated: {len(all_valid_smiles)}")
    print(f"‚úÖ Validity:     {validity * 100:.2f}%")
    print(f"üß¨ Uniqueness:   {uniqueness * 100:.2f}%")
    print(f"‚≠ê Novelty:      {novelty * 100:.2f}%")
    print("----------------------------------------")
    
    if not all_valid_mols:
        print("No valid molecules generated. Skipping plots.")
        return

    # --- 2. Calculate Properties ---
    props_real = {
        'MolWt': [Descriptors.MolWt(m) for m in real_mols],
        'LogP': [Descriptors.MolLogP(m) for m in real_mols],
        'QED': [QED.qed(m) for m in real_mols]
    }
    
    props_fake = {
        'MolWt': [Descriptors.MolWt(m) for m in all_valid_mols],
        'LogP': [Descriptors.MolLogP(m) for m in all_valid_mols],
        'QED': [QED.qed(m) for m in all_valid_mols]
    }

    # --- 3. Plot Distributions ---
    print("Generating property distribution plots...")
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    plot_titles = ['Molecular Weight (MolWt)', 'LogP', 'Quantitative Esimation of Drug-likeness (QED)']
    prop_keys = ['MolWt', 'LogP', 'QED']
    
    for ax, title, key in zip(axes, plot_titles, prop_keys):
        ax.hist(props_real[key], bins=50, alpha=0.7, label='Real (Training)', color='blue', density=True)
        ax.hist(props_fake[key], bins=50, alpha=0.7, label='Generated (Fake)', color='red', density=True)
        ax.set_title(title)
        ax.set_xlabel("Value")
        ax.set_ylabel("Density")
        ax.legend()
        
    plt.suptitle(f"Property Distributions (Real vs. Generated) for {TARGET_UNIPROT_ID}", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(f"property_plots_{TARGET_UNIPROT_ID}_FIXED.png")
    print(f"Plots saved to property_plots_{TARGET_UNIPROT_ID}_FIXED.png")
    plt.show()
    
    warnings.filterwarnings('default', '.*Implicit valence.*') # Restore warnings


# --- 9. --- Main Execution (Train & Evaluate) ---

# --- Execute Training ---
print("\n--- Starting WGAN-GP Training (FIXED) ---")
run_wgan_gp_training(generator, discriminator, real_loader, EPOCHS, N_CRITIC) 
print("\nTraining completed.")

# --- Execute Evaluation ---
num_to_eval = len(inhibitor_smiles) 
calculate_and_plot_metrics(generator, TARGET_EMBED, inhibitor_smiles, num_to_eval, DEVICE)

‚úÖ CUDA is available! GPU will be used for training.
PyTorch CUDA Version: 12.1
GPU Name: NVIDIA GeForce RTX 4060 Laptop GPU
Using device: cuda
Opening gzipped FASTA file: DL_ENDSEM__DATASET/chembl_35_blast.fa.gz
Loading ProtT5 model... (This may take a moment)
Generated protein embedding of shape: torch.Size([1024])
Found 3989 potent inhibitors for UniProt ID P00533.
Prepared 1332 real graph samples for training.
Initializing models...

--- Starting WGAN-GP Training (FIXED) ---


Epoch 1/100:  52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 11/21 [00:30<00:28,  2.81s/it, D_Loss=7.2260, G_Loss=0.0266]


KeyboardInterrupt: 

In [6]:
import torch
import pandas as pd
import sqlite3
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
# --- IMPORT THE PyG DATALOADER FOR BATCHING ---
from torch_geometric.loader import DataLoader as PyGDataLoader 
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import softmax
from rdkit import Chem, DataStructs
from rdkit.Chem import Descriptors, QED, AllChem
from Bio import SeqIO
import random
import re
from transformers import T5EncoderModel, T5Tokenizer
import gzip
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings

# --- 1. System & Configuration ---

# --- Configuration (UPDATE THESE PATHS & ID) ---
CHEMPL_DB_PATH = 'DL_ENDSEM__DATASET/chembl_35/chembl_35_sqlite/chembl_35.db'
BLAST_FASTA_PATH = 'DL_ENDSEM__DATASET/chembl_35_blast.fa.gz'
TARGET_UNIPROT_ID = "P00533" # Example: EGFR Kinase

# --- ATOM_CLASSES from your (v2) script ---
# This is the list that gave you 9 batches
ATOM_CLASSES = [6, 7, 8, 9, 15, 16, 17, 35, 53] # C, N, O, F, P, S, Cl, Br, I
ATOM_CLASSES_MAP = {num: i for i, num in enumerate(ATOM_CLASSES)}
ATOM_FEAT_DIM = len(ATOM_CLASSES) # Now 9

# Define bond types RDKit knows
BOND_CLASSES_RDKIT = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]
# --- Add a 5th "No Bond" class for the Generator ---
BOND_FEAT_DIM_GENERATOR = len(BOND_CLASSES_RDKIT) + 1 # Now 5
NO_BOND_IDX = len(BOND_CLASSES_RDKIT) # Index 4

# --- Discriminator only sees 4 bond types ---
BOND_FEAT_DIM_DISCRIMINATOR = len(BOND_CLASSES_RDKIT) # Back to 4

# --- Model Hyperparameters ---
Z_DIM = 100         # Latent noise dimension
EMBED_DIM = 128     # Hidden dimension for the Graph Transformer
T_EMBED_DIM = 1024  # Target embedding dimension (from ProtT5)
LAMBDA_GP = 10.0    # Gradient Penalty weight
MAX_NODES = 30      # Max atoms in generated molecules (for Generator tensor shape)
N_CRITIC = 5        # Discriminator training steps per Generator step
EPOCHS = 100        # Run for 100 epochs
BATCH_SIZE = 64     # Your increased batch size
CPU_WORKERS = 4     

# --- CUDA Check ---
if torch.cuda.is_available():
    print("‚úÖ CUDA is available! GPU will be used for training.")
    print(f"PyTorch CUDA Version: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    print("‚ùå CUDA not found. Running on CPU.")
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")


# --- 2. Real Protein Embedding Generation (No changes) ---

def load_target_sequence(fasta_path, uniprot_id):
    # (No changes to this function)
    print(f"Opening gzipped FASTA file: {fasta_path}")
    try:
        with gzip.open(fasta_path, "rt") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                if uniprot_id in record.id or uniprot_id in record.description:
                    return str(record.seq)
            print(f"Warning: Could not find sequence for {uniprot_id} in {fasta_path}")
            return None
    except FileNotFoundError:
        print(f"FATAL ERROR: FASTA file not found at {fasta_path}")
        raise
    except Exception as e:
        print(f"FATAL ERROR: Could not read FASTA file. Error: {e}")
        raise

def get_protein_embedding(sequence, device):
    # (No changes to this function)
    print("Loading ProtT5 model... (This may take a moment)")
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc').to(device)
    model.eval() 
    sequence_preprocessed = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
    inputs = tokenizer(sequence_preprocessed, return_tensors="pt", max_length=1024, truncation=True).to(device)
    with torch.no_grad():
        embedding = model(**inputs).last_hidden_state
    protein_vec = embedding.mean(dim=1).squeeze(0)
    print(f"Generated protein embedding of shape: {protein_vec.shape}")
    return protein_vec

# --- Generate the REAL Target Embedding ---
target_seq = load_target_sequence(BLAST_FASTA_PATH, TARGET_UNIPROT_ID)
if target_seq is None:
    raise ValueError(f"Target sequence for {TARGET_UNIPROT_ID} not found. Exiting.")
TARGET_EMBED = get_protein_embedding(target_seq, DEVICE)


# --- 3. Data Pipeline (Molecules -> Graphs) (v3 Logic) ---

def extract_potent_inhibitors(db_path, uniprot_id, potency_cutoff_nM=100):
    # (No changes to this function)
    try:
        conn = sqlite3.connect(db_path)
        sql_query = f"""
        SELECT DISTINCT cs.canonical_smiles
        FROM activities acts
        JOIN assays a ON acts.assay_id = a.assay_id
        JOIN target_dictionary td ON a.tid = td.tid
        JOIN target_components tc ON td.tid = tc.tid
        JOIN component_sequences cseq ON tc.component_id = cseq.component_id
        JOIN compound_structures cs ON acts.molregno = cs.molregno
        WHERE
            cseq.accession = '{uniprot_id}' AND
            acts.standard_type = 'IC50' AND
            acts.standard_units = 'nM' AND
            acts.standard_relation = '=' AND
            acts.standard_value <= {potency_cutoff_nM}
        """
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        print(f"Found {len(df)} potent inhibitors for UniProt ID {uniprot_id}.")
        return df['canonical_smiles'].unique().tolist()
    except Exception as e:
        print(f"Error during database query: {e}")
        raise

def get_atom_features(atom):
    """Creates a one-hot vector for the atom type."""
    atom_num = atom.GetAtomicNum()
    if atom_num not in ATOM_CLASSES_MAP:
        return None # Atom is not in our allowed list
        
    atom_index = ATOM_CLASSES_MAP[atom_num]
    atom_one_hot = torch.zeros(ATOM_FEAT_DIM, dtype=torch.float)
    atom_one_hot[atom_index] = 1.0
    return atom_one_hot

def smiles_to_graph(smiles, target_embed):
    """Converts SMILES to a PyG Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if not mol: return None
    if mol.GetNumAtoms() > MAX_NODES: return None

    atom_features_list = []
    for atom in mol.GetAtoms():
        features = get_atom_features(atom)
        if features is None: # Skip molecule if it contains an invalid atom
            return None
        atom_features_list.append(features)

    if not atom_features_list:
        return None
    x = torch.stack(atom_features_list)
    
    edge_indices, edge_attrs = [], []
    # --- Use the 4-class bond list ---
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        # Create a 4-dim one-hot vector
        bond_type_oh = [int(bond.GetBondType() == t) for t in BOND_CLASSES_RDKIT]
        
        # --- Ensure bond type is one we recognize ---
        if sum(bond_type_oh) == 1: #i.e., it's S, D, T, or Aromatic
            edge_indices.extend([[i, j], [j, i]])
            edge_attrs.extend([bond_type_oh, bond_type_oh])

    # --- Return None if molecule has no *recognized* bonds ---
    if not edge_indices: 
        return None
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float) # Shape: [num_bonds, 4]

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                target_embed=target_embed.unsqueeze(0))
    return data

# --- Data Pipeline Execution ---
inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID)
real_data_list = [smiles_to_graph(s, TARGET_EMBED.cpu()) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

if not real_data_list:
    print(f"FATAL: No valid inhibitor data found (or all were filtered out). Check ATOM_CLASSES and BOND_CLASSES.")
    exit()

# --- USE THE STANDARD PyG DATALOADER ---
real_loader = PyGDataLoader(
    real_data_list, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=CPU_WORKERS, 
    pin_memory=True 
)
print(f"Prepared {len(real_data_list)} real graph samples for training.")


# --- 4. Model Architecture (v3 Logic) ---

# --- 4.1. Relational Graph Transformer Layer (No changes) ---
class RelationalGraphTransformerLayer(MessagePassing):
    # (No changes to this class)
    def __init__(self, in_channels, out_channels, edge_dim, heads=4, dropout=0.1, **kwargs):
        super().__init__(aggr='add', node_dim=0, **kwargs)
        self.out_channels = out_channels
        self.heads = heads
        self.lin_q = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_k = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_v = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) 
        self.att_coeff = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels)) 
        self.lin_out = nn.Linear(heads * out_channels, out_channels)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.att_coeff)

    def forward(self, x, edge_index, edge_attr):
        Q = self.lin_q(x).view(-1, self.heads, self.out_channels)
        K = self.lin_k(x).view(-1, self.heads, self.out_channels)
        V = self.lin_v(x).view(-1, self.heads, self.out_channels)
        E = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels)
        out = self.propagate(edge_index, Q=Q, K=K, V=V, E=E)
        out = out.view(-1, self.heads * self.out_channels)
        out = self.lin_out(out)
        return out

    def message(self, Q_i, K_j, V_j, E_k, index):
        try: is_empty = (E_k.size(0) == 0)
        except AttributeError: is_empty = True
        if is_empty: E_bias = torch.zeros(Q_i.size(0), self.heads, 1, device=self.lin_q.weight.device)
        else: E_bias = E_k.mean(dim=-1, keepdim=True) 
        QK_cat = torch.cat([Q_i, K_j], dim=-1)
        e_ij = (QK_cat * self.att_coeff).sum(dim=-1, keepdim=True)
        e_ij = e_ij + E_bias
        e_ij = F.leaky_relu(e_ij)
        alpha = softmax(e_ij, index)
        alpha = self.dropout(alpha)
        return V_j * alpha.view(-1, self.heads, 1)

# --- 4.2. Discriminator (v3 Logic) ---
class Discriminator(nn.Module):
    # --- edge_dim is BOND_FEAT_DIM_DISCRIMINATOR (4) ---
    def __init__(self, node_features, edge_dim, t_embed_dim, embed_dim, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_d = node_features if i == 0 else embed_dim
            # --- This layer now expects edge_dim = 4 ---
            self.layers.append(RelationalGraphTransformerLayer(in_d, embed_dim, edge_dim))
        self.lin_final = nn.Linear(embed_dim + t_embed_dim, 1)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        t_embed = data.target_embed
        
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
        
        graph_embed = global_mean_pool(x, batch)
        if t_embed.dim() > 2: t_embed = t_embed.squeeze(1) 
        final_input = torch.cat([graph_embed, t_embed], dim=1)
        return self.lin_final(final_input).squeeze(1)

# --- 4.3. Generator (v3 Logic) ---
class Generator(nn.Module):
    # --- bond_features is BOND_FEAT_DIM_GENERATOR (5) ---
    def __init__(self, z_dim, t_embed_dim, node_features, bond_features, max_nodes=MAX_NODES):
        super().__init__()
        self.max_nodes = max_nodes
        self.node_features = node_features
        self.bond_features = bond_features
        self.lin_x = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * node_features))
        # --- This layer now outputs 5 features per bond ---
        self.lin_adj = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * max_nodes * bond_features))

    def forward(self, z, t_embed):
        zt = torch.cat([z, t_embed], dim=1)
        x_fake_logits = self.lin_x(zt).view(-1, self.max_nodes, self.node_features)
        adj_fake_logits = self.lin_adj(zt).view(-1, self.max_nodes, self.max_nodes, self.bond_features)
        return x_fake_logits, adj_fake_logits

# --- Model Initialization ---
print("Initializing models...")
# --- Pass the correct dimensions ---
generator = Generator(Z_DIM, T_EMBED_DIM, ATOM_FEAT_DIM, BOND_FEAT_DIM_GENERATOR).to(DEVICE)
discriminator = Discriminator(ATOM_FEAT_DIM, BOND_FEAT_DIM_DISCRIMINATOR, T_EMBED_DIM, EMBED_DIM).to(DEVICE)

# --- !!!!!!!!!!!!!!!!!!!!!!!!!!! ---
# --- 1st CHANGE: Lowered Learning Rate ---
# --- !!!!!!!!!!!!!!!!!!!!!!!!!!! ---
optimizer_G = optim.Adam(generator.parameters(), lr=1e-5, betas=(0.5, 0.9)) # Was 1e-4
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-5, betas=(0.5, 0.9)) # Was 1e-4


# --- 5. Training Utilities (v3 Logic) ---

# --- 5.1. Sparse Graph Conversion ---
def convert_fake_to_SPARSE_data_vectorized(x_fake_logits, adj_fake_logits, t_embed_batch, device, gumbel=False, temperature=0.5):
    """
    Converts Generator output (logits) to a BATCH of sparse PyG Data objects.
    """
    batch_size = x_fake_logits.size(0)
    data_list = []

    # 1. Sample nodes (still vectorized)
    if gumbel:
        x_fake_tensor = F.gumbel_softmax(x_fake_logits, tau=temperature, hard=True)
    else:
        x_indices = torch.argmax(x_fake_logits, dim=-1)
        x_fake_tensor = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()

    # 2. Sample bonds (still vectorized)
    if gumbel:
        adj_fake_tensor = F.gumbel_softmax(adj_fake_logits, tau=temperature, hard=True)
    else:
        adj_indices = torch.argmax(adj_fake_logits, dim=-1)
        adj_fake_tensor = F.one_hot(adj_indices, num_classes=BOND_FEAT_DIM_GENERATOR).float()
    
    # 3. Loop over batch to build sparse graphs
    for i in range(batch_size):
        x = x_fake_tensor[i] # Shape [N, ATOM_FEAT_DIM]
        adj_full = adj_fake_tensor[i] # Shape [N, N, BOND_FEAT_DIM_GENERATOR]
        
        edge_indices = []
        edge_attrs = []

        # Iterate over upper triangle
        for j in range(MAX_NODES):
            for k in range(j + 1, MAX_NODES):
                bond_logits = adj_full[j, k] # Shape [5]
                bond_type_idx = torch.argmax(bond_logits).item()
                
                # --- Check if it's NOT a "No Bond" ---
                if bond_type_idx != NO_BOND_IDX:
                    # Add this edge
                    edge_indices.extend([[j, k], [k, j]])
                    
                    # --- Create 4-dim one-hot vector ---
                    bond_attr = torch.zeros(BOND_FEAT_DIM_DISCRIMINATOR, device=device)
                    # Handle case where index might be out of bounds if something is wrong
                    if 0 <= bond_type_idx < BOND_FEAT_DIM_DISCRIMINATOR:
                         bond_attr[bond_type_idx] = 1.0
                    edge_attrs.extend([bond_attr, bond_attr])

        if not edge_indices:
            # No bonds were formed, create a dummy to avoid errors
            edge_index = torch.empty((2, 0), dtype=torch.long, device=device)
            edge_attr = torch.empty((0, BOND_FEAT_DIM_DISCRIMINATOR), dtype=torch.float, device=device)
        else:
            edge_index = torch.tensor(edge_indices, dtype=torch.long, device=device).t().contiguous()
            edge_attr = torch.stack(edge_attrs)

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            target_embed=t_embed_batch[i].unsqueeze(0) # Embed for this single graph
        )
        data_list.append(data)

    # 4. Re-batch the sparse graphs
    
    # Use a loader to properly collate the list of Data objects into a Batch object
    temp_loader = PyGDataLoader(data_list, batch_size=batch_size, shuffle=False)
    batch = next(iter(temp_loader))
    
    # Squeeze the target_embed back to [B, T_EMBED_DIM]
    batch.target_embed = batch.target_embed.squeeze(1)
    
    return batch.to(device)


# --- 5.2. WGAN-GP Gradient Penalty (v3 Logic) ---
def calculate_gradient_penalty(discriminator, real_data, fake_data, lambda_gp, device):
    """
    Interpolates on the GLOBAL graph embeddings.
    """
    
    # --- Get graph embeddings first ---
    discriminator.eval() # Freeze discriminator for this part
    
    real_x, real_edge_index, real_edge_attr, real_batch = real_data.x, real_data.edge_index, real_data.edge_attr, real_data.batch
    real_t_embed = real_data.target_embed
    for layer in discriminator.layers:
        real_x = layer(real_x, real_edge_index, real_edge_attr)
        real_x = F.relu(real_x)
    real_graph_embed = global_mean_pool(real_x, real_batch)
    
    fake_x, fake_edge_index, fake_edge_attr, fake_batch = fake_data.x, fake_data.edge_index, fake_data.edge_attr, fake_data.batch
    # --- Handle case where fake batch might be empty ---
    if fake_data.num_graphs == 0:
        discriminator.train()
        # Return a 0 penalty if there's nothing to compare
        return torch.tensor(0.0, device=device, requires_grad=True)
        
    fake_t_embed = fake_data.target_embed
    for layer in discriminator.layers:
        fake_x = layer(fake_x, fake_edge_index, fake_edge_attr)
        fake_x = F.relu(fake_x)
    fake_graph_embed = global_mean_pool(fake_x, fake_batch)
    
    discriminator.train() # Unfreeze
    
    # Match batch sizes if they differ
    batch_size = min(real_graph_embed.size(0), fake_graph_embed.size(0))
    if batch_size == 0:
        return torch.tensor(0.0, device=device, requires_grad=True)
        
    real_graph_embed = real_graph_embed[:batch_size]
    fake_graph_embed = fake_graph_embed[:batch_size]
    real_t_embed = real_t_embed[:batch_size]

    # --- Interpolate on graph_embed ---
    alpha = torch.rand(batch_size, 1).to(device)
    interpolated_embed = (alpha * real_graph_embed) + ((1 - alpha) * fake_graph_embed)
    interpolated_embed.requires_grad_(True)
    
    # --- Combine with target and pass to *final layer only* ---
    final_input = torch.cat([interpolated_embed, real_t_embed], dim=1)
    disc_interpolates = discriminator.lin_final(final_input).squeeze(1)

    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolated_embed, # Grad w.r.t. interpolated_embed
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True, retain_graph=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

# --- 6. Main Training Loop (v3 Logic + STABILIZATION) ---
def run_wgan_gp_training(generator, discriminator, data_loader, epochs, n_critic):
    
    generator.train()
    discriminator.train()
    
    for epoch in range(1, epochs + 1):
        g_loss_sum, d_loss_sum = 0, 0
        
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        
        for batch_idx, real_data in enumerate(progress_bar):
            real_data = real_data.to(DEVICE)
            
            # --- Handle small final batch ---
            if real_data.num_graphs < 2: # Need at least 2 for GP
                print("Warning: Skipping batch with < 2 graphs.")
                continue
                
            batch_size = real_data.num_graphs
            target_embed_batch = real_data.target_embed

            # 1. Train Discriminator (n_critic steps)
            for _ in range(n_critic):
                optimizer_D.zero_grad()
                
                # Real Loss
                d_real = discriminator(real_data).mean()
                
                # Fake Loss
                z = torch.randn(batch_size, Z_DIM).to(DEVICE)
                x_fake_logits, adj_fake_logits = generator(z, target_embed_batch)
                
                # --- Use new SPARSIFYING function ---
                with torch.no_grad():
                    fake_data = convert_fake_to_SPARSE_data_vectorized(
                        x_fake_logits.detach(), adj_fake_logits.detach(), 
                        target_embed_batch, DEVICE, gumbel=False
                    )
                
                if fake_data.num_graphs == 0:
                    print("Warning: Fake data batch was empty. Skipping D-step.")
                    continue
                
                d_fake = discriminator(fake_data).mean()
                
                gp = calculate_gradient_penalty(discriminator, real_data, fake_data, LAMBDA_GP, DEVICE)
                
                d_loss = - (d_real - d_fake) + gp
                d_loss.backward()
                
                # --- !!!!!!!!!!!!!!!!!!!!!!!!!!! ---
                # --- 2nd CHANGE: Add Gradient Clipping ---
                # --- !!!!!!!!!!!!!!!!!!!!!!!!!!! ---
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
                
                optimizer_D.step()
                d_loss_sum += d_loss.item()
            
            # 2. Train Generator (1 step)
            optimizer_G.zero_grad()
            
            z = torch.randn(batch_size, Z_DIM).to(DEVICE)
            x_fake_logits, adj_fake_logits = generator(z, target_embed_batch)
            
            fake_data = convert_fake_to_SPARSE_data_vectorized(
                x_fake_logits, adj_fake_logits, 
                target_embed_batch, DEVICE, gumbel=True
            )
            
            if fake_data.num_graphs == 0:
                print("Warning: Fake data batch was empty. Skipping G-step.")
                continue

            g_loss = - discriminator(fake_data).mean()
            g_loss.backward()
            optimizer_G.step()
            g_loss_sum += g_loss.item()
            
            progress_bar.set_postfix(
                D_Loss=f"{(d_loss_sum / (batch_idx+1) / n_critic):.4f}", 
                G_Loss=f"{(g_loss_sum / (batch_idx+1)):.4f}"
            )

# --- 7. Generation & SMILES Conversion (v3 Logic) ---

def tensors_to_smiles(x_fake_one_hot, adj_fake_logits):
    """
    Converts raw generator tensor output (one-hot nodes) into SMILES strings.
    """
    # --- Get atomic number indices from one-hot nodes ---
    x_fake_indices = torch.argmax(x_fake_one_hot, dim=-1).cpu().detach()
    adj_fake_logits = adj_fake_logits.cpu().detach()
    
    # --- Get bond indices from 5-class logits ---
    adj_bond_type_idx = torch.argmax(adj_fake_logits, dim=-1)
    
    batch_size = x_fake_indices.size(0)
    generated_smiles = []
    generated_mols = []
    
    for i in range(batch_size):
        mol = Chem.RWMol()
        atom_map = {} # Map from tensor index (0..MAX_NODES-1) to RDKit atom index
        
        # 1. Add atoms
        for j in range(MAX_NODES):
            atom_idx = x_fake_indices[i, j].item()
            atom_num = ATOM_CLASSES[atom_idx]
            
            atom = Chem.Atom(atom_num)
            rdkit_idx = mol.AddAtom(atom)
            atom_map[j] = rdkit_idx
                
        # 2. Add bonds
        for j in range(MAX_NODES):
            for k in range(j + 1, MAX_NODES):
                # --- Get bond type from 5-class indices ---
                bond_type_idx = adj_bond_type_idx[i, j, k].item()
                
                # --- Add bond IF NOT "No Bond" ---
                if bond_type_idx != NO_BOND_IDX:
                    # --- Check if bond_type_idx is valid for RDKit bonds ---
                    if 0 <= bond_type_idx < len(BOND_CLASSES_RDKIT):
                        bond_type = BOND_CLASSES_RDKIT[bond_type_idx]
                        mol.AddBond(atom_map[j], atom_map[k], bond_type)
        
        # 3. Sanitize and Convert
        try:
            Chem.SanitizeMol(mol)
            smi = Chem.MolToSmiles(mol)
            
            # --- Filter out disconnected fragments ---
            if '.' in smi:
                generated_smiles.append(None) # Invalid fragment
                generated_mols.append(None)
            else:
                generated_smiles.append(smi)
                generated_mols.append(mol)
        except Exception as e:
            # print(f"RDKit Error: {e}") # Uncomment for debugging
            generated_smiles.append(None) # Invalid molecule
            generated_mols.append(None)

    valid_smiles = [s for s in generated_smiles if s is not None]
    valid_mols = [m for m in generated_mols if m is not None]
    
    return valid_smiles, valid_mols, generated_smiles

# --- 8. Performance Metrics & Plotting (v3 Logic) ---

def calculate_and_plot_metrics(generator, target_embed, real_smiles_list, num_to_generate, device):
    """
    Generates molecules and calculates Validity, Uniqueness, Novelty,
    and plots property distributions.
    """
    print("\n--- Starting Generation & Evaluation ---")
    warnings.filterwarnings('ignore', '.*Implicit valence.*') # Suppress RDKit warnings
    
    generator.eval() # Set generator to evaluation mode
    
    real_mols = [Chem.MolFromSmiles(s) for s in real_smiles_list]
    real_mols = [m for m in real_mols if m is not None]
    real_smiles_set = set(real_smiles_list)
    
    all_valid_smiles = []
    all_valid_mols = []
    total_attempts = 0 # Track total attempts

    print(f"Generating {num_to_generate} *valid* molecules for evaluation...")
    with torch.no_grad():
        while len(all_valid_smiles) < num_to_generate:
            batch_size = BATCH_SIZE
            total_attempts += batch_size

            z = torch.randn(batch_size, Z_DIM).to(device)
            t_embed_batch = target_embed.unsqueeze(0).repeat(batch_size, 1)
            
            x_fake_logits, adj_fake_logits = generator(z, t_embed_batch)
            
            # --- Use argmax (not Gumbel) for final generation ---
            x_indices = torch.argmax(x_fake_logits, dim=-1)
            x_fake_one_hot = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()
            
            # --- Pass 5-dim bond logits to smiles converter ---
            valid_smiles, valid_mols, _ = tensors_to_smiles(x_fake_one_hot, adj_fake_logits)
            
            all_valid_smiles.extend(valid_smiles)
            all_valid_mols.extend(valid_mols)
            
            print(f"Generated: {len(all_valid_smiles)}/{num_to_generate} valid molecules...", end='\r')
            
            if total_attempts > num_to_generate * 50 and not all_valid_smiles:
                 print("\nError: Generated too many molecules with 0 validity. Stopping.")
                 break
            if total_attempts > num_to_generate * 10 and len(all_valid_smiles) < num_to_generate: 
                 print(f"\nWarning: Low validity. Stopping generation at {len(all_valid_smiles)} molecules.")
                 break

    print("\nGeneration complete. Calculating metrics...")
    
    # --- 1. Calculate Metrics ---
    
    if total_attempts == 0: total_attempts = 1
    validity = len(all_valid_smiles) / total_attempts
    
    if len(all_valid_smiles) > 0:
        uniqueness = len(set(all_valid_smiles)) / len(all_valid_smiles)
    else:
        uniqueness = 0.0
        
    if len(all_valid_smiles) > 0:
        unique_valid_smiles = set(all_valid_smiles)
        novel_smiles = unique_valid_smiles - real_smiles_set
        novelty = len(novel_smiles) / len(unique_valid_smiles)
    else:
        novelty = 0.0

    print("\n--- Generative Performance Metrics ---")
    print(f"Total Attempts: {total_attempts}")
    print(f"Total Valid Generated: {len(all_valid_smiles)}")
    print(f"‚úÖ Validity:     {validity * 100:.2f}%")
    print(f"üß¨ Uniqueness:   {uniqueness * 100:.2f}%")
    print(f"‚≠ê Novelty:      {novelty * 100:.2f}%")
    print("----------------------------------------")
    
    if not all_valid_mols:
        print("No valid molecules generated. Skipping plots.")
        return

    # --- 2. Calculate Properties ---
    props_real = {
        'MolWt': [Descriptors.MolWt(m) for m in real_mols],
        'LogP': [Descriptors.MolLogP(m) for m in real_mols],
        'QED': [QED.qed(m) for m in real_mols]
    }
    
    props_fake = {
        'MolWt': [Descriptors.MolWt(m) for m in all_valid_mols],
        'LogP': [Descriptors.MolLogP(m) for m in all_valid_mols],
        'QED': [QED.qed(m) for m in all_valid_mols]
    }

    # --- 3. Plot Distributions ---
    print("Generating property distribution plots...")
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    plot_titles = ['Molecular Weight (MolWt)', 'LogP', 'Quantitative Esimation of Drug-likeness (QED)']
    prop_keys = ['MolWt', 'LogP', 'QED']
    
    for ax, title, key in zip(axes, plot_titles, prop_keys):
        ax.hist(props_real[key], bins=50, alpha=0.7, label='Real (Training)', color='blue', density=True)
        ax.hist(props_fake[key], bins=50, alpha=0.7, label='Generated (Fake)', color='red', density=True)
        ax.set_title(title)
        ax.set_xlabel("Value")
        ax.set_ylabel("Density")
        ax.legend()
        
    plt.suptitle(f"Property Distributions (Real vs. Generated) for {TARGET_UNIPROT_ID}", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(f"property_plots_{TARGET_UNIPROT_ID}_v3.1.png")
    print(f"Plots saved to property_plots_{TARGET_UNIPROT_ID}_v3.1.png")
    plt.show()
    
    warnings.filterwarnings('default', '.*Implicit valence.*') # Restore warnings


# --- 9. --- Main Execution (Train & Evaluate) ---

# --- Execute Training ---
print("\n--- Starting WGAN-GP Training (v3.1 - Sparse & Stabilized) ---")
run_wgan_gp_training(generator, discriminator, real_loader, EPOCHS, N_CRITIC) 
print("\nTraining completed.")

# --- Execute Evaluation ---
num_to_eval = len(real_data_list) # --- Use count of *filtered* real data ---
calculate_and_plot_metrics(generator, TARGET_EMBED, inhibitor_smiles, num_to_eval, DEVICE)

‚úÖ CUDA is available! GPU will be used for training.
PyTorch CUDA Version: 12.1
GPU Name: NVIDIA GeForce RTX 4060 Laptop GPU
Using device: cuda
Opening gzipped FASTA file: DL_ENDSEM__DATASET/chembl_35_blast.fa.gz
Loading ProtT5 model... (This may take a moment)
Generated protein embedding of shape: torch.Size([1024])
Found 3989 potent inhibitors for UniProt ID P00533.
Prepared 1332 real graph samples for training.
Initializing models...

--- Starting WGAN-GP Training (v3.1 - Sparse & Stabilized) ---


Epoch 1/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [06:41<00:00, 19.10s/it, D_Loss=6.1672, G_Loss=-0.0204]
Epoch 2/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [06:49<00:00, 19.52s/it, D_Loss=5.9477, G_Loss=-0.0136]
Epoch 3/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [06:29<00:00, 18.55s/it, D_Loss=5.6340, G_Loss=0.0049] 
Epoch 4/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [07:01<00:00, 20.09s/it, D_Loss=5.1151, G_Loss=0.0406]
Epoch 5/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [07:11<00:00, 20.53s/it, D_Loss=4.1920, G_Loss=0.0994]
Epoch 6/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [07:33<00:00, 21.61s/it, D_Loss=2.5161, G_Loss=0.1862]
Epoch 7/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [07:18<00:00, 20.88s/it, D_Loss=-0.4970, G_Loss=0.3029]
Epoch 8/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [07:12<00:00, 20.60s/it, D_Loss=-5.6417, G_Loss=0.4704]
Epoch 9/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [07:23<00:00, 21.12s/it, D_Loss=-13.8790, G_Loss=0.

KeyboardInterrupt: 

In [1]:
import torch
import pandas as pd
import sqlite3
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
# --- IMPORT THE PyG DATALOADER FOR BATCHING ---
from torch_geometric.loader import DataLoader as PyGDataLoader 
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import softmax
from rdkit import Chem, DataStructs
from rdkit.Chem import Descriptors, QED, AllChem
from Bio import SeqIO
import random
import re
from transformers import T5EncoderModel, T5Tokenizer
import gzip
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings

# --- 1. System & Configuration ---

# --- Configuration (UPDATE THESE PATHS & ID) ---
CHEMPL_DB_PATH = 'DL_ENDSEM__DATASET/chembl_35/chembl_35_sqlite/chembl_35.db'
BLAST_FASTA_PATH = 'DL_ENDSEM__DATASET/chembl_35_blast.fa.gz'
TARGET_UNIPROT_ID = "P00533" # Example: EGFR Kinase

# --- ATOM_CLASSES from your (v2) script ---
# This is the list that gave you 9 batches
ATOM_CLASSES = [6, 7, 8, 9, 15, 16, 17, 35, 53] # C, N, O, F, P, S, Cl, Br, I
ATOM_CLASSES_MAP = {num: i for i, num in enumerate(ATOM_CLASSES)}
ATOM_FEAT_DIM = len(ATOM_CLASSES) # Now 9

# Define bond types RDKit knows
BOND_CLASSES_RDKIT = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]
# --- Add a 5th "No Bond" class for the Generator ---
BOND_FEAT_DIM_GENERATOR = len(BOND_CLASSES_RDKIT) + 1 # Now 5
NO_BOND_IDX = len(BOND_CLASSES_RDKIT) # Index 4

# --- Discriminator only sees 4 bond types ---
BOND_FEAT_DIM_DISCRIMINATOR = len(BOND_CLASSES_RDKIT) # Back to 4

# --- Model Hyperparameters ---
Z_DIM = 100         # Latent noise dimension
EMBED_DIM = 128     # Hidden dimension for the Graph Transformer
T_EMBED_DIM = 1024  # Target embedding dimension (from ProtT5)
LAMBDA_GP = 10.0    # Gradient Penalty weight
MAX_NODES = 30      # Max atoms in generated molecules (for Generator tensor shape)
N_CRITIC = 3        # Discriminator training steps per Generator step
EPOCHS = 100        # Run for 100 epochs
BATCH_SIZE = 64     # Your increased batch size
CPU_WORKERS = 4     

# --- CUDA Check ---
if torch.cuda.is_available():
    print("‚úÖ CUDA is available! GPU will be used for training.")
    print(f"PyTorch CUDA Version: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    print("‚ùå CUDA not found. Running on CPU.")
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")


# --- 2. Real Protein Embedding Generation (No changes) ---

def load_target_sequence(fasta_path, uniprot_id):
    # (No changes to this function)
    print(f"Opening gzipped FASTA file: {fasta_path}")
    try:
        with gzip.open(fasta_path, "rt") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                if uniprot_id in record.id or uniprot_id in record.description:
                    return str(record.seq)
            print(f"Warning: Could not find sequence for {uniprot_id} in {fasta_path}")
            return None
    except FileNotFoundError:
        print(f"FATAL ERROR: FASTA file not found at {fasta_path}")
        raise
    except Exception as e:
        print(f"FATAL ERROR: Could not read FASTA file. Error: {e}")
        raise

def get_protein_embedding(sequence, device):
    # (No changes to this function)
    print("Loading ProtT5 model... (This may take a moment)")
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc').to(device)
    model.eval() 
    sequence_preprocessed = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
    inputs = tokenizer(sequence_preprocessed, return_tensors="pt", max_length=1024, truncation=True).to(device)
    with torch.no_grad():
        embedding = model(**inputs).last_hidden_state
    protein_vec = embedding.mean(dim=1).squeeze(0)
    print(f"Generated protein embedding of shape: {protein_vec.shape}")
    return protein_vec

# --- Generate the REAL Target Embedding ---
target_seq = load_target_sequence(BLAST_FASTA_PATH, TARGET_UNIPROT_ID)
if target_seq is None:
    raise ValueError(f"Target sequence for {TARGET_UNIPROT_ID} not found. Exiting.")
TARGET_EMBED = get_protein_embedding(target_seq, DEVICE)


# --- 3. Data Pipeline (Molecules -> Graphs) (v3 Logic) ---

def extract_potent_inhibitors(db_path, uniprot_id, potency_cutoff_nM=100):
    # (No changes to this function)
    try:
        conn = sqlite3.connect(db_path)
        sql_query = f"""
        SELECT DISTINCT cs.canonical_smiles
        FROM activities acts
        JOIN assays a ON acts.assay_id = a.assay_id
        JOIN target_dictionary td ON a.tid = td.tid
        JOIN target_components tc ON td.tid = tc.tid
        JOIN component_sequences cseq ON tc.component_id = cseq.component_id
        JOIN compound_structures cs ON acts.molregno = cs.molregno
        WHERE
            cseq.accession = '{uniprot_id}' AND
            acts.standard_type = 'IC50' AND
            acts.standard_units = 'nM' AND
            acts.standard_relation = '=' AND
            acts.standard_value <= {potency_cutoff_nM}
        """
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        print(f"Found {len(df)} potent inhibitors for UniProt ID {uniprot_id}.")
        return df['canonical_smiles'].unique().tolist()
    except Exception as e:
        print(f"Error during database query: {e}")
        raise

def get_atom_features(atom):
    """Creates a one-hot vector for the atom type."""
    atom_num = atom.GetAtomicNum()
    if atom_num not in ATOM_CLASSES_MAP:
        return None # Atom is not in our allowed list
        
    atom_index = ATOM_CLASSES_MAP[atom_num]
    atom_one_hot = torch.zeros(ATOM_FEAT_DIM, dtype=torch.float)
    atom_one_hot[atom_index] = 1.0
    return atom_one_hot

def smiles_to_graph(smiles, target_embed):
    """Converts SMILES to a PyG Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if not mol: return None
    if mol.GetNumAtoms() > MAX_NODES: return None

    atom_features_list = []
    for atom in mol.GetAtoms():
        features = get_atom_features(atom)
        if features is None: # Skip molecule if it contains an invalid atom
            return None
        atom_features_list.append(features)

    if not atom_features_list:
        return None
    x = torch.stack(atom_features_list)
    
    edge_indices, edge_attrs = [], []
    # --- Use the 4-class bond list ---
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        # Create a 4-dim one-hot vector
        bond_type_oh = [int(bond.GetBondType() == t) for t in BOND_CLASSES_RDKIT]
        
        # --- Ensure bond type is one we recognize ---
        if sum(bond_type_oh) == 1: #i.e., it's S, D, T, or Aromatic
            edge_indices.extend([[i, j], [j, i]])
            edge_attrs.extend([bond_type_oh, bond_type_oh])

    # --- Return None if molecule has no *recognized* bonds ---
    if not edge_indices: 
        return None
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float) # Shape: [num_bonds, 4]

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                target_embed=target_embed.unsqueeze(0))
    return data

# --- Data Pipeline Execution ---
inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID)
real_data_list = [smiles_to_graph(s, TARGET_EMBED.cpu()) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

if not real_data_list:
    print(f"FATAL: No valid inhibitor data found (or all were filtered out). Check ATOM_CLASSES and BOND_CLASSES.")
    exit()

# --- USE THE STANDARD PyG DATALOADER ---
real_loader = PyGDataLoader(
    real_data_list, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=CPU_WORKERS, 
    pin_memory=True 
)
print(f"Prepared {len(real_data_list)} real graph samples for training.")


# --- 4. Model Architecture (v3 Logic) ---

# --- 4.1. Relational Graph Transformer Layer (No changes) ---
class RelationalGraphTransformerLayer(MessagePassing):
    # (No changes to this class)
    def __init__(self, in_channels, out_channels, edge_dim, heads=4, dropout=0.1, **kwargs):
        super().__init__(aggr='add', node_dim=0, **kwargs)
        self.out_channels = out_channels
        self.heads = heads
        self.lin_q = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_k = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_v = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) 
        self.att_coeff = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels)) 
        self.lin_out = nn.Linear(heads * out_channels, out_channels)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.att_coeff)

    def forward(self, x, edge_index, edge_attr):
        Q = self.lin_q(x).view(-1, self.heads, self.out_channels)
        K = self.lin_k(x).view(-1, self.heads, self.out_channels)
        V = self.lin_v(x).view(-1, self.heads, self.out_channels)
        E = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels)
        out = self.propagate(edge_index, Q=Q, K=K, V=V, E=E)
        out = out.view(-1, self.heads * self.out_channels)
        out = self.lin_out(out)
        return out

    def message(self, Q_i, K_j, V_j, E_k, index):
        try: is_empty = (E_k.size(0) == 0)
        except AttributeError: is_empty = True
        if is_empty: E_bias = torch.zeros(Q_i.size(0), self.heads, 1, device=self.lin_q.weight.device)
        else: E_bias = E_k.mean(dim=-1, keepdim=True) 
        QK_cat = torch.cat([Q_i, K_j], dim=-1)
        e_ij = (QK_cat * self.att_coeff).sum(dim=-1, keepdim=True)
        e_ij = e_ij + E_bias
        e_ij = F.leaky_relu(e_ij)
        alpha = softmax(e_ij, index)
        alpha = self.dropout(alpha)
        return V_j * alpha.view(-1, self.heads, 1)

# --- 4.2. Discriminator (v3 Logic) ---
class Discriminator(nn.Module):
    # --- edge_dim is BOND_FEAT_DIM_DISCRIMINATOR (4) ---
    def __init__(self, node_features, edge_dim, t_embed_dim, embed_dim, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_d = node_features if i == 0 else embed_dim
            # --- This layer now expects edge_dim = 4 ---
            self.layers.append(RelationalGraphTransformerLayer(in_d, embed_dim, edge_dim))
        self.lin_final = nn.Linear(embed_dim + t_embed_dim, 1)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        t_embed = data.target_embed
        
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
        
        graph_embed = global_mean_pool(x, batch)
        if t_embed.dim() > 2: t_embed = t_embed.squeeze(1) 
        final_input = torch.cat([graph_embed, t_embed], dim=1)
        return self.lin_final(final_input).squeeze(1)

# --- 4.3. Generator (v3 Logic) ---
class Generator(nn.Module):
    # --- bond_features is BOND_FEAT_DIM_GENERATOR (5) ---
    def __init__(self, z_dim, t_embed_dim, node_features, bond_features, max_nodes=MAX_NODES):
        super().__init__()
        self.max_nodes = max_nodes
        self.node_features = node_features
        self.bond_features = bond_features
        self.lin_x = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * node_features))
        # --- This layer now outputs 5 features per bond ---
        self.lin_adj = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * max_nodes * bond_features))

    def forward(self, z, t_embed):
        zt = torch.cat([z, t_embed], dim=1)
        x_fake_logits = self.lin_x(zt).view(-1, self.max_nodes, self.node_features)
        adj_fake_logits = self.lin_adj(zt).view(-1, self.max_nodes, self.max_nodes, self.bond_features)
        return x_fake_logits, adj_fake_logits

# --- Model Initialization ---
print("Initializing models...")
# --- Pass the correct dimensions ---
generator = Generator(Z_DIM, T_EMBED_DIM, ATOM_FEAT_DIM, BOND_FEAT_DIM_GENERATOR).to(DEVICE)
discriminator = Discriminator(ATOM_FEAT_DIM, BOND_FEAT_DIM_DISCRIMINATOR, T_EMBED_DIM, EMBED_DIM).to(DEVICE)

# --- !!!!!!!!!!!!!!!!!!!!!!!!!!! ---
# --- 1st CHANGE: Lowered Learning Rate ---
# --- !!!!!!!!!!!!!!!!!!!!!!!!!!! ---
optimizer_G = optim.Adam(generator.parameters(), lr=1e-5, betas=(0.5, 0.9)) # Was 1e-4
optimizer_D = optim.Adam(discriminator.parameters(), lr=5e-6, betas=(0.5, 0.9)) # <-- REDUCE D's LR (e.g., by 2-5x)


# --- 5. Training Utilities (v3 Logic) ---

# --- 5.1. Sparse Graph Conversion ---
def convert_fake_to_SPARSE_data_vectorized(x_fake_logits, adj_fake_logits, t_embed_batch, device, gumbel=False, temperature=0.5):
    """
    Converts Generator output (logits) to a BATCH of sparse PyG Data objects.
    """
    batch_size = x_fake_logits.size(0)
    data_list = []

    # 1. Sample nodes (still vectorized)
    if gumbel:
        x_fake_tensor = F.gumbel_softmax(x_fake_logits, tau=temperature, hard=True)
    else:
        x_indices = torch.argmax(x_fake_logits, dim=-1)
        x_fake_tensor = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()

    # 2. Sample bonds (still vectorized)
    if gumbel:
        adj_fake_tensor = F.gumbel_softmax(adj_fake_logits, tau=temperature, hard=True)
    else:
        adj_indices = torch.argmax(adj_fake_logits, dim=-1)
        adj_fake_tensor = F.one_hot(adj_indices, num_classes=BOND_FEAT_DIM_GENERATOR).float()
    
    # 3. Loop over batch to build sparse graphs
    for i in range(batch_size):
        x = x_fake_tensor[i] # Shape [N, ATOM_FEAT_DIM]
        adj_full = adj_fake_tensor[i] # Shape [N, N, BOND_FEAT_DIM_GENERATOR]
        
        edge_indices = []
        edge_attrs = []

        # Iterate over upper triangle
        for j in range(MAX_NODES):
            for k in range(j + 1, MAX_NODES):
                bond_logits = adj_full[j, k] # Shape [5]
                bond_type_idx = torch.argmax(bond_logits).item()
                
                # --- Check if it's NOT a "No Bond" ---
                if bond_type_idx != NO_BOND_IDX:
                    # Add this edge
                    edge_indices.extend([[j, k], [k, j]])
                    
                    # --- Create 4-dim one-hot vector ---
                    bond_attr = torch.zeros(BOND_FEAT_DIM_DISCRIMINATOR, device=device)
                    # Handle case where index might be out of bounds if something is wrong
                    if 0 <= bond_type_idx < BOND_FEAT_DIM_DISCRIMINATOR:
                         bond_attr[bond_type_idx] = 1.0
                    edge_attrs.extend([bond_attr, bond_attr])

        if not edge_indices:
            # No bonds were formed, create a dummy to avoid errors
            edge_index = torch.empty((2, 0), dtype=torch.long, device=device)
            edge_attr = torch.empty((0, BOND_FEAT_DIM_DISCRIMINATOR), dtype=torch.float, device=device)
        else:
            edge_index = torch.tensor(edge_indices, dtype=torch.long, device=device).t().contiguous()
            edge_attr = torch.stack(edge_attrs)

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            target_embed=t_embed_batch[i].unsqueeze(0) # Embed for this single graph
        )
        data_list.append(data)

    # 4. Re-batch the sparse graphs
    
    # Use a loader to properly collate the list of Data objects into a Batch object
    temp_loader = PyGDataLoader(data_list, batch_size=batch_size, shuffle=False)
    batch = next(iter(temp_loader))
    
    # Squeeze the target_embed back to [B, T_EMBED_DIM]
    batch.target_embed = batch.target_embed.squeeze(1)
    
    return batch.to(device)


# --- 5.2. WGAN-GP Gradient Penalty (v3 Logic) ---
def calculate_gradient_penalty(discriminator, real_data, fake_data, lambda_gp, device):
    """
    Interpolates on the GLOBAL graph embeddings.
    """
    
    # --- Get graph embeddings first ---
    discriminator.eval() # Freeze discriminator for this part
    
    real_x, real_edge_index, real_edge_attr, real_batch = real_data.x, real_data.edge_index, real_data.edge_attr, real_data.batch
    real_t_embed = real_data.target_embed
    for layer in discriminator.layers:
        real_x = layer(real_x, real_edge_index, real_edge_attr)
        real_x = F.relu(real_x)
    real_graph_embed = global_mean_pool(real_x, real_batch)
    
    fake_x, fake_edge_index, fake_edge_attr, fake_batch = fake_data.x, fake_data.edge_index, fake_data.edge_attr, fake_data.batch
    # --- Handle case where fake batch might be empty ---
    if fake_data.num_graphs == 0:
        discriminator.train()
        # Return a 0 penalty if there's nothing to compare
        return torch.tensor(0.0, device=device, requires_grad=True)
        
    fake_t_embed = fake_data.target_embed
    for layer in discriminator.layers:
        fake_x = layer(fake_x, fake_edge_index, fake_edge_attr)
        fake_x = F.relu(fake_x)
    fake_graph_embed = global_mean_pool(fake_x, fake_batch)
    
    discriminator.train() # Unfreeze
    
    # Match batch sizes if they differ
    batch_size = min(real_graph_embed.size(0), fake_graph_embed.size(0))
    if batch_size == 0:
        return torch.tensor(0.0, device=device, requires_grad=True)
        
    real_graph_embed = real_graph_embed[:batch_size]
    fake_graph_embed = fake_graph_embed[:batch_size]
    real_t_embed = real_t_embed[:batch_size]

    # --- Interpolate on graph_embed ---
    alpha = torch.rand(batch_size, 1).to(device)
    interpolated_embed = (alpha * real_graph_embed) + ((1 - alpha) * fake_graph_embed)
    interpolated_embed.requires_grad_(True)
    
    # --- Combine with target and pass to *final layer only* ---
    final_input = torch.cat([interpolated_embed, real_t_embed], dim=1)
    disc_interpolates = discriminator.lin_final(final_input).squeeze(1)

    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolated_embed, # Grad w.r.t. interpolated_embed
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True, retain_graph=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

# --- 6. Main Training Loop (v3 Logic + STABILIZATION) ---
# --- 6. Main Training Loop (v3 Logic + STABILIZATION) ---
def run_wgan_gp_training(generator, discriminator, data_loader, epochs, n_critic):
    
    generator.train()
    discriminator.train()
    
    # --- ADDED: History lists to store epoch averages ---
    d_loss_history = []
    g_loss_history = []
    d_real_history = []
    d_fake_history = []
    
    for epoch in range(1, epochs + 1):
        # --- Reset sums for each epoch ---
        g_loss_sum, d_loss_sum = 0, 0
        d_real_sum, d_fake_sum = 0, 0
        
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        
        for batch_idx, real_data in enumerate(progress_bar):
            real_data = real_data.to(DEVICE)
            
            if real_data.num_graphs < 2: # Need at least 2 for GP
                continue
                
            batch_size = real_data.num_graphs
            target_embed_batch = real_data.target_embed

            # 1. Train Discriminator (n_critic steps)
            for _ in range(n_critic):
                optimizer_D.zero_grad()
                
                d_real = discriminator(real_data).mean()
                
                z = torch.randn(batch_size, Z_DIM).to(DEVICE)
                x_fake_logits, adj_fake_logits = generator(z, target_embed_batch)
                
                with torch.no_grad():
                    fake_data = convert_fake_to_SPARSE_data_vectorized(
                        x_fake_logits.detach(), adj_fake_logits.detach(), 
                        target_embed_batch, DEVICE, gumbel=False
                    )
                
                if fake_data.num_graphs == 0:
                    continue
                
                d_fake = discriminator(fake_data).mean()
                gp = calculate_gradient_penalty(discriminator, real_data, fake_data, LAMBDA_GP, DEVICE)
                d_loss = - (d_real - d_fake) + gp
                d_loss.backward()
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
                optimizer_D.step()
                
                d_loss_sum += d_loss.item()
                d_real_sum += d_real.item() 
                d_fake_sum += d_fake.item()
            
            # 2. Train Generator (1 step)
            optimizer_G.zero_grad()
            
            z = torch.randn(batch_size, Z_DIM).to(DEVICE)
            x_fake_logits, adj_fake_logits = generator(z, target_embed_batch)
            
            fake_data = convert_fake_to_SPARSE_data_vectorized(
                x_fake_logits, adj_fake_logits, 
                target_embed_batch, DEVICE, gumbel=True
            )
            
            if fake_data.num_graphs == 0:
                continue

            g_loss = - discriminator(fake_data).mean()
            g_loss.backward()
            optimizer_G.step()
            g_loss_sum += g_loss.item()
            
            # --- Update postfix (calculates average for batches *so far*) ---
            num_batches = batch_idx + 1
            progress_bar.set_postfix(
                D_Loss=f"{(d_loss_sum / num_batches / n_critic):.4f}", 
                G_Loss=f"{(g_loss_sum / num_batches):.4f}",
                D_Real=f"{(d_real_sum / num_batches / n_critic):.4f}",
                D_Fake=f"{(d_fake_sum / num_batches / n_critic):.4f}"
            )
        
        # --- ADDED: Store the final average for the completed epoch ---
        num_batches_total = len(data_loader)
        d_loss_history.append(d_loss_sum / num_batches_total / n_critic)
        g_loss_history.append(g_loss_sum / num_batches_total)
        d_real_history.append(d_real_sum / num_batches_total / n_critic)
        d_fake_history.append(d_fake_sum / num_batches_total / n_critic)

    # --- ADDED: Return the histories ---
    return d_loss_history, g_loss_history, d_real_history, d_fake_history
# --- 7. Generation & SMILES Conversion (v3 Logic) ---

def tensors_to_smiles(x_fake_one_hot, adj_fake_logits):
    """
    Converts raw generator tensor output (one-hot nodes) into SMILES strings.
    """
    # --- Get atomic number indices from one-hot nodes ---
    x_fake_indices = torch.argmax(x_fake_one_hot, dim=-1).cpu().detach()
    adj_fake_logits = adj_fake_logits.cpu().detach()
    
    # --- Get bond indices from 5-class logits ---
    adj_bond_type_idx = torch.argmax(adj_fake_logits, dim=-1)
    
    batch_size = x_fake_indices.size(0)
    generated_smiles = []
    generated_mols = []
    
    for i in range(batch_size):
        mol = Chem.RWMol()
        atom_map = {} # Map from tensor index (0..MAX_NODES-1) to RDKit atom index
        
        # 1. Add atoms
        for j in range(MAX_NODES):
            atom_idx = x_fake_indices[i, j].item()
            atom_num = ATOM_CLASSES[atom_idx]
            
            atom = Chem.Atom(atom_num)
            rdkit_idx = mol.AddAtom(atom)
            atom_map[j] = rdkit_idx
                
        # 2. Add bonds
        for j in range(MAX_NODES):
            for k in range(j + 1, MAX_NODES):
                # --- Get bond type from 5-class indices ---
                bond_type_idx = adj_bond_type_idx[i, j, k].item()
                
                # --- Add bond IF NOT "No Bond" ---
                if bond_type_idx != NO_BOND_IDX:
                    # --- Check if bond_type_idx is valid for RDKit bonds ---
                    if 0 <= bond_type_idx < len(BOND_CLASSES_RDKIT):
                        bond_type = BOND_CLASSES_RDKIT[bond_type_idx]
                        mol.AddBond(atom_map[j], atom_map[k], bond_type)
        
        # 3. Sanitize and Convert
        try:
            Chem.SanitizeMol(mol)
            smi = Chem.MolToSmiles(mol)
            
            # --- Filter out disconnected fragments ---
            if '.' in smi:
                generated_smiles.append(None) # Invalid fragment
                generated_mols.append(None)
            else:
                generated_smiles.append(smi)
                generated_mols.append(mol)
        except Exception as e:
            # print(f"RDKit Error: {e}") # Uncomment for debugging
            generated_smiles.append(None) # Invalid molecule
            generated_mols.append(None)

    valid_smiles = [s for s in generated_smiles if s is not None]
    valid_mols = [m for m in generated_mols if m is not None]
    
    return valid_smiles, valid_mols, generated_smiles

# --- 8. Performance Metrics & Plotting (v3 Logic) ---

def calculate_and_plot_metrics(generator, target_embed, real_smiles_list, num_to_generate, device):
    """
    Generates molecules and calculates Validity, Uniqueness, Novelty,
    and plots property distributions.
    """
    print("\n--- Starting Generation & Evaluation ---")
    warnings.filterwarnings('ignore', '.*Implicit valence.*') # Suppress RDKit warnings
    
    generator.eval() # Set generator to evaluation mode
    
    real_mols = [Chem.MolFromSmiles(s) for s in real_smiles_list]
    real_mols = [m for m in real_mols if m is not None]
    real_smiles_set = set(real_smiles_list)
    
    all_valid_smiles = []
    all_valid_mols = []
    total_attempts = 0 # Track total attempts

    print(f"Generating {num_to_generate} *valid* molecules for evaluation...")
    with torch.no_grad():
        while len(all_valid_smiles) < num_to_generate:
            batch_size = BATCH_SIZE
            total_attempts += batch_size

            z = torch.randn(batch_size, Z_DIM).to(device)
            t_embed_batch = target_embed.unsqueeze(0).repeat(batch_size, 1)
            
            x_fake_logits, adj_fake_logits = generator(z, t_embed_batch)
            
            # --- Use argmax (not Gumbel) for final generation ---
            x_indices = torch.argmax(x_fake_logits, dim=-1)
            x_fake_one_hot = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()
            
            # --- Pass 5-dim bond logits to smiles converter ---
            valid_smiles, valid_mols, _ = tensors_to_smiles(x_fake_one_hot, adj_fake_logits)
            
            all_valid_smiles.extend(valid_smiles)
            all_valid_mols.extend(valid_mols)
            
            print(f"Generated: {len(all_valid_smiles)}/{num_to_generate} valid molecules...", end='\r')
            
            if total_attempts > num_to_generate * 50 and not all_valid_smiles:
                 print("\nError: Generated too many molecules with 0 validity. Stopping.")
                 break
            if total_attempts > num_to_generate * 10 and len(all_valid_smiles) < num_to_generate: 
                 print(f"\nWarning: Low validity. Stopping generation at {len(all_valid_smiles)} molecules.")
                 break

    print("\nGeneration complete. Calculating metrics...")
    
    # --- 1. Calculate Metrics ---
    
    if total_attempts == 0: total_attempts = 1
    validity = len(all_valid_smiles) / total_attempts
    
    if len(all_valid_smiles) > 0:
        uniqueness = len(set(all_valid_smiles)) / len(all_valid_smiles)
    else:
        uniqueness = 0.0
        
    if len(all_valid_smiles) > 0:
        unique_valid_smiles = set(all_valid_smiles)
        novel_smiles = unique_valid_smiles - real_smiles_set
        novelty = len(novel_smiles) / len(unique_valid_smiles)
    else:
        novelty = 0.0

    print("\n--- Generative Performance Metrics ---")
    print(f"Total Attempts: {total_attempts}")
    print(f"Total Valid Generated: {len(all_valid_smiles)}")
    print(f"‚úÖ Validity:     {validity * 100:.2f}%")
    print(f"üß¨ Uniqueness:   {uniqueness * 100:.2f}%")
    print(f"‚≠ê Novelty:      {novelty * 100:.2f}%")
    print("----------------------------------------")
    
    if not all_valid_mols:
        print("No valid molecules generated. Skipping plots.")
        return

    # --- 2. Calculate Properties ---
    props_real = {
        'MolWt': [Descriptors.MolWt(m) for m in real_mols],
        'LogP': [Descriptors.MolLogP(m) for m in real_mols],
        'QED': [QED.qed(m) for m in real_mols]
    }
    
    props_fake = {
        'MolWt': [Descriptors.MolWt(m) for m in all_valid_mols],
        'LogP': [Descriptors.MolLogP(m) for m in all_valid_mols],
        'QED': [QED.qed(m) for m in all_valid_mols]
    }

    # --- 3. Plot Distributions ---
    print("Generating property distribution plots...")
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    plot_titles = ['Molecular Weight (MolWt)', 'LogP', 'Quantitative Esimation of Drug-likeness (QED)']
    prop_keys = ['MolWt', 'LogP', 'QED']
    
    for ax, title, key in zip(axes, plot_titles, prop_keys):
        ax.hist(props_real[key], bins=50, alpha=0.7, label='Real (Training)', color='blue', density=True)
        ax.hist(props_fake[key], bins=50, alpha=0.7, label='Generated (Fake)', color='red', density=True)
        ax.set_title(title)
        ax.set_xlabel("Value")
        ax.set_ylabel("Density")
        ax.legend()
        
    plt.suptitle(f"Property Distributions (Real vs. Generated) for {TARGET_UNIPROT_ID}", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(f"property_plots_{TARGET_UNIPROT_ID}_v3.1.png")
    print(f"Plots saved to property_plots_{TARGET_UNIPROT_ID}_v3.1.png")
    plt.show()
    
    warnings.filterwarnings('default', '.*Implicit valence.*') # Restore warnings

# --- NEW FUNCTION (insert before Section 9) ---
def plot_training_losses(d_loss_hist, g_loss_hist, d_real_hist, d_fake_hist, target_id):
    """
    Plots the training history of WGAN-GP losses and scores.
    """
    print("Generating training loss plots...")
    epochs_range = range(1, len(g_loss_hist) + 1)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    
    # --- Plot 1: G_Loss vs D_Loss ---
    ax1.plot(epochs_range, g_loss_hist, label='Generator Loss (G_Loss)', color='blue')
    ax1.plot(epochs_range, d_loss_hist, label='Discriminator Loss (D_Loss)', color='red')
    ax1.set_title(f"Generator & Discriminator Losses for {target_id}")
    ax1.set_ylabel("Loss")
    ax1.legend()
    ax1.grid(True)
    
    # --- Plot 2: D(Real) vs D(Fake) ---
    ax2.plot(epochs_range, d_real_hist, label='Avg. D(Real) Score', color='green')
    ax2.plot(epochs_range, d_fake_hist, label='Avg. D(Fake) Score', color='orange')
    ax2.set_title(f"Critic Scores for {target_id}")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Score")
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"training_loss_plots_{target_id}_v3.1.png")
    print(f"Loss plots saved to training_loss_plots_{target_id}_v3.1.png")
    plt.show()
# --- 9. --- Main Execution (Train & Evaluate) ---

# --- 9. --- Main Execution (Train & Evaluate) ---

# --- Execute Training ---
print("\n--- Starting WGAN-GP Training (v3.1 - Sparse & Stabilized) ---")
# --- MODIFIED: Capture the returned histories ---
d_loss_hist, g_loss_hist, d_real_hist, d_fake_hist = run_wgan_gp_training(
    generator, 
    discriminator, 
    real_loader, 
    EPOCHS, 
    N_CRITIC
) 
print("\nTraining completed.")

# --- ADDED: Call the new plotting function ---
plot_training_losses(d_loss_hist, g_loss_hist, d_real_hist, d_fake_hist, TARGET_UNIPROT_ID)

# --- Execute Evaluation ---
num_to_eval = len(real_data_list) # --- Use count of *filtered* real data ---
calculate_and_plot_metrics(generator, TARGET_EMBED, inhibitor_smiles, num_to_eval, DEVICE)

‚úÖ CUDA is available! GPU will be used for training.
PyTorch CUDA Version: 12.1
GPU Name: NVIDIA GeForce RTX 4060 Laptop GPU
Using device: cuda
Opening gzipped FASTA file: DL_ENDSEM__DATASET/chembl_35_blast.fa.gz
Loading ProtT5 model... (This may take a moment)


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Generated protein embedding of shape: torch.Size([1024])
Found 3989 potent inhibitors for UniProt ID P00533.
Prepared 1332 real graph samples for training.
Initializing models...

--- Starting WGAN-GP Training (v3.1 - Sparse & Stabilized) ---


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch 1/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [05:12<00:00, 14.87s/it, D_Fake=0.0184, D_Loss=6.2638, D_Real=0.0207, G_Loss=-0.0184]
Epoch 2/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [04:51<00:00, 13.89s/it, D_Fake=0.0190, D_Loss=6.2063, D_Real=0.0279, G_Loss=-0.0185]
Epoch 3/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [04:58<00:00, 14.23s/it, D_Fake=0.0199, D_Loss=6.1459, D_Real=0.0385, G_Loss=-0.0182]
Epoch 4/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [05:07<00:00, 14.64s/it, D_Fake=0.0211, D_Loss=6.0822, D_Real=0.0527, G_Loss=-0.0170]
Epoch 5/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [05:13<00:00, 14.92s/it, D_Fake=0.0229, D_Loss=6.0150, D_Real=0.0708, G_Loss=-0.0147]
Epoch 6/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [05:18<00:00, 15.15s/it, D_Fake=0.0252, D_Loss=5.9433, D_Real=0.0943, G_Loss=-0.0109]
Epoch 7/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñ

KeyboardInterrupt: 

In [2]:
# --- ADDED: Call the new plotting function ---
plot_training_losses(d_loss_hist, g_loss_hist, d_real_hist, d_fake_hist, TARGET_UNIPROT_ID)

# --- Execute Evaluation ---
num_to_eval = len(real_data_list) # --- Use count of *filtered* real data ---
calculate_and_plot_metrics(generator, TARGET_EMBED, inhibitor_smiles, num_to_eval, DEVICE)

NameError: name 'd_loss_hist' is not defined