<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_universe_creation_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_universe_creation_ai.py

Physics-informed AI pipeline for UniverseCreationAI with NaN‐safe training.

1. Synthetic dataset with range prints
2. Float32 normalization and dtype consistency
3. MLP with LayerNorm, Dropout, ReLU
4. Physics-informed residual with denominator clamps
5. MC-Dropout for uncertainty quantification
6. Training loop: AdamW, ReduceLROnPlateau, gradient clipping, NaN checks, early stopping
7. Safe checkpoint load
8. Visualizations: losses, scatter, uncertainty heatmap
"""

import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split

# ------------------------------------------------------------------------------
# 1. Synthetic Universe Dataset
# ------------------------------------------------------------------------------
class UniverseDataset(Dataset):
    def __init__(self, n_samples=5000, seed=123):
        np.random.seed(seed)
        # Feature ranges
        ve = np.random.uniform(1e-12, 1e-10, (n_samples,1))
        si = np.random.uniform(1e50,   1e52,   (n_samples,1))
        ir = np.random.uniform(50,     70,     (n_samples,1))
        de = np.random.uniform(0.6,    0.8,    (n_samples,1))
        md = np.random.uniform(0.2,    0.4,    (n_samples,1))
        cp = np.random.uniform(0.0,    1.0,    (n_samples,1))

        X_raw = np.hstack([ve, si, ir, de, md, cp]).astype(np.float64)

        # Toy cosmic‐law targets
        eps = 1e-12
        us = ir / (ve + eps)
        ec = de * md
        qc = cp * si**(-1)

        Y_raw = np.hstack([us, ec, qc]).astype(np.float64)
        Y_raw += 0.01 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # Compute stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-8
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-8

        # Standardize to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

        # Print ranges for debug
        print(f"X range: {self.X.min():.3e} to {self.X.max():.3e}")
        print(f"Y range: {self.Y.min():.3e} to {self.Y.max():.3e}")
        print(f"X_std stats: {self.X_std}")
        print(f"Y_std stats: {self.Y_std}")

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------
class UniverseCreationAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Residual & Total Loss (with clamps)
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    # Denormalize
    X_den = X * stats['X_std'] + stats['X_mean']
    ve, si, ir, de, md, cp = X_den.t()
    eps = 1e-6  # larger clamp to avoid tiny denominators

    us_t = ir / torch.clamp(ve + eps, min=eps)
    ec_t = de * md
    qc_t = cp * si.pow(-1).clamp(max=1e12)
    Yt   = torch.stack([us_t, ec_t, qc_t], dim=1)

    Yt_norm = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_norm)

def total_loss(pred, true, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout for Uncertainty Quantification
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    arr = torch.stack(preds, dim=0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop with NaN Checks
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-4, wd=1e-5, lam=1.0, epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for ep in range(1, epochs+1):
        # Training
        model.train()
        run_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, mse, phys = total_loss(pred, yb, xb, stats, lam)

            if torch.isnan(loss):
                print(f"NaN detected at epoch {ep}, batch aborting.")
                return history

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            run_loss += loss.item() * xb.size(0)
        train_loss = run_loss / len(train_loader.dataset)

        # Validation
        model.eval()
        val_run = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lam)
                val_run += loss.item() * xb.size(0)
        val_loss = val_run / len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {ep:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        # Checkpoint
        if val_loss < best_val - 1e-8:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_universe_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {ep}")
                break

    # Safe load
    if os.path.exists("best_universe_ai.pth"):
        model.load_state_dict(torch.load("best_universe_ai.pth", map_location=device))
    else:
        print("Warning: no checkpoint found, using last model state.")

    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(hist):
    plt.figure()
    plt.plot(hist['train'], label='Train')
    plt.plot(hist['val'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.show()

def plot_scatter(y_true, y_pred, name):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.title(name); plt.xlabel("True"); plt.ylabel("Pred"); plt.show()

def plot_uncertainty(model, stats, device):
    ve_vals = np.linspace(1e-12, 1e-10, 100, dtype=np.float32)
    ir_vals = np.linspace(50, 70,        100, dtype=np.float32)
    VE, IR   = np.meshgrid(ve_vals, ir_vals)
    pts = VE.size

    grid = torch.zeros((pts, 6), device=device)
    grid[:,0] = torch.from_numpy(VE.ravel()).to(device)
    grid[:,2] = torch.from_numpy(IR.ravel()).to(device)
    for i in (1,3,4,5):
        grid[:,i] = stats['X_mean'][i]

    Xn = (grid - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout_predict(model, Xn, T=100)
    U = std[:,0].cpu().numpy().reshape(VE.shape)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(VE, IR, U, cmap='viridis', shading='auto')
    plt.colorbar(label='Std Stability')
    plt.xlabel('Vacuum Energy Density'); plt.ylabel('Inflation Rate')
    plt.title('Uncertainty Heatmap: Universe Stability'); plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset & stats
    ds = UniverseDataset(n_samples=5000)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std':  torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    # Data loaders
    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=128, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=256)

    # Model & train
    model   = UniverseCreationAI(input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1)
    history = train(model, tr_ld, va_ld, stats, device,
                    lr=1e-4, wd=1e-5, lam=1.0, epochs=100, patience=10)

    # Visualize
    plot_losses(history)

    X_all = torch.from_numpy(ds.X).to(device)
    with torch.no_grad():
        Yp_norm = model(X_all).cpu().numpy()
    Yp = Yp_norm * ds.Y_std + ds.Y_mean
    Yt = ds.Y    * ds.Y_std + ds.Y_mean

    names = ["Universe Stability", "Expansion Consistency", "Quantum Coherence"]
    for i, nm in enumerate(names):
        plot_scatter(Yt[:,i], Yp[:,i], nm)

    plot_uncertainty(model, stats, device)