<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_lqg_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib

In [None]:
#!/usr/bin/env python3
"""
train_lqg_ai.py

Physics-informed LQGAI pipeline:

 1. Synthetic SpinNetworkDataset sampling 5 spin parameters
 2. MLP with LayerNorm & Dropout
 3. PINN loss: MSE + residual of toy Hamiltonian constraint
 4. MC-Dropout inference for uncertainty
 5. Training loop with AdamW, ReduceLROnPlateau, early stopping
 6. Visualization of losses, true vs predicted scatter, and uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Dataset
# ------------------------------------------------------------------------------

def toy_hamiltonian_constraint(spins, evo):
    """
    Toy Hamiltonian constraint residual:
      H |ψ> ≈ sum_{i=0}^{output_dim-1} (spin_i^2) * evo_i
    (we only use as many spins as evo has dimensions)
    """
    # only take the first evo.size(1) spins
    spin_sub = spins[:, :evo.size(1)]
    return (spin_sub**2 * evo).sum(dim=1)

class SpinNetworkDataset(Dataset):
    def __init__(self, n_samples=5000, seed=42):
        torch.manual_seed(seed)
        # Each sample: 5 spin parameters in [0.5, 2.5]
        spins = 0.5 + 2.0 * torch.rand(n_samples, 5)

        # Build a toy "true" evolution via a random projection + noise
        proj = nn.Linear(5, 3)
        with torch.no_grad():
            true_evo = proj(spins)

        Y_raw = true_evo + 0.05 * torch.randn_like(true_evo)

        self.X = spins
        self.Y = Y_raw

        # compute stats for normalization
        self.stats = {
            'X_mean': self.X.mean(0),
            'X_std':  self.X.std(0),
            'Y_mean': self.Y.mean(0),
            'Y_std':  self.Y.std(0),
        }

        # normalize
        self.X = (self.X - self.stats['X_mean']) / self.stats['X_std']
        self.Y = (self.Y - self.stats['Y_mean']) / self.stats['Y_std']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------

class LQGAI(nn.Module):
    def __init__(self, input_dim=5, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# ------------------------------------------------------------------------------
# 3. Physics-Informed Loss
# ------------------------------------------------------------------------------

def physics_residual(pred, inp, stats_torch):
    """
    Enforce toy Hamiltonian constraint:
      H_residual(spins, evo_pred) ≈ 0
    """
    # denormalize
    spins_den = inp * stats_torch['X_std'] + stats_torch['X_mean']
    evo_den   = pred * stats_torch['Y_std'] + stats_torch['Y_mean']

    resid = toy_hamiltonian_constraint(spins_den, evo_den)
    return torch.mean(resid**2)

def total_loss(pred, true, inp, stats_torch, lambda_phys=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, inp, stats_torch)
    return mse + lambda_phys * phys, mse, phys


# ------------------------------------------------------------------------------
# 4. MC-Dropout Inference
# ------------------------------------------------------------------------------

def mc_dropout_predict(model, x, n_samples=50):
    model.train()  # keep dropout active
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds, axis=0)
    return arr.mean(axis=0), arr.std(axis=0)


# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------

def train_model(model, train_loader, val_loader, stats_torch, device,
                lr=1e-3, weight_decay=1e-5, lambda_phys=1.0,
                max_epochs=200, patience=20):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(1, max_epochs+1):
        # training
        model.train()
        train_accum = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats_torch, lambda_phys)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_accum += loss.item() * xb.size(0)
        train_loss = train_accum / len(train_loader.dataset)

        # validation
        model.eval()
        val_accum = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats_torch, lambda_phys)
                val_accum += loss.item() * xb.size(0)
        val_loss = val_accum / len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_lqg_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    model.load_state_dict(torch.load("best_lqg_ai.pth"))
    return history


# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------

def plot_history(hist):
    plt.plot(hist['train_loss'], label='Train Loss')
    plt.plot(hist['val_loss'],   label='Val Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training Curve")
    plt.show()

def plot_scatter(true_vals, pred_vals, name):
    plt.scatter(true_vals, pred_vals, s=5, alpha=0.5)
    m, M = true_vals.min(), true_vals.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel(f"True {name}")
    plt.ylabel(f"Pred {name}")
    plt.title(f"{name}: True vs Pred")
    plt.show()

def plot_uncertainty_heatmap(model, stats_np, device, dim_indices=(0,1)):
    idx0, idx1 = dim_indices
    vals0 = np.linspace(-2, 2, 100)
    vals1 = np.linspace(-2, 2, 100)

    G0, G1 = np.meshgrid(vals0, vals1)
    # fix other dims at zero
    grid = np.zeros((G0.size, 5))
    grid[:, idx0] = G0.ravel()
    grid[:, idx1] = G1.ravel()

    Xnp    = (grid - stats_np['X_mean']) / stats_np['X_std']
    Xtorch = torch.from_numpy(Xnp).float().to(device)

    _, std = mc_dropout_predict(model, Xtorch, n_samples=50)
    heat   = std[:,0].reshape(G0.shape)

    plt.pcolormesh(vals0, vals1, heat, shading='auto', cmap='magma')
    plt.colorbar(label="Std of evo_0")
    plt.xlabel(f"Norm Spin {idx0}")
    plt.ylabel(f"Norm Spin {idx1}")
    plt.title("Uncertainty Heatmap (evo_0)")
    plt.show()


# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    dataset = SpinNetworkDataset(n_samples=5000)
    stats_torch = {k: v.to(device)         for k, v in dataset.stats.items()}
    stats_np    = {k: v.cpu().numpy()      for k, v in dataset.stats.items()}

    # Split dataset
    n_val   = int(0.2 * len(dataset))
    n_trn   = len(dataset) - n_val
    trn_ds, val_ds = random_split(dataset, [n_trn, n_val])
    trn_loader     = DataLoader(trn_ds, batch_size=64, shuffle=True)
    val_loader     = DataLoader(val_ds,   batch_size=128)

    # Build, train, and visualize
    model   = LQGAI().to(device)
    history = train_model(
        model, trn_loader, val_loader,
        stats_torch, device,
        lr=1e-3, weight_decay=1e-5,
        lambda_phys=1.0, max_epochs=200, patience=20
    )

    plot_history(history)

    # True vs Pred for evo_0
    Xall = dataset.X.to(device)
    with torch.no_grad():
        Ypred_norm = model(Xall).cpu().numpy()
    Ytrue_norm   = dataset.Y.numpy()
    evo0_pred = Ypred_norm[:, 0] * stats_np['Y_std'][0] + stats_np['Y_mean'][0]
    evo0_true = Ytrue_norm[:, 0]   * stats_np['Y_std'][0] + stats_np['Y_mean'][0]
    plot_scatter(evo0_true, evo0_pred, "Evolution Factor 0")

    # Uncertainty heatmap over spin dims 0 & 1
    plot_uncertainty_heatmap(model, stats_np, device, dim_indices=(0,1))