<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_wormhole_stability_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_wormhole_stability.py

Physics-informed pipeline for traversable wormhole stability:

1. Synthetic dataset of 6 wormhole parameters → 3 stability metrics
2. PINN residual enforcing throat flare-out condition
3. MLP with LayerNorm & Dropout for uncertainty quantification
4. MC-Dropout inference
5. Training loop with AdamW, ReduceLROnPlateau, early stopping
6. Visualization: loss, true vs pred scatter, uncertainty heatmap
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Wormhole Dataset
# ------------------------------------------------------------------------------
class WormholeDataset(Dataset):
    def __init__(self, n_samples=6000, seed=0):
        np.random.seed(seed)
        # Inputs: exotic_density, energy_density, curvature, pressure, charge, spin
        ed = np.random.uniform(0.1,10.0,(n_samples,1)).astype(np.float32)
        en = np.random.uniform(0.1,10.0,(n_samples,1)).astype(np.float32)
        cu = np.random.uniform(-1.0,1.0,(n_samples,1)).astype(np.float32)
        pr = np.random.uniform(0.0,5.0,(n_samples,1)).astype(np.float32)
        ch = np.random.uniform(0.0,2.0,(n_samples,1)).astype(np.float32)
        sp = np.random.uniform(-1.0,1.0,(n_samples,1)).astype(np.float32)
        X_raw = np.hstack([ed,en,cu,pr,ch,sp])

        # Supervised targets (toy)
        stability_1 = ed*np.exp(-cu) + 0.1*pr
        stability_2 = np.sin(en) * sp + 0.05*ch
        stability_3 = (ed+en)/(1+np.abs(cu))
        Y_raw = np.hstack([stability_1, stability_2, stability_3]).astype(np.float32)
        Y_raw += 0.02 * np.random.randn(*Y_raw.shape).astype(np.float32)

        # Normalize with 1D stats
        self.X_mean, self.X_std = X_raw.mean(0), X_raw.std(0)+1e-6
        self.Y_mean, self.Y_std = Y_raw.mean(0), Y_raw.std(0)+1e-6

        self.X = (X_raw - self.X_mean) / self.X_std
        self.Y = (Y_raw - self.Y_mean) / self.Y_std

    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------
class WormholeAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Residual
# ------------------------------------------------------------------------------

def physics_residual(pred, inp, stats):
    # Denormalize
    X_den = inp * stats['X_std'] + stats['X_mean']
    ed, en, cu, pr, ch, sp = [X_den[:,i] for i in range(6)]
    Y_den = pred * stats['Y_std'] + stats['Y_mean']
    s1, s2, s3 = Y_den[:,0], Y_den[:,1], Y_den[:,2]

    # Toy flare-out: second derivative of radius > 0
    # Here mimic as: ed - cu^2 ≈ s1
    flare = ed - cu**2
    res1  = nn.MSELoss()(s1, flare)
    # Energy-spinning stability: sin(en)*sp ≈ s2
    res2  = nn.MSELoss()(s2, torch.sin(en)*sp)
    # Combined metric: (ed+en)/(1+|cu|) ≈ s3
    res3  = nn.MSELoss()(s3, (ed+en)/(1+cu.abs()))
    return res1 + res2 + res3

def total_loss(pred, true, inp, stats, λ=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, inp, stats)
    return mse + λ*phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(x))
    arr = torch.stack(preds, dim=0)  # [T,B,3]
    mean = arr.mean(0)
    std  = arr.std(0)
    return mean, std

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, λ=1.0, max_epochs=150, patience=15):
    model.to(device)
    opt   = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min',
                                                factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train':[], 'val':[]}

    for epoch in range(1, max_epochs+1):
        model.train()
        tr_loss = 0.0
        for xb,yb in train_loader:
            xb,yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss,_,_ = total_loss(pred,yb,xb,stats,λ)
            opt.zero_grad(); loss.backward(); opt.step()
            tr_loss += loss.item()*xb.size(0)
        tr_loss /= len(train_loader.dataset)

        model.eval()
        val_loss=0.0
        with torch.no_grad():
            for xb,yb in val_loader:
                xb,yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss,_,_ = total_loss(pred,yb,xb,stats,λ)
                val_loss+= loss.item()*xb.size(0)
        val_loss /= len(val_loader.dataset)

        sched.step(val_loss)
        history['train'].append(tr_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {tr_loss:.4e} | Val {val_loss:.4e}")

        if val_loss<best_val-1e-6:
            best_val,wait = val_loss,0
            torch.save(model.state_dict(),"best_wormhole_ai.pth")
        else:
            wait+=1
            if wait>=patience:
                print(f"Early stopping at epoch {epoch}")
                break

    model.load_state_dict(torch.load("best_wormhole_ai.pth"))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_history(h):
    plt.figure()
    plt.plot(h['train'], label='Train')
    plt.plot(h['val'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.legend(); plt.title("Training Curve")
    plt.show()

def plot_scatter(y_true, y_pred, name):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m,M=y_true.min(),y_true.max()
    plt.plot([m,M],[m,M],'r--')
    plt.xlabel(f"True {name}"); plt.ylabel(f"Pred {name}")
    plt.title(name)
    plt.show()

def plot_uncertainty_heatmap(model, stats_np, device):
    # Vary exotic_density vs energy_density, fix others
    vals = np.linspace(0.1,10,100)
    ED, EN = np.meshgrid(vals, vals)
    grid = np.zeros((ED.size,6),dtype=np.float32)
    grid[:,0],grid[:,1]=ED.ravel(),EN.ravel()
    for i in (2,3,4,5):
        grid[:,i]=stats_np['X_mean'][i]

    Xn=torch.from_numpy((grid - stats_np['X_mean'])/stats_np['X_std']).float().to(device)
    mean,std=mc_dropout_predict(model,Xn,T=100)
    std0=std[:,0].cpu().numpy().reshape(ED.shape)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(ED,EN,std0,cmap='viridis',shading='auto')
    plt.colorbar(label='Std(s1)')
    plt.xlabel('Exotic Density');plt.ylabel('Energy Density')
    plt.title('Uncertainty Heatmap')
    plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__=="__main__":
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds=WormholeDataset(6000)
    stats={
        'X_mean':torch.tensor(ds.X_mean,device=device),
        'X_std' :torch.tensor(ds.X_std, device=device),
        'Y_mean':torch.tensor(ds.Y_mean,device=device),
        'Y_std' :torch.tensor(ds.Y_std, device=device),
    }
    stats_np={
        'X_mean':ds.X_mean,'X_std':ds.X_std,
        'Y_mean':ds.Y_mean,'Y_std':ds.Y_std,
    }

    n_val=int(0.2*len(ds))
    tr,va=random_split(ds,[len(ds)-n_val,n_val])
    tr_ld=DataLoader(tr, batch_size=128, shuffle=True)
    va_ld=DataLoader(va, batch_size=256)

    model=WormholeAI().to(device)
    history=train(model,tr_ld,va_ld,stats,device)

    plot_history(history)

    # scatter
    X_all=torch.from_numpy(ds.X).float().to(device)
    with torch.no_grad():
        Yp = model(X_all).cpu().numpy()
    Yt=ds.Y
    for i,name in enumerate(['s1','s2','s3']):
        plot_scatter(Yt[:,i],Yp[:,i],name)

    plot_uncertainty_heatmap(model,stats_np,device)