In [1]:
import pandas as pd
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, global_add_pool
from torch_geometric.utils import add_self_loops, to_undirected
import os
import numpy as np
import pickle
import time
import traceback
from datetime import datetime
from copy import deepcopy
import random
import warnings
warnings.filterwarnings('ignore')

# Atom property dictionary
atom_property_dict = {
    'H': {'atomic_num': 1, 'mass': 1.008, 'electronegativity': 2.20, 'vdw_radius': 1.20},
    'C': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'N': {'atomic_num': 7, 'mass': 14.007, 'electronegativity': 3.04, 'vdw_radius': 1.55},
    'O': {'atomic_num': 8, 'mass': 15.999, 'electronegativity': 3.44, 'vdw_radius': 1.52},
    'P': {'atomic_num': 15, 'mass': 30.974, 'electronegativity': 2.19, 'vdw_radius': 1.80},
    'S': {'atomic_num': 16, 'mass': 32.065, 'electronegativity': 2.58, 'vdw_radius': 1.80},
    'F': {'atomic_num': 9, 'mass': 18.998, 'electronegativity': 3.98, 'vdw_radius': 1.47},
    'Cl': {'atomic_num': 17, 'mass': 35.453, 'electronegativity': 3.16, 'vdw_radius': 1.75},
    'Br': {'atomic_num': 35, 'mass': 79.904, 'electronegativity': 2.96, 'vdw_radius': 1.85},
    'I': {'atomic_num': 53, 'mass': 126.904, 'electronegativity': 2.66, 'vdw_radius': 1.98},
    'CA': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'CZ': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'OG': {'atomic_num': 8, 'mass': 15.999, 'electronegativity': 3.44, 'vdw_radius': 1.52},
    'ZN': {'atomic_num': 30, 'mass': 65.38, 'electronegativity': 1.65, 'vdw_radius': 1.39},
    'MG': {'atomic_num': 12, 'mass': 24.305, 'electronegativity': 1.31, 'vdw_radius': 1.73},
    'FE': {'atomic_num': 26, 'mass': 55.845, 'electronegativity': 1.83, 'vdw_radius': 1.72},
    'MN': {'atomic_num': 25, 'mass': 54.938, 'electronegativity': 1.55, 'vdw_radius': 1.73},
    'CU': {'atomic_num': 29, 'mass': 63.546, 'electronegativity': 1.90, 'vdw_radius': 1.40},
}

def load_csv(csv_path, use_half=False):
    """Load CSV with optional half sampling"""
    df = pd.read_csv(csv_path)
    df = df[df['Affinity_pK'] != 0]
    
    if use_half:
        half_size = len(df) // 2
        df = df.head(half_size)
        print(f"Using half dataset: {half_size} samples from {csv_path}")
    
    return df

def create_enhanced_features(node, atom_property_dict, graph_type='P'):
    """Create enhanced node features"""
    atom_type = node['attype']
    prop = atom_property_dict.get(atom_type, 
                                 {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70})
    
    if 'pl' in node:
        is_protein = node['pl'] == 'P'
        is_ligand = node['pl'] == 'L'
        is_interaction = graph_type == 'I'
    else:
        is_protein = graph_type == 'P'
        is_ligand = graph_type == 'L'
        is_interaction = graph_type == 'I'
    
    features = [
        prop['atomic_num'] / 30.0, prop['mass'] / 100.0, prop['electronegativity'] / 4.0, prop['vdw_radius'] / 2.0,
        prop['atomic_num'] ** 0.5 / 5.5, prop['mass'] / prop['atomic_num'], 1.0 / prop['electronegativity'], prop['vdw_radius'] ** 2,
        1.0 if prop['atomic_num'] in [6] else 0.0, 1.0 if prop['atomic_num'] in [7] else 0.0,
        1.0 if prop['atomic_num'] in [8] else 0.0, 1.0 if prop['atomic_num'] in [16] else 0.0,
        1.0 if prop['atomic_num'] > 10 else 0.0, 1.0 if prop['electronegativity'] > 3.0 else 0.0,
        1.0 if is_protein else 0.0, 1.0 if is_ligand else 0.0, 1.0 if is_interaction else 0.0,
    ]
    return features

