In [1]:
import pandas as pd
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import add_self_loops, to_undirected
import os
import numpy as np
from scipy.stats import pearsonr
import pickle
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Simplified atom property dictionary
atom_property_dict = {
    'H': {'atomic_num': 1, 'mass': 1.008, 'electronegativity': 2.20, 'vdw_radius': 1.20},
    'C': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'N': {'atomic_num': 7, 'mass': 14.007, 'electronegativity': 3.04, 'vdw_radius': 1.55},
    'O': {'atomic_num': 8, 'mass': 15.999, 'electronegativity': 3.44, 'vdw_radius': 1.52},
    'P': {'atomic_num': 15, 'mass': 30.974, 'electronegativity': 2.19, 'vdw_radius': 1.80},
    'S': {'atomic_num': 16, 'mass': 32.065, 'electronegativity': 2.58, 'vdw_radius': 1.80},
    'F': {'atomic_num': 9, 'mass': 18.998, 'electronegativity': 3.98, 'vdw_radius': 1.47},
    'Cl': {'atomic_num': 17, 'mass': 35.453, 'electronegativity': 3.16, 'vdw_radius': 1.75},
    'Br': {'atomic_num': 35, 'mass': 79.904, 'electronegativity': 2.96, 'vdw_radius': 1.85},
    'I': {'atomic_num': 53, 'mass': 126.904, 'electronegativity': 2.66, 'vdw_radius': 1.98},
    'CA': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'CZ': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'OG': {'atomic_num': 8, 'mass': 15.999, 'electronegativity': 3.44, 'vdw_radius': 1.52},
    'ZN': {'atomic_num': 30, 'mass': 65.38, 'electronegativity': 1.65, 'vdw_radius': 1.39},
    'MG': {'atomic_num': 12, 'mass': 24.305, 'electronegativity': 1.31, 'vdw_radius': 1.73},
    'FE': {'atomic_num': 26, 'mass': 55.845, 'electronegativity': 1.83, 'vdw_radius': 1.72},
    'MN': {'atomic_num': 25, 'mass': 54.938, 'electronegativity': 1.55, 'vdw_radius': 1.73},
    'CU': {'atomic_num': 29, 'mass': 63.546, 'electronegativity': 1.90, 'vdw_radius': 1.40},
}

def load_csv(csv_path):
    """Load CSV and sample half the data"""
    df = pd.read_csv(csv_path)
    df = df[df['Affinity_pK'] != 0]
    return df

def create_basic_features(node, atom_property_dict):
    """Create basic atomic features (deliberately simple)"""
    atom_type = node['attype']
    prop = atom_property_dict.get(atom_type, 
                                 {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70})
    
    # Only basic features - no complex encoding
    features = [
        prop['atomic_num'],
        prop['mass'],
        prop['electronegativity'],
        prop['vdw_radius']
    ]
    return features

def load_single_graph(pdb_id, base_path, graph_type):
    """Load a single real graph with basic processing"""
    if graph_type == 'P':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_protein_graph.json')
    elif graph_type == 'L':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_ligand_graph.json')
    elif graph_type == 'I':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_interaction_graph.json')
    else:
        return None
    
    try:
        with open(json_path, 'r') as file:
            graph = json.load(file)
    except FileNotFoundError:
        return None

    if not graph['nodes']:
        return None

    # Create basic node features
    node_features = []
    for node in graph['nodes']:
        features = create_basic_features(node, atom_property_dict)
        node_features.append(features)

    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Basic edge processing
    edge_index = []
    for edge in graph['edges']:
        if edge['id1'] is not None and edge['id2'] is not None:
            edge_index.append([edge['id1'], edge['id2']])

    if not edge_index:
        num_nodes = len(node_features)
        edge_index = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_index = to_undirected(edge_index)

    return {
        'node_features': node_features,
        'edge_index': edge_index,
        'num_nodes': len(node_features)
    }

