<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_space_time_network_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_space_time_network_ai.py

Physics-informed AI pipeline for SpaceTimeNetworkAI with NaN safety.

1. Synthetic dataset: wormhole stability, warp field, temporal distortions, exotic energy, cosmic flux, spatial curvature
2. Float32 normalization and dtype consistency
3. MLP with LayerNorm, Dropout, ReLU
4. Physics-informed residual with denominator clamps
5. MC-Dropout for uncertainty quantification
6. Training loop: AdamW, ReduceLROnPlateau, gradient clipping, NaN checks, early stopping
7. Safe checkpoint load
8. Visualizations: losses, scatter, uncertainty heatmap
"""

import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split

# ------------------------------------------------------------------------------
# 1. Synthetic Space-Time Network Dataset
# ------------------------------------------------------------------------------
class SpaceTimeNetworkDataset(Dataset):
    def __init__(self, n_samples=6000, seed=0):
        np.random.seed(seed)
        # Features:
        # WS ∈ [0.5,1.5], WF ∈ [1e2,1e4], TD ∈ [0,0.1], EE ∈ [1e-9,1e-7], CF ∈ [0,1], SC ∈ [0.1,1]
        WS = np.random.uniform(0.5, 1.5,   (n_samples,1))
        WF = np.random.uniform(1e2, 1e4,   (n_samples,1))
        TD = np.random.uniform(0.0, 0.1,   (n_samples,1))
        EE = np.random.uniform(1e-9,1e-7,  (n_samples,1))
        CF = np.random.uniform(0.0, 1.0,   (n_samples,1))
        SC = np.random.uniform(0.1, 1.0,   (n_samples,1))

        X_raw = np.hstack([WS, WF, TD, EE, CF, SC]).astype(np.float64)

        # Physics targets:
        # GS = WS * CF / (TD + eps)
        # OC = SC * WF / (EE + eps)
        # MEC = EE * WF * SC
        eps = 1e-6
        GS  = WS * CF / (TD + eps)
        OC  = SC * WF / (EE + eps)
        MEC = EE * WF * SC

        Y_raw = np.hstack([GS, OC, MEC]).astype(np.float64)
        Y_raw += 0.01 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # Stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-8
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-8

        # Normalize to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

        print(f"Dataset X range {self.X.min():.3e}–{self.X.max():.3e}")
        print(f"Dataset Y range {self.Y.min():.3e}–{self.Y.max():.3e}")

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), torch.from_numpy(self.Y[i])

# ------------------------------------------------------------------------------
# 2. Model Definition with hidden_dims int-handling
# ------------------------------------------------------------------------------
class SpaceTimeNetworkAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64, 64), output_dim=3, p_drop=0.1):
        super().__init__()
        # Allow an int for hidden_dims
        if isinstance(hidden_dims, int):
            hidden_dims = (hidden_dims,)

        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h

        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Loss with Clamps
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    X_den = X * stats['X_std'] + stats['X_mean']
    WS, WF, TD, EE, CF, SC = X_den.t()
    eps = 1e-4
    GS_t  = WS * CF / torch.clamp(TD + eps, min=eps)
    OC_t  = SC * WF / torch.clamp(EE + eps, min=eps)
    MEC_t = EE * WF * SC

    Yt = torch.stack([GS_t, OC_t, MEC_t], dim=1)
    Yt_norm = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_norm)

def total_loss(pred, true, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Uncertainty
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    arr = torch.stack(preds, 0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop with NaN Protection & Safe Checkpoint
# ------------------------------------------------------------------------------
def train(model, tr_ld, va_ld, stats, device,
          lr=1e-4, wd=1e-5, lam=1.0, epochs=100, patience=10):
    model.to(device)
    opt   = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, epochs + 1):
        # Training step
        model.train()
        train_loss = 0.0
        for xb, yb in tr_ld:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, lam)
            if torch.isnan(loss):
                print("NaN loss detected. Aborting training.")
                return history

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            train_loss += loss.item() * xb.size(0)

        train_loss /= len(tr_ld.dataset)

        # Validation step
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in va_ld:
                xb, yb = xb.to(device), yb.to(device)
                pred    = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lam)
                val_loss += loss.item() * xb.size(0)

        val_loss /= len(va_ld.dataset)
        sched.step(val_loss)

        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:02d} ┃ Train: {train_loss:.4e} ┃ Val: {val_loss:.4e}")

        # Checkpoint & early stop
        if val_loss < best_val - 1e-8:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_netai.pth")
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping triggered.")
                break

    # Load best model if exists
    if os.path.exists("best_netai.pth"):
        model.load_state_dict(torch.load("best_netai.pth", map_location=device))
    else:
        print("No checkpoint found; using last epoch model.")

    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_hist(history):
    plt.plot(history['train'], label='train')
    plt.plot(history['val'],   label='val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

def plot_scatter(y_true, y_pred):
    plt.scatter(y_true, y_pred, s=5)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel("True")
    plt.ylabel("Pred")
    plt.show()

def plot_uncertainty(model, stats, device):
    WS = np.linspace(0.5, 1.5, 80)
    TD = np.linspace(0.0, 0.1, 80)
    WSm, TDm = np.meshgrid(WS, TD)
    pts = WSm.size

    grid = torch.zeros((pts, 6), device=device)
    # Vary WS (idx 0) and TD (idx 2), fix others at mean
    grid[:, 0] = torch.from_numpy(WSm.ravel()).to(device)
    grid[:, 2] = torch.from_numpy(TDm.ravel()).to(device)
    for i in (1, 3, 4, 5):
        grid[:, i] = stats['X_mean'][i]

    # Normalize
    Xn = (grid - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout_predict(model, Xn, T=100)
    U = std[:, 0].cpu().numpy().reshape(WSm.shape)

    plt.pcolormesh(WSm, TDm, U, cmap='magma')
    plt.colorbar(label="Uncertainty (σ)")
    plt.xlabel("Wormhole Stability (WS)")
    plt.ylabel("Temporal Distortion (TD)")
    plt.show()

# ------------------------------------------------------------------------------
# 7. Run Everything
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    dataset = SpaceTimeNetworkDataset(n_samples=6000)
    stats = {
        'X_mean': torch.tensor(dataset.X_mean, dtype=torch.float32, device=device),
        'X_std' : torch.tensor(dataset.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(dataset.Y_mean, dtype=torch.float32, device=device),
        'Y_std' : torch.tensor(dataset.Y_std,  dtype=torch.float32, device=device)
    }

    train_set, val_set = random_split(dataset, [len(dataset) - 1200, 1200])
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_set,   batch_size=256, shuffle=False)

    # Build & train
    model = SpaceTimeNetworkAI(input_dim=6, hidden_dims=32, output_dim=3).to(device)
    history = train(model, train_loader, val_loader, stats, device)

    # Visualize results
    plot_hist(history)

    X_all = torch.from_numpy(dataset.X).to(device)
    with torch.no_grad():
        Y_pred = model(X_all).cpu().numpy()

    for i in range(3):
        plot_scatter(dataset.Y[:, i], Y_pred[:, i])

    plot_uncertainty(model, stats, device)