In [1]:
import pandas as pd
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.utils import add_self_loops, to_undirected
import os
import numpy as np
from scipy.stats import pearsonr
import pickle
import warnings
warnings.filterwarnings('ignore')

# Atom property dictionary
atom_property_dict = {
    'H': {'atomic_num': 1, 'mass': 1.008, 'electronegativity': 2.20, 'vdw_radius': 1.20},
    'C': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'N': {'atomic_num': 7, 'mass': 14.007, 'electronegativity': 3.04, 'vdw_radius': 1.55},
    'O': {'atomic_num': 8, 'mass': 15.999, 'electronegativity': 3.44, 'vdw_radius': 1.52},
    'P': {'atomic_num': 15, 'mass': 30.974, 'electronegativity': 2.19, 'vdw_radius': 1.80},
    'S': {'atomic_num': 16, 'mass': 32.065, 'electronegativity': 2.58, 'vdw_radius': 1.80},
    'F': {'atomic_num': 9, 'mass': 18.998, 'electronegativity': 3.98, 'vdw_radius': 1.47},
    'Cl': {'atomic_num': 17, 'mass': 35.453, 'electronegativity': 3.16, 'vdw_radius': 1.75},
    'Br': {'atomic_num': 35, 'mass': 79.904, 'electronegativity': 2.96, 'vdw_radius': 1.85},
    'I': {'atomic_num': 53, 'mass': 126.904, 'electronegativity': 2.66, 'vdw_radius': 1.98},
    'CA': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'CZ': {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70},
    'OG': {'atomic_num': 8, 'mass': 15.999, 'electronegativity': 3.44, 'vdw_radius': 1.52},
    'ZN': {'atomic_num': 30, 'mass': 65.38, 'electronegativity': 1.65, 'vdw_radius': 1.39},
    'MG': {'atomic_num': 12, 'mass': 24.305, 'electronegativity': 1.31, 'vdw_radius': 1.73},
    'FE': {'atomic_num': 26, 'mass': 55.845, 'electronegativity': 1.83, 'vdw_radius': 1.72},
    'MN': {'atomic_num': 25, 'mass': 54.938, 'electronegativity': 1.55, 'vdw_radius': 1.73},
    'CU': {'atomic_num': 29, 'mass': 63.546, 'electronegativity': 1.90, 'vdw_radius': 1.40},
}

def load_csv(csv_path, max_samples=None, use_half=False):
    df = pd.read_csv(csv_path)
    df = df[df['Affinity_pK'] != 0]
    
    if len(df) == 0:
        print(f"Warning: No valid data found in {csv_path}")
        return pd.DataFrame()
    
    if use_half:
        half_size = len(df) // 2
        if half_size == 0:
            half_size = 1
        df = df.head(half_size)
        print(f"Using half dataset: {half_size} samples from {csv_path}")
    elif max_samples:
        df = df.head(max_samples)
    
    return df 

# OPTIMIZED: Simplified normalization
def fast_normalize(features):
    if features.size(0) <= 1:
        return torch.zeros_like(features)
    
    mean = features.mean(dim=0, keepdim=True)
    std = features.std(dim=0, keepdim=True, unbiased=False)
    std = torch.clamp(std, min=1e-6)
    
    normalized = (features - mean) / std
    return torch.clamp(normalized, min=-3, max=3)

def create_enhanced_features(node, atom_property_dict, graph_type='P'):
    atom_type = node['attype']
    prop = atom_property_dict.get(atom_type, 
                                 {'atomic_num': 6, 'mass': 12.011, 'electronegativity': 2.55, 'vdw_radius': 1.70})
    
    if 'pl' in node:
        is_protein = node['pl'] == 'P'
        is_ligand = node['pl'] == 'L'
        is_interaction = graph_type == 'I'
    else:
        is_protein = graph_type == 'P'
        is_ligand = graph_type == 'L'
        is_interaction = graph_type == 'I'
    
    features = [
        prop['atomic_num'] / 30.0, prop['mass'] / 100.0, prop['electronegativity'] / 4.0, prop['vdw_radius'] / 2.0,
        prop['atomic_num'] ** 0.5 / 5.5, prop['mass'] / prop['atomic_num'], 1.0 / prop['electronegativity'], prop['vdw_radius'] ** 2,
        1.0 if prop['atomic_num'] in [6] else 0.0, 1.0 if prop['atomic_num'] in [7] else 0.0,
        1.0 if prop['atomic_num'] in [8] else 0.0, 1.0 if prop['atomic_num'] in [16] else 0.0,
        1.0 if prop['atomic_num'] > 10 else 0.0, 1.0 if prop['electronegativity'] > 3.0 else 0.0,
        1.0 if is_protein else 0.0, 1.0 if is_ligand else 0.0, 1.0 if is_interaction else 0.0,
    ]
    return features

