<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_wormhole_network_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib

In [None]:
#!/usr/bin/env python3
"""
train_wormhole_network_ai.py

Physics‐informed WormholeNetworkAI pipeline:

 1. Synthetic dataset of 6 wormhole network parameters → 3 stability factors
 2. PINN loss: supervised MSE + toy “energy–topology” residual
 3. MLP with LayerNorm & Dropout for uncertainty
 4. MC‐Dropout inference to quantify predictive variance
 5. Training loop with AdamW, ReduceLROnPlateau, early stopping
 6. Visualizations: loss curves, true vs. predicted scatter, uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Dataset
# ------------------------------------------------------------------------------
class WormholeNetworkDataset(Dataset):
    def __init__(self, n_samples=5000, seed=0):
        np.random.seed(seed)
        # Inputs: exotic_energy, topo0, topo1, topo2, var0, var1
        X_raw = np.random.uniform(0.1, 10.0, size=(n_samples, 6)).astype(np.float32)
        # True stability factors (toy analytic):
        #   s0 = energy * topo0
        #   s1 = topo1 * topo2
        #   s2 = var0 + var1
        s0 = X_raw[:,0] * X_raw[:,1]
        s1 = X_raw[:,3] * X_raw[:,4]
        s2 = X_raw[:,2] + X_raw[:,5]
        Y_raw = np.stack([s0, s1, s2], axis=1).astype(np.float32)
        Y_raw += 0.05 * np.random.randn(*Y_raw.shape).astype(np.float32)

        # Convert to torch
        X_t = torch.from_numpy(X_raw)
        Y_t = torch.from_numpy(Y_raw)

        # Compute normalization stats
        self.stats = {
            'X_mean': X_t.mean(0), 'X_std': X_t.std(0),
            'Y_mean': Y_t.mean(0), 'Y_std': Y_t.std(0),
        }

        # Normalize
        self.X = (X_t - self.stats['X_mean']) / self.stats['X_std']
        self.Y = (Y_t - self.stats['Y_mean']) / self.stats['Y_std']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------
class WormholeNetworkAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# ------------------------------------------------------------------------------
# 3. Physics‐Informed Loss
# ------------------------------------------------------------------------------
def physics_residual(pred, inp, stats_torch):
    """
    Enforce toy residual: sum(pred) ≈
      energy*topo0 + topo1*topo2 + var0+var1
    """
    # Denormalize
    X_den = inp * stats_torch['X_std'] + stats_torch['X_mean']
    Y_den = pred * stats_torch['Y_std'] + stats_torch['Y_mean']
    analytic = (
        X_den[:,0]*X_den[:,1] +
        X_den[:,3]*X_den[:,4] +
        (X_den[:,2] + X_den[:,5])
    )
    pred_sum = Y_den.sum(dim=1)
    return nn.MSELoss()(pred_sum, analytic)

def total_loss(pred, true, inp, stats_torch, lambda_phys=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, inp, stats_torch)
    return mse + lambda_phys * phys, mse, phys


# ------------------------------------------------------------------------------
# 4. MC‐Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, n_samples=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds, axis=0)
    return arr.mean(axis=0), arr.std(axis=0)


# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train_model(model, train_loader, val_loader, stats_torch, device,
                lr=1e-3, wd=1e-5, lambda_phys=1.0,
                max_epochs=200, patience=20):
    model.to(device)
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min',
                                                factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train_loss':[], 'val_loss':[]}

    for epoch in range(1, max_epochs+1):
        # Training
        model.train()
        tr_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats_torch, lambda_phys)
            opt.zero_grad(); loss.backward(); opt.step()
            tr_loss += loss.item() * xb.size(0)
        tr_loss /= len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats_torch, lambda_phys)
                val_loss += loss.item() * xb.size(0)
        val_loss /= len(val_loader.dataset)

        sched.step(val_loss)
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {tr_loss:.4e} | Val {val_loss:.4e}")

        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_wormhole_net_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    model.load_state_dict(torch.load("best_wormhole_net_ai.pth"))
    return history


# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_history(h):
    plt.figure()
    plt.plot(h['train_loss'], label='Train')
    plt.plot(h['val_loss'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.legend(); plt.title("Training Curve")
    plt.tight_layout(); plt.show()

def plot_scatter(true_vals, pred_vals, label):
    plt.figure()
    plt.scatter(true_vals, pred_vals, s=5, alpha=0.5)
    m, M = true_vals.min(), true_vals.max()
    plt.plot([m,M],[m,M],'r--')
    plt.xlabel(f"True {label}"); plt.ylabel(f"Pred {label}")
    plt.title(f"{label}: True vs Pred")
    plt.tight_layout(); plt.show()

def plot_uncertainty_heatmap(model, stats_np, device):
    # heatmap over exotic_energy vs topology0
    E_vals = np.linspace(0.1,10,100)
    T0_vals = np.linspace(0.1,10,100)
    EE, T0 = np.meshgrid(E_vals, T0_vals)

    grid = np.zeros((EE.size, 6), dtype=np.float32)
    grid[:,0] = EE.ravel()
    grid[:,1] = T0.ravel()
    # fix other inputs at their mean
    for i in (2,3,4,5):
        grid[:,i] = stats_np['X_mean'][i]

    Xn = (grid - stats_np['X_mean']) / stats_np['X_std']
    Xt = torch.from_numpy(Xn).to(device).float()

    _, std = mc_dropout_predict(model, Xt, n_samples=100)
    std0 = std[:,0].reshape(EE.shape)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(EE, T0, std0, shading='auto', cmap='viridis')
    plt.colorbar(label="Std of stability₀")
    plt.xlabel("Exotic Energy")
    plt.ylabel("Topology₀")
    plt.title("Uncertainty Heatmap (stability₀)")
    plt.tight_layout(); plt.show()


# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds = WormholeNetworkDataset(n_samples=8000)
    # Torch stats for loss
    stats_torch = {k:v.to(device) for k,v in ds.stats.items()}
    # NumPy stats for plotting
    stats_np    = {k:v.cpu().numpy() for k,v in ds.stats.items()}

    n_val = int(0.2 * len(ds))
    n_trn = len(ds) - n_val
    trn, val = random_split(ds, [n_trn, n_val])
    trn_ld = DataLoader(trn, batch_size=128, shuffle=True)
    val_ld = DataLoader(val, batch_size=256)

    model = WormholeNetworkAI().to(device)
    history = train_model(
        model, trn_ld, val_ld,
        stats_torch, device,
        lr=1e-3, wd=1e-5, lambda_phys=1.0,
        max_epochs=200, patience=20
    )

    plot_history(history)

    # scatter true vs predicted for each stability factor
    X_all = ds.X.to(device)
    with torch.no_grad():
        Y_pred_norm = model(X_all).cpu().numpy()
    Y_true_norm = ds.Y.numpy()

    # denormalize
    Y_pred = Y_pred_norm * stats_np['Y_std'] + stats_np['Y_mean']
    Y_true = Y_true_norm * stats_np['Y_std'] + stats_np['Y_mean']

    for i, name in enumerate(["stability₀","stability₁","stability₂"]):
        plot_scatter(Y_true[:,i], Y_pred[:,i], name)

    plot_uncertainty_heatmap(model, stats_np, device)