<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_reality_editor_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_reality_editor_ai.py

Physics‐informed AI pipeline for RealityEditorAI:

1. Synthetic dataset:
   Inputs: Planck‐constant variation, gravitational strength, quantum‐fluct control,
           vacuum permittivity shift, cosmological constant tweak, dark‐matter coupling
   Targets: reality_stability, quantum_consistency, entropy_management
2. 1D normalization (float32)
3. MLP with LayerNorm & Dropout
4. Physics‐informed residual enforcing toy conservation & consistency laws
5. MC‐Dropout inference for uncertainty quantification
6. Training loop with AdamW, ReduceLROnPlateau, gradient clipping, early stopping
7. Visualizations: loss curves, scatter plots, uncertainty heatmap
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Reality Dataset
# ------------------------------------------------------------------------------
class RealityDataset(Dataset):
    def __init__(self, n_samples=5000, seed=42):
        np.random.seed(seed)
        # Features:
        #   dh/h ∈ [-1e-5, 1e-5]
        #   G_strength ∈ [0.9, 1.1]
        #   quantum_ctrl ∈ [0.0, 1.0]
        #   eps_shift ∈ [-1e-6, 1e-6]
        #   lambda_cos ∈ [0.9e-52, 1.1e-52]
        #   dm_coupling ∈ [0.0, 1.0]
        dh   = np.random.uniform(-1e-5, 1e-5, (n_samples,1))
        Gs   = np.random.uniform(0.9, 1.1,     (n_samples,1))
        qctl = np.random.uniform(0.0, 1.0,     (n_samples,1))
        eps  = np.random.uniform(-1e-6, 1e-6,  (n_samples,1))
        lamc = np.random.uniform(0.9e-52,1.1e-52,(n_samples,1))
        dmc  = np.random.uniform(0.0, 1.0,     (n_samples,1))

        X_raw = np.hstack([dh, Gs, qctl, eps, lamc, dmc]).astype(np.float64)

        # Toy‐physics targets
        # 1) reality_stability ≈ |dh| * Gs / (qctl + ε)
        eps0 = 1e-12
        stab = np.abs(dh) * Gs / (qctl + eps0)

        # 2) quantum_consistency ≈ qctl * sqrt(Gs) / (|dh|+ε)
        qc = qctl * np.sqrt(Gs) / (np.abs(dh) + eps0)

        # 3) entropy_management ≈ lamc * dmc / (|eps| + ε)
        em = lamc * dmc / (np.abs(eps) + eps0)

        Y_raw = np.hstack([stab, qc, em]).astype(np.float64)

        # add 2% noise
        Y_raw += 0.02 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # compute normalization stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-8
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-8

        # standardize and cast to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------
class RealityEditorAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics‐Informed Residual
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    # Denormalize inputs
    X_den = X * stats['X_std'] + stats['X_mean']
    dh, Gs, qctl, eps, lamc, dmc = X_den.t()
    eps0 = 1e-12

    # Compute true physics‐based targets
    stab_t = torch.abs(dh) * Gs / (qctl + eps0)
    qc_t   = qctl * torch.sqrt(Gs) / (torch.abs(dh) + eps0)
    em_t   = lamc * dmc / (torch.abs(eps) + eps0)
    Yt = torch.stack([stab_t, qc_t, em_t], dim=1)

    # Normalize true targets
    Yt_norm = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_norm)

def total_loss(pred, true, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC‐Dropout for Uncertainty
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    arr = torch.stack(preds, dim=0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lam=1.0, epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for ep in range(1, epochs+1):
        # --- Train ---
        model.train()
        train_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, mse, phys = total_loss(pred, yb, xb, stats, lam)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item() * xb.size(0)
        train_loss /= len(train_loader.dataset)

        # --- Validate ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lam)
                val_loss += loss.item() * xb.size(0)
        val_loss /= len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {ep:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        if val_loss < best_val - 1e-8:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_reality_editor_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {ep}")
                break

    # Load best model
    if os.path.exists("best_reality_editor_ai.pth"):
        model.load_state_dict(torch.load("best_reality_editor_ai.pth", map_location=device))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(hist):
    plt.figure()
    plt.plot(hist['train'], label='Train')
    plt.plot(hist['val'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.show()

def plot_scatter(y_true, y_pred, name):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.title(name); plt.xlabel("True"); plt.ylabel("Pred"); plt.show()

def plot_uncertainty(model, stats, device):
    # Vary dh vs quantum_ctrl to map stability uncertainty
    dh_vals = np.linspace(-1e-5, 1e-5, 100, dtype=np.float32)
    q_vals  = np.linspace(0, 1.0,   100, dtype=np.float32)
    Dh, Q   = np.meshgrid(dh_vals, q_vals); pts = Dh.size

    grid = torch.zeros((pts, 6), device=device)
    grid[:, 0] = torch.from_numpy(Dh.ravel()).to(device)
    grid[:, 2] = torch.from_numpy(Q.ravel()).to(device)
    # fix other dims to mean
    for i in (1,3,4,5):
        grid[:, i] = stats['X_mean'][i]

    Xn = (grid - stats['X_mean']) / stats['X_std']
    _, unc = mc_dropout_predict(model, Xn, T=100)
    U = unc[:, 0].cpu().numpy().reshape(Dh.shape)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(Dh, Q, U, cmap='plasma', shading='auto')
    plt.colorbar(label='Std Reality Stability')
    plt.xlabel('dh/h variation')
    plt.ylabel('Quantum Control')
    plt.title('Uncertainty Heatmap: Reality Stability')
    plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load dataset
    ds = RealityDataset(n_samples=5000)

    # Cast stats to float32
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std':  torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    # Split
    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=128, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=256)

    # Model & train
    model   = RealityEditorAI(input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1)
    history = train(model, tr_ld, va_ld, stats, device,
                    lr=1e-3, wd=1e-5, lam=1.0, epochs=80, patience=10)

    # Plots
    plot_losses(history)

    # Scatter true vs pred
    X_all   = torch.from_numpy(ds.X).to(device)
    with torch.no_grad():
        Yp_norm = model(X_all).cpu().numpy()
    Yp = Yp_norm * ds.Y_std + ds.Y_mean
    Yt = ds.Y * ds.Y_std + ds.Y_mean

    names = ["Reality Stability", "Quantum Consistency", "Entropy Management"]
    for i, nm in enumerate(names):
        plot_scatter(Yt[:,i], Yp[:,i], nm)

    # Uncertainty heatmap
    plot_uncertainty(model, stats, device)