def load_single_graph(pdb_id, base_path, graph_type):
    """Load single graph with proper node types"""
    if graph_type == 'P':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_protein_graph.json')
    elif graph_type == 'L':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_ligand_graph.json')
    elif graph_type == 'I':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_interaction_graph.json')
    else:
        return None
    
    try:
        with open(json_path, 'r') as file:
            graph = json.load(file)
    except FileNotFoundError:
        return None

    if not graph['nodes']:
        return None

    node_features = []
    node_types = []
    
    for node in graph['nodes']:
        features = create_enhanced_features(node, atom_property_dict, graph_type)
        node_features.append(features)
        
        # Extract proper node types
        if 'pl' in node:
            node_types.append(node['pl'])
        else:
            node_types.append(graph_type)

    node_features = torch.tensor(node_features, dtype=torch.float)
    
    edge_index = []
    edge_features = []
    
    for edge in graph['edges']:
        if edge['id1'] is not None and edge['id2'] is not None:
            length = max(edge['length'], 0.1)
            edge_index.append([edge['id1'], edge['id2']])
            edge_features.append([length / 10.0, 1.0 / length, np.exp(-length/2.0)])

    if not edge_index:
        num_nodes = len(node_features)
        edge_index = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
        edge_features = torch.ones(num_nodes, 3) * 0.5
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_features = torch.tensor(edge_features, dtype=torch.float)
        edge_index = to_undirected(edge_index)
        if edge_features.size(0) * 2 == edge_index.size(1):
            edge_features = edge_features.repeat(2, 1)

    return {
        'node_features': node_features,
        'edge_index': edge_index,
        'edge_features': edge_features,
        'num_nodes': len(node_features),
        'graph_type': graph_type,
        'node_types': node_types,
        'pdb_id': pdb_id,
        'combination': graph_type
    }

def merge_graphs(graphs):
    """Merge multiple graphs preserving node types"""
    all_node_features = []
    all_edge_indices = []
    all_edge_features = []
    all_node_types = []
    
    node_offset = 0
    
    for graph in graphs:
        if graph is None:
            continue
            
        all_node_features.append(graph['node_features'])
        adjusted_edge_index = graph['edge_index'] + node_offset
        all_edge_indices.append(adjusted_edge_index)
        all_edge_features.append(graph['edge_features'])
        all_node_types.extend(graph['node_types'])
        
        node_offset += graph['num_nodes']
    
    if not all_node_features:
        return None
    
    merged_node_features = torch.cat(all_node_features, dim=0)
    merged_edge_index = torch.cat(all_edge_indices, dim=1) if all_edge_indices else torch.empty((2, 0), dtype=torch.long)
    merged_edge_features = torch.cat(all_edge_features, dim=0) if all_edge_features else torch.empty((0, 3))
    
    return {
        'node_features': merged_node_features,
        'edge_index': merged_edge_index,
        'edge_features': merged_edge_features,
        'node_types': all_node_types,
        'num_nodes': merged_node_features.size(0)
    }

def load_combined_graph(pdb_id, base_path, combination):
    """Load and combine graphs for specific combination"""
    graphs_to_load = []
    
    if 'P' in combination:
        graphs_to_load.append('P')
    if 'L' in combination:
        graphs_to_load.append('L')
    if 'I' in combination:
        graphs_to_load.append('I')
    
    loaded_graphs = []
    for graph_type in graphs_to_load:
        graph = load_single_graph(pdb_id, base_path, graph_type)
        loaded_graphs.append(graph)
    
    merged_result = merge_graphs(loaded_graphs)
    if merged_result is None:
        return None
    
    # Store complete graph data
    graph_data = {
        'pdb_id': pdb_id,
        'combination': combination,
        'node_features': merged_result['node_features'].numpy().tolist(),
        'edge_index': merged_result['edge_index'].numpy().tolist(),
        'edge_features': merged_result['edge_features'].numpy().tolist(),
        'node_types': merged_result['node_types'],
        'num_nodes': merged_result['num_nodes']
    }
    
    return graph_data

def save_graph_data(graph_data, output_dir, dataset_name, affinity):
    """Save graph data with affinity"""
    pdb_id = graph_data['pdb_id']
    combination = graph_data['combination']
    
    # Create directory
    pdb_dir = os.path.join(output_dir, dataset_name, pdb_id)
    os.makedirs(pdb_dir, exist_ok=True)
    
    # Save graph
    graph_path = os.path.join(pdb_dir, f'{pdb_id}_{combination}.pkl')
    with open(graph_path, 'wb') as f:
        pickle.dump(graph_data, f)
    
    # Save affinity
    affinity_path = os.path.join(pdb_dir, f'{pdb_id}_affinity.pkl')
    affinity_data = {'affinity': affinity, 'pdb_id': pdb_id}
    with open(affinity_path, 'wb') as f:
        pickle.dump(affinity_data, f)

# Fast GAN Classes
class FastGraphAugmentationGenerator(nn.Module):
    def __init__(self, node_features=17, edge_features=3, hidden_dim=64):
        super(ImprovedFastGraphAugmentationGenerator, self).__init__()
        
        self.node_encoder = nn.Sequential(
            nn.Linear(node_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 32)
        )
        
        # Less aggressive dropout (max 15% instead of 30%)
        self.node_drop_predictor = nn.Sequential(
            nn.Linear(32, 1), 
            nn.Sigmoid()
        )
        self.edge_drop_predictor = nn.Sequential(
            nn.Linear(64, 1), 
            nn.Sigmoid()
        )
        
        # Smaller noise scale
        self.node_noise_predictor = nn.Sequential(
            nn.Linear(32, node_features), 
            nn.Tanh()
        )
        self.edge_noise_predictor = nn.Sequential(
            nn.Linear(64, edge_features), 
            nn.Tanh()
        )
        
    def forward(self, node_features, edge_index, edge_features):
        node_emb = self.node_encoder(node_features)
        node_drop_probs = self.node_drop_predictor(node_emb).squeeze()
        node_noise = self.node_noise_predictor(node_emb) * 0.08  # Reduced from 0.15
        
        edge_embs = []
        for i in range(edge_index.size(1)):
            edge_emb = torch.cat([node_emb[edge_index[0, i]], node_emb[edge_index[1, i]]])
            edge_embs.append(edge_emb)
        
        if edge_embs:
            edge_embs = torch.stack(edge_embs)
            edge_drop_probs = self.edge_drop_predictor(edge_embs).squeeze()
            edge_noise = self.edge_noise_predictor(edge_embs) * 0.06  # Reduced from 0.12
        else:
            edge_drop_probs = torch.tensor([])
            edge_noise = torch.empty(0, edge_features.size(-1))
        
        return node_drop_probs, edge_drop_probs, node_noise, edge_noise

