<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_quantum_cosmology_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib

In [None]:
#!/usr/bin/env python3
"""
train_quantum_cosmology_ai.py

Physics-informed AI pipeline for early-universe quantum state evolution:

1. Synthetic dataset of 6 initial cosmological variables → 3 evolution metrics
2. PINN loss: supervised MSE + Friedmann‐equation residual
3. MLP with LayerNorm & Dropout for predictive uncertainty
4. MC-Dropout inference to quantify uncertainty
5. Training loop with AdamW, ReduceLROnPlateau, early stopping
6. Visualizations: loss curves, true vs. predicted scatter, uncertainty heatmap
"""

import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Dataset
# ------------------------------------------------------------------------------
class QuantumCosmologyDataset(Dataset):
    def __init__(self, n_samples=5000, seed=0):
        np.random.seed(seed)
        # Sample inputs:
        # rho_rad, rho_mat, rho_de in [0.1, 1.0]
        # a0 in [0.5, 1.5], k in [-0.1,0.1], t in [0.1,5.0]
        rho_rad = np.random.uniform(0.1,1.0, size=(n_samples,1)).astype(np.float32)
        rho_mat = np.random.uniform(0.1,1.0, size=(n_samples,1)).astype(np.float32)
        rho_de  = np.random.uniform(0.1,1.0, size=(n_samples,1)).astype(np.float32)
        a0      = np.random.uniform(0.5,1.5, size=(n_samples,1)).astype(np.float32)
        k       = np.random.uniform(-0.1,0.1, size=(n_samples,1)).astype(np.float32)
        t       = np.random.uniform(0.1,5.0, size=(n_samples,1)).astype(np.float32)

        X_raw = np.hstack([rho_rad, rho_mat, rho_de, a0, k, t])

        # Compute analytic targets
        sum_rho = rho_rad + rho_mat + rho_de  # total density
        H0_sq = np.maximum(sum_rho - k/(a0**2), 1e-6)
        H0 = np.sqrt(H0_sq)

        a_t = a0 * np.exp(H0 * t)                        # scale factor
        H_t = H0                                         # Hubble parameter
        Omega_k = - k / (a_t**2 * H0_sq)                 # curvature parameter

        Y_raw = np.hstack([a_t, H_t, Omega_k]).astype(np.float32)
        # add small noise
        Y_raw += 0.01 * np.random.randn(*Y_raw.shape).astype(np.float32)

        # stats for normalization
        self.X_mean, self.X_std = X_raw.mean(0), X_raw.std(0)
        self.Y_mean, self.Y_std = Y_raw.mean(0), Y_raw.std(0)

        # normalize
        self.X = (X_raw - self.X_mean) / self.X_std
        self.Y = (Y_raw - self.Y_mean) / self.Y_std

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])
        y = torch.from_numpy(self.Y[idx])
        return x, y

# ------------------------------------------------------------------------------
# 2. Model Architecture with Dropout
# ------------------------------------------------------------------------------
class QuantumCosmologyAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Loss
# ------------------------------------------------------------------------------
def physics_residual(pred, inp, stats):
    # Denormalize inputs and predictions
    rho_rad = inp[:,0] * stats['X_std'][0] + stats['X_mean'][0]
    rho_mat = inp[:,1] * stats['X_std'][1] + stats['X_mean'][1]
    rho_de  = inp[:,2] * stats['X_std'][2] + stats['X_mean'][2]
    a0      = inp[:,3] * stats['X_std'][3] + stats['X_mean'][3]
    k       = inp[:,4] * stats['X_std'][4] + stats['X_mean'][4]
    t       = inp[:,5] * stats['X_std'][5] + stats['X_mean'][5]

    # Predictions
    a_pred = pred[:,0] * stats['Y_std'][0] + stats['Y_mean'][0]
    H_pred = pred[:,1] * stats['Y_std'][1] + stats['Y_mean'][1]
    # Omega_k_pred not used in residual

    # Friedmann equation residual: H^2 ≈ rho_total - k/a^2
    rho_tot = rho_rad + rho_mat + rho_de
    resid = H_pred**2 - (rho_tot - k / (a_pred**2))
    return torch.mean(resid**2)