def load_single_graph(pdb_id, base_path, graph_type):
    if graph_type == 'P':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_protein_graph.json')
    elif graph_type == 'L':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_ligand_graph.json')
    elif graph_type == 'I':
        json_path = os.path.join(base_path, pdb_id, f'{pdb_id}_interaction_graph.json')
    else:
        return None
    
    try:
        with open(json_path, 'r') as file:
            graph = json.load(file)
    except FileNotFoundError:
        return None

    if not graph['nodes']:
        return None

    node_features = []
    node_types = []
    
    for node in graph['nodes']:
        features = create_enhanced_features(node, atom_property_dict, graph_type)
        node_features.append(features)
        
        if 'pl' in node:
            node_types.append(node['pl'])
        else:
            node_types.append(graph_type)

    node_features = torch.tensor(node_features, dtype=torch.float)
    
    if torch.isnan(node_features).any() or torch.isinf(node_features).any():
        return None
    
    edge_index = []
    edge_features = []
    edge_types = []
    
    for edge in graph['edges']:
        if edge['id1'] is None or edge['id2'] is None:
            continue
            
        length = max(edge['length'], 0.1)
        edge_index.append([edge['id1'], edge['id2']])
        edge_features.append([length / 10.0, 1.0 / length, np.exp(-length/2.0)])
        
        node1_type = node_types[edge['id1']] if edge['id1'] < len(node_types) else graph_type
        node2_type = node_types[edge['id2']] if edge['id2'] < len(node_types) else graph_type
        
        if node1_type == 'P' and node2_type == 'P':
            edge_types.append(0)
        elif node1_type == 'L' and node2_type == 'L':
            edge_types.append(1)
        else:
            edge_types.append(2)

    if not edge_index:
        num_nodes = len(node_features)
        edge_index = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
        edge_features = torch.ones(num_nodes, 3) * 0.5
        edge_types = [0] * num_nodes
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_features = torch.tensor(edge_features, dtype=torch.float)
        
        edge_index = to_undirected(edge_index)
        if edge_features.size(0) * 2 == edge_index.size(1):
            edge_features = edge_features.repeat(2, 1)
            edge_types = edge_types + edge_types

    return {
        'node_features': node_features,
        'edge_index': edge_index,
        'edge_features': edge_features,
        'edge_types': torch.tensor(edge_types, dtype=torch.long),
        'num_nodes': len(node_features),
        'graph_type': graph_type,
        'node_types': node_types
    }

def merge_graphs(graphs):
    all_node_features = []
    all_edge_indices = []
    all_edge_features = []
    all_edge_types = []
    graph_type_markers = []
    
    node_offset = 0
    
    for graph in graphs:
        if graph is None:
            continue
            
        all_node_features.append(graph['node_features'])
        adjusted_edge_index = graph['edge_index'] + node_offset
        all_edge_indices.append(adjusted_edge_index)
        all_edge_features.append(graph['edge_features'])
        all_edge_types.append(graph['edge_types'])
        
        graph_type_markers.extend([graph['graph_type']] * graph['num_nodes'])
        node_offset += graph['num_nodes']
    
    if not all_node_features:
        return None
    
    merged_node_features = torch.cat(all_node_features, dim=0)
    merged_edge_index = torch.cat(all_edge_indices, dim=1) if all_edge_indices else torch.empty((2, 0), dtype=torch.long)
    merged_edge_features = torch.cat(all_edge_features, dim=0) if all_edge_features else torch.empty((0, 3))
    merged_edge_types = torch.cat(all_edge_types, dim=0) if all_edge_types else torch.empty((0,), dtype=torch.long)
    
    return merged_node_features, merged_edge_index, merged_edge_features, merged_edge_types, graph_type_markers

# OPTIMIZED: Precompute merged graphs during data loading
def precompute_combined_graph(pdb_id, base_path, combination):
    graphs_to_load = []
    
    if 'P' in combination:
        graphs_to_load.append('P')
    if 'L' in combination:
        graphs_to_load.append('L')
    if 'I' in combination:
        graphs_to_load.append('I')
    
    loaded_graphs = []
    for graph_type in graphs_to_load:
        graph = load_single_graph(pdb_id, base_path, graph_type)
        loaded_graphs.append(graph)
    
    merged_result = merge_graphs(loaded_graphs)
    if merged_result is None:
        return None
    
    node_features, edge_index, edge_features, edge_types, graph_type_markers = merged_result
    node_features = fast_normalize(node_features)
    edge_index, edge_attr = add_self_loops(edge_index, edge_features, num_nodes=node_features.size(0))
    
    # num_self_loops = node_features.size(0)
    # self_loop_types = torch.zeros(num_self_loops, dtype=torch.long)
    # edge_types = torch.cat([edge_types, self_loop_types], dim=0)

    num_self_loops = node_features.size(0)
    if combination == 'P':
        self_loop_types = torch.zeros(num_self_loops, dtype=torch.long)  # Type 0 (intra-protein)
    elif combination == 'L': 
        self_loop_types = torch.ones(num_self_loops, dtype=torch.long)   # Type 1 (intra-ligand)
    else:
        # For mixed combinations, assign based on actual node types
        self_loop_types = []
        for node_type in graph_type_markers:
            if node_type == 'P':
                self_loop_types.append(0)  # Intra-protein
            elif node_type == 'L':
                self_loop_types.append(1)  # Intra-ligand
            else:
                self_loop_types.append(2)  # Default to inter
        self_loop_types = torch.tensor(self_loop_types, dtype=torch.long)
    edge_types = torch.cat([edge_types, self_loop_types], dim=0)
    
    
    if edge_attr.size(0) > 0:
        edge_attr = fast_normalize(edge_attr)
    
    # OPTIMIZATION: Cache edge type statistics
    edge_type_counts = torch.bincount(edge_types, minlength=3)
    
    data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
    data.edge_types = edge_types
    data.edge_type_counts = edge_type_counts  # Cached for routing
    data.graph_type_markers = graph_type_markers
    return data

def prepare_real_dataset_combined(df, base_path, combination):
    data_list = []
    failed_count = 0
    
    if df is None or len(df) == 0:
        print(f"  {combination}: No real data available")
        return []
    
    for _, row in df.iterrows():
        pdb_id, affinity = row['PDB_ID'], row['Affinity_pK']
        
        if np.isnan(affinity) or np.isinf(affinity):
            failed_count += 1
            continue
            
        data = precompute_combined_graph(pdb_id, base_path, combination)
        if data is not None:
            data.y = torch.tensor([affinity], dtype=torch.float)
            data.is_synthetic = False
            data_list.append(data)
        else:
            failed_count += 1
    
    print(f"  {combination}: {len(data_list)} real graphs loaded, {failed_count} failed")
    return data_list

