<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_spacetime_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib

In [None]:
#!/usr/bin/env python3
"""
train_spacetime_ai.py

Physics-informed AI pipeline for space-time curvature control:

1. Generate synthetic dataset of 6 inputs → 3 curvature outputs
2. Normalize with 1D mean/std vectors
3. Build MLP with LayerNorm & Dropout
4. Physics-informed loss enforcing R ≈ κ·(energy_density - tension)
5. MC-Dropout for uncertainty estimation
6. Training loop with AdamW, LR scheduler, early stopping
7. Plot losses, scatter, and uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Dataset
# ------------------------------------------------------------------------------
class SpaceTimeDataset(Dataset):
    def __init__(self, n_samples=6000, seed=1):
        np.random.seed(seed)
        # Inputs: quantum_fluc, exotic_density, energy_density,
        # tension_anisotropy, curvature_perturb, control_field
        qf  = np.random.uniform(0.0, 1.0, (n_samples,1)).astype(np.float32)
        ed  = np.random.uniform(0.1, 10.0,(n_samples,1)).astype(np.float32)
        en  = np.random.uniform(0.5, 20.0,(n_samples,1)).astype(np.float32)
        ta  = np.random.uniform(0.0, 1.0, (n_samples,1)).astype(np.float32)
        cp  = np.random.uniform(-0.5,0.5,(n_samples,1)).astype(np.float32)
        cf  = np.random.uniform(0.0, 5.0, (n_samples,1)).astype(np.float32)

        X_raw = np.hstack([qf, ed, en, ta, cp, cf])

        # Targets: local_curvature, global_curvature, warp_factor
        # local_curvature ∝ κ*(energy_density - tension_anisotropy)
        κ = 8.0 * np.pi
        lc = (κ * (en - ta)).astype(np.float32)
        # global_curvature ∝ control_field / (1 + exotic_density)
        gc = (cf / (1.0 + ed)).astype(np.float32)
        # warp_factor ∝ exp(-quantum_fluc * curvature_perturb)
        wf = np.exp(- qf * np.abs(cp)).astype(np.float32)

        Y_raw = np.hstack([lc, gc, wf])
        Y_raw += 0.01 * np.random.randn(*Y_raw.shape).astype(np.float32)

        # 1D normalization stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-6
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-6

        # normalized tensors
        self.X = (X_raw - self.X_mean) / self.X_std
        self.Y = (Y_raw - self.Y_mean) / self.Y_std

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])


# ------------------------------------------------------------------------------
# 2. Model Architecture
# ------------------------------------------------------------------------------
class SpaceTimeAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim_in = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim_in, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim_in = h
        layers.append(nn.Linear(dim_in, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# ------------------------------------------------------------------------------
# 3. Physics-Informed Loss
# ------------------------------------------------------------------------------
def physics_residual(pred, inp, stats):
    """
    Enforce an Einstein-like relation:
    R_pred ≈ κ*(energy_density - tension_anisotropy)
    """
    # Denormalize
    X = inp * stats['X_std'] + stats['X_mean']
    en = X[:,2]
    ta = X[:,3]
    κ = 8.0 * np.pi

    # Predicted local curvature
    lc_pred = pred[:,0] * stats['Y_std'][0] + stats['Y_mean'][0]
    lc_phys = κ * (en - ta)
    return torch.mean((lc_pred - lc_phys)**2)

def total_loss(pred, true, inp, stats, lambda_phys=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, inp, stats)
    return mse + lambda_phys * phys, mse, phys


# ------------------------------------------------------------------------------
# 4. MC-Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, n_samples=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds, axis=0)
    return arr.mean(0), arr.std(0)


# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lambda_phys=1.0,
          epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, epochs + 1):
        # Training step
        model.train()
        total_train = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, lambda_phys)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train += loss.item() * xb.size(0)
        avg_train = total_train / len(train_loader.dataset)

        # Validation step
        model.eval()
        total_val = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lambda_phys)
                total_val += loss.item() * xb.size(0)
        avg_val = total_val / len(val_loader.dataset)

        scheduler.step(avg_val)
        history['train'].append(avg_train)
        history['val'].append(avg_val)
        print(f"Epoch {epoch:03d} | Train {avg_train:.4e} | Val {avg_val:.4e}")

        # Early stopping
        if avg_val < best_val - 1e-6:
            best_val, wait = avg_val, 0
            torch.save(model.state_dict(), "best_spacetime_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    model.load_state_dict(torch.load("best_spacetime_ai.pth"))
    return history


# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(history):
    plt.figure()
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'],   label='Val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training & Validation Loss")
    plt.tight_layout()
    plt.show()

def plot_scatter(true_vals, pred_vals, name):
    plt.figure()
    plt.scatter(true_vals, pred_vals, s=5, alpha=0.5)
    m, M = true_vals.min(), true_vals.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel(f"True {name}")
    plt.ylabel(f"Pred {name}")
    plt.title(f"{name}: True vs Pred")
    plt.tight_layout()
    plt.show()

def plot_uncertainty_heatmap(model, stats_np, device):
    # Vary energy_density vs control_field, fix others at mean
    vals = np.linspace(0.1, 20.0, 100)
    ED, CF = np.meshgrid(vals, vals)
    grid = np.zeros((ED.size, 6), dtype=np.float32)
    grid[:,2] = ED.ravel()  # energy_density
    grid[:,5] = CF.ravel()  # control_field
    # fix quantum_fluc, exotic_density, tension, curvature_perturb
    for i in [0,1,3,4]:
        grid[:,i] = stats_np['X_mean'][i]

    Xn = (grid - stats_np['X_mean']) / stats_np['X_std']
    Xt = torch.from_numpy(Xn).float().to(device)
    _, std = mc_dropout_predict(model, Xt, n_samples=50)
    std_map = std[:,0].reshape(ED.shape)  # uncertainty in local_curvature

    plt.figure(figsize=(6,5))
    plt.pcolormesh(vals, vals, std_map, cmap='viridis', shading='auto')
    plt.colorbar(label="Std(local_curvature)")
    plt.xlabel("Energy Density")
    plt.ylabel("Control Field")
    plt.title("Uncertainty Heatmap for Local Curvature")
    plt.tight_layout()
    plt.show()


# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare dataset
    ds = SpaceTimeDataset(n_samples=8000)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, device=device),
        'X_std':  torch.tensor(ds.X_std,  device=device),
        'Y_mean': torch.tensor(ds.Y_mean, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  device=device),
    }
    stats_np = {
        'X_mean': ds.X_mean,
        'X_std':  ds.X_std,
        'Y_mean': ds.Y_mean,
        'Y_std':  ds.Y_std,
    }

    # Split into train/validation
    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds) - n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=64, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=128)

    # Build and train model
    model   = SpaceTimeAI().to(device)
    history = train(model, tr_ld, va_ld, stats, device)

    # Visualizations
    plot_losses(history)

    # Scatter plots for outputs
    X_all = torch.from_numpy(ds.X).float().to(device)
    with torch.no_grad():
        Y_pred_n = model(X_all).cpu().numpy()
    Y_true_n = ds.Y
    Y_pred = Y_pred_n * ds.Y_std + ds.Y_mean
    Y_true = Y_true_n * ds.Y_std + ds.Y_mean
    names = ["Local Curvature", "Global Curvature", "Warp Factor"]
    for i, nm in enumerate(names):
        plot_scatter(Y_true[:,i], Y_pred[:,i], nm)

    # Uncertainty heatmap
    plot_uncertainty_heatmap(model, stats_np, device)