class FastGraphDiscriminator(nn.Module):
    def __init__(self, node_features=17, edge_features=3, hidden_dim=64):
        super(FastGraphDiscriminator, self).__init__()
        
        self.conv1 = GCNConv(node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, 32)
        
        self.classifier = nn.Sequential(
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return self.classifier(x)

def prepare_real_graphs(df, base_path, combination, output_dir, dataset_name):
    """Prepare real graphs from JSON files"""
    print(f"  Processing {dataset_name}: {len(df)} entries for {combination}...")
    
    successful = 0
    failed = 0
    
    for _, row in df.iterrows():
        pdb_id, affinity = row['PDB_ID'], row['Affinity_pK']
        
        if np.isnan(affinity) or np.isinf(affinity):
            failed += 1
            continue
            
        graph_data = load_combined_graph(pdb_id, base_path, combination)
        if graph_data is not None:
            save_graph_data(graph_data, output_dir, dataset_name, affinity)
            successful += 1
        else:
            failed += 1
    
    print(f"  {dataset_name} {combination}: {successful} real graphs saved, {failed} failed")
    return successful

# GAN Components for Graph Generation
class GraphEncoder(nn.Module):
    def __init__(self, node_features=17, edge_features=3, hidden_dim=128, latent_dim=64):
        super(GraphEncoder, self).__init__()
        self.node_norm = nn.LayerNorm(node_features)
        self.edge_norm = nn.LayerNorm(edge_features)
        
        self.node_conv1 = GCNConv(node_features, hidden_dim)
        self.node_conv2 = GCNConv(hidden_dim, hidden_dim)
        self.node_conv3 = GCNConv(hidden_dim, latent_dim)
        
        self.node_bn1 = nn.LayerNorm(hidden_dim)
        self.node_bn2 = nn.LayerNorm(hidden_dim)
        
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_features, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, latent_dim)
        )
        
        self.attention = nn.Sequential(
            nn.Linear(latent_dim, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )
        
        self.graph_mlp = nn.Sequential(
            nn.Linear(latent_dim * 3, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, latent_dim)
        )
        
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_norm(x)
        
        h1 = F.relu(self.node_bn1(self.node_conv1(x, edge_index)))
        h2 = F.relu(self.node_bn2(self.node_conv2(h1, edge_index)))
        node_emb = self.node_conv3(h2, edge_index)
        
        node_pool_mean = global_mean_pool(node_emb, batch)
        
        att_weights = self.attention(node_emb)
        att_weights = F.softmax(att_weights, dim=0)
        node_pool_att = global_add_pool(node_emb * att_weights, batch)
        
        if edge_attr.size(0) > 0:
            edge_attr = self.edge_norm(edge_attr)
            edge_emb = self.edge_mlp(edge_attr)
            edge_batch = batch[edge_index[0]]
            edge_pool = global_mean_pool(edge_emb, edge_batch)
            
            num_graphs = batch.max().item() + 1
            if edge_pool.size(0) < num_graphs:
                padding = torch.zeros(num_graphs - edge_pool.size(0), edge_pool.size(1)).to(x.device)
                edge_pool = torch.cat([edge_pool, padding], dim=0)
        else:
            num_graphs = batch.max().item() + 1
            edge_pool = torch.zeros(num_graphs, self.edge_mlp[-1].out_features).to(x.device)
        
        graph_emb = torch.cat([node_pool_mean, node_pool_att, edge_pool], dim=1)
        return self.graph_mlp(graph_emb)

class GraphAugmentationGenerator(nn.Module):
    def __init__(self, node_features=17, edge_features=3, hidden_dim=128):
        super(GraphAugmentationGenerator, self).__init__()
        
        # Node feature augmentation
        self.node_augment = nn.Sequential(
            nn.Linear(node_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, node_features),
            nn.Tanh()  # Scale augmentation
        )
        
        # Edge augmentation
        self.edge_augment = nn.Sequential(
            nn.Linear(edge_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, edge_features),
            nn.Tanh()
        )
        
        # Edge probability modification
        self.edge_prob_mlp = nn.Sequential(
            nn.Linear(node_features * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def forward(self, node_features, edge_index, edge_features):
        # Augment node features
        node_noise = self.node_augment(node_features) * 0.15  # 15% max change
        augmented_nodes = node_features + node_noise
        
        # Augment edge features
        if edge_features.size(0) > 0:
            edge_noise = self.edge_augment(edge_features) * 0.12  # 12% max change
            augmented_edges = edge_features + edge_noise
        else:
            augmented_edges = edge_features
        
        # Edge probability modifications
        edge_probs = []
        for i in range(edge_index.size(1)):
            node_i = augmented_nodes[edge_index[0, i]]
            node_j = augmented_nodes[edge_index[1, i]]
            edge_input = torch.cat([node_i, node_j])
            prob = self.edge_prob_mlp(edge_input)
            edge_probs.append(prob)
        
        edge_probs = torch.stack(edge_probs).squeeze()
        
        return augmented_nodes, augmented_edges, edge_probs

class GraphDiscriminator(nn.Module):
    def __init__(self, node_features=17, edge_features=3, hidden_dim=128):
        super(GraphDiscriminator, self).__init__()
        
        self.encoder = GraphEncoder(node_features, edge_features, hidden_dim, 64)
        
        self.classifier = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(64, 128)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.utils.spectral_norm(nn.Linear(128, 64)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.utils.spectral_norm(nn.Linear(64, 1))
        )
    
    def forward(self, x, edge_index, edge_attr, batch):
        graph_emb = self.encoder(x, edge_index, edge_attr, batch)
        return self.classifier(graph_emb)

class MolecularGraphGAN:
    def __init__(self, device='cuda'):
        self.device = device
        self.generator = GraphAugmentationGenerator().to(device)
        self.discriminator = GraphDiscriminator().to(device)
        
        self.g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001)
        self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.002)
        
    def convert_graph_to_pyg(self, graph_data):
        """Convert graph dict to PyG Data object"""
        node_features = torch.tensor(graph_data['node_features'], dtype=torch.float)
        edge_index = torch.tensor(graph_data['edge_index'], dtype=torch.long)
        edge_attr = torch.tensor(graph_data['edge_features'], dtype=torch.float)
        
        if edge_index.size(0) != 2:
            edge_index = edge_index.t()
        
        return Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
        
    def generate_synthetic_graph(self, original_graph, pdb_id, combination):
        """Generate synthetic graph by augmenting original"""
        try:
            # Convert to PyG format
            pyg_data = self.convert_graph_to_pyg(original_graph)
            pyg_data = pyg_data.to(self.device)
            
            # Train GAN on this specific graph
            self.train_on_graph(pyg_data, epochs=30)
            
            # Generate augmented version
            self.generator.eval()
            with torch.no_grad():
                aug_nodes, aug_edges, edge_probs = self.generator(
                    pyg_data.x, pyg_data.edge_index, pyg_data.edge_attr
                )
                
                # Apply edge modifications based on probabilities
                edge_mask = edge_probs > 0.7  # Keep edges with high probability
                kept_edges = pyg_data.edge_index[:, edge_mask]
                kept_edge_attrs = aug_edges[edge_mask] if aug_edges.size(0) > 0 else aug_edges
                
                # Create synthetic graph
                synthetic_graph = deepcopy(original_graph)
                synthetic_graph['node_features'] = aug_nodes.cpu().numpy().tolist()
                synthetic_graph['edge_index'] = kept_edges.cpu().numpy().tolist()
                synthetic_graph['edge_features'] = kept_edge_attrs.cpu().numpy().tolist()
                synthetic_graph['node_types'] = original_graph['node_types'].copy()
                synthetic_graph['pdb_id'] = f"{pdb_id}_synthetic"
                synthetic_graph['synthetic_method'] = 'GAN_Augmentation'
                synthetic_graph['combination'] = combination
                
                return synthetic_graph
                
        except Exception as e:
            print(f"    GAN error for {pdb_id}: {e}")
            return None
    
    def train_on_graph(self, pyg_data, epochs=30):
        """Train GAN using original graph as target"""
        batch = torch.zeros(pyg_data.x.size(0), dtype=torch.long, device=self.device)
        
        for epoch in range(epochs):
            # Train Discriminator
            self.d_optimizer.zero_grad()
            
            # Real graph score
            real_score = self.discriminator(pyg_data.x, pyg_data.edge_index, pyg_data.edge_attr, batch)
            
            # Generate augmented version
            aug_nodes, aug_edges, edge_probs = self.generator(pyg_data.x, pyg_data.edge_index, pyg_data.edge_attr)
            
            # Apply edge sampling
            edge_mask = edge_probs > 0.7
            if edge_mask.sum() > 0:
                fake_edges = pyg_data.edge_index[:, edge_mask]
                fake_edge_attrs = aug_edges[edge_mask] if aug_edges.size(0) > 0 else aug_edges
                fake_batch = torch.zeros(aug_nodes.size(0), dtype=torch.long, device=self.device)
                fake_score = self.discriminator(aug_nodes, fake_edges, fake_edge_attrs, fake_batch)
                
                # WGAN loss
                d_loss = fake_score.mean() - real_score.mean()
                d_loss.backward()
                self.d_optimizer.step()
            
            # Train Generator (every 2 epochs)
            if epoch % 2 == 0:
                self.g_optimizer.zero_grad()
                aug_nodes, aug_edges, edge_probs = self.generator(pyg_data.x, pyg_data.edge_index, pyg_data.edge_attr)
                
                edge_mask = edge_probs > 0.7
                if edge_mask.sum() > 0:
                    fake_edges = pyg_data.edge_index[:, edge_mask]
                    fake_edge_attrs = aug_edges[edge_mask] if aug_edges.size(0) > 0 else aug_edges
                    fake_batch = torch.zeros(aug_nodes.size(0), dtype=torch.long, device=self.device)
                    fake_score = self.discriminator(aug_nodes, fake_edges, fake_edge_attrs, fake_batch)
                    
                    g_loss = -fake_score.mean()
                    g_loss.backward()
                    self.g_optimizer.step()

class StructureAwareGAN:
    def __init__(self, device='cuda'):
        self.device = device
        self.gan = MolecularGraphGAN(device=device)
        
    def generate_synthetic_graph(self, original_graph, pdb_id, combination):
        return self.gan.generate_synthetic_graph(original_graph, pdb_id, combination)
        
    def augment_molecular_graph(self, original_graph, pdb_id, combination):
        """Create synthetic graph by controlled augmentation"""
        try:
            synthetic_graph = deepcopy(original_graph)
            
            node_features = np.array(original_graph['node_features'])
            edge_index = original_graph['edge_index']
            edge_features = np.array(original_graph['edge_features']) if original_graph.get('edge_features') else None
            
            n_nodes = len(node_features)
            n_edges = len(edge_index[0]) if edge_index and len(edge_index) >= 2 else 0
            
            # print(f"    DEBUG {pdb_id}: Original - {n_nodes} nodes, {n_edges} edges")
            
            # Skip augmentation if no edges
            if n_edges == 0:
                print(f"    Warning: {pdb_id} has no edges, skipping augmentation")
                return None
            
            # 1. Add controlled noise to node features (10-15%)
            noise_scale = random.uniform(0.10, 0.15)
            node_noise = np.random.normal(0, noise_scale, node_features.shape)
            augmented_nodes = node_features + node_noise
            
            # 2. Randomly remove some edges (15-25%) - FIXED
            if n_edges > 3:
                removal_rate = random.uniform(0.15, 0.25)
                edges_to_remove = int(n_edges * removal_rate)
                edges_to_keep = n_edges - edges_to_remove
                
                edge_indices = list(range(n_edges))
                random.shuffle(edge_indices)
                keep_indices = edge_indices[:edges_to_keep]
                
                new_edge_index = [
                    [edge_index[0][i] for i in keep_indices],
                    [edge_index[1][i] for i in keep_indices]
                ]
                
                if edge_features is not None and len(edge_features) > 0:
                    new_edge_features = edge_features[keep_indices]
                else:
                    new_edge_features = edge_features
            else:
                new_edge_index = deepcopy(edge_index)  # Use deepcopy
                new_edge_features = deepcopy(edge_features) if edge_features is not None else None
            
            # 3. Add some new random edges (10-15%)
            if n_nodes > 2:
                addition_rate = random.uniform(0.10, 0.15)
                edges_to_add = max(1, int(n_edges * addition_rate))
                
                for _ in range(edges_to_add):
                    attempts = 0
                    while attempts < 10:
                        u, v = random.sample(range(n_nodes), 2)
                        edge_exists = False
                        for i in range(len(new_edge_index[0])):
                            if (new_edge_index[0][i] == u and new_edge_index[1][i] == v) or \
                               (new_edge_index[0][i] == v and new_edge_index[1][i] == u):
                                edge_exists = True
                                break
                        
                        if not edge_exists:
                            new_edge_index[0].append(u)
                            new_edge_index[1].append(v)
                            
                            if new_edge_features is not None and len(new_edge_features) > 0:
                                avg_edge_feat = np.mean(new_edge_features, axis=0)
                                edge_noise = np.random.normal(0, 0.1, avg_edge_feat.shape)
                                new_edge_feat = avg_edge_feat + edge_noise
                                new_edge_features = np.vstack([new_edge_features, new_edge_feat])
                            break
                        attempts += 1
            
            # 4. Add noise to edge features (8-12%)
            if new_edge_features is not None and len(new_edge_features) > 0:
                edge_noise_scale = random.uniform(0.08, 0.12)
                edge_noise = np.random.normal(0, edge_noise_scale, new_edge_features.shape)
                new_edge_features = new_edge_features + edge_noise
            
            # Verify new edge structure
            new_n_edges = len(new_edge_index[0]) if new_edge_index else 0
            # print(f"    DEBUG {pdb_id}: Synthetic - {n_nodes} nodes, {new_n_edges} edges")
            
            if new_n_edges == 0:
                print(f"    Error: {pdb_id} synthetic graph has no edges after augmentation")
                return None
            
            # Update synthetic graph - PRESERVE NODE TYPES
            synthetic_graph['node_features'] = augmented_nodes.tolist()
            synthetic_graph['edge_index'] = new_edge_index
            synthetic_graph['edge_features'] = new_edge_features.tolist() if new_edge_features is not None else []
            
            # CRITICAL: Preserve original node types exactly
            synthetic_graph['node_types'] = original_graph['node_types'].copy()
            
            # Update metadata
            synthetic_graph['pdb_id'] = f"{pdb_id}_synthetic"
            synthetic_graph['synthetic_method'] = 'Controlled_Augmentation'
            synthetic_graph['combination'] = combination
            
            return synthetic_graph
            
        except Exception as e:
            print(f"    Augmentation error for {pdb_id}: {e}")
            traceback.print_exc()
            return None

class FastMolecularGraphGAN:
    def __init__(self, device='cuda'):
        self.device = device
        self.generator = ImprovedFastGraphAugmentationGenerator().to(device)
        self.discriminator = FastGraphDiscriminator().to(device)  # Keep original discriminator
        self.g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.002)
        self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.001)
        
    def convert_graph_to_pyg(self, graph_data):
        node_features = torch.tensor(graph_data['node_features'], dtype=torch.float)
        edge_index = torch.tensor(graph_data['edge_index'], dtype=torch.long)
        edge_attr = torch.tensor(graph_data['edge_features'], dtype=torch.float)
        if edge_index.size(0) != 2:
            edge_index = edge_index.t()
        return Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
    
    def apply_conservative_augmentation(self, node_features, edge_index, edge_features, 
                                      node_drop_probs, edge_drop_probs, node_noise, edge_noise):
        """More conservative augmentation that preserves graph structure"""
        
        # Much less aggressive node dropping (max 10% instead of 30%)
        node_drop_threshold = 0.1  # Reduced from 0.3
        node_keep_mask = torch.rand(node_features.size(0), device=self.device) > (node_drop_probs * node_drop_threshold)
        
        # Ensure we keep at least 70% of nodes
        min_nodes_to_keep = max(1, int(node_features.size(0) * 0.7))
        if node_keep_mask.sum() < min_nodes_to_keep:
            # Keep top nodes by inverse drop probability
            _, top_indices = torch.topk(1 - node_drop_probs, min_nodes_to_keep)
            node_keep_mask = torch.zeros_like(node_keep_mask, dtype=torch.bool)
            node_keep_mask[top_indices] = True
        
        # Apply node augmentation
        kept_nodes = node_features[node_keep_mask] + node_noise[node_keep_mask]
        
        # Create mapping for kept nodes
        old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(torch.where(node_keep_mask)[0].tolist())}
        
        # Process edges more conservatively
        valid_edges = []
        valid_edge_features = []
        
        edge_drop_threshold = 0.1  # Reduced from 0.3
        
        for i in range(edge_index.size(1)):
            src, dst = edge_index[0, i].item(), edge_index[1, i].item()
            
            # Check if both nodes are kept
            if src in old_to_new and dst in old_to_new:
                # Less aggressive edge dropping
                if len(edge_drop_probs) > i:
                    keep_prob = 1.0 - edge_drop_probs[i] * edge_drop_threshold
                else:
                    keep_prob = 0.9  # Default keep probability
                
                if torch.rand(1).item() < keep_prob:
                    valid_edges.append([old_to_new[src], old_to_new[dst]])
                    if edge_features.size(0) > i:
                        valid_edge_features.append(edge_features[i] + edge_noise[i])
                    else:
                        valid_edge_features.append(torch.zeros(edge_features.size(1), device=self.device))
        
        # Ensure we have some edges
        if not valid_edges and kept_nodes.size(0) > 1:
            # Create at least one edge between random nodes
            n_nodes = kept_nodes.size(0)
            for _ in range(min(3, n_nodes-1)):  # Add up to 3 random edges
                u, v = random.sample(range(n_nodes), 2)
                valid_edges.append([u, v])
                if edge_features.size(1) > 0:
                    avg_edge_feat = torch.mean(edge_features, dim=0) if edge_features.size(0) > 0 else torch.zeros(edge_features.size(1), device=self.device)
                    valid_edge_features.append(avg_edge_feat)
        
        if valid_edges:
            new_edge_index = torch.tensor(valid_edges, device=self.device).t()
            new_edge_features = torch.stack(valid_edge_features)
        else:
            # Fallback: create minimal connectivity
            n_nodes = kept_nodes.size(0)
            if n_nodes > 1:
                new_edge_index = torch.tensor([[0], [1]], device=self.device)
                new_edge_features = torch.zeros(1, edge_features.size(1), device=self.device)
            else:
                new_edge_index = torch.tensor([[], []], device=self.device, dtype=torch.long)
                new_edge_features = torch.zeros(0, edge_features.size(1), device=self.device)
        
        return kept_nodes, new_edge_index, new_edge_features, node_keep_mask
    
    def generate_synthetic_graph(self, original_graph, pdb_id, combination):
        """Generate synthetic graph with better preservation of structure"""
        try:
            pyg_data = self.convert_graph_to_pyg(original_graph).to(self.device)
            
            # Skip if graph is too small
            if pyg_data.x.size(0) < 3:
                return None
                
            self.generator.eval()
            with torch.no_grad():
                node_drop_probs, edge_drop_probs, node_noise, edge_noise = self.generator(
                    pyg_data.x, pyg_data.edge_index, pyg_data.edge_attr
                )
                
                aug_nodes, aug_edges, aug_edge_attrs, node_keep_mask = self.apply_conservative_augmentation(
                    pyg_data.x, pyg_data.edge_index, pyg_data.edge_attr, 
                    node_drop_probs, edge_drop_probs, node_noise, edge_noise
                )
            
            # Check if we have a valid graph
            if aug_nodes.size(0) == 0:
                print(f"    Warning: {pdb_id} - No nodes after augmentation")
                return None
                
            # Map node types correctly
            original_node_types = original_graph['node_types']
            kept_indices = torch.where(node_keep_mask)[0].tolist()
            new_node_types = [original_node_types[i] for i in kept_indices if i < len(original_node_types)]
            
            # Ensure we have the right number of node types
            if len(new_node_types) != aug_nodes.size(0):
                print(f"    Warning: {pdb_id} - Node type mismatch: {len(new_node_types)} types, {aug_nodes.size(0)} nodes")
                # Pad or truncate as needed
                if len(new_node_types) < aug_nodes.size(0):
                    # Pad with the most common type
                    from collections import Counter
                    most_common = Counter(new_node_types).most_common(1)[0][0] if new_node_types else 'P'
                    new_node_types.extend([most_common] * (aug_nodes.size(0) - len(new_node_types)))
                else:
                    new_node_types = new_node_types[:aug_nodes.size(0)]
            
            # Create synthetic graph
            synthetic_graph = deepcopy(original_graph)
            synthetic_graph['node_features'] = aug_nodes.cpu().numpy().tolist()
            synthetic_graph['edge_index'] = aug_edges.cpu().numpy().tolist()
            synthetic_graph['edge_features'] = aug_edge_attrs.cpu().numpy().tolist()
            synthetic_graph['node_types'] = new_node_types
            synthetic_graph['num_nodes'] = aug_nodes.size(0)
            synthetic_graph['pdb_id'] = f"{pdb_id}_synthetic"
            synthetic_graph['synthetic_method'] = 'Conservative_GAN_Augmentation'
            synthetic_graph['combination'] = combination
            
            return synthetic_graph
            
        except Exception as e:
            print(f"    GAN error for {pdb_id}: {e}")
            return None