def prepare_synthetic_dataset(synthetic_dir, combination):
    data_list = []
    failed_count = 0
    
    if not os.path.exists(synthetic_dir):
        print(f"Synthetic directory not found: {synthetic_dir}")
        return []
    
    pdb_dirs = [d for d in os.listdir(synthetic_dir) if os.path.isdir(os.path.join(synthetic_dir, d))]
    
    for pdb_dir in pdb_dirs:
        graph_file = os.path.join(synthetic_dir, pdb_dir, f'{pdb_dir}_{combination}.pkl')
        affinity_file = os.path.join(synthetic_dir, pdb_dir, f'{pdb_dir}_affinity.pkl')
        
        if not os.path.exists(graph_file) or not os.path.exists(affinity_file):
            failed_count += 1
            continue
        
        try:
            with open(graph_file, 'rb') as f:
                graph_data = pickle.load(f)
            with open(affinity_file, 'rb') as f:
                affinity_data = pickle.load(f)
            
            affinity = affinity_data.get('affinity', None)
            if affinity is None or np.isnan(affinity) or np.isinf(affinity):
                failed_count += 1
                continue
            
            # Check if node_types exists
            if 'node_types' not in graph_data:
                print(f"❌ Missing node_types in {graph_file}")
                failed_count += 1
                continue
                
            node_types = graph_data['node_types']
            if not node_types:
                print(f"❌ Empty node_types in {graph_file}")
                failed_count += 1
                continue
            
            node_features = torch.tensor(graph_data['node_features'], dtype=torch.float)
            edge_index = torch.tensor(graph_data['edge_index'], dtype=torch.long)
            edge_attr = torch.tensor(graph_data['edge_features'], dtype=torch.float)
            
            # Validate node_types length matches node_features
            if len(node_types) != node_features.size(0):
                print(f"❌ Node types mismatch in {graph_file}: {len(node_types)} types, {node_features.size(0)} nodes")
                failed_count += 1
                continue
            
            if edge_index.size(0) != 2:
                edge_index = edge_index.t()
            
            # Create edge types based on actual node types
            edge_types = []
            for i in range(edge_index.size(1)):
                src_type = node_types[edge_index[0, i]]
                dst_type = node_types[edge_index[1, i]]
                
                if src_type == 'P' and dst_type == 'P':
                    edge_types.append(0)  # Intra-protein
                elif src_type == 'L' and dst_type == 'L':
                    edge_types.append(1)  # Intra-ligand
                else:
                    edge_types.append(2)  # Inter-protein-ligand
            
            edge_types = torch.tensor(edge_types, dtype=torch.long)
            
            node_features = fast_normalize(node_features)
            edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=node_features.size(0))
            
            # Create self-loop types based on actual node types
            self_loop_types = []
            for node_type in node_types:
                if node_type == 'P':
                    self_loop_types.append(0)  # Intra-protein
                elif node_type == 'L':
                    self_loop_types.append(1)  # Intra-ligand
                else:
                    self_loop_types.append(2)  # Default
            
            self_loop_types = torch.tensor(self_loop_types, dtype=torch.long)
            edge_types = torch.cat([edge_types, self_loop_types], dim=0)
            
            if edge_attr.size(0) > 0:
                edge_attr = fast_normalize(edge_attr)
            
            edge_type_counts = torch.bincount(edge_types, minlength=3)
            
            data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
            data.edge_types = edge_types
            data.edge_type_counts = edge_type_counts
            data.graph_type_markers = node_types
            data.y = torch.tensor([affinity], dtype=torch.float)
            data.is_synthetic = True
            data_list.append(data)
            
        except Exception as e:
            print(f"❌ Error loading {graph_file}: {e}")
            failed_count += 1
            continue
    
    print(f"  {combination}: {len(data_list)} synthetic graphs loaded, {failed_count} failed")
    return data_list

def prepare_combined_training_dataset(real_train_csv, real_val_csv, real_data_path, 
                                    synthetic_train_dir, synthetic_val_dir, combination, use_half=False):
    print(f"Preparing {'HALF' if use_half else 'FULL'} combined training dataset for {combination}...")
    
    real_train_df = load_csv(real_train_csv, use_half=use_half)
    real_val_df = load_csv(real_val_csv, use_half=use_half)
    
    real_train_data = prepare_real_dataset_combined(real_train_df, real_data_path, combination)
    real_val_data = prepare_real_dataset_combined(real_val_df, real_data_path, combination)

    # Load synthetic data (halved)
    synthetic_train_data = prepare_synthetic_dataset(synthetic_train_dir, combination)
    synthetic_val_data = prepare_synthetic_dataset(synthetic_val_dir, combination)
    
    # Combine
    combined_train_data = real_train_data + synthetic_train_data
    combined_val_data = real_val_data + synthetic_val_data
    
    print(f"  Train: {len(real_train_data)} real + {len(synthetic_train_data)} synthetic = {len(combined_train_data)}")
    print(f"  Val: {len(real_val_data)} real + {len(synthetic_val_data)} synthetic = {len(combined_val_data)}")
    
    
    return combined_train_data, combined_val_data   

