<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_hyperdim_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_hyperdim_ai.py

Physics-informed AI pipeline for hyperdimensional space-time analysis,
using full double precision and a normalized physics residual.
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Double-Precision Synthetic Dataset
# ------------------------------------------------------------------------------
class HyperDimDataset(Dataset):
    def __init__(self, n_samples=6000, seed=42):
        np.random.seed(seed)
        # Features: curvature R, extra_dims, energy_density ed, momentum_flux mf,
        # torsion_scalar ts, brane_tension bt
        R  = np.random.uniform(-1e-35, 1e-35,  (n_samples,1)).astype(np.float64)
        ex = np.random.uniform(0,       7,      (n_samples,1)).astype(np.float64)
        ed = np.random.uniform(1e-27,   1e-25,  (n_samples,1)).astype(np.float64)
        mf = np.random.uniform(1e-20,   1e-18,  (n_samples,1)).astype(np.float64)
        ts = np.random.uniform(-1e-30,  1e-30,  (n_samples,1)).astype(np.float64)
        bt = np.random.uniform(1e-28,   1e-26,  (n_samples,1)).astype(np.float64)

        X_raw = np.hstack([R, ex, ed, mf, ts, bt])

        # Toy targets:
        # 1) geometric stability ≈ R^2 / (ed + ε)
        gs = R**2 / (ed + 1e-60)
        # 2) topology coherence ≈ 1 / (1 + ex * mf * 1e10)
        tc = 1.0 / (1.0 + ex * mf * 1e10)
        # 3) warp wrinkle index ≈ (ts + bt) / (ed + ε)
        ww = (ts + bt) / (ed + 1e-60)

        Y_raw = np.hstack([gs, tc, ww])
        # add small relative noise
        Y_raw += 0.02 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # compute normalization stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-12
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-12

        # standardize and convert to torch.double
        Xn = (X_raw - self.X_mean) / self.X_std
        Yn = (Y_raw - self.Y_mean) / self.Y_std
        self.X = torch.from_numpy(Xn).double()
        self.Y = torch.from_numpy(Yn).double()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


# ------------------------------------------------------------------------------
# 2. Double-Precision Model Definition
# ------------------------------------------------------------------------------
class HyperDimAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(128,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop),
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# ------------------------------------------------------------------------------
# 3. Normalized Physics-Informed Residual
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    # denormalize inputs
    X_den = X * stats['X_std'] + stats['X_mean']
    R_den, ex_den, ed_den = X_den[:,0], X_den[:,1], X_den[:,2]

    # true geometric stability: R^2/(ed + ε)
    gs_true = R_den**2 / (ed_den + 1e-60)
    # normalize gs_true
    gs_true_n = (gs_true - stats['Y_mean'][0]) / stats['Y_std'][0]

    # model's normalized prediction
    gs_pred_n = pred[:,0]
    return torch.mean((gs_pred_n - gs_true_n)**2)


def total_loss(pred, true, X, stats, lam=1.0):
    mse = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys


# ------------------------------------------------------------------------------
# 4. MC-Dropout for Uncertainty Quantification
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    arr = torch.stack(preds, dim=0)
    return arr.mean(0), arr.std(0)


# ------------------------------------------------------------------------------
# 5. Training Loop with Early Stopping & Scheduler
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lam=1.0, epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for ep in range(1, epochs+1):
        # training
        model.train()
        train_loss = 0.0
        for Xb, Yb in train_loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            loss, _, _ = total_loss(pred, Yb, Xb, stats)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item() * Xb.size(0)
        train_loss /= len(train_loader.dataset)

        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for Xb, Yb in val_loader:
                Xb, Yb = Xb.to(device), Yb.to(device)
                pred = model(Xb)
                loss, _, _ = total_loss(pred, Yb, Xb, stats)
                val_loss += loss.item() * Xb.size(0)
        val_loss /= len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {ep:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        # early stopping
        if val_loss + 1e-8 < best_val:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_hyperdim_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {ep}")
                break

    model.load_state_dict(torch.load("best_hyperdim_ai.pth"))
    return history


# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(hist):
    plt.figure()
    plt.plot(hist['train'], label='Train')
    plt.plot(hist['val'],   label='Val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()


def plot_scatter(y_true, y_pred, title):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.title(title)
    plt.xlabel("True")
    plt.ylabel("Pred")
    plt.show()


def plot_uncertainty(model, stats, device):
    # 1) build R-ex grid
    R_vals = torch.linspace(-1e-35, 1e-35, 100,
                             device=device, dtype=torch.double)
    ex_vals = torch.linspace(0, 7, 100,
                              device=device, dtype=torch.double)
    Rg, Ex = torch.meshgrid(R_vals, ex_vals, indexing='xy')
    Rf, Ef = Rg.reshape(-1), Ex.reshape(-1)

    # 2) other features at their mean
    rest_mean = stats['X_mean'][2:]                                 # shape [4]
    rest = rest_mean.unsqueeze(0).repeat(Rf.shape[0], 1)            # shape [N,4]

    # 3) assemble & standardize
    Xg = torch.cat([Rf.unsqueeze(1),
                    Ef.unsqueeze(1),
                    rest], dim=1)
    Xn = (Xg - stats['X_mean']) / stats['X_std']

    # 4) MC-dropout uncertainty
    _, std = mc_dropout_predict(model, Xn, T=100)
    std_map = std[:,0].reshape(Rg.shape).cpu().numpy()

    # 5) plot heatmap
    plt.figure()
    plt.pcolormesh(Rg.cpu().numpy(), Ex.cpu().numpy(), std_map,
                   cmap='viridis', shading='auto')
    plt.colorbar(label="Std of Geometric Stability")
    plt.xlabel("Curvature Scalar R")
    plt.ylabel("Extra Dimensions")
    plt.title("Geometric Stability Uncertainty Heatmap")
    plt.show()


# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare data
    ds = HyperDimDataset()
    n_val = int(0.2 * len(ds))
    train_ds, val_ds = random_split(ds, [len(ds)-n_val, n_val])
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=256)

    # stats for physics residual (torch.double on device)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.double, device=device),
        'X_std':  torch.tensor(ds.X_std,  dtype=torch.double, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.double, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  dtype=torch.double, device=device),
    }

    # build and train model
    model = HyperDimAI().double().to(device)
    history = train(model, train_loader, val_loader, stats, device,
                    lr=1e-3, wd=1e-5, lam=1.0, epochs=100, patience=10)

    # visualize
    plot_losses(history)

    with torch.no_grad():
        X_all = ds.X.to(device)
        Y_pred_n = model(X_all).cpu().numpy()
    Y_pred = Y_pred_n * ds.Y_std + ds.Y_mean
    Y_true = ds.Y.cpu().numpy() * ds.Y_std + ds.Y_mean

    names = ["Geom Stability", "Topo Coherence", "Warp Wrinkle"]
    for i, nm in enumerate(names):
        plot_scatter(Y_true[:,i], Y_pred[:,i], nm)

    plot_uncertainty(model, stats, device)