<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_universe_sim_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_universe_sim_ai.py

Physics-informed AI pipeline for UniverseSimAI:

1. Synthetic dataset of 6 cosmic parameters → 3 evolution metrics
2. Overflow-safe 1D normalization
3. Physics-informed residual enforcing toy Friedmann & galaxy/entropy laws
4. MLP with LayerNorm & Dropout
5. MC-Dropout inference for uncertainty quantification
6. Training loop with AdamW, ReduceLROnPlateau, gradient clipping, early stopping
7. Visualizations: training curves, scatter plots, uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Universe Simulation Dataset
# ------------------------------------------------------------------------------
class UniverseSimDataset(Dataset):
    def __init__(self, n_samples=7000, seed=1):
        np.random.seed(seed)
        # Inputs:
        # ΩΛ ∈ [0.6,0.8], Ωm ∈ [0.2,0.4], H0 ∈ [60,80] km/s/Mpc,
        # curvature k ∈ [-0.1,0.1], scale_factor a ∈ [0.5,1.5], t ∈ [0,14] Gyr
        ΩΛ = np.random.uniform(0.6, 0.8,  (n_samples,1))
        Ωm = np.random.uniform(0.2, 0.4,  (n_samples,1))
        H0 = np.random.uniform(60,  80,   (n_samples,1))
        k  = np.random.uniform(-0.1,0.1,  (n_samples,1))
        a  = np.random.uniform(0.5, 1.5,  (n_samples,1))
        t  = np.random.uniform(0.0, 14.0, (n_samples,1))

        X_raw = np.hstack([ΩΛ, Ωm, H0, k, a, t]).astype(np.float64)

        # Targets (toy physics):
        er  = H0 * np.sqrt(ΩΛ + Ωm/(a**3) + k/(a**2) + 1e-12)   # expansion rate
        gfi = Ωm * a / (1.0 + ΩΛ)                              # galaxy index
        se  = t  * (1.0 + ΩΛ)/(1.0 + Ωm)                      # entropy evolution

        Y_raw = np.hstack([er, gfi, se]).astype(np.float64)
        Y_raw += 0.02 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # 1D normalization stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-6
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-6

        # standardize
        X = ((X_raw - self.X_mean)/self.X_std).astype(np.float32)
        Y = ((Y_raw - self.Y_mean)/self.Y_std).astype(np.float32)

        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (
            torch.from_numpy(self.X[idx]),
            torch.from_numpy(self.Y[idx])
        )

# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------
class UniverseSimAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(32,32), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Residual and Loss
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    X_den = X * stats['X_std'] + stats['X_mean']
    ΩΛ, Ωm, H0, k, a, t = X_den.t()
    Y_den = pred * stats['Y_std'] + stats['Y_mean']
    er_p, gfi_p, se_p = Y_den.t()

    er_t  = H0 * torch.sqrt(ΩΛ + Ωm/(a**3) + k/(a**2) + 1e-12)
    gfi_t = Ωm * a / (1.0 + ΩΛ)
    se_t  = t  * (1.0 + ΩΛ)/(1.0 + Ωm)

    return (nn.MSELoss()(er_p, er_t) +
            nn.MSELoss()(gfi_p, gfi_t) +
            nn.MSELoss()(se_p, se_t))

def total_loss(pred, true, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(x))
    arr = torch.stack(preds, dim=0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lam=1.0, epochs=150, patience=10):
    model.to(device)
    opt   = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train':[], 'val':[]}

    for ep in range(1, epochs+1):
        model.train()
        t_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred   = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, lam)
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            t_loss += loss.item() * xb.size(0)
        t_loss /= len(train_loader.dataset)

        model.eval()
        v_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred, _, _ = total_loss(model(xb), yb, xb, stats, lam)
                v_loss += pred.item() * xb.size(0)
        v_loss /= len(val_loader.dataset)

        sched.step(v_loss)
        history['train'].append(t_loss)
        history['val'].append(v_loss)
        print(f"Epoch {ep:03d} | Train {t_loss:.4e} | Val {v_loss:.4e}")

        if v_loss + 1e-8 < best_val:
            best_val, wait = v_loss, 0
            torch.save(model.state_dict(), "best_universe_sim_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {ep}")
                break

    model.load_state_dict(torch.load("best_universe_sim_ai.pth"))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_history(hist):
    plt.figure()
    plt.plot(hist['train'], label='Train')
    plt.plot(hist['val'],   label='Val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

def plot_scatter(y_t, y_p, name):
    plt.figure()
    plt.scatter(y_t, y_p, s=5, alpha=0.6)
    m, M = y_t.min(), y_t.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.title(name)
    plt.xlabel("True")
    plt.ylabel("Pred")
    plt.show()

def plot_uncertainty(model, stats, device):
    Ωm_vals = np.linspace(0.2, 0.4, 100).astype(np.float32)
    a_vals  = np.linspace(0.5, 1.5, 100).astype(np.float32)
    Mg, Ag  = np.meshgrid(Ωm_vals, a_vals)
    pts = Mg.size

    # Create grid with matching dtype to stats
    grid_t = torch.zeros((pts, 6), device=device, dtype=stats['X_mean'].dtype)
    # fill inputs: ΩΛ, Ωm, H0, k, a, t
    grid_t[:, 1] = torch.from_numpy(Mg.ravel()).to(device)
    grid_t[:, 4] = torch.from_numpy(Ag.ravel()).to(device)
    for i in (0, 2, 3, 5):
        grid_t[:, i] = stats['X_mean'][i]

    Xn = (grid_t - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout_predict(model, Xn, T=100)
    std_map = std[:, 0].cpu().numpy().reshape(Mg.shape)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(Mg, Ag, std_map, cmap='viridis', shading='auto')
    plt.colorbar(label='Std Expansion Rate')
    plt.xlabel('Matter Density Ωm')
    plt.ylabel('Scale Factor a')
    plt.title('Uncertainty Heatmap: Expansion Rate')
    plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds = UniverseSimDataset(7000)
    # Ensure stats are float32
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std':  torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    n_val = int(0.2 * len(ds))
    tr, va = random_split(ds, [len(ds)-n_val, n_val])
    tr_loader = DataLoader(tr, batch_size=128, shuffle=True)
    va_loader = DataLoader(va, batch_size=256)

    model   = UniverseSimAI().to(device)
    history = train(model, tr_loader, va_loader, stats, device)

    plot_history(history)

    # scatter plots
    with torch.no_grad():
        X_all = torch.from_numpy(ds.X).to(device)
        Yp    = model(X_all).cpu().numpy()
    for i, nm in enumerate(['Expansion Rate','Galaxy Formation','Entropy Evolution']):
        plot_scatter(ds.Y[:, i], Yp[:, i], nm)

    plot_uncertainty(model, stats, device)