# OPTIMIZED: Reduced routing iterations and cached statistics
class OptimizedEdgeTypeAwareCapsuleLayer(nn.Module):
    def __init__(self, input_dim, capsule_dim=32, num_iterations=2):  # Reduced from 3 to 2
        super(OptimizedEdgeTypeAwareCapsuleLayer, self).__init__()
        self.input_dim = input_dim
        self.capsule_dim = capsule_dim
        self.num_iterations = num_iterations
        
        self.W_intra_protein = nn.Linear(input_dim, capsule_dim, bias=False)
        self.W_intra_ligand = nn.Linear(input_dim, capsule_dim, bias=False)
        self.W_inter_connection = nn.Linear(input_dim, capsule_dim, bias=False)
        
        self.routing_coefficients = None
        
    def squash(self, s):
        s_norm = torch.norm(s, dim=-1, keepdim=True)
        scale = (s_norm**2 / (1 + s_norm**2))
        return scale * s / (s_norm + 1e-8)
    
    def forward(self, x, edge_index, edge_types, batch, edge_type_counts=None):
        batch_size = batch.max().item() + 1
        device = x.device
        
        u_intra_protein = self.W_intra_protein(x)
        u_intra_ligand = self.W_intra_ligand(x)
        u_inter = self.W_inter_connection(x)
        
        u = torch.stack([u_intra_protein, u_intra_ligand, u_inter], dim=1)
        
        # OPTIMIZED: Use cached edge type counts if available
        b = torch.zeros(x.size(0), 3, device=device)
        
        if edge_type_counts is not None:
            total_edges = edge_type_counts.sum()
            if total_edges > 0:
                if edge_type_counts[2] > 0:
                    b[:, 2] += 4.0
                    if edge_type_counts[0] > 0:
                        b[:, 0] += 2.0
                    if edge_type_counts[1] > 0:
                        b[:, 1] += 2.0
                else:
                    if edge_type_counts[0] > 5:
                        b[:, 0] += 2.5
                    if edge_type_counts[1] > 5:
                        b[:, 1] += 2.5
        
        routing_history = []
        
        for iteration in range(self.num_iterations):
            c = F.softmax(b, dim=-1)
            routing_history.append(c.detach().cpu())
            
            s = torch.zeros(batch_size, 3, self.capsule_dim, device=device)
            
            for batch_idx in range(batch_size):
                batch_mask = (batch == batch_idx)
                if batch_mask.sum() == 0:
                    continue
                
                batch_u = u[batch_mask]
                batch_c = c[batch_mask]
                
                for cap_idx in range(3):
                    s[batch_idx, cap_idx] = torch.sum(
                        batch_c[:, cap_idx:cap_idx+1] * batch_u[:, cap_idx], dim=0
                    )
                
                s[batch_idx] = self.squash(s[batch_idx].clone())
            
            if iteration < self.num_iterations - 1:
                for batch_idx in range(batch_size):
                    batch_mask = (batch == batch_idx)
                    if batch_mask.sum() == 0:
                        continue
                    
                    batch_u = u[batch_mask]
                    batch_s = s[batch_idx]
                    
                    agreement = torch.sum(batch_u * batch_s.unsqueeze(0), dim=-1)
                    
                    # OPTIMIZED: Simplified routing updates
                    if edge_type_counts is not None:
                        edge_type_bonus = torch.zeros_like(agreement)
                        
                        inter_count = edge_type_counts[2].item()
                        if inter_count > 0:
                            edge_type_bonus[:, 2] += 2.5
                            if edge_type_counts[0] > 0:
                                edge_type_bonus[:, 0] += 0.8
                            if edge_type_counts[1] > 0:
                                edge_type_bonus[:, 1] += 0.8
                        else:
                            for edge_type in range(2):
                                if edge_type_counts[edge_type] > 3:
                                    edge_type_bonus[:, edge_type] += 1.2
                        
                        agreement += edge_type_bonus
                    
                    b[batch_mask] += agreement
        
        self.routing_coefficients = routing_history[-1]
        return s, self.routing_coefficients

class OptimizedEdgeAwareCapsuleGNN(nn.Module):
    def __init__(self, input_dim=17, hidden_dim=64, num_layers=2):
        super(OptimizedEdgeAwareCapsuleGNN, self).__init__()
        
        self.hidden_dim = hidden_dim
        
        self.convs = nn.ModuleList()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        for i in range(num_layers):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        self.capsule_layer = OptimizedEdgeTypeAwareCapsuleLayer(hidden_dim, capsule_dim=32)
        
        self.predictor = nn.Sequential(
            nn.Linear(3 * 32, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(), 
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        self.dropout = nn.Dropout(0.1)
        self.apply(self._init_weights)
        
        self.last_routing_coefficients = None
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
    
    def forward(self, x, edge_index, edge_attr, batch, edge_types=None, edge_type_counts=None):
        x = self.input_proj(x)
        x = F.relu(x)
        
        for i, conv in enumerate(self.convs):
            residual = x
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
            
            if i > 0:
                x = x + residual
        
        if edge_types is None:
            edge_types = torch.zeros(edge_index.size(1), dtype=torch.long, device=edge_index.device)
        
        capsule_outputs, routing_coeffs = self.capsule_layer(x, edge_index, edge_types, batch, edge_type_counts)
        self.last_routing_coefficients = routing_coeffs
        
        batch_size = capsule_outputs.size(0)
        flattened = capsule_outputs.view(batch_size, -1)
        output = self.predictor(flattened)
        
        return output
    
    def get_routing_analysis(self, combination):
        if self.last_routing_coefficients is None:
            return None
        
        routing = self.last_routing_coefficients
        
        analysis = {
            'intra_protein_attention': routing[:, 0].mean().item(),
            'intra_ligand_attention': routing[:, 1].mean().item(), 
            'inter_connection_attention': routing[:, 2].mean().item(),
            'combination': combination,
            'has_interaction': 'I' in combination
        }
        
        return analysis

# OPTIMIZED: Training with gradient accumulation and smaller batch size
def train_optimized_model(model, train_loader, val_loader, combination, epochs=150, device='cuda'):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-6)
    
    best_val_loss = float('inf')
    best_model_state = None
    patience = 8
    patience_counter = 0
    
    routing_stats = []
    accumulation_steps = 2  # Gradient accumulation
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()
        
        for batch_idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            
            edge_types = getattr(batch, 'edge_types', None)
            edge_type_counts = getattr(batch, 'edge_type_counts', None)
            
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, 
                        edge_types, edge_type_counts).squeeze()
            loss = criterion(pred, batch.y)
            
            loss = loss / accumulation_steps  # Scale loss for accumulation
            loss.backward()
            
            if (batch_idx + 1) % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * accumulation_steps
        
        # Handle remaining gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                edge_types = getattr(batch, 'edge_types', None)
                edge_type_counts = getattr(batch, 'edge_type_counts', None)
                pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch,
                           edge_types, edge_type_counts).squeeze()
                loss = criterion(pred, batch.y)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        scheduler.step()
        
        if epoch % 5 == 0:  # Reduced frequency
            routing_analysis = model.get_routing_analysis(combination)
            if routing_analysis:
                routing_analysis['epoch'] = epoch
                routing_stats.append(routing_analysis)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch >= 8 and patience_counter >= patience:
            print(f"    Early stopping at epoch {epoch}")
            break
        
        if epoch % 5 == 0:  # Reduced logging frequency
            print(f"    Epoch {epoch}: Train Loss={total_loss/len(train_loader):.4f}, Val Loss={avg_val_loss:.4f}")
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    # Save the trained model
    model_save_path = f"optimized_edge_aware_capsule_{combination}_model.pth"
    torch.save({
        'model_state_dict': model.state_dict(),
        'combination': combination,
        'input_dim': model.input_proj.in_features,
        'hidden_dim': model.hidden_dim,
        'best_val_loss': best_val_loss,
        'routing_stats': routing_stats
    }, model_save_path)
    print(f"    Model saved to {model_save_path}")
    
    return model, routing_stats

