<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_multiversal_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_multiversal_ai.py

Physics‐informed AI pipeline for MultiversalAI:

1. Synthetic dataset of 6 “multiversal decision” features → 3 decision metrics
2. Float32 normalization and dtype consistency
3. MLP with LayerNorm & Dropout (accepts int hidden_dims)
4. Physics‐informed residual enforcing toy “multiverse” laws in normalized space
5. MC‐Dropout for uncertainty quantification
6. Training loop with AdamW, ReduceLROnPlateau, gradient clipping, NaN checks, early stopping
7. Safe checkpoint load
8. Visualizations: loss curves, scatter plots, uncertainty heatmap
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split

# ------------------------------------------------------------------------------
# 1. Synthetic Multiversal Decision Dataset
# ------------------------------------------------------------------------------
class MultiversalDataset(Dataset):
    def __init__(self, n_samples=6000, seed=0):
        np.random.seed(seed)
        # Features:
        # ev_p: event probability ∈ [0,1]
        # ev_v: event variance    ∈ [0,1]
        # ent_b: entropy balance  ∈ [0.1,10.0]
        # caus_m: causality metric∈ [0,1]
        # temp_c: temporal coherence ∈ [0,1]
        # res_idx: resource index ∈ [1,100]
        ev_p   = np.random.rand(n_samples,1).astype(np.float64)
        ev_v   = np.random.rand(n_samples,1).astype(np.float64)
        ent_b  = np.random.uniform(0.1, 10.0,  (n_samples,1)).astype(np.float64)
        caus_m = np.random.rand(n_samples,1).astype(np.float64)
        temp_c = np.random.rand(n_samples,1).astype(np.float64)
        res_i  = np.random.uniform(1.0, 100.0, (n_samples,1)).astype(np.float64)

        X_raw = np.hstack([ev_p, ev_v, ent_b, caus_m, temp_c, res_i])

        # Toy “multiverse” targets:
        eps = 1e-6
        # 1. reality_stability
        rs = (ev_p * (1 - ev_v) + caus_m) * temp_c / (ent_b + eps)
        # 2. decision_impact
        di = ev_p * ev_v * res_i / (1 + ent_b)
        # 3. entropy_cost
        ec = ent_b * di / (res_i + eps)

        Y_raw = np.hstack([rs, di, ec]).astype(np.float64)
        # add 1% relative noise
        Y_raw += 0.01 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # Compute normalization stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-8
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-8

        # Standardize to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

        print(f"X range: {self.X.min():.3e}–{self.X.max():.3e}")
        print(f"Y range: {self.Y.min():.3e}–{self.Y.max():.3e}")

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

# ------------------------------------------------------------------------------
# 2. Model Definition (accepts int hidden_dims)
# ------------------------------------------------------------------------------
class MultiversalAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        # allow integer for hidden_dims
        if isinstance(hidden_dims, int):
            hidden_dims = (hidden_dims,)

        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Residual & Total Loss
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    # Denormalize inputs
    X_den = X * stats['X_std'] + stats['X_mean']
    ev_p, ev_v, ent_b, caus_m, temp_c, res_i = X_den.t()
    eps = 1e-4

    # True targets
    rs_t = (ev_p * (1 - ev_v) + caus_m) * temp_c / torch.clamp(ent_b + eps, min=eps)
    di_t = ev_p * ev_v * res_i / (1.0 + ent_b)
    ec_t = ent_b * di_t / torch.clamp(res_i + eps, min=eps)

    Yt = torch.stack([rs_t, di_t, ec_t], dim=1)
    Yt_norm = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_norm)

def total_loss(pred, true, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Uncertainty Quantification
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    stacked = torch.stack(preds, 0)
    return stacked.mean(0), stacked.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop with NaN Safety & Early Stopping
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-4, wd=1e-5, lam=1.0, epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     factor=0.5,
                                                     patience=5)

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for ep in range(1, epochs+1):
        # -- Train --
        model.train()
        run_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, mse, phys = total_loss(pred, yb, xb, stats, lam)

            if torch.isnan(loss):
                print(f"NaN loss at epoch {ep}. Aborting.")
                return history

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            run_loss += loss.item() * xb.size(0)
        train_loss = run_loss / len(train_loader.dataset)

        # -- Validate --
        model.eval()
        run_val = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lam)
                run_val += loss.item() * xb.size(0)
        val_loss = run_val / len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {ep:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        # Checkpoint & early stop
        if val_loss < best_val - 1e-8:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_multiversal_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {ep}")
                break

    # Safe load
    if os.path.exists("best_multiversal_ai.pth"):
        model.load_state_dict(torch.load("best_multiversal_ai.pth",
                                         map_location=device))
    else:
        print("No checkpoint found; using last model state.")

    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_history(history):
    plt.figure()
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'],   label='Val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

def plot_scatter(y_true, y_pred, title):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.title(title)
    plt.xlabel("True")
    plt.ylabel("Predicted")
    plt.show()

def plot_uncertainty(model, stats, device):
    # Vary ev_p (idx=0) vs ev_v (idx=1)
    ev_p_vals = np.linspace(0,1,100, dtype=np.float32)
    ev_v_vals = np.linspace(0,1,100, dtype=np.float32)
    P, V = np.meshgrid(ev_p_vals, ev_v_vals)
    pts = P.size

    grid = torch.zeros((pts,6), device=device)
    grid[:,0] = torch.from_numpy(P.ravel()).to(device)
    grid[:,1] = torch.from_numpy(V.ravel()).to(device)
    for i in (2,3,4,5):
        grid[:,i] = stats['X_mean'][i]

    Xn = (grid - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout_predict(model, Xn, T=100)
    sigma = std[:,0].cpu().numpy().reshape(P.shape)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(P, V, sigma, cmap='magma', shading='auto')
    plt.colorbar(label="Std Stability")
    plt.xlabel("Event Probability")
    plt.ylabel("Event Variance")
    plt.title("Uncertainty Heatmap: Reality Stability")
    plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare dataset & stats
    ds = MultiversalDataset(n_samples=6000)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std' : torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std' : torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    # Data loaders
    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=128, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=256, shuffle=False)

    # Model instantiation with single int hidden_dim
    model = MultiversalAI(input_dim=6, hidden_dims=32, output_dim=3).to(device)

    # Train
    history = train(model, tr_ld, va_ld, stats, device,
                    lr=1e-4, wd=1e-5, lam=1.0,
                    epochs=100, patience=10)

    # Visualize
    plot_history(history)

    # Scatter plots for each target
    X_all = torch.from_numpy(ds.X).to(device)
    with torch.no_grad():
        Yp_norm = model(X_all).cpu().numpy()
    Yp = Yp_norm * ds.Y_std + ds.Y_mean
    Yt = ds.Y * ds.Y_std + ds.Y_mean
    names = ["Reality Stability", "Decision Impact", "Entropy Cost"]
    for i, nm in enumerate(names):
        plot_scatter(Yt[:,i], Yp[:,i], nm)

    # Uncertainty heatmap for stability
    plot_uncertainty(model, stats, device)