<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_meta_existence_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_meta_existence_ai.py

End-to-end pipeline for MetaExistenceAI:
1. Synthetic “meta-existence” dataset of 6 inputs → 3 targets
2. float32 normalization & dtype consistency
3. MLP with LayerNorm, Dropout & ReLU (hidden_dims accepts int or tuple)
4. Physics-informed residual enforcing toy meta-laws
5. MC-Dropout for uncertainty quantification
6. Training loop with AdamW, ReduceLROnPlateau, grad clipping, NaN checks, early stopping
7. Safe checkpoint loading
8. Visualizations: loss curves, scatter plots, uncertainty heatmap
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split

# -----------------------------------------------------------------------------
# 1. Synthetic Meta-Existence Dataset
# -----------------------------------------------------------------------------
class MetaExistenceDataset(Dataset):
    def __init__(self, n_samples=6000, seed=123):
        np.random.seed(seed)
        # Inputs:
        # AE: Absolute existence flux ∈ [0.5, 5.0]
        # MCE: Meta-consciousness equilibrium ∈ [0,1]
        # SOM: Self-sustaining omnipresence ∈ [0.1, 10]
        # DP: Dimensional potential ∈ [0,1]
        # OM: Ontological magnitude ∈ [1e2,1e5]
        # CE: Consciousness entropy ∈ [0.01,2.0]
        AE  = np.random.uniform(0.5, 5.0,   (n_samples,1))
        MCE = np.random.rand(n_samples,1)
        SOM = np.random.uniform(0.1, 10.0,  (n_samples,1))
        DP  = np.random.rand(n_samples,1)
        OM  = np.random.uniform(1e2, 1e5,   (n_samples,1))
        CE  = np.random.uniform(0.01,2.0,   (n_samples,1))

        X_raw = np.hstack([AE, MCE, SOM, DP, OM, CE]).astype(np.float64)

        # Toy meta-existence targets:
        # BES: beyond-existence stability = AE * MCE / (SOM + eps)
        # MSC: meta-state coherence = (MCE + DP) * OM
        # AA: absolute autonomy = SOM * AE / (CE + eps)
        eps = 1e-6
        BES = AE * MCE / (SOM + eps)
        MSC = (MCE + DP) * OM
        AA  = SOM * AE / (CE + eps)

        Y_raw = np.hstack([BES, MSC, AA]).astype(np.float64)
        # add 2% relative noise
        Y_raw += 0.02 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # compute stats in float64
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-8
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-8

        # normalize & cast to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])