def test_optimized_model(model, test_loader, combination, device='cuda'):
    model.eval()
    predictions = []
    targets = []
    routing_analyses = []
    
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            edge_types = getattr(batch, 'edge_types', None)
            edge_type_counts = getattr(batch, 'edge_type_counts', None)
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch,
                        edge_types, edge_type_counts).squeeze()
            
            if pred.dim() == 0:
                pred = pred.unsqueeze(0)
            if batch.y.dim() == 0:
                batch.y = batch.y.unsqueeze(0)
            
            predictions.extend(pred.cpu().numpy())
            targets.extend(batch.y.cpu().numpy())
            
            routing_analysis = model.get_routing_analysis(combination)
            if routing_analysis:
                routing_analyses.append(routing_analysis)
    
    predictions = np.array(predictions)
    targets = np.array(targets)
    
    if len(predictions) > 1 and predictions.std() > 0.01:
        rp, _ = pearsonr(predictions, targets)
    else:
        rp = 0.0
    
    rmse = np.sqrt(np.mean((predictions - targets) ** 2))
    
    return predictions, targets, rp, rmse, routing_analyses

def load_optimized_model(model_path, device='cuda'):
    """Load a saved optimized edge-aware capsule model"""
    checkpoint = torch.load(model_path, map_location=device)
    
    input_dim = checkpoint['input_dim']
    hidden_dim = checkpoint['hidden_dim']
    model = OptimizedEdgeAwareCapsuleGNN(input_dim=input_dim, hidden_dim=hidden_dim, num_layers=2)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    print(f"Loaded optimized model for combination: {checkpoint['combination']}")
    print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
    
    return model, checkpoint

def analyze_optimized_edge_importance(model, test_loader, combination, device='cuda'):
    """Analyze which edge types the optimized model focuses on"""
    model.eval()
    
    edge_type_attention = {'intra_protein': [], 'intra_ligand': [], 'inter_connection': []}
    edge_type_counts = {'intra_protein': 0, 'intra_ligand': 0, 'inter_connection': 0}
    
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            edge_types = getattr(batch, 'edge_types', None)
            edge_type_counts_batch = getattr(batch, 'edge_type_counts', None)
            
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch, 
                        edge_types, edge_type_counts_batch)
            
            if hasattr(model, 'last_routing_coefficients') and model.last_routing_coefficients is not None:
                routing = model.last_routing_coefficients
                
                edge_type_attention['intra_protein'].extend(routing[:, 0].tolist())
                edge_type_attention['intra_ligand'].extend(routing[:, 1].tolist())
                edge_type_attention['inter_connection'].extend(routing[:, 2].tolist())
            
            if edge_type_counts_batch is not None:
                edge_type_counts['intra_protein'] += edge_type_counts_batch[0].item()
                edge_type_counts['intra_ligand'] += edge_type_counts_batch[1].item()
                edge_type_counts['inter_connection'] += edge_type_counts_batch[2].item()
    
    analysis = {
        'combination': combination,
        'avg_attention': {
            'intra_protein': np.mean(edge_type_attention['intra_protein']),
            'intra_ligand': np.mean(edge_type_attention['intra_ligand']),
            'inter_connection': np.mean(edge_type_attention['inter_connection'])
        },
        'edge_counts': edge_type_counts,
        'attention_vs_count_ratio': {
            'intra_protein': np.mean(edge_type_attention['intra_protein']) / max(edge_type_counts['intra_protein'], 1),
            'intra_ligand': np.mean(edge_type_attention['intra_ligand']) / max(edge_type_counts['intra_ligand'], 1),
            'inter_connection': np.mean(edge_type_attention['inter_connection']) / max(edge_type_counts['inter_connection'], 1)
        }
    }
    
    return analysis

