<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_multiverse_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib

In [None]:
#!/usr/bin/env python3
"""
train_multiverse_ai.py

Physics‐informed MultiverseAI pipeline with separate Torch/NumPy stats:

 1. Synthetic dataset for (energy, prob, variance, coupling) →
    (transition_prob, stability_metric)
 2. PINN loss combining data MSE and residuals enforcing analytic toy‐physics
 3. MLP with LayerNorm & Dropout for uncertainty quantification
 4. MC‐Dropout inference at test time
 5. Training loop with AdamW, ReduceLROnPlateau, early stopping
 6. Plots: training curves, scatter true vs predicted, and uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Dataset
# ------------------------------------------------------------------------------

def analytic_transition_probability(E, v, var):
    return torch.sigmoid(- (E * var))

def analytic_stability(p, var, cpl):
    return p * torch.exp(- var * cpl)

class MultiverseDataset(Dataset):
    def __init__(self, n_samples=8000, seed=0):
        torch.manual_seed(seed)
        # Sample ranges
        E   = torch.rand(n_samples, 1) * 10.0
        p   = torch.rand(n_samples, 1)
        var = torch.rand(n_samples, 1)
        cpl = torch.rand(n_samples, 1) * 5.0

        X_raw = torch.cat([E, p, var, cpl], dim=1)
        y1 = analytic_transition_probability(E, var, var)
        y2 = analytic_stability(p, var, cpl)
        Y_raw = torch.cat([y1, y2], dim=1) + 0.02 * torch.randn(n_samples, 2)

        # keep raw PyTorch stats here
        self.stats = {
            'X_mean': X_raw.mean(0),
            'X_std':  X_raw.std(0),
            'Y_mean': Y_raw.mean(0),
            'Y_std':  Y_raw.std(0),
        }

        # normalize
        self.X = (X_raw - self.stats['X_mean']) / self.stats['X_std']
        self.Y = (Y_raw - self.stats['Y_mean']) / self.stats['Y_std']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------

class MultiverseAI(nn.Module):
    def __init__(self, input_dim=4, hidden_dims=(64,64), output_dim=2, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# ------------------------------------------------------------------------------
# 3. Physics‐Informed Loss
# ------------------------------------------------------------------------------

def physics_residual(pred, inp, stats_torch):
    # Denormalize predictions and inputs (all Torch tensors on same device)
    Y_den = pred * stats_torch['Y_std'] + stats_torch['Y_mean']
    X_den = inp  * stats_torch['X_std'] + stats_torch['X_mean']
    E, p, var, cpl = X_den[:,0], X_den[:,1], X_den[:,2], X_den[:,3]

    y1_pred = Y_den[:,0]
    y2_pred = Y_den[:,1]
    y1_true = analytic_transition_probability(E, var, var).detach()
    y2_true = analytic_stability(p, var, cpl).detach()

    return nn.MSELoss()(y1_pred, y1_true) + nn.MSELoss()(y2_pred, y2_true)

def total_loss(pred, true, inp, stats_torch, lambda_phys=1.0):
    mse_loss  = nn.MSELoss()(pred, true)
    phys_loss = physics_residual(pred, inp, stats_torch)
    return mse_loss + lambda_phys * phys_loss, mse_loss, phys_loss


# ------------------------------------------------------------------------------
# 4. MC‐Dropout Inference
# ------------------------------------------------------------------------------

def mc_dropout_predict(model, x, n_samples=50):
    model.train()  # activate dropout
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds, axis=0)
    return arr.mean(axis=0), arr.std(axis=0)


# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------

def train_model(model, train_loader, val_loader, stats_torch, device,
                lr=1e-3, weight_decay=1e-5, lambda_phys=1.0,
                max_epochs=200, patience=20):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(1, max_epochs+1):
        # --- Training ---
        model.train()
        running_train = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats_torch, lambda_phys)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_train += loss.item() * xb.size(0)
        train_loss = running_train / len(train_loader.dataset)

        # --- Validation ---
        model.eval()
        running_val = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats_torch, lambda_phys)
                running_val += loss.item() * xb.size(0)
        val_loss = running_val / len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        # Early stopping
        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_multiverse_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    # Load best model
    model.load_state_dict(torch.load("best_multiverse_ai.pth"))
    return history


# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------

def plot_history(hist):
    plt.plot(hist['train_loss'], label='Train Loss')
    plt.plot(hist['val_loss'],   label='Val Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training Curve")
    plt.show()

def plot_scatter(true_vals, pred_vals, name):
    plt.scatter(true_vals, pred_vals, s=5, alpha=0.5)
    m, M = true_vals.min(), true_vals.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel(f"True {name}")
    plt.ylabel(f"Pred {name}")
    plt.title(f"{name}: True vs Pred")
    plt.show()

def plot_uncertainty_heatmap(model, stats_np, device):
    E_vals   = np.linspace(0,10,100)
    var_vals = np.linspace(0,1,100)
    P, C     = 0.5, 2.5

    EE, VV = np.meshgrid(E_vals, var_vals)
    # Normalize with NumPy stats
    Xgrid = np.stack([
        (EE - stats_np['X_mean'][0]) / stats_np['X_std'][0],
        (np.full_like(EE, P) - stats_np['X_mean'][1]) / stats_np['X_std'][1],
        (VV - stats_np['X_mean'][2]) / stats_np['X_std'][2],
        (np.full_like(EE, C) - stats_np['X_mean'][3]) / stats_np['X_std'][3],
    ], axis=-1).reshape(-1,4)

    X_tensor = torch.from_numpy(Xgrid).float().to(device)
    _, std = mc_dropout_predict(model, X_tensor, n_samples=50)
    std1 = std[:,0].reshape(EE.shape)

    plt.pcolormesh(E_vals, var_vals, std1, shading='auto', cmap='viridis')
    plt.colorbar(label="Std of y1_pred")
    plt.xlabel("Energy")
    plt.ylabel("Variance")
    plt.title("Prediction Uncertainty (y1)")
    plt.show()


# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    dataset = MultiverseDataset(n_samples=8000)

    # Extract Torch stats for training (on correct device)
    stats_torch = {
        k: v.to(device)
        for k, v in dataset.stats.items()
    }

    # Extract NumPy stats for plotting
    stats_np = {
        k: v.cpu().numpy()
        for k, v in dataset.stats.items()
    }

    # Split and loaders
    n_val       = int(0.2 * len(dataset))
    n_train     = len(dataset) - n_val
    train_ds, val_ds = random_split(dataset, [n_train, n_val])
    train_loader     = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader       = DataLoader(val_ds,   batch_size=256)

    # Build & train
    model   = MultiverseAI().to(device)
    history = train_model(
        model, train_loader, val_loader,
        stats_torch, device,
        lr=1e-3, weight_decay=1e-5, lambda_phys=1.0,
        max_epochs=200, patience=20
    )

    # Visualize
    plot_history(history)

    # Scatter true vs. pred
    X_all = dataset.X.to(device)
    with torch.no_grad():
        Y_pred_norm = model(X_all).cpu().numpy()
    Y_true_norm = dataset.Y.numpy()

    Y_pred = Y_pred_norm * stats_np['Y_std'] + stats_np['Y_mean']
    Y_true = Y_true_norm * stats_np['Y_std'] + stats_np['Y_mean']

    plot_scatter(Y_true[:,0], Y_pred[:,0], "Transition Probability")
    plot_scatter(Y_true[:,1], Y_pred[:,1], "Stability Metric")

    plot_uncertainty_heatmap(model, stats_np, device)