<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_reality_construct_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_reality_construct_ai.py

End-to-end pipeline for RealityConstructAI:
1. Synthetic dataset of 6 inputs → 3 targets
2. Float32 normalization and dtype consistency
3. MLP with LayerNorm, Dropout & ReLU
4. Physics-informed residual enforcing toy universe laws
5. MC-Dropout for uncertainty quantification
6. Training loop with AdamW, ReduceLROnPlateau, gradient clipping, NaN checks, early stopping
7. Safe checkpoint loading
8. Visualizations: training history, true vs. predicted scatter, uncertainty map
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split

# ------------------------------------------------------------------------------
# 1. Synthetic Universe Dataset
# ------------------------------------------------------------------------------
class UniverseDataset(Dataset):
    def __init__(self, n_samples=5000, seed=123):
        np.random.seed(seed)
        # Inputs in double precision
        REX = np.random.uniform(0.5, 2.0, (n_samples,1))
        SCL = np.random.rand(n_samples,1)
        ENT = np.random.uniform(0.1, 5.0, (n_samples,1))
        A   = np.random.uniform(-3, 3,     (n_samples,1))
        B   = np.random.uniform(-3, 3,     (n_samples,1))
        C   = np.random.uniform(-3, 3,     (n_samples,1))

        X_raw = np.hstack([REX, SCL, ENT, A, B, C]).astype(np.float64)

        # Toy universe laws
        STAB = REX**2 / (1 + ENT)
        SUST = np.exp(-A * B) * (1 + SCL)
        EXP  = (SCL + C) * np.sqrt(REX)

        Y_raw = np.hstack([STAB, SUST, EXP]).astype(np.float64)
        Y_raw += 0.01 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # Compute normalization stats in float64
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-8
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-8

        # Standardize and cast to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])


# ------------------------------------------------------------------------------
# 2. RealityConstructAI Model
# ------------------------------------------------------------------------------
class RealityConstructAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, d = [], input_dim
        for h in hidden_dims:
            layers.extend([
                nn.Linear(d, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ])
            d = h
        layers.append(nn.Linear(d, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# ------------------------------------------------------------------------------
# 3. Physics-Informed Residual
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    X_den = X * stats['X_std'] + stats['X_mean']
    REX, SCL, ENT, A, B, C = X_den.t()
    STAB_t = REX**2 / (1 + ENT)
    SUST_t = torch.exp(-A * B) * (1 + SCL)
    EXP_t  = (SCL + C) * torch.sqrt(REX)

    Yt = torch.stack([STAB_t, SUST_t, EXP_t], dim=1)
    Yt_n = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_n)


# ------------------------------------------------------------------------------
# 4. Combined Loss Function
# ------------------------------------------------------------------------------
def total_loss(pred, true, X, stats, λ=0.7):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + λ * phys, mse, phys


# ------------------------------------------------------------------------------
# 5. MC-Dropout Uncertainty
# ------------------------------------------------------------------------------
def mc_dropout(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    stack = torch.stack(preds)
    return stack.mean(dim=0), stack.std(dim=0)


# ------------------------------------------------------------------------------
# 6. Training Loop with Safety & Checkpointing
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=3e-4, wd=1e-5, λ=0.7, epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, epochs+1):
        # Training
        model.train()
        running_train = 0.0
        for Xb, Yb in train_loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            loss, _, _ = total_loss(pred, Yb, Xb, stats, λ)
            if torch.isnan(loss):
                print(f"NaN loss at epoch {epoch}, aborting.")
                return history
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_train += loss.item() * Xb.size(0)
        train_loss = running_train / len(train_loader.dataset)

        # Validation
        model.eval()
        running_val = 0.0
        with torch.no_grad():
            for Xv, Yv in val_loader:
                Xv, Yv = Xv.to(device), Yv.to(device)
                pred = model(Xv)
                l, _, _ = total_loss(pred, Yv, Xv, stats, λ)
                running_val += l.item() * Xv.size(0)
        val_loss = running_val / len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        # Checkpointing
        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_reality_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    # Load best checkpoint
    if os.path.exists("best_reality_ai.pth"):
        model.load_state_dict(
            torch.load("best_reality_ai.pth", map_location=device)
        )
    return history


# ------------------------------------------------------------------------------
# 7. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_history(history):
    plt.figure()
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'],   label='Val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()


def plot_scatter(true, pred, title):
    plt.figure()
    plt.scatter(true, pred, s=8, alpha=0.5)
    mn, mx = true.min(), true.max()
    plt.plot([mn, mx], [mn, mx], 'r--')
    plt.title(title)
    plt.show()


def plot_uncertainty(model, stats, device):
    G = 100
    REX = np.linspace(0.5, 2.0, G, dtype=np.float32)
    SCL = np.linspace(0, 1,   G, dtype=np.float32)
    R, S = np.meshgrid(REX, SCL)
    pts  = G * G

    # create grid tensor
    Xg = torch.zeros((pts, 6), device=device, dtype=torch.float32)

    # columns 2–5: denormalized fixed means
    # stats['X_mean'][2:] is already a Tensor on device
    Xg[:, 2:] = stats['X_mean'][2:].unsqueeze(0).expand(pts, 4)

    # columns 0–1: meshgrid values
    Xg[:, 0]  = torch.from_numpy(R.ravel()).to(device)
    Xg[:, 1]  = torch.from_numpy(S.ravel()).to(device)

    # normalize entire grid
    Xn = (Xg - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout(model, Xn, T=40)
    U = std[:, 0].cpu().reshape(G, G)

    plt.figure(figsize=(5,4))
    plt.pcolormesh(R, S, U, cmap='viridis', shading='auto')
    plt.colorbar(label="Std(STAB)")
    plt.xlabel("REX")
    plt.ylabel("SCL")
    plt.title("Uncertainty: Stability")
    plt.show()


# ------------------------------------------------------------------------------
# 8. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare dataset and cast stats to float32 tensors on device
    ds    = UniverseDataset(n_samples=5000, seed=123)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std' : torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std' : torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    # Split into train/val
    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=128, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=256, shuffle=False)

    # Build, train, and evaluate
    model   = RealityConstructAI().to(device)
    history = train(model, tr_ld, va_ld, stats, device)

    # Plot training history
    plot_history(history)

    # Scatter: true vs. predicted
    X_all = torch.from_numpy(ds.X).to(device)
    with torch.no_grad():
        Yp_norm = model(X_all).cpu().numpy()
    Yt = ds.Y * ds.Y_std + ds.Y_mean
    Yp = Yp_norm * ds.Y_std + ds.Y_mean

    names = ["Stability", "Sustainability", "Expansion Rate"]
    for i, nm in enumerate(names):
        plot_scatter(Yt[:, i], Yp[:, i], nm)

    # Plot uncertainty map
    plot_uncertainty(model, stats, device)