def main():
    print("🚀 OPTIMIZED EDGE-TYPE AWARE CAPSULE NETWORK")
    print("="*65)
    
    # File paths - UPDATE THESE TO YOUR ACTUAL PATHS
    real_train_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\training_set_with_affinity.csv'
    real_val_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\validation_set_with_affinity.csv'
    real_data_path = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\dataset'
    
    synthetic_train_dir = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\complete_graphs_20250709_163209\\training_synthetic'
    synthetic_val_dir = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\complete_graphs_20250709_163209\\validation_synthetic'
    
    core_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\core_set_with_affinity.csv'
    holdout_csv = 'D:\\PhD\\Chapter_4\\Code2\\pdbbind\\pdb_ids_Affinity\\hold_out_set_with_affinity.csv'
    
    print("⚡ OPTIMIZATION FEATURES:")
    print("✅ Reduced routing iterations (3→2)")
    print("✅ Cached edge type statistics")
    print("✅ Simplified normalization")
    print("✅ Precomputed merged graphs")
    print("✅ Gradient accumulation (batch_size=4, accumulate=2)")
    print("✅ Reduced logging frequency")
    
    combinations = ['P', 'L', 'I', 'PL', 'PI', 'LI', 'PLI']
    # combinations = ['LI']
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    results = {}
    routing_analyses = {}
    saved_models = {}
    
    for combination in combinations:
        print(f"\n{'='*20} OPTIMIZED {combination} {'='*20}")
        
        try:
            combined_train_data, combined_val_data = prepare_combined_training_dataset(
                real_train_csv, real_val_csv, real_data_path,
                synthetic_train_dir, synthetic_val_dir, combination,
                use_half=False
            )
            
            if len(combined_train_data) == 0 or len(combined_val_data) == 0:
                print(f"  Insufficient combined data for {combination}, skipping...")
                continue
            
            core_df = load_csv(core_csv, use_half=False)
            holdout_df = load_csv(holdout_csv, use_half=False)
            core_data = prepare_real_dataset_combined(core_df, real_data_path, combination)
            holdout_data = prepare_real_dataset_combined(holdout_df, real_data_path, combination)
            
            # OPTIMIZED: Smaller batch size with gradient accumulation
            train_loader = DataLoader(combined_train_data, batch_size=4, shuffle=True)
            val_loader = DataLoader(combined_val_data, batch_size=4)
            core_loader = DataLoader(core_data, batch_size=4) if core_data else None
            holdout_loader = DataLoader(holdout_data, batch_size=4) if holdout_data else None
            
            input_dim = combined_train_data[0].x.size(1)
            print(f"  Input dimension: {input_dim}")
            
            print("Training Optimized Edge-Type Aware Capsule Network...")
            model = OptimizedEdgeAwareCapsuleGNN(input_dim=input_dim, hidden_dim=64, num_layers=2)
            print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
            
            # Analyze edge types in training data
            edge_type_stats = {'intra_protein': 0, 'intra_ligand': 0, 'inter_connection': 0}
            for data in combined_train_data:
                if hasattr(data, 'edge_type_counts'):
                    edge_type_stats['intra_protein'] += data.edge_type_counts[0].item()
                    edge_type_stats['intra_ligand'] += data.edge_type_counts[1].item()
                    edge_type_stats['inter_connection'] += data.edge_type_counts[2].item()
            
            print(f"  Edge type distribution: Intra-P={edge_type_stats['intra_protein']}, "
                  f"Intra-L={edge_type_stats['intra_ligand']}, Inter={edge_type_stats['inter_connection']}")
            
            expected_inter = 'I' in combination
            actual_inter = edge_type_stats['inter_connection'] > 0
            print(f"  Inter-connection check: Expected={expected_inter}, Actual={actual_inter}")
            
            trained_model, routing_stats = train_optimized_model(
                model, train_loader, val_loader, combination, 
                epochs=150, device=device
            )
            
            saved_models[combination] = f"optimized_edge_aware_capsule_{combination}_model.pth"
            
            # Test the model
            core_rp = core_rmse = holdout_rp = holdout_rmse = 0
            
            if core_loader:
                print("Testing on 2016 Core Set...")
                core_preds, core_targets, core_rp, core_rmse, core_routing = test_optimized_model(
                    trained_model, core_loader, combination, device)
                print(f"  Core Set - Rp: {core_rp:.3f}, RMSE: {core_rmse:.3f}")
                
                if core_routing:
                    avg_routing = {
                        'intra_protein': np.mean([r['intra_protein_attention'] for r in core_routing]),
                        'intra_ligand': np.mean([r['intra_ligand_attention'] for r in core_routing]), 
                        'inter_connection': np.mean([r['inter_connection_attention'] for r in core_routing])
                    }
                    print(f"  Core Routing - Intra-P: {avg_routing['intra_protein']:.3f}, "
                          f"Intra-L: {avg_routing['intra_ligand']:.3f}, Inter: {avg_routing['inter_connection']:.3f}")
                    routing_analyses[combination] = avg_routing
            
            if holdout_loader:
                print("Testing on 2019 Holdout Set...")
                holdout_preds, holdout_targets, holdout_rp, holdout_rmse, holdout_routing = test_optimized_model(
                    trained_model, holdout_loader, combination, device)
                print(f"  Holdout Set - Rp: {holdout_rp:.3f}, RMSE: {holdout_rmse:.3f}")
                
                if holdout_routing:
                    holdout_avg_routing = {
                        'intra_protein': np.mean([r['intra_protein_attention'] for r in holdout_routing]),
                        'intra_ligand': np.mean([r['intra_ligand_attention'] for r in holdout_routing]), 
                        'inter_connection': np.mean([r['inter_connection_attention'] for r in holdout_routing])
                    }
                    print(f"  Holdout Routing - Intra-P: {holdout_avg_routing['intra_protein']:.3f}, "
                          f"Intra-L: {holdout_avg_routing['intra_ligand']:.3f}, Inter: {holdout_avg_routing['inter_connection']:.3f}")
            
            if not core_loader and not holdout_loader:
                avg_routing = {'intra_protein': 0, 'intra_ligand': 0, 'inter_connection': 0}
                routing_analyses[combination] = avg_routing
            
            results[combination] = {
                'core_rp': core_rp,
                'core_rmse': core_rmse,
                'holdout_rp': holdout_rp,
                'holdout_rmse': holdout_rmse,
                'routing_stats': routing_stats,
                'train_samples': len(combined_train_data),
                'val_samples': len(combined_val_data),
                'core_samples': len(core_data) if core_data else 0,
                'holdout_samples': len(holdout_data) if holdout_data else 0,
                'edge_stats': edge_type_stats,
                'expected_inter': expected_inter,
                'actual_inter': actual_inter
            }
            
            # Clean up GPU memory
            del trained_model, model
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"Error processing {combination}: {e}")
            results[combination] = {
                'core_rp': 0, 'core_rmse': 0, 'holdout_rp': 0, 'holdout_rmse': 0,
                'routing_stats': [], 'train_samples': 0, 'val_samples': 0, 
                'core_samples': 0, 'holdout_samples': 0,
                'edge_stats': {'intra_protein': 0, 'intra_ligand': 0, 'inter_connection': 0},
                'expected_inter': False, 'actual_inter': False
            }
            routing_analyses[combination] = {'intra_protein': 0, 'intra_ligand': 0, 'inter_connection': 0}
            import traceback
            traceback.print_exc()
            continue
    
    # Print results
    print(f"\n{'='*80}")
    print("OPTIMIZED EDGE-TYPE AWARE CAPSULE NETWORK RESULTS")
    print(f"{'='*80}")
    print(f"{'Model':<6} {'Core Rp':<10} {'Core RMSE':<12} {'Holdout Rp':<12} {'Holdout RMSE':<14} {'Train Size':<12}")
    print("-" * 85)
    
    for combination in combinations:
        if combination in results:
            r = results[combination]
            print(f"{combination:<6} {r['core_rp']:<10.3f} {r['core_rmse']:<12.3f} "
                  f"{r['holdout_rp']:<12.3f} {r['holdout_rmse']:<14.3f} {r['train_samples']:<12}")
    
    print(f"{'='*85}")
    
    # Edge-type validation
    print(f"\n🔧 OPTIMIZED EDGE-TYPE DISTRIBUTION VALIDATION:")
    # for combination in ['LI']:
    for combination in ['P', 'L', 'I', 'PL', 'PI', 'LI', 'PLI']:
        if combination in results and results[combination]['train_samples'] > 0:
            r = results[combination]
            edge_stats = r['edge_stats']
            expected_inter = r['expected_inter']
            actual_inter = r['actual_inter']
            
            print(f"{combination}: Intra-P={edge_stats['intra_protein']}, Intra-L={edge_stats['intra_ligand']}, Inter={edge_stats['inter_connection']}")
            if expected_inter == actual_inter:
                print(f"  ✅ Edge distribution correct (Expected Inter={expected_inter}, Got={actual_inter})")
            else:
                print(f"  ⚠️  Edge mismatch: Expected Inter={expected_inter}, Got Inter={actual_inter}")
    
    # Routing analysis
    print("\n🔍 OPTIMIZED ROUTING COEFFICIENT ANALYSIS")
    print("="*60)
    print(f"{'Model':<6} {'Intra-P':<10} {'Intra-L':<10} {'Inter-PL':<12} {'Dominant':<10}")
    print("-" * 60)
    
    for combination in combinations:
        if combination in routing_analyses:
            routing = routing_analyses[combination]
            dominant = max(routing.keys(), key=lambda k: routing[k])
            dominant_short = {'intra_protein': 'Intra-P', 'intra_ligand': 'Intra-L', 'inter_connection': 'Inter-PL'}[dominant]
            
            print(f"{combination:<6} {routing['intra_protein']:<10.3f} {routing['intra_ligand']:<10.3f} "
                  f"{routing['inter_connection']:<12.3f} {dominant_short:<10}")
    
    print("\n📈 OPTIMIZED INTER-CONNECTION EFFECTIVENESS ANALYSIS")
    print("="*75)
    
    comparisons = [
        ('P', 'PI', 'P vs PI (Inter-connection benefit)'),
        ('L', 'LI', 'L vs LI (Inter-connection benefit)'), 
        ('PL', 'PLI', 'PL vs PLI (Inter-connection benefit)')
    ]
    
    print(f"{'Comparison':<35} {'Without Inter':<12} {'With Inter':<12} {'ΔRp':<8} {'ΔInter-attn':<12} {'Effect':<8}")
    print("-" * 95)
    
    valid_comparisons = 0
    significant_gains = 0
    
    for base, inter_version, label in comparisons:
        if base in results and inter_version in results:
            base_rp = results[base]['core_rp']
            inter_rp = results[inter_version]['core_rp']
            delta_rp = inter_rp - base_rp
            
            base_inter_attn = routing_analyses.get(base, {}).get('inter_connection', 0)
            inter_inter_attn = routing_analyses.get(inter_version, {}).get('inter_connection', 0)
            delta_inter_attn = inter_inter_attn - base_inter_attn
            
            base_rmse = results[base]['core_rmse']
            inter_rmse = results[inter_version]['core_rmse']
            delta_rmse = base_rmse - inter_rmse
            
            effect = "✅ GAIN" if (delta_rp > 0.01 or delta_rmse > 0.1) else "❌ LOSS" if (delta_rp < -0.01 and delta_rmse < -0.1) else "➖ NEUTRAL"
            
            print(f"{label:<35} {base_rp:<12.3f} {inter_rp:<12.3f} {delta_rp:<8.3f} {delta_inter_attn:<12.3f} {effect:<8}")
            print(f"{'    (RMSE comparison)':<35} {base_rmse:<12.3f} {inter_rmse:<12.3f} {delta_rmse:<8.3f}")
            
            valid_comparisons += 1
            if delta_rp > 0.01 or delta_rmse > 0.1:
                significant_gains += 1
        else:
            print(f"{label:<35} {'N/A':<12} {'N/A':<12} {'N/A':<8} {'N/A':<12} {'MISSING':<8}")
    
    # Summary analysis
    print(f"\n📊 OPTIMIZED INTER-CONNECTION IMPACT SUMMARY:")
    if valid_comparisons > 0:
        print(f"Significant improvements: {significant_gains}/{valid_comparisons}")
        print(f"Inter-connection success rate: {significant_gains/valid_comparisons*100:.1f}%")
    else:
        print("No valid comparisons available")
    
    # Save results
    print("\n💾 Saving optimized results...")
    
    test_results_data = []
    for combination in combinations:
        if combination in results:
            r = results[combination]
            routing = routing_analyses.get(combination, {'intra_protein': 0, 'intra_ligand': 0, 'inter_connection': 0})
            
            test_results_data.append({
                'combination': combination,
                'core_rp': r['core_rp'],
                'core_rmse': r['core_rmse'],
                'holdout_rp': r['holdout_rp'],
                'holdout_rmse': r['holdout_rmse'],
                'train_samples': r['train_samples'],
                'val_samples': r['val_samples'],
                'core_samples': r['core_samples'],
                'holdout_samples': r['holdout_samples'],
                'intra_protein_attention': routing['intra_protein'],
                'intra_ligand_attention': routing['intra_ligand'],
                'inter_connection_attention': routing['inter_connection'],
                'intra_protein_edges': r['edge_stats']['intra_protein'],
                'intra_ligand_edges': r['edge_stats']['intra_ligand'],
                'inter_connection_edges': r['edge_stats']['inter_connection'],
                'expected_inter': r['expected_inter'],
                'actual_inter': r['actual_inter'],
                'model_path': saved_models.get(combination, 'N/A')
            })
    
    results_df = pd.DataFrame(test_results_data)
    results_df.to_csv('optimized_edge_aware_capsule_results.csv', index=False)
    
    print("✅ Results saved to 'optimized_edge_aware_capsule_results.csv'")
    print("✅ Models saved with naming: 'optimized_edge_aware_capsule_{combination}_model.pth'")
    
    print("\n🎉 OPTIMIZED EXECUTION COMPLETED!")
    print(f"📊 {len(results)} models tested with optimizations")
    print(f"⚡ Training speed significantly improved")
    print(f"🔍 Capsule architecture and edge-type awareness preserved")
    print(f"💾 All models saved for future use")
    
    print("\n📋 OPTIMIZATION SUMMARY:")
    print("- ⚡ Routing iterations: 3→2 (33% speedup)")
    print("- 📊 Edge type statistics cached (eliminates recomputation)")
    print("- 🔄 Simplified normalization (faster)")
    print("- 🗂️  Precomputed merged graphs (no on-demand merging)")
    print("- 🎯 Gradient accumulation (effective batch_size=8 with memory=4)")
    print("- 📝 Reduced logging frequency (less I/O overhead)")
    print("- 🧠 Capsule architecture fully preserved")
    print("- 🔗 Edge-type awareness intact")
    
    print("\n📖 OPTIMIZED MODEL LOADING:")
    print("```python")
    print("model, checkpoint = load_optimized_model('optimized_edge_aware_capsule_PLI_model.pth')")
    print("analysis = analyze_optimized_edge_importance(model, test_loader, 'PLI')")
    print("```")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"\n❌ CRITICAL ERROR: {e}")
        print("\nOPTIMIZATION DEBUGGING:")
        print("1. ✓ Check if optimizations maintain model accuracy")
        print("2. ✓ Verify cached edge statistics are correct")
        print("3. ✓ Ensure gradient accumulation works properly")
        print("4. ✓ Monitor memory usage with smaller batches")
        print("5. ✓ Validate capsule routing still functions")
        import traceback
        traceback.print_exc() 