class ConservativeStructureAwareGAN:
    def __init__(self, device='cuda'):
        self.device = device
        self.gan = ImprovedFastMolecularGraphGAN(device=device)
        self.is_trained = False
    
    def train_on_combination(self, graph_batch, combination):
        """Train once per combination on batch of graphs"""
        if not self.is_trained:
            print(f"    Training Conservative GAN for {combination}...")
            self.gan.batch_train(graph_batch[:50], epochs=5)  # Reduced epochs
            self.is_trained = True
    
    def generate_synthetic_graph(self, original_graph, pdb_id, combination):
        return self.gan.generate_synthetic_graph(original_graph, pdb_id, combination)

def generate_synthetic_graphs_controlled(real_graph_dir, dataset_name, combination, output_dir):
    """Generate synthetic graphs using controlled augmentation (not GAN)"""
    print(f"  Generating synthetic graphs for {combination} using Controlled Augmentation...")
    
    dataset_dir = os.path.join(real_graph_dir, dataset_name)
    
    if not os.path.exists(dataset_dir):
        print(f"    Dataset directory not found: {dataset_dir}")
        return 0
    
    pdb_dirs = [d for d in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, d))]
    print(f"    Found {len(pdb_dirs)} PDB directories")
    
    # Use the working StructureAwareGAN from your original code
    augmenter = StructureAwareGAN()
    
    generated_count = 0
    failed_count = 0
    debug_failures = {"no_graph": 0, "no_affinity": 0, "generation_failed": 0, "empty_result": 0}
    
    for idx, pdb_dir in enumerate(pdb_dirs):
        if idx % 500 == 0:
            print(f"    Progress: {idx}/{len(pdb_dirs)}")
            
        original_path = os.path.join(dataset_dir, pdb_dir, f'{pdb_dir}_{combination}.pkl')
        affinity_path = os.path.join(dataset_dir, pdb_dir, f'{pdb_dir}_affinity.pkl')
        
        if not os.path.exists(original_path):
            debug_failures["no_graph"] += 1
            failed_count += 1
            continue
            
        if not os.path.exists(affinity_path):
            debug_failures["no_affinity"] += 1
            failed_count += 1
            continue
        
        try:
            # Load original graph
            with open(original_path, 'rb') as f:
                original_graph = pickle.load(f)
            
            # Load affinity
            with open(affinity_path, 'rb') as f:
                affinity_data = pickle.load(f)
            
            # Check if graph is valid
            if not original_graph.get('node_features') or len(original_graph['node_features']) < 3:
                debug_failures["empty_result"] += 1
                failed_count += 1
                continue
            
            # Generate synthetic graph using controlled augmentation
            synthetic_graph = augmenter.augment_molecular_graph(original_graph, pdb_dir, combination)
            
            if synthetic_graph and len(synthetic_graph.get('node_features', [])) > 0:
                # Save synthetic graph
                save_graph_data(synthetic_graph, output_dir, f"{dataset_name}_synthetic", affinity_data['affinity'])
                generated_count += 1
            else:
                debug_failures["generation_failed"] += 1
                failed_count += 1
                
        except Exception as e:
            print(f"    Error processing {pdb_dir}: {e}")
            debug_failures["generation_failed"] += 1
            failed_count += 1
            continue
    
    print(f"  {combination}: {generated_count} synthetic graphs generated, {failed_count} failed")
    print(f"    Failure breakdown: {debug_failures}")
    return generated_count

