<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_meta_omniversal_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_meta_omniversal_ai.py

End-to-end pipeline for MetaOmniversalAI:
1. Synthetic “meta-omniversal” dataset of 6 inputs → 3 targets
2. Float32 normalization and dtype consistency
3. MLP with LayerNorm, Dropout & ReLU
4. Theory-informed residual enforcing toy “meta-omniversal” laws
5. MC-Dropout for uncertainty quantification
6. Training loop with AdamW, ReduceLROnPlateau, gradient clipping, NaN checks, early stopping
7. Safe checkpoint loading
8. Visualizations: training history, true vs. predicted scatter, uncertainty map
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split

# ------------------------------------------------------------------------------
# 1. Synthetic Meta-Omniversal Dataset
# ------------------------------------------------------------------------------
class MetaUniverseDataset(Dataset):
    def __init__(self, n_samples=3000, seed=42):
        np.random.seed(seed)
        # Inputs:
        # INF: Computation beyond infinity ∈ [1e6,1e9]
        # NDC: Non-Deterministic Causality ∈ [0,1]
        # QTS: Quantum-Temporal Structure ∈ [0.1,5]
        # X1,X2,X3: Auxiliary meta-parameters ∈ [−2,2]
        INF = np.random.uniform(1e6, 1e9, (n_samples,1))
        NDC = np.random.rand(n_samples,1)
        QTS = np.random.uniform(0.1, 5.0, (n_samples,1))
        X1  = np.random.uniform(-2, 2, (n_samples,1))
        X2  = np.random.uniform(-2, 2, (n_samples,1))
        X3  = np.random.uniform(-2, 2, (n_samples,1))

        X_raw = np.hstack([INF, NDC, QTS, X1, X2, X3]).astype(np.float64)

        # Targets (toy meta-omniversal laws):
        # CT  = log(INF)*QTS / (1 + NDC)
        # MES = exp(−X1*X2)/(1 + X3**2)
        # CC  = (NDC + X3) * sqrt(QTS)
        eps = 1e-9
        CT  = np.log(INF + eps) * QTS / (1 + NDC)
        MES = np.exp(-X1 * X2) / (1 + X3**2)
        CC  = (NDC + X3) * np.sqrt(QTS)

        Y_raw = np.hstack([CT, MES, CC]).astype(np.float64)
        Y_raw += 0.005 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # Normalization stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-8
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-8

        # Standardize to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

# ------------------------------------------------------------------------------
# 2. MetaOmniversalAI Model Definition
# ------------------------------------------------------------------------------
class MetaOmniversalAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.15):
        super().__init__()
        if isinstance(hidden_dims, int):
            hidden_dims = (hidden_dims,)
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Theory-Informed Residual & Total Loss
# ------------------------------------------------------------------------------
def theory_residual(pred, X, stats):
    # Denormalize inputs
    X_den = X * stats['X_std'] + stats['X_mean']
    INF, NDC, QTS, X1, X2, X3 = X_den.t()
    eps = 1e-9

    CT_t  = torch.log(INF + eps) * QTS / (1 + NDC)
    MES_t = torch.exp(-X1 * X2) / (1 + X3**2)
    CC_t  = (NDC + X3) * torch.sqrt(QTS)

    Yt = torch.stack([CT_t, MES_t, CC_t], dim=1)
    Yt_norm = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_norm)

def total_loss(pred, true, X, stats, lam=0.8):
    mse  = nn.MSELoss()(pred, true)
    phys = theory_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout for Uncertainty Quantification
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=60):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    stacked = torch.stack(preds, dim=0)
    return stacked.mean(0), stacked.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop with Safety & Checkpointing
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=2e-4, wd=1e-5, lam=0.8, epochs=80, patience=8):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=4
    )

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, epochs+1):
        # Training
        model.train()
        run_loss = 0.0
        for Xb, Yb in train_loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            loss, mse, phys = total_loss(pred, Yb, Xb, stats, lam)
            if torch.isnan(loss):
                print(f"NaN loss at epoch {epoch}, aborting.")
                return history
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            run_loss += loss.item() * Xb.size(0)
        train_loss = run_loss / len(train_loader.dataset)

        # Validation
        model.eval()
        run_loss = 0.0
        with torch.no_grad():
            for Xv, Yv in val_loader:
                Xv, Yv = Xv.to(device), Yv.to(device)
                pred = model(Xv)
                loss, _, _ = total_loss(pred, Yv, Xv, stats, lam)
                run_loss += loss.item() * Xv.size(0)
        val_loss = run_loss / len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_meta_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    if os.path.exists("best_meta_ai.pth"):
        model.load_state_dict(torch.load("best_meta_ai.pth", map_location=device))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_history(history):
    plt.figure()
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'],   label='Val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

def plot_scatter(y_true, y_pred, title):
    plt.figure()
    plt.scatter(y_true, y_pred, s=8, alpha=0.5)
    mn, mx = y_true.min(), y_true.max()
    plt.plot([mn,mx], [mn,mx], 'r--')
    plt.title(title)
    plt.show()

def plot_uncertainty_map(model, stats, device):
    grid = 80
    INF = np.linspace(1e6, 1e9, grid, dtype=np.float32)
    NDC = np.linspace(0, 1,   grid, dtype=np.float32)
    G1, G2 = np.meshgrid(INF, NDC)
    pts = grid * grid

    Xg = torch.zeros((pts,6), device=device)
    Xg[:,2:] = stats['X_mean'][2:]  # QTS,X1,X2,X3 means
    Xg[:,0] = torch.from_numpy(G1.ravel()).to(device)
    Xg[:,1] = torch.from_numpy(G2.ravel()).to(device)

    Xn = (Xg - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout_predict(model, Xn, T=60)
    U = std[:,0].cpu().reshape(G1.shape)

    plt.figure(figsize=(5,4))
    plt.pcolormesh(G1, G2, U, cmap='viridis', shading='auto')
    plt.colorbar(label="Std(CT)")
    plt.xlabel("INF")
    plt.ylabel("NDC")
    plt.title("Uncertainty: Computational Transcendence")
    plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds = MetaUniverseDataset(n_samples=3000, seed=42)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std' : torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std' : torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=128, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=256, shuffle=False)

    model = MetaOmniversalAI(input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.15).to(device)
    history = train(model, tr_ld, va_ld, stats, device)

    plot_history(history)

    X_all = torch.from_numpy(ds.X).to(device)
    with torch.no_grad():
        Yp_norm = model(X_all).cpu().numpy()
    Yp = Yp_norm * ds.Y_std + ds.Y_mean
    Yt = ds.Y * ds.Y_std + ds.Y_mean

    names = ["Computational Transcendence", "Meta-Existence Stability", "Causality Control"]
    for i, nm in enumerate(names):
        plot_scatter(Yt[:,i], Yp[:,i], nm)

    plot_uncertainty_map(model, stats, device)