In [None]:
import torch
import pandas as pd
import sqlite3
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader as PyGDataLoader 
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import softmax
from rdkit import Chem, DataStructs
from rdkit.Chem import Descriptors, QED, AllChem
from Bio import SeqIO
import random
import re
from transformers import T5EncoderModel, T5Tokenizer
import gzip
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings
import os 

# --- 1. System & Configuration ---

# --- Configuration (UPDATE THESE PATHS & ID) ---
CHEMPL_DB_PATH = 'DL_ENDSEM__DATASET/chembl_35/chembl_35_sqlite/chembl_35.db'
BLAST_FASTA_PATH = 'DL_ENDSEM__DATASET/chembl_35_blast.fa.gz'
TARGET_UNIPROT_ID = "P00533" # Example: EGFR Kinase

# --- Atom Definitions ---
ATOM_CLASSES = [6, 7, 8, 9, 15, 16, 17, 35, 53] # C, N, O, F, P, S, Cl, Br, I
ATOM_CLASSES_MAP = {num: i for i, num in enumerate(ATOM_CLASSES)}
ATOM_FEAT_DIM = len(ATOM_CLASSES) # Now 9
# Define valencies for our ATOM_CLASSES: [C, N, O, F, P, S, Cl, Br, I]
VALENCIES = [4, 3, 2, 1, 5, 6, 1, 1, 1]

# --- Bond Definitions ---
# Define bond types RDKit knows
BOND_CLASSES_RDKIT = [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]
# Add a 5th "No Bond" class for the Generator
BOND_FEAT_DIM_GENERATOR = len(BOND_CLASSES_RDKIT) + 1 # Now 5
NO_BOND_IDX = len(BOND_CLASSES_RDKIT) # Index 4
# Discriminator only sees 4 bond types
BOND_FEAT_DIM_DISCRIMINATOR = len(BOND_CLASSES_RDKIT) # Back to 4
# Define bond orders for our 5 Generator bond types: [S, D, T, A, None]
BOND_ORDERS = [1.0, 2.0, 3.0, 1.5, 0.0]

# --- Model Hyperparameters ---
Z_DIM = 100          # Latent noise dimension
EMBED_DIM = 128      # Hidden dimension for the Graph Transformer
T_EMBED_DIM = 1024   # Target embedding dimension (from ProtT5)
LAMBDA_GP = 10.0     # Gradient Penalty weight
LAMBDA_VALENCY = 1.0 # Valency loss weight
MAX_NODES = 30       # Max atoms in generated molecules
N_CRITIC = 1         # Discriminator training steps per Generator step (1:1)
EPOCHS = 100         # Run for 100 epochs
BATCH_SIZE = 64      # Your increased batch size
CPU_WORKERS = 4      

# --- CUDA Check ---
if torch.cuda.is_available():
    print("‚úÖ CUDA is available! GPU will be used for training.")
    print(f"PyTorch CUDA Version: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    print("‚ùå CUDA not found. Running on CPU.")
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")

# --- Tensors (Define *after* DEVICE is set) ---
VALENCY_TENSOR = torch.tensor(VALENCIES, dtype=torch.float, device=DEVICE)
BOND_ORDER_TENSOR = torch.tensor(BOND_ORDERS, dtype=torch.float, device=DEVICE)


# --- 2. Real Protein Embedding Generation (No changes) ---
# ... (rest of your script follows) ..

# --- 2. Real Protein Embedding Generation (No changes) ---

def load_target_sequence(fasta_path, uniprot_id):
    # (No changes to this function)
    print(f"Opening gzipped FASTA file: {fasta_path}")
    try:
        with gzip.open(fasta_path, "rt") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                if uniprot_id in record.id or uniprot_id in record.description:
                    return str(record.seq)
            print(f"Warning: Could not find sequence for {uniprot_id} in {fasta_path}")
            return None
    except FileNotFoundError:
        print(f"FATAL ERROR: FASTA file not found at {fasta_path}")
        raise
    except Exception as e:
        print(f"FATAL ERROR: Could not read FASTA file. Error: {e}")
        raise

def get_protein_embedding(sequence, device):
    # (No changes to this function)
    print("Loading ProtT5 model... (This may take a moment)")
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc').to(device)
    model.eval() 
    sequence_preprocessed = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))
    inputs = tokenizer(sequence_preprocessed, return_tensors="pt", max_length=1024, truncation=True).to(device)
    with torch.no_grad():
        embedding = model(**inputs).last_hidden_state
    protein_vec = embedding.mean(dim=1).squeeze(0)
    print(f"Generated protein embedding of shape: {protein_vec.shape}")
    return protein_vec

# --- Generate the REAL Target Embedding ---
target_seq = load_target_sequence(BLAST_FASTA_PATH, TARGET_UNIPROT_ID)
if target_seq is None:
    raise ValueError(f"Target sequence for {TARGET_UNIPROT_ID} not found. Exiting.")
TARGET_EMBED = get_protein_embedding(target_seq, DEVICE)


# --- 3. Data Pipeline (Molecules -> Graphs) (v3 Logic) ---

def extract_potent_inhibitors(db_path, uniprot_id, potency_cutoff_nM=100):
    # (No changes to this function)
    try:
        conn = sqlite3.connect(db_path)
        sql_query = f"""
        SELECT DISTINCT cs.canonical_smiles
        FROM activities acts
        JOIN assays a ON acts.assay_id = a.assay_id
        JOIN target_dictionary td ON a.tid = td.tid
        JOIN target_components tc ON td.tid = tc.tid
        JOIN component_sequences cseq ON tc.component_id = cseq.component_id
        JOIN compound_structures cs ON acts.molregno = cs.molregno
        WHERE
            cseq.accession = '{uniprot_id}' AND
            acts.standard_type = 'IC50' AND
            acts.standard_units = 'nM' AND
            acts.standard_relation = '=' AND
            acts.standard_value <= {potency_cutoff_nM}
        """
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        print(f"Found {len(df)} potent inhibitors for UniProt ID {uniprot_id}.")
        return df['canonical_smiles'].unique().tolist()
    except Exception as e:
        print(f"Error during database query: {e}")
        raise

def get_atom_features(atom):
    """Creates a one-hot vector for the atom type."""
    atom_num = atom.GetAtomicNum()
    if atom_num not in ATOM_CLASSES_MAP:
        return None # Atom is not in our allowed list
        
    atom_index = ATOM_CLASSES_MAP[atom_num]
    atom_one_hot = torch.zeros(ATOM_FEAT_DIM, dtype=torch.float)
    atom_one_hot[atom_index] = 1.0
    return atom_one_hot