def main():
    print("🔬 COMPLETE GRAPH GENERATION PIPELINE")
    print("="*60)
    
    # File paths
    real_train_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\training_set_with_affinity.csv'
    real_val_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\validation_set_with_affinity.csv'
    real_data_path = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\dataset'
    core_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\core_set_with_affinity.csv'
    holdout_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\hold_out_set_with_affinity.csv'
    
    # Create timestamped output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\complete_graphs_{timestamp}'
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Output directory: {output_dir}")
    
    # Load datasets
    print("\n📊 Loading datasets...")
    train_df = load_csv(real_train_csv, use_half=False)
    val_df = load_csv(real_val_csv, use_half=False)
    core_df = load_csv(core_csv, use_half=False)
    holdout_df = load_csv(holdout_csv, use_half=False)
    
    print(f"Train: {len(train_df)}, Val: {len(val_df)}, Core: {len(core_df)}, Holdout: {len(holdout_df)}")
    
    datasets = {
        'training': train_df,
        'validation': val_df,
        'core': core_df,
        'holdout': holdout_df
    }
    
    combinations = ['P', 'L', 'I', 'PL', 'PI', 'LI', 'PLI']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    total_start_time = time.time()
    
    # Process each combination
    for combination in combinations:
        print(f"\n{'='*30} {combination} {'='*30}")
        
        # Step 1: Generate real graphs
        print("🔧 Generating real graphs from JSON files...")
        for dataset_name, df in datasets.items():
            if len(df) > 0:
                prepare_real_graphs(df, real_data_path, combination, output_dir, dataset_name)
        
        # Step 2: Generate synthetic graphs
        print("🤖 Generating synthetic graphs using augmentation...")
        for dataset_name in ['training', 'validation']:
            generate_synthetic_graphs_controlled(output_dir, dataset_name, combination, output_dir)
    
    # Summary
    total_time = time.time() - total_start_time
    print(f"\n{'='*60}")
    print("✅ COMPLETE PIPELINE FINISHED!")
    print(f"Total time: {total_time/60:.1f} minutes")
    print(f"Output saved to: {output_dir}")
    
    # Count files
    for dataset_type in ['training', 'validation', 'core', 'holdout', 'training_synthetic', 'validation_synthetic']:
        dataset_path = os.path.join(output_dir, dataset_type)
        if os.path.exists(dataset_path):
            count = len([d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))])
            print(f"{dataset_type}: {count} PDB directories")
    
    print(f"\n📁 Directory structure created:")
    print(f"  {output_dir}/")
    print(f"    ├── training/")
    print(f"    ├── validation/")
    print(f"    ├── core/")
    print(f"    ├── holdout/")
    print(f"    ├── training_synthetic/")
    print(f"    └── validation_synthetic/")
    print(f"\nEach PDB directory contains:")
    print(f"  - {'{pdb_id}'}_{'{combination}'}.pkl (graph data)")
    print(f"  - {'{pdb_id}'}_affinity.pkl (binding affinity)")

if __name__ == "__main__":
    main()

🔬 COMPLETE GRAPH GENERATION PIPELINE
Output directory: D:\PhD\Chapter_4\Code2\pdbbind\complete_graphs_20250709_163209

📊 Loading datasets...
Train: 9662, Val: 903, Core: 257, Holdout: 3393
Using device: cuda

🔧 Generating real graphs from JSON files...
  Processing training: 9662 entries for P...
  training P: 9312 real graphs saved, 350 failed
  Processing validation: 903 entries for P...
  validation P: 871 real graphs saved, 32 failed
  Processing core: 257 entries for P...
  core P: 249 real graphs saved, 8 failed
  Processing holdout: 3393 entries for P...
  holdout P: 3232 real graphs saved, 161 failed
🤖 Generating synthetic graphs using augmentation...
  Generating synthetic graphs for P using Controlled Augmentation...
    Found 9312 PDB directories
    Progress: 0/9312
    Progress: 500/9312
    Progress: 1000/9312
    Progress: 1500/9312
    Progress: 2000/9312
    Progress: 2500/9312
    Progress: 3000/9312
    Progress: 3500/9312
    Progress: 4000/9312
    Progress: 4500/9