<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_wormhole_comm_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib

In [None]:
#!/usr/bin/env python3
"""
train_wormhole_comm_ai.py

Physics-informed AI pipeline for optimizing wormhole communication stability:

1. Synthetic dataset of 5 wormhole parameters → 3 stability metrics
2. Normalize so means/stds are 1D arrays, avoiding shape‐mismatch
3. MLP with LayerNorm & Dropout to capture uncertainty
4. Physics‐informed residual enforcing empirical time‐delay law
5. MC-Dropout inference for predictive confidence
6. Training loop with AdamW, LR scheduler & early stopping
7. Visualizations: loss curves, true vs predicted, uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Wormhole Dataset
# ------------------------------------------------------------------------------
class WormholeCommDataset(Dataset):
    def __init__(self, n_samples=5000, seed=42):
        np.random.seed(seed)
        # Input ranges
        d_exotic   = np.random.uniform(0.1, 10.0, (n_samples,1)).astype(np.float32)
        r_throat   = np.random.uniform(1.0, 10.0, (n_samples,1)).astype(np.float32)
        k_curv     = np.random.uniform(-0.5,0.5,  (n_samples,1)).astype(np.float32)
        z_factor   = np.random.uniform(0.5, 5.0,   (n_samples,1)).astype(np.float32)
        sigma_aniso= np.random.uniform(0.0, 1.0,   (n_samples,1)).astype(np.float32)

        X_raw = np.hstack([d_exotic, r_throat, k_curv, z_factor, sigma_aniso])

        # Empirical targets
        td = (r_throat / (d_exotic * (1 + z_factor) + 1e-6)).astype(np.float32)
        att = np.exp(- d_exotic * sigma_aniso).astype(np.float32)
        st = (d_exotic / (np.abs(k_curv) + 1.0)).astype(np.float32)

        Y_raw = np.hstack([td, att, st])
        Y_raw += 0.02 * np.random.randn(*Y_raw.shape).astype(np.float32)

        # 1D normalization stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0)  + 1e-6
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0)  + 1e-6

        # normalized arrays
        self.X = (X_raw - self.X_mean) / self.X_std
        self.Y = (Y_raw - self.Y_mean) / self.Y_std

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

# ------------------------------------------------------------------------------
# 2. MLP with Dropout & LayerNorm
# ------------------------------------------------------------------------------
class WormholeCommAI(nn.Module):
    def __init__(self, input_dim=5, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, in_dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(in_dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            in_dim = h
        layers.append(nn.Linear(in_dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Residual Loss
# ------------------------------------------------------------------------------
def physics_residual(pred, inp, stats):
    # Denormalize inputs & prediction
    X = inp * stats['X_std'] + stats['X_mean']
    d_ex, r_th, _, z, _ = X[:,0], X[:,1], X[:,2], X[:,3], X[:,4]
    td_pred = pred[:,0] * stats['Y_std'][0] + stats['Y_mean'][0]

    # Empirical law
    t_phys = r_th / (d_ex * (1 + z) + 1e-6)
    return torch.mean((td_pred - t_phys)**2)

def total_loss(pred, true, inp, stats, lambda_phys=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, inp, stats)
    return mse + lambda_phys * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, n_samples=50):
    model.train()  # keep dropout active
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds, 0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lambda_phys=1.0,
          epochs=100, patience=10):
    model.to(device)
    opt   = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train':[], 'val':[]}

    for ep in range(1, epochs+1):
        # — train
        model.train()
        t_loss = 0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, lambda_phys)
            opt.zero_grad(); loss.backward(); opt.step()
            t_loss += loss.item() * xb.size(0)
        t_loss /= len(train_loader.dataset)

        # — validate
        model.eval()
        v_loss = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lambda_phys)
                v_loss += loss.item() * xb.size(0)
        v_loss /= len(val_loader.dataset)

        sched.step(v_loss)
        history['train'].append(t_loss)
        history['val'].append(v_loss)
        print(f"Epoch {ep:03d} | Train {t_loss:.4e} | Val {v_loss:.4e}")

        # early stopping
        if v_loss < best_val - 1e-6:
            best_val, wait = v_loss, 0
            torch.save(model.state_dict(), "best_wormhole_comm_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {ep}")
                break

    model.load_state_dict(torch.load("best_wormhole_comm_ai.pth"))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(hist):
    plt.figure()
    plt.plot(hist['train'], label='Train')
    plt.plot(hist['val'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.legend(); plt.title("Training Curve")
    plt.tight_layout(); plt.show()

def plot_scatter(true, pred, name):
    plt.figure()
    plt.scatter(true, pred, s=5, alpha=0.5)
    m, M = true.min(), true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel(f"True {name}"); plt.ylabel(f"Pred {name}")
    plt.title(name)
    plt.tight_layout(); plt.show()

def plot_uncertainty_heatmap(model, stats_np, device):
    vals = np.linspace(0.1,10.0,100)
    D, R = np.meshgrid(vals, vals)
    grid = np.zeros((D.size,5), dtype=np.float32)
    grid[:,0], grid[:,1] = D.ravel(), R.ravel()
    grid[:,2] = 0.0
    grid[:,3] = stats_np['X_mean'][3]
    grid[:,4] = stats_np['X_mean'][4]

    Xn = (grid - stats_np['X_mean']) / stats_np['X_std']
    Xt = torch.from_numpy(Xn).to(device)
    _, std = mc_dropout_predict(model, Xt, n_samples=50)
    std_map = std[:,0].reshape(D.shape)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(vals, vals, std_map, cmap='plasma', shading='auto')
    plt.colorbar(label="Std(time_delay)")
    plt.xlabel("Exotic Density"); plt.ylabel("Throat Radius")
    plt.title("Uncertainty in Time Delay")
    plt.tight_layout(); plt.show()

# ------------------------------------------------------------------------------
# 7. Main
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds = WormholeCommDataset(n_samples=8000)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, device=device),
        'X_std':  torch.tensor(ds.X_std,  device=device),
        'Y_mean': torch.tensor(ds.Y_mean, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  device=device),
    }
    stats_np = {
        'X_mean': ds.X_mean,
        'X_std':  ds.X_std,
        'Y_mean': ds.Y_mean,
        'Y_std':  ds.Y_std,
    }

    # Split & loaders
    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=64, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=128)

    # Build, train & visualize
    model   = WormholeCommAI().to(device)
    history = train(model, tr_ld, va_ld, stats, device)

    plot_losses(history)

    # scatter plots
    X_all = torch.from_numpy(ds.X).float().to(device)
    with torch.no_grad():
        Yp_n = model(X_all).cpu().numpy()
    Yt_n = ds.Y
    Yp = Yp_n * ds.Y_std + ds.Y_mean
    Yt = Yt_n * ds.Y_std + ds.Y_mean
    names = ["Time Delay", "Attenuation", "Stability Margin"]
    for i, nm in enumerate(names):
        plot_scatter(Yt[:,i], Yp[:,i], nm)

    # uncertainty heatmap
    plot_uncertainty_heatmap(model, stats_np, device)