def load_combined_graph(pdb_id, base_path, combination):
    """Load and combine real graphs with robust processing"""
    graphs_to_load = []
    
    if 'P' in combination:
        graphs_to_load.append('P')
    if 'L' in combination:
        graphs_to_load.append('L')
    if 'I' in combination:
        graphs_to_load.append('I')
    
    loaded_graphs = []
    for graph_type in graphs_to_load:
        graph = load_single_graph(pdb_id, base_path, graph_type)
        loaded_graphs.append(graph)
    
    # Merge graphs
    all_node_features = []
    all_edge_indices = []
    node_offset = 0
    
    for graph in loaded_graphs:
        if graph is None:
            continue
            
        all_node_features.append(graph['node_features'])
        adjusted_edge_index = graph['edge_index'] + node_offset
        all_edge_indices.append(adjusted_edge_index)
        node_offset += graph['num_nodes']
    
    if not all_node_features:
        return None
    
    merged_node_features = torch.cat(all_node_features, dim=0)
    merged_edge_index = torch.cat(all_edge_indices, dim=1) if all_edge_indices else torch.empty((2, 0), dtype=torch.long)
    
    # Robust normalization to prevent NaN
    if torch.isnan(merged_node_features).any() or torch.isinf(merged_node_features).any():
        return None
    
    mean = merged_node_features.mean(dim=0, keepdim=True)
    std = merged_node_features.std(dim=0, keepdim=True)
    
    # Prevent division by zero
    std = torch.where(std < 1e-8, torch.ones_like(std), std)
    merged_node_features = (merged_node_features - mean) / std
    
    # Clamp to prevent extreme values
    merged_node_features = torch.clamp(merged_node_features, min=-10, max=10)
    
    merged_edge_index, _ = add_self_loops(merged_edge_index, num_nodes=merged_node_features.size(0))
    
    return Data(x=merged_node_features, edge_index=merged_edge_index)

def prepare_real_dataset_combined(df, base_path, combination):
    """Prepare real dataset"""
    data_list = []
    failed_count = 0
    
    for _, row in df.iterrows():
        pdb_id, affinity = row['PDB_ID'], row['Affinity_pK']
        
        if np.isnan(affinity) or np.isinf(affinity):
            failed_count += 1
            continue
            
        data = load_combined_graph(pdb_id, base_path, combination)
        if data is not None:
            data.y = torch.tensor([affinity], dtype=torch.float)
            data.is_synthetic = False
            data_list.append(data)
        else:
            failed_count += 1
    
    print(f"  Real {combination}: {len(data_list)} graphs loaded, {failed_count} failed")
    return data_list

def load_synthetic_graph_simple(pdb_id, synthetic_dir, combination):
    """Load synthetic graph with robust processing"""
    graph_file = os.path.join(synthetic_dir, pdb_id, f'{pdb_id}_{combination}.pkl')
    
    if not os.path.exists(graph_file):
        return None
    
    try:
        with open(graph_file, 'rb') as f:
            graph_data = pickle.load(f)
        
        # Extract basic features only
        node_features = torch.tensor(graph_data['node_features'], dtype=torch.float)
        edge_index = torch.tensor(graph_data['edge_index'], dtype=torch.long)
        
        # Basic validation
        if torch.isnan(node_features).any() or torch.isinf(node_features).any():
            return None
        
        # Simplify features to match real data format
        if node_features.size(1) > 4:
            node_features = node_features[:, :4]  # Take only first 4 features
            # node_features = node_features
        
        # Ensure edge_index format
        if edge_index.size(0) != 2:
            edge_index = edge_index.t()
        
        # Filter invalid edges
        if edge_index.size(1) > 0:
            valid_edges = (edge_index[0] < node_features.size(0)) & (edge_index[1] < node_features.size(0))
            edge_index = edge_index[:, valid_edges]
        
        # Handle empty edges
        if edge_index.size(1) == 0:
            num_nodes = node_features.size(0)
            edge_index = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
        
        # Robust normalization
        if torch.isnan(node_features).any() or torch.isinf(node_features).any():
            return None
        
        mean = node_features.mean(dim=0, keepdim=True)
        std = node_features.std(dim=0, keepdim=True)
        
        # Prevent division by zero
        std = torch.where(std < 1e-8, torch.ones_like(std), std)
        node_features = (node_features - mean) / std
        
        # Clamp to prevent extreme values
        node_features = torch.clamp(node_features, min=-10, max=10)
        
        # Add self-loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=node_features.size(0))
        
        return Data(x=node_features, edge_index=edge_index)
    
    except Exception as e:
        return None

def load_affinity_data(pdb_id, synthetic_dir):
    """Load affinity data for synthetic graph"""
    affinity_file = os.path.join(synthetic_dir, pdb_id, f'{pdb_id}_affinity.pkl')
    
    if not os.path.exists(affinity_file):
        return None
    
    try:
        with open(affinity_file, 'rb') as f:
            affinity_data = pickle.load(f)
        return affinity_data.get('affinity', None)
    except:
        return None