# -----------------------------------------------------------------------------
# 2. MetaExistenceAI Model Definition
# -----------------------------------------------------------------------------
class MetaExistenceAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64, 64),
                 output_dim=3, p_drop=0.2):
        super().__init__()
        if isinstance(hidden_dims, int):
            hidden_dims = (hidden_dims,)
        layers, d = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(d, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            d = h
        layers.append(nn.Linear(d, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# -----------------------------------------------------------------------------
# 3. Physics-Informed Residual Loss
# -----------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    # denormalize features
    X_den = X * stats['X_std'] + stats['X_mean']
    AE, MCE, SOM, DP, OM, CE = X_den.t()
    eps = 1e-6

    BES_t = AE * MCE / (SOM + eps)
    MSC_t = (MCE + DP) * OM
    AA_t  = SOM * AE / (CE + eps)

    Yt = torch.stack([BES_t, MSC_t, AA_t], dim=1)
    Yt_norm = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_norm)


# -----------------------------------------------------------------------------
# 4. Total Loss
# -----------------------------------------------------------------------------
def total_loss(pred, true, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys


# -----------------------------------------------------------------------------
# 5. MC-Dropout Uncertainty
# -----------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=30):
    model.train()  # enable dropout
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    arr = torch.stack(preds, dim=0)
    return arr.mean(dim=0), arr.std(dim=0)


# -----------------------------------------------------------------------------
# 6. Training Loop
# -----------------------------------------------------------------------------
def train(model, tr_loader, va_loader, stats, device,
          lr=1e-4, wd=1e-6, lam=0.5,
          epochs=80, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, epochs + 1):
        # training
        model.train()
        train_loss = 0.0
        for Xb, Yb in tr_loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            loss, _, _ = total_loss(pred, Yb, Xb, stats, lam)
            if torch.isnan(loss):
                print("NaN detected, stopping.")
                return history
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item() * Xb.size(0)
        train_loss /= len(tr_loader.dataset)

        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for Xb, Yb in va_loader:
                Xb, Yb = Xb.to(device), Yb.to(device)
                pred = model(Xb)
                loss, _, _ = total_loss(pred, Yb, Xb, stats, lam)
                val_loss += loss.item() * Xb.size(0)
        val_loss /= len(va_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        # checkpoint & early stopping
        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_meta_existence_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    # restore best
    if os.path.exists("best_meta_existence_ai.pth"):
        model.load_state_dict(torch.load("best_meta_existence_ai.pth",
                                         map_location=device))
    return history


# -----------------------------------------------------------------------------
# 7. Visualization Helpers
# -----------------------------------------------------------------------------
def plot_history(hist):
    plt.figure(figsize=(6,4))
    plt.plot(hist['train'], label='Train')
    plt.plot(hist['val'],   label='Val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

def plot_scatter(y_true, y_pred, title):
    plt.figure(figsize=(5,5))
    plt.scatter(y_true, y_pred, s=6, alpha=0.5)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.title(title)
    plt.show()

def plot_uncertainty_heatmap(model, stats, device):
    # vary AE vs MCE
    grid = 80
    AE  = np.linspace(0.5, 5.0, grid, dtype=np.float32)
    MCE = np.linspace(0.0, 1.0, grid, dtype=np.float32)
    G1, G2 = np.meshgrid(AE, MCE)
    pts = grid * grid

    Xg = torch.zeros((pts, 6), device=device, dtype=torch.float32)
    # fix other dims at mean
    Xg[:,2:] = stats['X_mean'][2:].unsqueeze(0).expand(pts,4)
    Xg[:,0] = torch.from_numpy(G1.ravel()).to(device)
    Xg[:,1] = torch.from_numpy(G2.ravel()).to(device)

    Xn = (Xg - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout_predict(model, Xn, T=50)
    U = std[:,0].cpu().reshape(G1.shape)

    plt.figure(figsize=(5,4))
    plt.pcolormesh(G1, G2, U, shading='auto', cmap='viridis')
    plt.colorbar(label="Std(BES)")
    plt.xlabel("Absolute Existence Flux (AE)")
    plt.ylabel("Meta-Consciousness Equilibrium (MCE)")
    plt.title("Uncertainty: Beyond-Existence Stability")
    plt.show()


# -----------------------------------------------------------------------------
# 8. Main Execution
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = MetaExistenceDataset(n_samples=6000, seed=123)
    stats = {
        'X_mean': torch.tensor(dataset.X_mean, dtype=torch.float32, device=device),
        'X_std' : torch.tensor(dataset.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(dataset.Y_mean, dtype=torch.float32, device=device),
        'Y_std' : torch.tensor(dataset.Y_std,  dtype=torch.float32, device=device),
    }

    # split
    n_val = int(0.2 * len(dataset))
    train_ds, val_ds = random_split(dataset, [len(dataset)-n_val, n_val])
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False)

    # instantiate & train
    model = MetaExistenceAI(hidden_dims=(128,64), p_drop=0.2).to(device)
    history = train(model, train_loader, val_loader, stats, device)

    # plots
    plot_history(history)

    # full-dataset scatter
    X_all = torch.from_numpy(dataset.X).to(device)
    with torch.no_grad():
        Yp_norm = model(X_all).cpu().numpy()
    Y_true = dataset.Y * dataset.Y_std + dataset.Y_mean
    Y_pred = Yp_norm * dataset.Y_std + dataset.Y_mean
    names = ["Beyond-Existence Stability", "Meta-State Coherence", "Absolute Autonomy"]
    for i, name in enumerate(names):
        plot_scatter(Y_true[:,i], Y_pred[:,i], name)

    # uncertainty map
    plot_uncertainty_heatmap(model, stats, device)