def smiles_to_graph(smiles, target_embed):
    """Converts SMILES to a PyG Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if not mol: return None
    if mol.GetNumAtoms() > MAX_NODES: return None

    atom_features_list = []
    for atom in mol.GetAtoms():
        features = get_atom_features(atom)
        if features is None: # Skip molecule if it contains an invalid atom
            return None
        atom_features_list.append(features)

    if not atom_features_list:
        return None
    x = torch.stack(atom_features_list)
    
    edge_indices, edge_attrs = [], []
    # --- Use the 4-class bond list ---
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        # Create a 4-dim one-hot vector
        bond_type_oh = [int(bond.GetBondType() == t) for t in BOND_CLASSES_RDKIT]
        
        # --- Ensure bond type is one we recognize ---
        if sum(bond_type_oh) == 1: #i.e., it's S, D, T, or Aromatic
            edge_indices.extend([[i, j], [j, i]])
            edge_attrs.extend([bond_type_oh, bond_type_oh])

    # --- Return None if molecule has no *recognized* bonds ---
    if not edge_indices and mol.GetNumAtoms() > 1: # Only allow single-atom "graphs"
        return None
    
    if not edge_indices:
        # Handle single-atom molecule (e.g., [O-2])
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, BOND_FEAT_DIM_DISCRIMINATOR), dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attrs, dtype=torch.float) # Shape: [num_bonds, 4]

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                target_embed=target_embed.unsqueeze(0))
    return data

# --- Data Pipeline Execution ---

inhibitor_smiles = extract_potent_inhibitors(CHEMPL_DB_PATH, TARGET_UNIPROT_ID, potency_cutoff_nM=5000)
real_data_list = [smiles_to_graph(s, TARGET_EMBED.cpu()) for s in inhibitor_smiles]
real_data_list = [d for d in real_data_list if d is not None]

if not real_data_list:
    print(f"FATAL: No valid inhibitor data found (or all were filtered out). Check ATOM_CLASSES and BOND_CLASSES.")
    exit()

# --- USE THE STANDARD PyG DATALOADER ---
real_loader = PyGDataLoader(
    real_data_list, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=CPU_WORKERS, 
    pin_memory=True 
)
print(f"Prepared {len(real_data_list)} real graph samples for training.")


# --- 4. Model Architecture (v3 Logic) ---

# --- 4.1. Relational Graph Transformer Layer (MODIFIED) ---
class RelationalGraphTransformerLayer(MessagePassing):
    # --- CHANGED: Added residual connection & robust message passing ---
    def __init__(self, in_channels, out_channels, edge_dim, heads=4, dropout=0.1, **kwargs):
        super().__init__(aggr='add', node_dim=0, **kwargs)
        self.out_channels = out_channels
        self.heads = heads
        self.lin_q = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_k = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_v = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) 
        self.att_coeff = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels)) 
        self.lin_out = nn.Linear(heads * out_channels, out_channels)
        self.dropout = nn.Dropout(dropout)
        
        # --- CHANGED: Add a skip connection linear layer ---
        # This is CRITICAL for stability and to handle nodes/graphs with 0 edges.
        if in_channels == out_channels:
            self.lin_skip = nn.Identity()
        else:
            self.lin_skip = nn.Linear(in_channels, out_channels)
            
        nn.init.xavier_uniform_(self.att_coeff)

    def forward(self, x, edge_index, edge_attr):
        Q = self.lin_q(x).view(-1, self.heads, self.out_channels)
        K = self.lin_k(x).view(-1, self.heads, self.out_channels)
        V = self.lin_v(x).view(-1, self.heads, self.out_channels)
        
        # --- Handle empty edge_attr tensor ---
        if edge_attr.size(0) > 0:
             E = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels)
        else:
             # Create empty E on the correct device with correct shape
             E = torch.empty((0, self.heads, self.out_channels), device=x.device)

        out = self.propagate(edge_index, Q=Q, K=K, V=V, E_k=E) # <--- Renamed E=E to E_k=E
        out = out.view(-1, self.heads * self.out_channels)
        out = self.lin_out(out)
        
        # --- CHANGED: Add the residual (skip) connection ---
        x_skip = self.lin_skip(x)
        out = out + x_skip 
        
        return out

    def message(self, Q_i, K_j, V_j, E_k, index):
        # --- CHANGED: Replaced brittle try/except with robust check ---
        # E_k is the tensor of edge features [num_edges_in_batch, heads, out_channels]
        if E_k is None or E_k.size(0) == 0:
            # If there are no edges, E_bias is zero.
            # Q_i has shape [num_edges, heads, out_channels], so Q_i.size(0) is num_edges
            E_bias = torch.zeros(Q_i.size(0), self.heads, 1, device=Q_i.device)
        else:
            E_bias = E_k.mean(dim=-1, keepdim=True) 
            
        QK_cat = torch.cat([Q_i, K_j], dim=-1)
        e_ij = (QK_cat * self.att_coeff).sum(dim=-1, keepdim=True)
        e_ij = e_ij + E_bias
        e_ij = F.leaky_relu(e_ij)
        alpha = softmax(e_ij, index)
        alpha = self.dropout(alpha)
        return V_j * alpha.view(-1, self.heads, 1)

# --- 4.2. Discriminator (v3 Logic) ---
class Discriminator(nn.Module):
    # --- edge_dim is BOND_FEAT_DIM_DISCRIMINATOR (4) ---
    def __init__(self, node_features, edge_dim, t_embed_dim, embed_dim, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_d = node_features if i == 0 else embed_dim
            # --- This layer now expects edge_dim = 4 ---
            self.layers.append(RelationalGraphTransformerLayer(in_d, embed_dim, edge_dim))
        self.lin_final = nn.Linear(embed_dim + t_embed_dim, 1)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        # --- Handle empty batch gracefully ---
        if data.num_graphs == 0 or x.size(0) == 0:
             # Return a single 0 for the batch, requires grad
             return torch.tensor([0.0], device=x.device, requires_grad=True)

        t_embed = data.target_embed
        
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
        
        graph_embed = global_mean_pool(x, batch)
        
        # Ensure t_embed is [B, D]
        if t_embed.dim() == 3 and t_embed.size(1) == 1: 
            t_embed = t_embed.squeeze(1) 
        elif t_embed.dim() == 1: # Single item in batch
            t_embed = t_embed.unsqueeze(0)
        
        # Handle mismatch if global_mean_pool returns empty
        if graph_embed.size(0) == 0:
            return torch.tensor([0.0], device=x.device, requires_grad=True)

        # Ensure batch sizes match (t_embed might be [1, D] broadcasted)
        if t_embed.size(0) == 1 and graph_embed.size(0) > 1:
            t_embed = t_embed.repeat(graph_embed.size(0), 1)

        final_input = torch.cat([graph_embed, t_embed], dim=1)
        return self.lin_final(final_input).squeeze(1)

# --- 4.3. Generator (v3 Logic) ---
class Generator(nn.Module):
    # --- bond_features is BOND_FEAT_DIM_GENERATOR (5) ---
    def __init__(self, z_dim, t_embed_dim, node_features, bond_features, max_nodes=MAX_NODES):
        super().__init__()
        self.max_nodes = max_nodes
        self.node_features = node_features
        self.bond_features = bond_features
        self.lin_x = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * node_features))
        # --- This layer now outputs 5 features per bond ---
        self.lin_adj = nn.Sequential(nn.Linear(z_dim + t_embed_dim, 256), nn.ReLU(), nn.Linear(256, max_nodes * max_nodes * bond_features))

    def forward(self, z, t_embed):
        zt = torch.cat([z, t_embed], dim=1)
        x_fake_logits = self.lin_x(zt).view(-1, self.max_nodes, self.node_features)
        adj_fake_logits = self.lin_adj(zt).view(-1, self.max_nodes, self.max_nodes, self.bond_features)
        return x_fake_logits, adj_fake_logits

# --- Model Initialization ---
print("Initializing models...")
# --- Pass the correct dimensions ---
generator = Generator(Z_DIM, T_EMBED_DIM, ATOM_FEAT_DIM, BOND_FEAT_DIM_GENERATOR).to(DEVICE)
discriminator = Discriminator(ATOM_FEAT_DIM, BOND_FEAT_DIM_DISCRIMINATOR, T_EMBED_DIM, EMBED_DIM).to(DEVICE)

# --- Learning Rate (as in your script) ---
optimizer_G = optim.Adam(generator.parameters(), lr=5e-5, betas=(0.5, 0.9)) # <-- 5x FASTER
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-5, betas=(0.5, 0.9)) # <-- 2x SLOWER


# --- 5. Training Utilities (v3 Logic) ---

# --- 5.1. Sparse Graph Conversion (FIXED) ---
def convert_fake_to_SPARSE_data_vectorized(x_fake_logits, adj_fake_logits, t_embed_batch, device, gumbel=False, temperature=0.5):
    """
    Converts Generator output (logits) to a BATCH of sparse PyG Data objects.
    """
    batch_size = x_fake_logits.size(0)
    data_list = []

    # 1. Sample nodes (still vectorized)
    if gumbel:
        x_fake_tensor = F.gumbel_softmax(x_fake_logits, tau=temperature, hard=True)
    else:
        x_indices = torch.argmax(x_fake_logits, dim=-1)
        x_fake_tensor = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()

    # 2. Sample bonds (still vectorized)
    # --- !! FIXED: Apply Gumbel to bonds if in training (gumbel=True) !! ---
    if gumbel:
        adj_fake_tensor = F.gumbel_softmax(adj_fake_logits, tau=temperature, hard=True)
        # adj_fake_tensor is [B, N, N, 5] (one-hot)
    else:
        adj_indices = torch.argmax(adj_fake_logits, dim=-1)
        adj_fake_tensor = F.one_hot(adj_indices, num_classes=BOND_FEAT_DIM_GENERATOR).float()
        # adj_fake_tensor is [B, N, N, 5] (one-hot)
    
    # 3. Loop over batch to build sparse graphs
    for i in range(batch_size):
        x = x_fake_tensor[i] # Shape [N, ATOM_FEAT_DIM]
        adj_full_one_hot = adj_fake_tensor[i] # Shape [N, N, 5] (one-hot)
        
        edge_indices = []
        edge_attrs = []

        # Iterate over upper triangle
        for j in range(MAX_NODES):
            for k in range(j + 1, MAX_NODES):
                
                # --- CHANGED: Get bond type from one-hot tensor ---
                bond_one_hot = adj_full_one_hot[j, k] # Shape [5]
                bond_type_idx = torch.argmax(bond_one_hot).item()
                
                # --- Check if it's NOT a "No Bond" ---
                if bond_type_idx != NO_BOND_IDX:
                    # Add this edge
                    edge_indices.extend([[j, k], [k, j]])
                    
                    # --- Create 4-dim one-hot vector ---
                    bond_attr = torch.zeros(BOND_FEAT_DIM_DISCRIMINATOR, device=device)
                    # We know bond_type_idx is 0, 1, 2, or 3
                    bond_attr[bond_type_idx] = 1.0
                    edge_attrs.extend([bond_attr, bond_attr])

        if not edge_indices:
            # No bonds were formed, create a dummy
            edge_index = torch.empty((2, 0), dtype=torch.long, device=device)
            edge_attr = torch.empty((0, BOND_FEAT_DIM_DISCRIMINATOR), dtype=torch.float, device=device)
        else:
            edge_index = torch.tensor(edge_indices, dtype=torch.long, device=device).t().contiguous()
            edge_attr = torch.stack(edge_attrs)

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            target_embed=t_embed_batch[i].unsqueeze(0) # Embed for this single graph
        )
        data_list.append(data)

    # 4. Re-batch the sparse graphs
    
    # --- CHANGED: Use Batch.from_data_list (more efficient) ---
    if not data_list:
        return Batch() # Return an empty batch object
        
    batch = Batch.from_data_list(data_list)
    
    # Squeeze the target_embed back to [B, T_EMBED_DIM]
    if batch.target_embed.dim() == 3:
        batch.target_embed = batch.target_embed.squeeze(1)
    
    return batch.to(device)


# --- 5.2. WGAN-GP Gradient Penalty (v3 Logic) ---
def calculate_gradient_penalty(discriminator, real_data, fake_data, lambda_gp, device):
    """
    Interpolates on the GLOBAL graph embeddings.
    """
    
    # --- Get graph embeddings first ---
    # We don't need eval/train switching if we only use .lin_final
    
    # --- Get Real graph_embed ---
    real_x, real_edge_index, real_edge_attr, real_batch = real_data.x, real_data.edge_index, real_data.edge_attr, real_data.batch
    real_t_embed = real_data.target_embed
    with torch.no_grad(): # Don't need grads for this part
        for layer in discriminator.layers:
            real_x = layer(real_x, real_edge_index, real_edge_attr)
            real_x = F.relu(real_x)
    real_graph_embed = global_mean_pool(real_x, real_batch)
    
    # --- Get Fake graph_embed ---
    fake_x, fake_edge_index, fake_edge_attr, fake_batch = fake_data.x, fake_data.edge_index, fake_data.edge_attr, fake_data.batch
    if fake_data.num_graphs == 0 or fake_x.size(0) == 0:
        return torch.tensor(0.0, device=device, requires_grad=True)
        
    fake_t_embed = fake_data.target_embed
    with torch.no_grad(): # Don't need grads for this part
        for layer in discriminator.layers:
            fake_x = layer(fake_x, fake_edge_index, fake_edge_attr)
            fake_x = F.relu(fake_x)
    fake_graph_embed = global_mean_pool(fake_x, fake_batch)
    
    
    # Match batch sizes if they differ
    batch_size = min(real_graph_embed.size(0), fake_graph_embed.size(0))
    if batch_size == 0:
        return torch.tensor(0.0, device=device, requires_grad=True)
        
    real_graph_embed = real_graph_embed[:batch_size]
    fake_graph_embed = fake_graph_embed[:batch_size]
    real_t_embed = real_t_embed[:batch_size]

    # --- Interpolate on graph_embed ---
    alpha = torch.rand(batch_size, 1).to(device)
    interpolated_embed = (alpha * real_graph_embed.detach()) + ((1 - alpha) * fake_graph_embed.detach())
    interpolated_embed.requires_grad_(True)
    
    # --- Combine with target and pass to *final layer only* ---
    final_input = torch.cat([interpolated_embed, real_t_embed.detach()], dim=1)
    disc_interpolates = discriminator.lin_final(final_input).squeeze(1)

    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolated_embed, # Grad w.r.t. interpolated_embed
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True, retain_graph=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

# --- 6. Main Training Loop (MODIFIED FOR CHECKPOINTING) ---
# --- 6. Main Training Loop (CORRECTED INDENTATION) ---
def run_wgan_gp_training(generator, discriminator, data_loader, epochs, n_critic, target_id, resume_from_checkpoint=None):
    
    generator.train()
    discriminator.train()
    
    # --- Create checkpoint directory ---
    CHECKPOINT_DIR = "checkpoints"
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    
    # --- History lists to store epoch averages ---
    d_loss_history = []
    g_loss_history = []
    d_real_history = []
    d_fake_history = []
    
    start_epoch = 1
    
    # --- Logic to resume from checkpoint ---
    if resume_from_checkpoint:
        if os.path.exists(resume_from_checkpoint):
            print(f"Loading checkpoint: {resume_from_checkpoint}")
            checkpoint = torch.load(resume_from_checkpoint, map_location=DEVICE)
            
            generator.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
            optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
            optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
            
            start_epoch = checkpoint['epoch'] + 1
            
            # Load histories
            d_loss_history = checkpoint.get('d_loss_history', [])
            g_loss_history = checkpoint.get('g_loss_history', [])
            d_real_history = checkpoint.get('d_real_history', [])
            d_fake_history = checkpoint.get('d_fake_history', [])
            
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print(f"Warning: Checkpoint file not found, starting from scratch: {resume_from_checkpoint}")

    
    for epoch in range(start_epoch, epochs + 1):
        # --- Reset sums for each epoch ---
        g_loss_sum, d_loss_sum = 0, 0
        d_real_sum, d_fake_sum = 0, 0
        
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        
        # --- THIS IS THE BATCH LOOP ---
        for batch_idx, real_data in enumerate(progress_bar):
            real_data = real_data.to(DEVICE)
            
            if real_data.num_graphs < 2: # Need at least 2 for GP
                continue
                
            batch_size = real_data.num_graphs
            target_embed_batch = real_data.target_embed

            # 1. Train Discriminator (n_critic steps)
            for _ in range(n_critic):
                optimizer_D.zero_grad()
                
                d_real = discriminator(real_data).mean()
                
                z = torch.randn(batch_size, Z_DIM).to(DEVICE)
                x_fake_logits, adj_fake_logits = generator(z, target_embed_batch)
                
                with torch.no_grad():
                    fake_data = convert_fake_to_SPARSE_data_vectorized(
                        x_fake_logits.detach(), adj_fake_logits.detach(), 
                        target_embed_batch, DEVICE, gumbel=False
                    )
                
                if fake_data.num_graphs == 0:
                    continue
                
                d_fake = discriminator(fake_data).mean()
                gp = calculate_gradient_penalty(discriminator, real_data, fake_data, LAMBDA_GP, DEVICE)
                
                d_loss = (d_fake - d_real) + gp
                
                d_loss.backward()
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
                optimizer_D.step()
                
                d_loss_sum += d_loss.item()
                d_real_sum += d_real.item() 
                d_fake_sum += d_fake.item()
            
            # 2. Train Generator (1 step)
            optimizer_G.zero_grad()
            
            z = torch.randn(batch_size, Z_DIM).to(DEVICE)
            x_fake_logits, adj_fake_logits = generator(z, target_embed_batch)
            
            # --- FIXED: We need the SOFT tensors for valency loss ---
            x_fake_soft = F.gumbel_softmax(x_fake_logits, tau=0.5, hard=False)
            adj_fake_soft = F.gumbel_softmax(adj_fake_logits, tau=0.5, hard=False)
            
            # --- Create the "hard" graph for the discriminator ---
            fake_data =fake_data = convert_fake_to_SPARSE_data_vectorized(
                x_fake_logits.detach(), adj_fake_logits.detach(), 
                target_embed_batch, DEVICE, gumbel=True
            )
            
            if fake_data.num_graphs == 0:
                continue

            # --- CALCULATE NEW LOSSES ---
            critic_loss = - discriminator(fake_data).mean()
            valency_loss = calculate_valency_loss(x_fake_soft, adj_fake_soft)
            g_loss = critic_loss + (LAMBDA_VALENCY * valency_loss)
            
            g_loss.backward()
            optimizer_G.step()
            
            g_loss_sum += g_loss.item()
                
            # --- Update postfix (THIS IS IN THE RIGHT PLACE) ---
            num_batches = batch_idx + 1
            progress_bar.set_postfix(
                D_Loss=f"{(d_loss_sum / num_batches / n_critic):.4f}", 
                G_Loss=f"{(g_loss_sum / num_batches):.4f}",
                D_Real=f"{(d_real_sum / num_batches / n_critic):.4f}",
                D_Fake=f"{(d_fake_sum / num_batches / n_critic):.4f}"
            )
        # --- END OF THE BATCH LOOP ---

        # --- !! MOVED: This logic now runs ONCE PER EPOCH !! ---
        num_batches_total = len(data_loader)
        if num_batches_total > 0:
            d_loss_history.append(d_loss_sum / num_batches_total / n_critic)
            g_loss_history.append(g_loss_sum / num_batches_total)
            d_real_history.append(d_real_sum / num_batches_total / n_critic)
            d_fake_history.append(d_fake_sum / num_batches_total / n_critic)
        else:
             d_loss_history.append(0)
             g_loss_history.append(0)
             d_real_history.append(0)
             d_fake_history.append(0)

        # --- !! MOVED: This logic now runs ONCE PER EPOCH !! ---
        if epoch % 3 == 0:
            checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{target_id}_epoch_{epoch}.pth")
            print(f"\nSaving checkpoint to {checkpoint_path}...")
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'd_loss_history': d_loss_history,
                'g_loss_history': g_loss_history,
                'd_real_history': d_real_history,
                'd_fake_history': d_fake_history
            }, checkpoint_path)
        # --- End of checkpointing logic ---

    # --- ADDED: Return the histories ---
    return d_loss_history, g_loss_history, d_real_history, d_fake_history
# --- 5.3. NEW: Differentiable Valency Loss ---
def calculate_valency_loss(x_fake_soft, adj_fake_soft):
    """
    Calculates a differentiable valency loss.
    x_fake_soft: [B, N, 9] (from Gumbel-Softmax, hard=False)
    adj_fake_soft: [B, N, N, 5] (from Gumbel-Softmax, hard=False)
    """
    
    # 1. Calculate the "expected max valency" for each atom
    # (x_fake_soft * VALENCY_TENSOR) -> [B, N, 9]
    # .sum(dim=-1) -> [B, N]
    expected_max_val = (x_fake_soft * VALENCY_TENSOR).sum(dim=-1)
    
    # 2. Calculate the "expected bond order" for each edge
    # (adj_fake_soft * BOND_ORDER_TENSOR) -> [B, N, N, 5]
    # .sum(dim=-1) -> [B, N, N]
    expected_bond_orders = (adj_fake_soft * BOND_ORDER_TENSOR).sum(dim=-1)
    
    # 3. Sum bond orders for each atom to get its "current valency"
    # .sum(dim=-1) sums over all 'k' for each 'j'
    current_valency = expected_bond_orders.sum(dim=-1)
    
    # 4. Calculate the error (how much we are *over* the max valency)
    # We only penalize *over-bonding*, so we use relu
    valency_error = F.relu(current_valency - expected_max_val)
    
    # 5. Return the mean error
    return valency_error.mean()

# --- 7. Generation & SMILES Conversion (v3 Logic) ---

def tensors_to_smiles(x_fake_one_hot, adj_fake_logits):
    """
    Converts raw generator tensor output (one-hot nodes) into SMILES strings.
    """
    # --- Get atomic number indices from one-hot nodes ---
    x_fake_indices = torch.argmax(x_fake_one_hot, dim=-1).cpu().detach()
    adj_fake_logits = adj_fake_logits.cpu().detach()
    
    # --- Get bond indices from 5-class logits ---
    adj_bond_type_idx = torch.argmax(adj_fake_logits, dim=-1)
    
    batch_size = x_fake_indices.size(0)
    generated_smiles = []
    generated_mols = []
    
    for i in range(batch_size):
        mol = Chem.RWMol()
        atom_map = {} # Map from tensor index (0..MAX_NODES-1) to RDKit atom index
        
        # 1. Add atoms
        for j in range(MAX_NODES):
            atom_idx = x_fake_indices[i, j].item()
            atom_num = ATOM_CLASSES[atom_idx]
            
            atom = Chem.Atom(atom_num)
            rdkit_idx = mol.AddAtom(atom)
            atom_map[j] = rdkit_idx
                
        # 2. Add bonds
        for j in range(MAX_NODES):
            for k in range(j + 1, MAX_NODES):
                # --- Get bond type from 5-class indices ---
                bond_type_idx = adj_bond_type_idx[i, j, k].item()
                
                # --- Add bond IF NOT "No Bond" ---
                if bond_type_idx != NO_BOND_IDX:
                    # --- Check if bond_type_idx is valid for RDKit bonds ---
                    if 0 <= bond_type_idx < len(BOND_CLASSES_RDKIT):
                        bond_type = BOND_CLASSES_RDKIT[bond_type_idx]
                        mol.AddBond(atom_map[j], atom_map[k], bond_type)
        
        # 3. Sanitize and Convert
        try:
            Chem.SanitizeMol(mol)
            smi = Chem.MolToSmiles(mol)
            
            # --- Filter out disconnected fragments ---
            if '.' in smi:
                generated_smiles.append(None) # Invalid fragment
                generated_mols.append(None)
            else:
                generated_smiles.append(smi)
                generated_mols.append(mol)
        except Exception as e:
            # print(f"RDKit Error: {e}") # Uncomment for debugging
            generated_smiles.append(None) # Invalid molecule
            generated_mols.append(None)

    valid_smiles = [s for s in generated_smiles if s is not None]
    valid_mols = [m for m in generated_mols if m is not None]
    
    return valid_smiles, valid_mols, generated_smiles

# --- 8. Performance Metrics & Plotting (v3 Logic) ---

def calculate_and_plot_metrics(generator, target_embed, real_smiles_list, num_to_generate, device):
    """
    Generates molecules and calculates Validity, Uniqueness, Novelty,
    and plots property distributions.
    """
    print("\n--- Starting Generation & Evaluation ---")
    warnings.filterwarnings('ignore', '.*Implicit valence.*') # Suppress RDKit warnings
    
    generator.eval() # Set generator to evaluation mode
    
    real_mols = [Chem.MolFromSmiles(s) for s in real_smiles_list]
    real_mols = [m for m in real_mols if m is not None]
    real_smiles_set = set(real_smiles_list)
    
    all_valid_smiles = []
    all_valid_mols = []
    total_attempts = 0 # Track total attempts

    print(f"Generating {num_to_generate} *valid* molecules for evaluation...")
    with torch.no_grad():
        while len(all_valid_smiles) < num_to_generate:
            batch_size = BATCH_SIZE
            total_attempts += batch_size

            z = torch.randn(batch_size, Z_DIM).to(device)
            t_embed_batch = target_embed.unsqueeze(0).repeat(batch_size, 1)
            
            x_fake_logits, adj_fake_logits = generator(z, t_embed_batch)
            
            # --- Use argmax (not Gumbel) for final generation ---
            x_indices = torch.argmax(x_fake_logits, dim=-1)
            x_fake_one_hot = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()
            
            # --- Pass 5-dim bond logits to smiles converter ---
            valid_smiles, valid_mols, _ = tensors_to_smiles(x_fake_one_hot, adj_fake_logits)
            
            all_valid_smiles.extend(valid_smiles)
            all_valid_mols.extend(valid_mols)
            
            print(f"Generated: {len(all_valid_smiles)}/{num_to_generate} valid molecules...", end='\r')
            
            if total_attempts > num_to_generate * 50 and not all_valid_smiles:
                    print("\nError: Generated too many molecules with 0 validity. Stopping.")
                    break
            if total_attempts > num_to_generate * 10 and (len(all_valid_smiles) < num_to_generate): 
                # If we've tried 10x and still not enough, stop early
                if len(all_valid_smiles) < (num_to_generate / 10):
                    print(f"\nWarning: Very low validity. Stopping generation at {len(all_valid_smiles)} molecules.")
                    break


    print("\nGeneration complete. Calculating metrics...")
    
    # --- 1. Calculate Metrics ---
    
    if total_attempts == 0: total_attempts = 1
    validity = len(all_valid_smiles) / total_attempts
    
    if len(all_valid_smiles) > 0:
        uniqueness = len(set(all_valid_smiles)) / len(all_valid_smiles)
    else:
        uniqueness = 0.0
        
    if len(all_valid_smiles) > 0:
        unique_valid_smiles = set(all_valid_smiles)
        novel_smiles = unique_valid_smiles - real_smiles_set
        novelty = len(novel_smiles) / len(unique_valid_smiles)
    else:
        novelty = 0.0

    print("\n--- Generative Performance Metrics ---")
    print(f"Total Attempts: {total_attempts}")
    print(f"Total Valid Generated: {len(all_valid_smiles)}")
    print(f"‚úÖ Validity:       {validity * 100:.2f}%")
    print(f"üß¨ Uniqueness:     {uniqueness * 100:.2f}%")
    print(f"‚≠ê Novelty:        {novelty * 100:.2f}%")
    print("----------------------------------------")
    
    if not all_valid_mols:
        print("No valid molecules generated. Skipping plots.")
        return

    # --- 2. Calculate Properties ---
    props_real = {
        'MolWt': [Descriptors.MolWt(m) for m in real_mols],
        'LogP': [Descriptors.MolLogP(m) for m in real_mols],
        'QED': [QED.qed(m) for m in real_mols]
    }
    
    props_fake = {
        'MolWt': [Descriptors.MolWt(m) for m in all_valid_mols],
        'LogP': [Descriptors.MolLogP(m) for m in all_valid_mols],
        'QED': [QED.qed(m) for m in all_valid_mols]
    }

    # --- 3. Plot Distributions ---
    print("Generating property distribution plots...")
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    plot_titles = ['Molecular Weight (MolWt)', 'LogP', 'Quantitative Esimation of Drug-likeness (QED)']
    prop_keys = ['MolWt', 'LogP', 'QED']
    
    for ax, title, key in zip(axes, plot_titles, prop_keys):
        ax.hist(props_real[key], bins=50, alpha=0.7, label='Real (Training)', color='blue', density=True)
        ax.hist(props_fake[key], bins=50, alpha=0.7, label='Generated (Fake)', color='red', density=True)
        ax.set_title(title)
        ax.set_xlabel("Value")
        ax.set_ylabel("Density")
        ax.legend()
        
    plt.suptitle(f"Property Distributions (Real vs. Generated) for {TARGET_UNIPROT_ID}", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(f"property_plots_{TARGET_UNIPROT_ID}_v3.1_checkpointed.png")
    print(f"Plots saved to property_plots_{TARGET_UNIPROT_ID}_v3.1_checkpointed.png")
    plt.show()
    
    warnings.filterwarnings('default', '.*Implicit valence.*') # Restore warnings

# --- NEW FUNCTION (insert before Section 9) ---
def plot_training_losses(d_loss_hist, g_loss_hist, d_real_hist, d_fake_hist, target_id):
    """
    Plots the training history of WGAN-GP losses and scores.
    """
    print("Generating training loss plots...")
    if not g_loss_hist: # Check if histories are empty
        print("No loss history found. Skipping loss plots.")
        return
        
    epochs_range = range(1, len(g_loss_hist) + 1)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    
    # --- Plot 1: G_Loss vs D_Loss ---
    ax1.plot(epochs_range, g_loss_hist, label='Generator Loss (G_Loss)', color='blue')
    ax1.plot(epochs_range, d_loss_hist, label='Discriminator Loss (D_Loss)', color='red')
    ax1.set_title(f"Generator & Discriminator Losses for {target_id}")
    ax1.set_ylabel("Loss")
    ax1.legend()
    ax1.grid(True)
    
    # --- Plot 2: D(Real) vs D(Fake) ---
    ax2.plot(epochs_range, d_real_hist, label='Avg. D(Real) Score', color='green')
    ax2.plot(epochs_range, d_fake_hist, label='Avg. D(Fake) Score', color='orange')
    ax2.set_title(f"Critic Scores for {target_id}")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Score")
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"training_loss_plots_{target_id}_v3.1_checkpointed.png")
    print(f"Loss plots saved to training_loss_plots_{target_id}_v3.1_checkpointed.png")
    plt.show()

# --- 9. --- Main Execution (Train & Evaluate) ---

# --- !! MODIFIED: Set resume path here !! ---
# Set to None to start a new training run
# Set to a valid path to resume, e.g., "checkpoints/P00533_epoch_3.pth"
RESUME_CHECKPOINT_PATH = "checkpoints/P00533_epoch_39.pth" # <-- CHANGE THIS

# --- Execute Training ---
print("\n--- Starting WGAN-GP Training (v3.1 - Sparse & Stabilized) ---")
# --- MODIFIED: Capture the returned histories & pass resume path ---
d_loss_hist, g_loss_hist, d_real_hist, d_fake_hist = run_wgan_gp_training(
    generator, 
    discriminator, 
    real_loader, 
    EPOCHS, 
    N_CRITIC,
    TARGET_UNIPROT_ID, # <-- Pass the ID for checkpoint filenames
    resume_from_checkpoint=RESUME_CHECKPOINT_PATH # <-- Pass the resume path
) 
print("\nTraining completed.")

# --- ADDED: Call the new plotting function ---
plot_training_losses(d_loss_hist, g_loss_hist, d_real_hist, d_fake_hist, TARGET_UNIPROT_ID)

# --- Execute Evaluation ---
# This will evaluate the FINAL model state after training/resuming
num_to_eval = len(real_data_list) # --- Use count of *filtered* real data ---
calculate_and_plot_metrics(generator, TARGET_EMBED, inhibitor_smiles, num_to_eval, DEVICE)

# --- To plot from a specific checkpoint (e.g., epoch 30) ---
# 1. You would *comment out* the training run and evaluation above.
# 2. You would uncomment and run code like this:

# print("\n--- Loading Checkpoint for Evaluation ONLY ---")
# CHECKPOINT_TO_EVAL = "checkpoints/P00533_epoch_30.pth" 
# if os.path.exists(CHECKPOINT_TO_EVAL):
#     checkpoint = torch.load(CHECKPOINT_TO_EVAL, map_location=DEVICE)
    
#     # Load model state
#     generator.load_state_dict(checkpoint['generator_state_dict'])
    
#     # Plot metrics from this checkpoint's generator
#     print(f"Evaluating generator from epoch {checkpoint['epoch']}...")
#     calculate_and_plot_metrics(generator, TARGET_EMBED, inhibitor_smiles, num_to_eval, DEVICE)
    
#     # Plot loss history *up to* this checkpoint
#     print(f"Plotting loss history up to epoch {checkpoint['epoch']}...")
#     plot_training_losses(
#         checkpoint['d_loss_history'],
#         checkpoint['g_loss_history'],
#         checkpoint['d_real_history'],
#         checkpoint['d_fake_history'],
#         TARGET_UNIPROT_ID
#     )
# else:
#     print(f"Checkpoint not found: {CHECKPOINT_TO_EVAL}")

‚úÖ CUDA is available! GPU will be used for training.
PyTorch CUDA Version: 12.1
GPU Name: NVIDIA GeForce RTX 4060 Laptop GPU
Using device: cuda
Opening gzipped FASTA file: DL_ENDSEM__DATASET/chembl_35_blast.fa.gz
Loading ProtT5 model... (This may take a moment)
Generated protein embedding of shape: torch.Size([1024])
Found 6727 potent inhibitors for UniProt ID P00533.
Prepared 2612 real graph samples for training.
Initializing models...

--- Starting WGAN-GP Training (v3.1 - Sparse & Stabilized) ---


Epoch 1/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:00<00:00,  4.39s/it, D_Fake=0.0785, D_Loss=6.3134, D_Real=0.1290, G_Loss=422.3418]
Epoch 2/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:52<00:00,  4.22s/it, D_Fake=0.0750, D_Loss=6.2448, D_Real=0.1615, G_Loss=420.2953]
Epoch 3/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:00<00:00,  4.40s/it, D_Fake=0.0736, D_Loss=6.1762, D_Real=0.1961, G_Loss=417.7917]



Saving checkpoint to checkpoints\P00533_epoch_3.pth...


Epoch 4/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:56<00:00,  4.29s/it, D_Fake=0.0738, D_Loss=6.1063, D_Real=0.2337, G_Loss=415.3689]
Epoch 5/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:56<00:00,  4.31s/it, D_Fake=0.0738, D_Loss=6.0334, D_Real=0.2740, G_Loss=412.4175]
Epoch 6/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:02<00:00,  4.46s/it, D_Fake=0.0707, D_Loss=5.9543, D_Real=0.3174, G_Loss=409.2957]



Saving checkpoint to checkpoints\P00533_epoch_6.pth...


Epoch 7/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:00<00:00,  4.39s/it, D_Fake=0.0634, D_Loss=5.8679, D_Real=0.3639, G_Loss=405.4732]
Epoch 8/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:59<00:00,  4.37s/it, D_Fake=0.0528, D_Loss=5.7743, D_Real=0.4143, G_Loss=401.0146]
Epoch 9/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:58<00:00,  4.36s/it, D_Fake=0.0366, D_Loss=5.6705, D_Real=0.4694, G_Loss=396.2563]



Saving checkpoint to checkpoints\P00533_epoch_9.pth...


Epoch 10/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:58<00:00,  4.36s/it, D_Fake=0.0090, D_Loss=5.5496, D_Real=0.5302, G_Loss=391.2940]
Epoch 11/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:56<00:00,  4.30s/it, D_Fake=-0.0193, D_Loss=5.4220, D_Real=0.5969, G_Loss=385.5860]
Epoch 12/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:26<00:00,  5.05s/it, D_Fake=-0.0584, D_Loss=5.2756, D_Real=0.6717, G_Loss=379.4465]



Saving checkpoint to checkpoints\P00533_epoch_12.pth...


Epoch 13/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:23<00:00,  6.44s/it, D_Fake=-0.1063, D_Loss=5.1100, D_Real=0.7568, G_Loss=372.7449]
Epoch 14/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:25<00:00,  6.48s/it, D_Fake=-0.1587, D_Loss=4.9298, D_Real=0.8522, G_Loss=365.3006]
Epoch 15/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:30<00:00,  6.59s/it, D_Fake=-0.2215, D_Loss=4.7254, D_Real=0.9614, G_Loss=357.7542]



Saving checkpoint to checkpoints\P00533_epoch_15.pth...


Epoch 16/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:32<00:00,  6.65s/it, D_Fake=-0.2899, D_Loss=4.5031, D_Real=1.0829, G_Loss=349.8709]
Epoch 17/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:26<00:00,  6.51s/it, D_Fake=-0.3700, D_Loss=4.2543, D_Real=1.2192, G_Loss=340.6096]
Epoch 18/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:23<00:00,  6.43s/it, D_Fake=-0.4641, D_Loss=3.9753, D_Real=1.3718, G_Loss=331.6261]



Saving checkpoint to checkpoints\P00533_epoch_18.pth...


Epoch 19/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:22<00:00,  6.39s/it, D_Fake=-0.5653, D_Loss=3.6684, D_Real=1.5452, G_Loss=321.8462]
Epoch 20/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:34<00:00,  6.70s/it, D_Fake=-0.6721, D_Loss=3.3369, D_Real=1.7376, G_Loss=311.6981]
Epoch 21/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:22<00:00,  6.41s/it, D_Fake=-0.7978, D_Loss=2.9639, D_Real=1.9526, G_Loss=301.4414]



Saving checkpoint to checkpoints\P00533_epoch_21.pth...


Epoch 22/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:16<00:00,  6.26s/it, D_Fake=-0.9233, D_Loss=2.5713, D_Real=2.1876, G_Loss=290.2246]
Epoch 23/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:23<00:00,  6.43s/it, D_Fake=-1.0758, D_Loss=2.1222, D_Real=2.4520, G_Loss=279.1951]
Epoch 24/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:19<00:00,  6.33s/it, D_Fake=-1.2311, D_Loss=1.6402, D_Real=2.7466, G_Loss=267.6787]



Saving checkpoint to checkpoints\P00533_epoch_24.pth...


Epoch 25/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:09<00:00,  6.07s/it, D_Fake=-1.4153, D_Loss=1.0994, D_Real=3.0712, G_Loss=255.7196]
Epoch 26/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:12<00:00,  6.16s/it, D_Fake=-1.6114, D_Loss=0.5098, D_Real=3.4327, G_Loss=243.7436]
Epoch 27/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [04:08<00:00,  6.05s/it, D_Fake=-1.8269, D_Loss=-0.1352, D_Real=3.8302, G_Loss=231.5846]



Saving checkpoint to checkpoints\P00533_epoch_27.pth...


Epoch 28/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:53<00:00,  5.70s/it, D_Fake=-2.0838, D_Loss=-0.8640, D_Real=4.2702, G_Loss=218.7500]
Epoch 29/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:46<00:00,  5.54s/it, D_Fake=-2.3540, D_Loss=-1.6545, D_Real=4.7586, G_Loss=205.6269]
Epoch 30/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:46<00:00,  5.54s/it, D_Fake=-2.6489, D_Loss=-2.5252, D_Real=5.3026, G_Loss=192.4258]



Saving checkpoint to checkpoints\P00533_epoch_30.pth...


Epoch 31/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:41<00:00,  5.39s/it, D_Fake=-2.9247, D_Loss=-3.4320, D_Real=5.9020, G_Loss=179.0131]
Epoch 32/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:38<00:00,  5.32s/it, D_Fake=-3.1658, D_Loss=-4.3609, D_Real=6.5580, G_Loss=165.7755]
Epoch 33/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:42<00:00,  5.42s/it, D_Fake=-3.2347, D_Loss=-5.1959, D_Real=7.2925, G_Loss=151.6629]



Saving checkpoint to checkpoints\P00533_epoch_33.pth...


Epoch 34/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:38<00:00,  5.33s/it, D_Fake=-3.0599, D_Loss=-5.8637, D_Real=8.1036, G_Loss=138.6071]
Epoch 35/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:32<00:00,  5.19s/it, D_Fake=-2.6266, D_Loss=-6.3577, D_Real=8.9996, G_Loss=124.5354]
Epoch 36/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:19<00:00,  4.86s/it, D_Fake=-2.0827, D_Loss=-6.8263, D_Real=9.9810, G_Loss=111.3535]



Saving checkpoint to checkpoints\P00533_epoch_36.pth...


Epoch 37/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:20<00:00,  4.90s/it, D_Fake=-1.5849, D_Loss=-7.4504, D_Real=11.0719, G_Loss=98.7935] 
Epoch 38/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:25<00:00,  5.01s/it, D_Fake=-1.1493, D_Loss=-8.2244, D_Real=12.2506, G_Loss=86.0722]
Epoch 39/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:21<00:00,  4.91s/it, D_Fake=-0.7783, D_Loss=-9.2058, D_Real=13.5722, G_Loss=73.9501]



Saving checkpoint to checkpoints\P00533_epoch_39.pth...


Epoch 40/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:25<00:00,  5.02s/it, D_Fake=-0.4598, D_Loss=-10.3516, D_Real=15.0059, G_Loss=62.4002]
Epoch 41/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:19<00:00,  4.87s/it, D_Fake=-0.2108, D_Loss=-11.6765, D_Real=16.5491, G_Loss=51.2625]
Epoch 42/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:20<00:00,  4.89s/it, D_Fake=-0.0311, D_Loss=-13.2513, D_Real=18.2731, G_Loss=41.4048]



Saving checkpoint to checkpoints\P00533_epoch_42.pth...


Epoch 43/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:18<00:00,  4.83s/it, D_Fake=0.0555, D_Loss=-15.0720, D_Real=20.1496, G_Loss=32.0571]
Epoch 44/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:26<00:00,  5.03s/it, D_Fake=0.0782, D_Loss=-17.1005, D_Real=22.1700, G_Loss=24.2391]
Epoch 45/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:26<00:00,  5.04s/it, D_Fake=0.0812, D_Loss=-19.3529, D_Real=24.3948, G_Loss=17.9597]



Saving checkpoint to checkpoints\P00533_epoch_45.pth...


Epoch 46/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [03:11<00:00,  4.66s/it, D_Fake=0.0818, D_Loss=-21.7847, D_Real=26.7966, G_Loss=13.0526]
Epoch 47/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:51<00:00,  4.19s/it, D_Fake=0.0823, D_Loss=-24.4242, D_Real=29.4061, G_Loss=9.4442] 
Epoch 48/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [02:30<00:00,  3.66s/it, D_Fake=0.0828, D_Loss=-27.2789, D_Real=32.2307, G_Loss=7.1539]



Saving checkpoint to checkpoints\P00533_epoch_48.pth...


Epoch 49/100:   0%|          | 0/41 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [11]:
import torch
import torch.nn.functional as F
from rdkit import Chem
import io
import sys

# --- SET THIS to your latest checkpoint ---
CHECKPOINT_TO_LOAD = "checkpoints/P00533_epoch_30.pth"
NUM_TO_DIAGNOSE = 100

print(f"Loading checkpoint: {CHECKPOINT_TO_LOAD}")
try:
    checkpoint = torch.load(CHECKPOINT_TO_LOAD, map_location=DEVICE)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    generator.eval()

    print(f"Generating {NUM_TO_DIAGNOSE} molecules for diagnosis...")
    with torch.no_grad():
        z = torch.randn(NUM_TO_DIAGNOSE, Z_DIM).to(DEVICE)
        t_embed_batch = TARGET_EMBED.unsqueeze(0).repeat(NUM_TO_DIAGNOSE, 1)
        x_fake_logits, adj_fake_logits = generator(z, t_embed_batch)

        x_indices = torch.argmax(x_fake_logits, dim=-1)
        x_fake_one_hot = F.one_hot(x_indices, num_classes=ATOM_FEAT_DIM).float()
        adj_bond_type_idx = torch.argmax(adj_fake_logits, dim=-1)

    print("\n--- DIAGNOSTIC REPORT ---")
    valid_count = 0
    error_counts = {}
    
    # Capture stderr to see RDKit errors
    stderr_capture = io.StringIO()
    original_stderr = sys.stderr
    sys.stderr = stderr_capture

    for i in range(NUM_TO_DIAGNOSE):
        mol = Chem.RWMol()
        atom_map = {}
        for j in range(MAX_NODES):
            atom_idx = x_indices[i, j].item()
            atom = Chem.Atom(ATOM_CLASSES[atom_idx])
            atom_map[j] = mol.AddAtom(atom)
        for j in range(MAX_NODES):
            for k in range(j + 1, MAX_NODES):
                bond_idx = adj_bond_type_idx[i, j, k].item()
                if bond_idx != NO_BOND_IDX and 0 <= bond_idx < len(BOND_CLASSES_RDKIT):
                    mol.AddBond(atom_map[j], atom_map[k], BOND_CLASSES_RDKIT[bond_idx])
        
        try:
            Chem.SanitizeMol(mol)
            smi = Chem.MolToSmiles(mol)
            if '.' in smi:
                error = "Disconnected Fragments"
            else:
                valid_count += 1
                error = "None (Valid)"
        except Exception:
            # Catch the RDKit error from stderr
            error_msg = stderr_capture.getvalue().split('\n')[-2] if stderr_capture.getvalue() else "Unknown RDKit Error"
            # Simplify common errors for counting
            if "valence" in error_msg.lower(): error = "Valence Error"
            elif "kekulize" in error_msg.lower(): error = "Ring/Kekulization Error"
            else: error = "Other RDKit Error"

        error_counts[error] = error_counts.get(error, 0) + 1
    
    # Restore stderr
    sys.stderr = original_stderr
    
    print(f"Total generated: {NUM_TO_DIAGNOSE}")
    print(f"Valid molecules: {valid_count}")
    print("\nFailure Modes:")
    for error, count in error_counts.items():
        print(f" - {error}: {count}")

except Exception as e:
    print(f"\nERROR: {e}")

Loading checkpoint: checkpoints/P00533_epoch_30.pth
Generating 100 molecules for diagnosis...

--- DIAGNOSTIC REPORT ---
Total generated: 100
Valid molecules: 1

Failure Modes:
 - Valence Error: 94
 - Disconnected Fragments: 5
 - None (Valid): 1