def prepare_synthetic_dataset_half(synthetic_dir, combination):
    """Prepare synthetic dataset with half data sampling"""
    data_list = []
    failed_count = 0
    
    if not os.path.exists(synthetic_dir):
        print(f"Synthetic directory not found: {synthetic_dir}")
        return []
    
    # Get all PDB directories and sample half
    pdb_dirs = [d for d in os.listdir(synthetic_dir) if os.path.isdir(os.path.join(synthetic_dir, d))]
    # pdb_dirs = np.random.choice(pdb_dirs, size=len(pdb_dirs)//2, replace=False).tolist()  # Half data
    
    for pdb_dir in pdb_dirs:
        # Load affinity
        affinity = load_affinity_data(pdb_dir, synthetic_dir)
        if affinity is None or np.isnan(affinity) or np.isinf(affinity):
            failed_count += 1
            continue
        
        # Load graph
        data = load_synthetic_graph_simple(pdb_dir, synthetic_dir, combination)
        if data is not None:
            data.y = torch.tensor([affinity], dtype=torch.float)
            data.is_synthetic = True
            data_list.append(data)
        else:
            failed_count += 1
    
    print(f"  Synthetic {combination}: {len(data_list)} graphs loaded, {failed_count} failed")
    return data_list

def prepare_combined_training_dataset_half(real_train_csv, real_val_csv, real_data_path, 
                                          synthetic_train_dir, synthetic_val_dir, combination):
    """Combine real and synthetic datasets (half data)"""
    print(f"Preparing combined dataset for {combination}...")
    
    # Load real data (already halved in load_csv)
    real_train_df = load_csv(real_train_csv)
    real_val_df = load_csv(real_val_csv)
    
    real_train_data = prepare_real_dataset_combined(real_train_df, real_data_path, combination)
    real_val_data = prepare_real_dataset_combined(real_val_df, real_data_path, combination)
    
    # Load synthetic data (halved)
    synthetic_train_data = prepare_synthetic_dataset_half(synthetic_train_dir, combination)
    synthetic_val_data = prepare_synthetic_dataset_half(synthetic_val_dir, combination)
    
    # Combine
    combined_train_data = real_train_data + synthetic_train_data
    combined_val_data = real_val_data + synthetic_val_data
    
    print(f"  Train: {len(real_train_data)} real + {len(synthetic_train_data)} synthetic = {len(combined_train_data)}")
    print(f"  Val: {len(real_val_data)} real + {len(synthetic_val_data)} synthetic = {len(combined_val_data)}")
    
    return combined_train_data, combined_val_data

