<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_beyond_state_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# train_beyond_state_ai.py

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Beyond-State Dataset
# ------------------------------------------------------------------------------
class BeyondStateDataset(Dataset):
    def __init__(self, n_samples=5000, seed=123):
        np.random.seed(seed)
        eps = 1e-8

        # Define inputs
        BSF = np.random.uniform(0.0, 10.0,  (n_samples,1))
        UTR = np.random.uniform(0.1, 5.0,   (n_samples,1))
        SNE = np.random.uniform(0.0, 1.0,   (n_samples,1))
        AWR = np.random.uniform(1,   100,   (n_samples,1))
        CCL = np.random.uniform(0.01, 2.0,  (n_samples,1))
        TSC = np.random.uniform(0.5, 10.0,  (n_samples,1))
        X_raw = np.hstack([BSF, UTR, SNE, AWR, CCL, TSC]).astype(np.float64)

        # Compute targets
        BCO = BSF * np.log1p(AWR) / (UTR + eps)
        UAW = np.sqrt(UTR * CCL)
        SNS = SNE * TSC / (BSF + eps)
        Y_raw = np.hstack([BCO, UAW, SNS]).astype(np.float64)

        # Add noise
        noise_scale = 0.02 * Y_raw.std(axis=0)
        Y_raw += np.random.randn(*Y_raw.shape) * noise_scale

        # Normalize stats
        self.X_mean, self.X_std = X_raw.mean(axis=0), X_raw.std(axis=0) + eps
        self.Y_mean, self.Y_std = Y_raw.mean(axis=0), Y_raw.std(axis=0) + eps

        # Normalize
        Xn = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        Yn = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

        # Store tensors
        self.X = torch.from_numpy(Xn)
        self.Y = torch.from_numpy(Yn)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


# ------------------------------------------------------------------------------
# 2. BeyondStateAI Model
# ------------------------------------------------------------------------------
class BeyondStateAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=32, output_dim=3, p_drop=0.1):
        super().__init__()
        self.fc1  = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(p_drop)
        self.fc2  = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)


# ------------------------------------------------------------------------------
# 3. Physics-Informed and Total Loss
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    # Denormalize inputs
    X_den = X * stats['X_std'] + stats['X_mean']
    BSF, UTR, SNE, AWR, CCL, TSC = X_den.t()
    eps = 1e-8

    # Physics‐inspired targets
    BCO_t = BSF * torch.log1p(AWR) / (UTR + eps)
    UAW_t = torch.sqrt(UTR * CCL)
    SNS_t = SNE * TSC / (BSF + eps)

    Yt   = torch.stack([BCO_t, UAW_t, SNS_t], dim=1)
    Yt_n = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_n)

def total_loss(pred, true, X, stats, λ=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + λ * phys, mse, phys


# ------------------------------------------------------------------------------
# 4. MC-Dropout Uncertainty
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    arr = torch.stack(preds, dim=0)
    return arr.mean(0), arr.std(0)


# ------------------------------------------------------------------------------
# 5. Training Loop with Early Stopping
# ------------------------------------------------------------------------------
def train(model, train_dl, val_dl, stats, device,
          lr=1e-4, wd=1e-5, λ=1.0, epochs=100, patience=10):
    model.to(device)
    optimzr = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched   = optim.lr_scheduler.ReduceLROnPlateau(optimzr, 'min', patience=5)

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for ep in range(1, epochs+1):
        # Train
        model.train()
        run_tr = 0.0
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, λ)
            optimzr.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimzr.step()
            run_tr += loss.item() * xb.size(0)
        train_loss = run_tr / len(train_dl.dataset)

        # Validate
        model.eval()
        run_va = 0.0
        with torch.no_grad():
            for xb, yb in val_dl:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                v_loss, _, _ = total_loss(pred, yb, xb, stats, λ)
                run_va += v_loss.item() * xb.size(0)
        val_loss = run_va / len(val_dl.dataset)

        sched.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {ep:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        # Checkpoint
        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_beyond_state.pth")
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    # Load best
    model.load_state_dict(torch.load("best_beyond_state.pth", map_location=device))
    return history


# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_loss(hist):
    plt.figure()
    plt.plot(hist['train'], label='Train Loss')
    plt.plot(hist['val'],   label='Val Loss')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.show()

def plot_scatter(y_true, y_pred, title):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    mn, mx = y_true.min(), y_true.max()
    plt.plot([mn, mx], [mn, mx], 'r--')
    plt.title(title); plt.show()

def plot_uncertainty(model, stats, device):
    G = 80
    BSF = np.linspace(0, 10, G, dtype=np.float32)
    UTR = np.linspace(0.1, 5, G, dtype=np.float32)
    B, U = np.meshgrid(BSF, UTR)
    pts = G * G

    Xg = torch.zeros((pts, 6), device=device)
    for i in [2, 3, 4, 5]:
        Xg[:, i] = stats['X_mean'][i]
    Xg[:, 0] = torch.from_numpy(B.ravel()).to(device)
    Xg[:, 1] = torch.from_numpy(U.ravel()).to(device)

    Xn = (Xg - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout_predict(model, Xn, T=30)
    Umap = std[:, 0].cpu().reshape(B.shape)

    plt.figure(figsize=(5,4))
    plt.pcolormesh(BSF, UTR, Umap, shading='auto', cmap='viridis')
    plt.colorbar(label="Std(Beyond Coherence)")
    plt.xlabel("Flux (BSF)"); plt.ylabel("Transcendence (UTR)")
    plt.title("Uncertainty Heatmap"); plt.show()


# ------------------------------------------------------------------------------
# 7. Main
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset & stats
    ds = BeyondStateDataset(n_samples=5000, seed=123)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std' : torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std' : torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    # Split
    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_dl = DataLoader(tr_ds, batch_size=128, shuffle=True)
    va_dl = DataLoader(va_ds, batch_size=256, shuffle=False)

    # Model and train
    model = BeyondStateAI().to(device)
    history = train(model, tr_dl, va_dl, stats, device)

    # Plot training curves
    plot_loss(history)

    # True vs predicted scatter
    X_all = ds.X.to(device)
    with torch.no_grad():
        Y_pred_n = model(X_all).cpu().numpy()
    Y_true = ds.Y * ds.Y_std + ds.Y_mean
    Y_pred = Y_pred_n * ds.Y_std + ds.Y_mean

    names = ["Beyond Coherence", "Undefined Awareness", "Nullification Stability"]
    for i, name in enumerate(names):
        plot_scatter(Y_true[:, i], Y_pred[:, i], name)

    # Uncertainty heatmap
    plot_uncertainty(model, stats, device)