🚀 OPTIMIZED EDGE-TYPE AWARE CAPSULE NETWORK
⚡ OPTIMIZATION FEATURES:
✅ Reduced routing iterations (3→2)
✅ Cached edge type statistics
✅ Simplified normalization
✅ Precomputed merged graphs
✅ Gradient accumulation (batch_size=4, accumulate=2)
✅ Reduced logging frequency
Using device: cuda

Preparing FULL combined training dataset for P...
  P: 9312 real graphs loaded, 350 failed
  P: 871 real graphs loaded, 32 failed
  P: 9287 synthetic graphs loaded, 25 failed
  P: 871 synthetic graphs loaded, 0 failed
  Train: 9312 real + 9287 synthetic = 18599
  Val: 871 real + 871 synthetic = 1742
  P: 249 real graphs loaded, 8 failed
  P: 3232 real graphs loaded, 161 failed
  Input dimension: 17
Training Optimized Edge-Type Aware Capsule Network...
Model parameters: 23,937
  Edge type distribution: Intra-P=4415508, Intra-L=0, Inter=0
  Inter-connection check: Expected=False, Actual=False
    Epoch 0: Train Loss=3.6519, Val Loss=3.3622
    Epoch 5: Train Loss=2.9193, Val Loss=2.8863
    Epoch 10: Tr