class SimpleMPNN(MessagePassing):
    """Simple MPNN layer"""
    def __init__(self, in_channels, out_channels):
        super(SimpleMPNN, self).__init__(aggr='mean')
        self.mlp = nn.Sequential(
            nn.Linear(in_channels * 2, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
        
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        return self.mlp(torch.cat([x_i, x_j], dim=1))

class SimpleCombinedBaseline(nn.Module):
    """Simple MPNN baseline for combined data"""
    def __init__(self, input_dim=4, hidden_dim=64, num_layers=2):
        super(SimpleCombinedBaseline, self).__init__()
        
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        self.mpnn_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.mpnn_layers.append(SimpleMPNN(hidden_dim, hidden_dim))
        
        # Simple predictor
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
        
    def forward(self, x, edge_index, batch):
        x = self.input_proj(x)
        x = F.relu(x)
        
        for mpnn in self.mpnn_layers:
            x = mpnn(x, edge_index)
            x = F.relu(x)
        
        x = global_mean_pool(x, batch)
        x = self.predictor(x)
        
        return x

def train_model_simple(model, train_loader, val_loader, epochs=100, device='cuda'):
    """Train model with robust setup"""
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()
    
    best_val_loss = float('inf')
    best_model_state = None
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0
        valid_batches = 0
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Check for NaN in input
            if torch.isnan(batch.x).any() or torch.isnan(batch.y).any():
                continue
                
            pred = model(batch.x, batch.edge_index, batch.batch).squeeze()
            
            # Check for NaN in prediction
            if torch.isnan(pred).any():
                continue
                
            loss = criterion(pred, batch.y)
            
            # Check for NaN in loss
            if torch.isnan(loss):
                continue
            
            loss.backward()
            
            # Clip gradients to prevent explosion
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            total_loss += loss.item()
            valid_batches += 1
        
        # Validation
        model.eval()
        val_loss = 0
        val_count = 0
        
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                
                # Check for NaN in input
                if torch.isnan(batch.x).any() or torch.isnan(batch.y).any():
                    continue
                    
                pred = model(batch.x, batch.edge_index, batch.batch).squeeze()
                
                # Check for NaN in prediction
                if torch.isnan(pred).any():
                    continue
                    
                loss = criterion(pred, batch.y)
                
                # Check for NaN in loss
                if torch.isnan(loss):
                    continue
                    
                val_loss += loss.item()
                val_count += 1
        
        avg_val_loss = val_loss / max(val_count, 1)
        avg_train_loss = total_loss / max(valid_batches, 1)
        
        if avg_val_loss < best_val_loss and not np.isnan(avg_val_loss):
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
        
        if epoch % 25 == 0:
            print(f"    Epoch {epoch}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model

def test_model(model, test_loader, device='cuda'):
    """Test model and return metrics"""
    model.eval()
    predictions = []
    targets = []
    
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            pred = model(batch.x, batch.edge_index, batch.batch).squeeze()
            
            if pred.dim() == 0:
                pred = pred.unsqueeze(0)
            if batch.y.dim() == 0:
                batch.y = batch.y.unsqueeze(0)
            
            predictions.extend(pred.cpu().numpy())
            targets.extend(batch.y.cpu().numpy())
    
    predictions = np.array(predictions)
    targets = np.array(targets)
    
    # Calculate metrics
    if len(predictions) > 1 and predictions.std() > 0.01:
        rp, _ = pearsonr(predictions, targets)
    else:
        rp = 0.0
    
    rmse = np.sqrt(np.mean((predictions - targets) ** 2))
    
    return predictions, targets, rp, rmse

def save_model_and_results(model, results, combination, save_dir="saved_models"):
    """Save trained model and results"""
    
    # Create save directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_dir = os.path.join(save_dir, f"GAN_MPNN_baseline_models_{timestamp}")
    os.makedirs(model_dir, exist_ok=True)
    
    # Save model state dict
    model_path = os.path.join(model_dir, f"model_{combination}.pth")
    torch.save({
        'model_state_dict': model.state_dict(),
        'combination': combination,
        'results': results,
        'model_config': {
            'input_dim': 4,
            'hidden_dim': 64,
            'num_layers': 2
        }
    }, model_path)
    
    print(f"  Model saved: {model_path}")
    return model_path

def load_saved_model(model_path, device='cuda'):
    """Load a saved model"""
    checkpoint = torch.load(model_path, map_location=device)
    
    # Recreate model with saved config
    config = checkpoint['model_config']
    model = SimpleBaseline(
        input_dim=config['input_dim'],
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers']
    )
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    return model, checkpoint['results']

def main():
    print("🧬 SIMPLE MPNN BASELINE (COMBINED REAL+SYNTHETIC)")
    print("="*60)
    
    # File paths - UPDATE THESE
    real_train_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\training_set_with_affinity.csv'
    real_val_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\validation_set_with_affinity.csv'
    real_data_path = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\dataset'
    
    synthetic_train_dir = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\complete_graphs_20250709_163209\\training_synthetic'
    synthetic_val_dir = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\complete_graphs_20250709_163209\\validation_synthetic'
    
    # Test data paths
    core_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\core_set_with_affinity.csv'
    holdout_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\hold_out_set_with_affinity.csv'
    
    # Load test datasets (half data)
    print("Loading test datasets...")
    core_df = load_csv(core_csv)
    holdout_df = load_csv(holdout_csv)
    
    print(f"Core: {len(core_df)}, Holdout: {len(holdout_df)}")
    
    combinations = ['P', 'L', 'I', 'PL', 'PI', 'LI', 'PLI']
    
    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Results storage
    results = {}
    
    for combination in combinations:
        print(f"\n{'='*20} COMBINED {combination} {'='*20}")
        
        # Prepare combined training datasets
        combined_train_data, combined_val_data = prepare_combined_training_dataset_half(
            real_train_csv, real_val_csv, real_data_path,
            synthetic_train_dir, synthetic_val_dir, combination
        )
        
        if len(combined_train_data) == 0 or len(combined_val_data) == 0:
            print(f"  Insufficient data for {combination}, skipping...")
            continue
        
        # Prepare test datasets
        print("Preparing test datasets...")
        core_data = prepare_real_dataset_combined(core_df, real_data_path, combination)
        holdout_data = prepare_real_dataset_combined(holdout_df, real_data_path, combination)
        
        # Create data loaders
        train_loader = DataLoader(combined_train_data, batch_size=64, shuffle=True)
        val_loader = DataLoader(combined_val_data, batch_size=64)
        core_loader = DataLoader(core_data, batch_size=64) if core_data else None
        holdout_loader = DataLoader(holdout_data, batch_size=64) if holdout_data else None
        
        # Train model
        print("Training model...")
        input_dim = combined_train_data[0].x.size(1)
        model = SimpleCombinedBaseline(input_dim=input_dim, hidden_dim=64, num_layers=2)
        trained_model = train_model_simple(model, train_loader, val_loader, epochs=100, device=device)
        
        # Test on core set
        if core_loader:
            print("Testing on 2016 core set...")
            core_preds, core_targets, core_rp, core_rmse = test_model(trained_model, core_loader, device)
            print(f"  Core Set - Rp: {core_rp:.3f}, RMSE: {core_rmse:.3f}")
        else:
            core_rp, core_rmse = 0, 0
        
        # Test on holdout set
        if holdout_loader:
            print("Testing on 2019 hold-out set...")
            holdout_preds, holdout_targets, holdout_rp, holdout_rmse = test_model(trained_model, holdout_loader, device)
            print(f"  Holdout Set - Rp: {holdout_rp:.3f}, RMSE: {holdout_rmse:.3f}")
        else:
            holdout_rp, holdout_rmse = 0, 0
        
        # Store results
        combination_results = {
            'core_rp': core_rp,
            'core_rmse': core_rmse,
            'holdout_rp': holdout_rp,
            'holdout_rmse': holdout_rmse
        }
        results[combination] = combination_results
        
        # Save model and results
        save_model_and_results(trained_model, combination_results, combination)        
    
    # Print results
    print(f"\n{'='*70}")
    print("COMBINED BASELINE RESULTS (Real + Synthetic)")
    print(f"{'='*70}")
    print(f"{'Model':<6} {'2016 core set':<25} {'2019 hold-out set':<25}")
    print(f"{'':6} {'Rp':<12} {'RMSE':<12} {'Rp':<12} {'RMSE':<12}")
    print("-" * 70)
    
    for combination in combinations:
        if combination in results:
            r = results[combination]
            print(f"{combination:<6} {r['core_rp']:<12.3f} {r['core_rmse']:<12.3f} {r['holdout_rp']:<12.3f} {r['holdout_rmse']:<12.3f}")
    
    print(f"{'='*70}")
    print("Note: Simple MPNN baseline with combined real+synthetic data")

if __name__ == "__main__":
    main()

🧬 SIMPLE MPNN BASELINE (COMBINED REAL+SYNTHETIC)
Loading test datasets...
Core: 257, Holdout: 3393
Using device: cuda

Preparing combined dataset for P...
  Real P: 9312 graphs loaded, 350 failed
  Real P: 871 graphs loaded, 32 failed
  Synthetic P: 9287 graphs loaded, 25 failed
  Synthetic P: 871 graphs loaded, 0 failed
  Train: 9312 real + 9287 synthetic = 18599
  Val: 871 real + 871 synthetic = 1742
Preparing test datasets...
  Real P: 249 graphs loaded, 8 failed
  Real P: 3232 graphs loaded, 161 failed
Training model...
    Epoch 0: Train Loss=4.7890, Val Loss=3.7244
    Epoch 25: Train Loss=3.1310, Val Loss=3.5064
    Epoch 50: Train Loss=3.0816, Val Loss=3.5660
    Epoch 75: Train Loss=3.0931, Val Loss=3.4774
Testing on 2016 core set...
  Core Set - Rp: 0.232, RMSE: 2.012
Testing on 2019 hold-out set...
  Holdout Set - Rp: 0.261, RMSE: 1.711
  Model saved: saved_models\GAN_MPNN_baseline_models_20250709_185820\model_P.pth

Preparing combined dataset for L...
  Real L: 9312 graphs 