In [77]:
import os
import sys
import logging
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset
from torch_geometric.nn import AttentionalAggregation
from torch.utils.data import random_split, WeightedRandomSampler
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error

from torch_geometric.nn import (
    GINEConv, 
    Set2Set, 
    TransformerConv
)
from torch_geometric.nn.norm import BatchNorm

# -----------------------------------------------------------
# Configuration
# -----------------------------------------------------------
CONFIG = {
    'processed_dir': './processed_experimental12',  
    'processed_file_name': 'data.pt',
    'batch_size': 1024,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'hidden_channels': 512,
    'num_epochs': 100,
    'patience': 75,
    'random_seed': 42,
    'best_model_path': 'best_model.pth',
    'dropout_p': 0.4,
    'scheduler_factor': 0.5,
    'scheduler_patience': 10,
    'grad_clip': 1.0,
    'curriculum_alpha': 0.5,
    'num_layers': 10
}

# -----------------------------------------------------------
# Logging & Utilities
# -----------------------------------------------------------
def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(sys.stdout)]
    )

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

# -----------------------------------------------------------
# SpinSystemDataset
# -----------------------------------------------------------
class SpinSystemDataset(InMemoryDataset):
    """Loads the data.pt from processed_dir (old-style)."""
    def __init__(self, root='.', transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return [CONFIG['processed_file_name']]

    def download(self):
        pass

    def process(self):
        pass

class ListDataset(InMemoryDataset):
    """Simple dataset class that wraps a list of data objects."""
    def __init__(self, data_list):
        super().__init__()
        self._data_list = data_list
        
    def __len__(self):  
        return len(self._data_list)
    
    def __getitem__(self, idx): 
        return self._data_list[idx]

# -----------------------------------------------------------
# ExperimentalGNN
# -----------------------------------------------------------
class ExperimentalGNN(nn.Module):
    def __init__(
        self,
        num_node_features,      # Total number of node features
        edge_attr_dim=3,        # [angle, correlation, normalized_distance]
        hidden_channels=64,
        num_layers=4,
        dropout_p=0.4
    ):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.dropout_p = dropout_p

        # Node feature indices (for clarity)
        self.feature_indices = {
            'position': slice(0, 2),           # x, y coordinates
            'rydberg_val': slice(2, 3),        # Site occupation/correlation value
            'mask': slice(3, 4),               # Subsystem mask
            'boundaries': slice(4, 8),         # Distances to boundaries
            'radial': slice(8, 9),             # Radial distance from center
            'angle': slice(9, 10),             # Angular position
            'neighbors': slice(10, 11),        # Normalized neighbor count
            'quantum_windows': slice(11, None) # 4 features × 3 windows
        }

        # Process spatial features: position, boundaries, radial distance
        spatial_dim = 2 + 4 + 1  # pos(2) + boundaries(4) + radial(1)
        self.spatial_transform = nn.Sequential(
            nn.Linear(spatial_dim, hidden_channels // 2),
            nn.LayerNorm(hidden_channels // 2),
            nn.SiLU(),
        )

        # Process quantum features: rydberg value + window features
        quantum_dim = 1 + (4 * 3)  # rydberg(1) + (4 features × 3 windows)
        self.quantum_transform = nn.Sequential(
            nn.Linear(quantum_dim, hidden_channels // 2),
            nn.LayerNorm(hidden_channels // 2),
            nn.SiLU(),
        )

        # Edge feature processing
        self.edge_transform = nn.Sequential(
            nn.Linear(edge_attr_dim, hidden_channels // 2),
            nn.LayerNorm(hidden_channels // 2),
            nn.SiLU(),
        )

        # Combine transformed features
        self.feature_combiner = nn.Sequential(
            nn.Linear(hidden_channels + 3, hidden_channels),  # +3 for angle, neighbors, mask
            nn.LayerNorm(hidden_channels),
            nn.SiLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_channels, hidden_channels)
        )

        # Message Passing layers (alternating GINE and TransformerConv)
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(num_layers):
            if i % 2 == 0:
                mp_mlp = nn.Sequential(
                    nn.Linear(hidden_channels, hidden_channels),
                    nn.LayerNorm(hidden_channels),
                    nn.SiLU(),
                    nn.Dropout(dropout_p),
                    nn.Linear(hidden_channels, hidden_channels)
                )
                conv = GINEConv(mp_mlp, edge_dim=hidden_channels // 2)
            else:
                conv = TransformerConv(
                    hidden_channels, hidden_channels // 4,
                    heads=4,
                    edge_dim=hidden_channels // 2,
                    dropout=dropout_p,
                    beta=True
                )
            self.convs.append(conv)
            self.norms.append(BatchNorm(hidden_channels))

        # Readouts
        self.set2set_readout = Set2Set(hidden_channels, processing_steps=4)
        self.gate_nn = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.SiLU(),
            nn.Linear(hidden_channels // 2, 1)
        )
        self.global_attention = AttentionalAggregation(gate_nn=self.gate_nn)

        self.global_transform = nn.Sequential(
            nn.Linear(3, hidden_channels),  # [nA/N, nB/N, N]
            nn.LayerNorm(hidden_channels),
            nn.SiLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_channels, hidden_channels)
        )

        # Size encoder (still present, even if you only train on 1 size)
        self.size_encoder = nn.Sequential(
            nn.Linear(1, hidden_channels // 2),
            nn.SiLU(),
            nn.Linear(hidden_channels // 2, hidden_channels // 2)
        )

        combined_dim = (2 * hidden_channels) + (2 * hidden_channels) + (hidden_channels // 2)
        self.final_mlp = nn.Sequential(
            nn.Linear(combined_dim, hidden_channels),
            nn.LayerNorm(hidden_channels),
            nn.SiLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.LayerNorm(hidden_channels // 2),
            nn.SiLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_channels // 2, 2)
        )

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # Split node features
        spatial_features = torch.cat([
            x[:, self.feature_indices['position']],
            x[:, self.feature_indices['boundaries']],
            x[:, self.feature_indices['radial']]
        ], dim=1)

        quantum_features = torch.cat([
            x[:, self.feature_indices['rydberg_val']],
            x[:, self.feature_indices['quantum_windows']]
        ], dim=1)

        angle_features = x[:, self.feature_indices['angle']]
        neighbor_features = x[:, self.feature_indices['neighbors']]
        mask_features = x[:, self.feature_indices['mask']]

        # Transform features
        spatial_out = self.spatial_transform(spatial_features)
        quantum_out = self.quantum_transform(quantum_features)
        edge_features = self.edge_transform(edge_attr)

        # Combine node features
        h = self.feature_combiner(torch.cat([
            spatial_out, quantum_out,
            angle_features, neighbor_features, mask_features
        ], dim=1))

        # Message passing
        for i in range(self.num_layers):
            h_new = self.convs[i](h, edge_index, edge_features)
            h_new = self.norms[i](h_new)
            h = h + h_new

        # Readouts
        s2s = self.set2set_readout(h, batch)
        ga = self.global_attention(h, batch)

        system_size = data.system_size.squeeze(-1)
        nA = data.nA.squeeze(-1)
        nB = data.nB.squeeze(-1)

        global_feats = torch.stack([
            nA / system_size,
            nB / system_size,
            system_size,
        ], dim=1)

        gf_out = self.global_transform(global_feats)
        size_encoded = self.size_encoder(system_size.unsqueeze(-1))
        
        combined = torch.cat([s2s, ga, gf_out, size_encoded], dim=-1)
        out = self.final_mlp(combined)
        return out

# -----------------------------------------------------------
# PhysicalScaleAwareLoss
# -----------------------------------------------------------
class PhysicalScaleAwareLoss(nn.Module):
    def __init__(self, base_size=4, scaling_power=1.5, physics_weight=1.0):
        super().__init__()
        self.base_size = base_size
        self.scaling_power = scaling_power
        self.physics_weight = physics_weight

    def get_entropy_bounds(self, system_size, subsystem_size):
        """
        Lower bound is 0, 
        upper bound is min(A, B)*log(2).
        """
        lower_bound = torch.zeros_like(system_size, dtype=torch.float)
        min_size = torch.minimum(subsystem_size.float(), (system_size - subsystem_size).float())
        upper_bound = min_size * torch.log(torch.tensor(2.0, device=system_size.device))
        return lower_bound, upper_bound

    def forward(self, pred_log_s_over_n, target, system_size, subsystem_size):
        # Convert prediction to absolute entropy
        pred_entropy = torch.exp(pred_log_s_over_n * system_size)

        # Physical bounds
        lower_bound, upper_bound = self.get_entropy_bounds(system_size, subsystem_size)

        # Out-of-bounds penalty
        lower_violation = F.smooth_l1_loss(
            torch.maximum(lower_bound, pred_entropy),
            pred_entropy,
            reduction='none'
        )
        upper_violation = F.smooth_l1_loss(
            torch.minimum(upper_bound, pred_entropy),
            pred_entropy,
            reduction='none'
        )
        physics_loss = lower_violation + upper_violation

        # MSE on log(S)/N
        log_target = torch.log(target + 1e-10) / system_size
        base_loss = F.mse_loss(pred_log_s_over_n, log_target, reduction='none')

        # Size-based weighting
        size_weight = (system_size.float() / self.base_size) ** self.scaling_power
        weighted_loss = base_loss * size_weight

        return (weighted_loss + self.physics_weight * physics_loss).mean()

def mape_loss(pred, target, eps=1e-10):
    """
    Standard MAPE:
      mean( |(pred - target) / (target + eps)| )
    We omit the *100 since it's just a scaling factor.
    """
    return torch.mean(torch.abs((pred - target) / (target + eps)))

class MultiTaskLoss(nn.Module):
    """
    Adds a MAPE term on absolute S with a small weight to keep training smooth.
    """
    def __init__(self, physics_weight=0.3, alpha_s_over_n=0.2, alpha_mape=0.1):
        """
        Args:
          physics_weight: weight for physical bound penalty (inside PhysicalScaleAwareLoss).
          alpha_s_over_n: weight for MSE on second output (S/N).
          alpha_mape:     weight for MAPE on absolute S.
        """
        super().__init__()
        self.phys_loss = PhysicalScaleAwareLoss(physics_weight=physics_weight)
        self.alpha_s_over_n = alpha_s_over_n
        self.alpha_mape = alpha_mape

    def forward(self, preds, targets, system_size, subsystem_size):
        # First head predicts log(S/N), second head predicts (S/N).
        pred_log_s_over_n = preds[:, 0]
        pred_s_over_n     = preds[:, 1]

        # (1) Physical scale aware loss (log(S/N) head)
        loss1 = self.phys_loss(pred_log_s_over_n, targets, system_size, subsystem_size)

        # (2) MSE on second head
        true_s_over_n = targets / (system_size + 1e-10)
        loss2 = F.mse_loss(pred_s_over_n, true_s_over_n)

        # (3) MAPE on absolute S
        pred_entropy_abs = torch.exp(pred_log_s_over_n * system_size)
        loss_mape = mape_loss(pred_entropy_abs, targets)

        # Combine
        total_loss = loss1 + self.alpha_s_over_n * loss2 + self.alpha_mape * loss_mape
        return total_loss
# -----------------------------------------------------------
# Curriculum Sampler
# -----------------------------------------------------------
def get_curriculum_sampler(dataset, epoch, max_epochs, alpha=0.5):
    """
    Adjust sampling probabilities based on system size and training progress.
    This has little effect if the dataset has only one size, but is kept for completeness.
    """
    system_sizes = [data.system_size.item() for data in dataset]
    if len(system_sizes) == 0:
        return None

    max_size = max(system_sizes)
    progress = min(1.0, epoch / (max_epochs * 0.7))

    weights = []
    for size in system_sizes:
        size_ratio = (size - 4) / (max_size - 4 + 1e-6)
        w = 1.0 + (progress ** alpha) * size_ratio
        weights.append(w if w > 0 else 1e-6)

    return WeightedRandomSampler(weights=weights, num_samples=len(dataset), replacement=True)

# -----------------------------------------------------------
# Training & Evaluation
# -----------------------------------------------------------
def train_epoch(model, loader, optimizer, criterion, device, clip_grad=None):
    model.train()
    total_loss = 0.0
    n_samples = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        preds = model(data)
        subsystem_size = data.nA.squeeze(-1)
        system_size_ = data.system_size.squeeze(-1)
        targets_ = data.y.squeeze()

        loss = criterion(preds, targets_, system_size_, subsystem_size)
        loss.backward()

        if clip_grad is not None:
            nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        optimizer.step()

        batch_size = data.num_graphs
        n_samples += batch_size
        total_loss += loss.item() * batch_size

    return total_loss / n_samples if n_samples > 0 else 0.0

@torch.no_grad()
def evaluate(model, loader, criterion, device, name='Eval'):
    model.eval()
    total_loss = 0.0
    n_samples = 0

    # Lists to accumulate predictions and targets
    all_preds_abs = []
    all_preds_s_over_n = []
    all_targets = []

    for data in loader:
        data = data.to(device)
        preds = model(data)

        subsystem_size = data.nA.squeeze(-1)
        system_size_ = data.system_size.squeeze(-1)
        targets_ = data.y.squeeze()

        # Compute loss
        loss = criterion(preds, targets_, system_size_, subsystem_size)
        batch_size = data.num_graphs
        total_loss += loss.item() * batch_size
        n_samples += batch_size

        # Convert first head (log(S/N)) -> absolute entropy
        pred_log_s_over_n = preds[:, 0]
        pred_entropy_abs = torch.exp(pred_log_s_over_n * system_size_)

        # Accumulate for metrics
        all_preds_abs.append(pred_entropy_abs.cpu())
        all_preds_s_over_n.append(preds[:, 1].cpu())
        all_targets.append(targets_.cpu())

    mean_loss = total_loss / n_samples if n_samples > 0 else 0.0

    # Convert all predictions/targets to NumPy for sklearn metrics
    all_preds_abs = torch.cat(all_preds_abs).numpy()
    all_preds_s_over_n = torch.cat(all_preds_s_over_n).numpy()
    all_targets = torch.cat(all_targets).numpy()

    # Compute MSE, MAE, MAPE on absolute entropy
    mse_abs = mean_squared_error(all_targets, all_preds_abs)
    mae_abs = mean_absolute_error(all_targets, all_preds_abs)
    # Small epsilon is added in the denominator to avoid division by zero
    mape_abs = np.mean(np.abs((all_preds_abs - all_targets) / (all_targets + 1e-10))) * 100

    logging.info(f"\n{name} Summary:")
    logging.info(f"  Loss: {mean_loss:.6f}")
    logging.info(f"  MSE (Absolute Entropy): {mse_abs:.6f}")
    logging.info(f"  MAE (Absolute Entropy): {mae_abs:.6f}")
    logging.info(f"  MAPE (Absolute Entropy): {mape_abs:.2f}%")

    return {
        'loss': mean_loss,
        'mse_abs': mse_abs,
        'mae_abs': mae_abs,
        'mape_abs': mape_abs,
        'predictions_abs': all_preds_abs,        # S_pred
        'predictions_s_over_n': all_preds_s_over_n,  # (S_pred)/N
        'targets': all_targets                   # Actual S
    }


# -----------------------------------------------------------
# Main
# -----------------------------------------------------------
def main():
    setup_logging()
    set_seed(CONFIG['random_seed'])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load dataset
    dataset = SpinSystemDataset(root=CONFIG['processed_dir'])
    if len(dataset) == 0:
        logging.error("Loaded dataset is empty. Exiting.")
        return

    # Since we train on a single size, dataset likely already contains only one size.
    # Still, we do a standard train/val split:
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(
        dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(CONFIG['random_seed'])
    )

    # Build model
    sample_data = next(iter(DataLoader(train_dataset, batch_size=1)))
    model = ExperimentalGNN(
        num_node_features=sample_data.x.size(1),
        edge_attr_dim=sample_data.edge_attr.size(1),
        hidden_channels=CONFIG['hidden_channels'],
        dropout_p=CONFIG['dropout_p'],
        num_layers=CONFIG['num_layers']
    ).to(device)

    criterion = MultiTaskLoss(physics_weight=0.3, alpha_s_over_n=0.2)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=CONFIG['scheduler_factor'],
        patience=CONFIG['scheduler_patience']
    )

    def create_dataloader(ds, epoch, shuffle=False):
        """Optional curriculum sampler; has minimal effect if ds has only one size."""
        sampler = None
        if shuffle:
            sampler = get_curriculum_sampler(
                ds, epoch, CONFIG['num_epochs'], alpha=CONFIG['curriculum_alpha']
            )
        return DataLoader(ds, batch_size=CONFIG['batch_size'], sampler=sampler, shuffle=(sampler is None))

    best_val_loss = float('inf')

    # Training loop
    for epoch in range(CONFIG['num_epochs']):
        logging.info(f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
        train_loader = create_dataloader(train_dataset, epoch, shuffle=True)
        val_loader = create_dataloader(val_dataset, epoch, shuffle=False)

        # Train
        train_loss = train_epoch(
            model, train_loader, optimizer, criterion,
            device, clip_grad=CONFIG['grad_clip']
        )
        logging.info(f"  Training Loss: {train_loss:.6f}")

        # Validate
        val_metrics = evaluate(model, val_loader, criterion, device, name='Validation')
        val_loss = val_metrics['loss']
        scheduler.step(val_loss)

        # Check if best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), CONFIG['best_model_path'])
            logging.info(f"  [Info] New best model saved (val_loss={best_val_loss:.6f})")

    logging.info("Training complete. Loading best model for final validation...")
    model.load_state_dict(torch.load(CONFIG['best_model_path'], map_location=device))
    _ = evaluate(model, create_dataloader(val_dataset, epoch=0), criterion, device, name='Final Validation')

if __name__ == "__main__":
    main()


2025-01-26 00:17:02 [INFO] Epoch 1/100
2025-01-26 00:18:56 [INFO]   Training Loss: 3884401148.076604
2025-01-26 00:19:12 [INFO] 
Validation Summary:
2025-01-26 00:19:12 [INFO]   Loss: 5.884071
2025-01-26 00:19:12 [INFO]   MSE (Absolute Entropy): 0.387598
2025-01-26 00:19:12 [INFO]   MAE (Absolute Entropy): 0.469329
2025-01-26 00:19:12 [INFO]   MAPE (Absolute Entropy): 1042.18%
2025-01-26 00:19:12 [INFO]   [Info] New best model saved (val_loss=5.884071)
2025-01-26 00:19:12 [INFO] Epoch 2/100
2025-01-26 00:20:33 [INFO]   Training Loss: 1171.676932
2025-01-26 00:20:45 [INFO] 
Validation Summary:
2025-01-26 00:20:45 [INFO]   Loss: 4.377482
2025-01-26 00:20:45 [INFO]   MSE (Absolute Entropy): 0.404510
2025-01-26 00:20:45 [INFO]   MAE (Absolute Entropy): 0.483312
2025-01-26 00:20:45 [INFO]   MAPE (Absolute Entropy): 95.07%
2025-01-26 00:20:45 [INFO]   [Info] New best model saved (val_loss=4.377482)
2025-01-26 00:20:45 [INFO] Epoch 3/100
2025-01-26 00:22:01 [INFO]   Training Loss: 53.469386
2

KeyboardInterrupt: 