def total_loss(pred, true, inp, stats, lambda_phys=1.0):
    mse = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, inp, stats)
    return mse + lambda_phys * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, n_samples=100):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds, axis=0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lambda_phys=1.0,
          max_epochs=100, patience=10):
    model.to(device)
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min',
                                                factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, max_epochs+1):
        # Training
        model.train()
        tr_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, lambda_phys)
            opt.zero_grad()
            loss.backward()
            opt.step()
            tr_loss += loss.item() * xb.size(0)
        tr_loss /= len(train_loader.dataset)

        # Validation
        model.eval()
        vl_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lambda_phys)
                vl_loss += loss.item() * xb.size(0)
        vl_loss /= len(val_loader.dataset)

        sched.step(vl_loss)
        history['train'].append(tr_loss)
        history['val'].append(vl_loss)
        print(f"Epoch {epoch:03d} | Train {tr_loss:.4e} | Val {vl_loss:.4e}")

        # Early stopping
        if vl_loss < best_val - 1e-6:
            best_val, wait = vl_loss, 0
            torch.save(model.state_dict(), "best_quantum_cosmology_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    model.load_state_dict(torch.load("best_quantum_cosmology_ai.pth"))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(history):
    plt.figure()
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.legend(); plt.title("Training Curve")
    plt.tight_layout(); plt.show()

def plot_scatter(true_vals, pred_vals, name):
    plt.figure()
    plt.scatter(true_vals, pred_vals, s=5, alpha=0.5)
    m, M = true_vals.min(), true_vals.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel(f"True {name}"); plt.ylabel(f"Pred {name}")
    plt.title(f"{name}: True vs Pred")
    plt.tight_layout(); plt.show()

def plot_uncertainty_heatmap(model, stats_np, device):
    # vary rho_mat vs rho_de, fix others
    vals = np.linspace(0.1, 1.0, 100)
    GM, GD = np.meshgrid(vals, vals)
    grid = np.zeros((GM.size, 6), dtype=np.float32)
    # set rho_rad, rho_mat, rho_de
    grid[:,0] = 0.5
    grid[:,1] = GM.ravel()
    grid[:,2] = GD.ravel()
    # fix a0,k,t at mean
    grid[:,3] = stats_np['X_mean'][3]
    grid[:,4] = stats_np['X_mean'][4]
    grid[:,5] = stats_np['X_mean'][5]

    Xn = (grid - stats_np['X_mean']) / stats_np['X_std']
    Xt = torch.from_numpy(Xn).float().to(device)
    _, std = mc_dropout_predict(model, Xt, n_samples=100)
    std_map = std[:,0].reshape(GM.shape)  # uncertainty in a(t)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(vals, vals, std_map, cmap='viridis', shading='auto')
    plt.colorbar(label="Std(a(t))")
    plt.xlabel("rho_mat"); plt.ylabel("rho_de")
    plt.title("Uncertainty Heatmap for Scale Factor")
    plt.tight_layout(); plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds = QuantumCosmologyDataset(n_samples=8000)
    # stats for torch
    stats = {
        'X_mean': torch.tensor(ds.X_mean, device=device),
        'X_std':  torch.tensor(ds.X_std,  device=device),
        'Y_mean': torch.tensor(ds.Y_mean, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  device=device),
    }
    # stats for numpy
    stats_np = {
        'X_mean': ds.X_mean,
        'X_std':  ds.X_std,
        'Y_mean': ds.Y_mean,
        'Y_std':  ds.Y_std,
    }

    # train/val split
    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=64, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=128)

    # build & train
    model   = QuantumCosmologyAI().to(device)
    history = train(model, tr_ld, va_ld, stats, device)

    # plots
    plot_losses(history)

    # scatter for outputs
    X_all = ds.X
    with torch.no_grad():
        Y_pred_norm = model(torch.from_numpy(X_all).float().to(device)).cpu().numpy()
    Y_true_norm = ds.Y
    Y_pred = Y_pred_norm * ds.Y_std + ds.Y_mean
    Y_true = Y_true_norm * ds.Y_std + ds.Y_mean

    for i, name in enumerate(["Scale Factor a(t)", "Hubble H(t)", "Omega_k(t)"]):
        plot_scatter(Y_true[:,i], Y_pred[:,i], name)

    # uncertainty heatmap
    plot_uncertainty_heatmap(model, stats_np, device)