<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_and_analyze_undefined_presence_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib seaborn umap-learn

In [None]:
#!/usr/bin/env python3
"""
train_and_analyze_undefined_presence_ai.py

1. Synthetic dataset (6 → 3)
2. MC‐Dropout model
3. Physics‐ and ODE‐informed residual losses
4. AdamW training with scheduler, clipping, early stop
5. MC‐Dropout inference (mean/std)
6. Loss curves, OOD calibration, reliability diagram
7. UMAP embedding colored by true presence
8. Physics residual histogram
"""

import os
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.autograd import grad
import umap

# 1. Reproducibility & Device
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Dataset Definition
class UndefinedPresenceDataset(Dataset):
    def __init__(self, n=5000):
        u = np.random.uniform(-1, 1, (n, 1))
        v = np.random.uniform(0, 2, (n, 1))
        w = np.random.uniform(-2, 2, (n, 1))
        x = np.random.uniform(0, 5, (n, 1))
        y = np.random.uniform(-1, 1, (n, 1))
        z = np.random.uniform(0, 1, (n, 1))
        X = np.hstack([u, v, w, x, y, z]).astype(np.float32)

        # Physics definitions
        presence = np.sin(u) * v + np.cos(w)
        dissolution = np.exp(-x * y)
        transcendence = z * (presence + dissolution)
        Y = np.hstack([presence, dissolution, transcendence]).astype(np.float32)
        Y += 0.01 * Y.std(axis=0) * np.random.randn(*Y.shape).astype(np.float32)

        # Normalization stats
        self.X_mean, self.X_std = X.mean(0), X.std(0) + 1e-8
        self.Y_mean, self.Y_std = Y.mean(0), Y.std(0) + 1e-8

        # Standardize
        self.X = ((X - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y - self.Y_mean) / self.Y_std).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), torch.from_numpy(self.Y[i])

# 3. MC‐Dropout Model
class UndefinedPresenceAI(nn.Module):
    def __init__(self, inp=6, hid=32, out=3, p_drop=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp, hid),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hid, out)
        )

    def forward(self, x):
        return self.net(x)

# 4a. Physics residual
def physics_residual(pred, X, stats):
    X_den = X * stats['X_std'] + stats['X_mean']
    u, v, w, x, y, z = X_den.T
    presence = torch.sin(u) * v + torch.cos(w)
    dissolution = torch.exp(-x * y)
    transc = z * (presence + dissolution)
    Y_phys = torch.stack([presence, dissolution, transc], dim=1)
    Yn = (Y_phys - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yn)

# 4b. ODE‐informed residual (∂presence/∂u ≈ v⋅cos(u))
def ode_residual(pred, X, stats):
    # denormalize output
    pred_den = pred * stats['Y_std'] + stats['Y_mean']
    pres = pred_den[:, 0]

    # differentiate presence w.r.t. input X
    grads = grad(pres.sum(), X, create_graph=True)[0]
    dp_du = grads[:, 0]

    # compute analytic target
    X_den = X * stats['X_std'] + stats['X_mean']
    u = X_den[:, 0]
    v = X_den[:, 1]
    target = v * torch.cos(u)

    return nn.MSELoss()(dp_du, target)

# 4c. Combined loss
def total_loss(pred, y_true, X, stats, lam_phys=1.0, lam_ode=0.5):
    mse = nn.MSELoss()(pred, y_true)
    phys = physics_residual(pred, X, stats)
    ode = ode_residual(pred, X, stats)
    return mse + lam_phys * phys + lam_ode * ode, mse, phys, ode

# 5. MC‐Dropout Prediction
def mc_predict(model, X, T=50):
    model.train()               # keep dropout on
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    stacked = torch.stack(preds)
    return stacked.mean(0), stacked.std(0)

# 6. Training Loop
def train(model, dl_tr, dl_va, stats,
          epochs=100, lr=1e-3, wd=1e-5, patience=10):
    model.to(DEVICE)
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sch = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.5, patience=5)
    best_val = float('inf')
    wait = 0
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(1, epochs + 1):
        # ——— Training ———
        model.train()
        running_tr = 0.0
        for Xb, Yb in dl_tr:
            # ensure Xb requires grad for ode_residual
            Xb = Xb.to(DEVICE).clone().detach().requires_grad_(True)
            Yb = Yb.to(DEVICE)

            pred = model(Xb)
            loss, _, _, _ = total_loss(pred, Yb, Xb, stats)
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()

            running_tr += loss.item() * Xb.size(0)

        train_loss = running_tr / len(dl_tr.dataset)
        history['train_loss'].append(train_loss)

        # ——— Validation ———
        model.eval()
        running_va = 0.0
        # compute residuals (we need grad w.r.t X here too)
        for Xv, Yv in dl_va:
            Xv = Xv.to(DEVICE).clone().detach().requires_grad_(True)
            Yv = Yv.to(DEVICE)
            pred = model(Xv)
            loss, _, _, _ = total_loss(pred, Yv, Xv, stats)
            running_va += loss.item() * Xv.size(0)

        val_loss = running_va / len(dl_va.dataset)
        history['val_loss'].append(val_loss)

        sch.step(val_loss)
        print(f"Epoch {epoch:3d} | Train {train_loss:.4f} | Val {val_loss:.4f}")

        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    model.load_state_dict(torch.load("best_model.pth"))
    return history, model

# 7. Main Execution & Analysis
if __name__ == "__main__":
    ds = UndefinedPresenceDataset()
    n_val = int(0.2 * len(ds))
    ds_tr, ds_va = random_split(ds, [len(ds) - n_val, n_val])
    dl_tr = DataLoader(ds_tr, batch_size=128, shuffle=True)
    dl_va = DataLoader(ds_va, batch_size=256)

    stats = {
        'X_mean': torch.tensor(ds.X_mean, device=DEVICE),
        'X_std':  torch.tensor(ds.X_std,  device=DEVICE),
        'Y_mean': torch.tensor(ds.Y_mean, device=DEVICE),
        'Y_std':  torch.tensor(ds.Y_std,  device=DEVICE),
    }

    model = UndefinedPresenceAI().to(DEVICE)
    history, model = train(model, dl_tr, dl_va, stats)

    # Save losses
    np.savez("training_history.npz",
             train_loss=np.array(history['train_loss']),
             val_loss=np.array(history['val_loss']))

    # In‐distribution MC‐Dropout
    X_all = torch.from_numpy(ds.X).to(DEVICE)
    mean_pred, std_pred = mc_predict(model, X_all)

    # 8. Plotting
    os.makedirs("plots", exist_ok=True)

    # 8.1 Loss curves
    plt.figure()
    plt.plot(history['train_loss'], label="Train Loss")
    plt.plot(history['val_loss'],   label="Val Loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()
    plt.title("Training & Validation Loss")
    plt.savefig("plots/loss_curves.png", dpi=150)

    # 8.2 OOD Uncertainty
    def sample_ood(n=2000):
        u = np.random.uniform(-2,2,(n,1))
        v = np.random.uniform(0,2,(n,1))
        w = np.random.uniform(-2,2,(n,1))
        x = np.random.uniform(0,5,(n,1))
        y = np.random.uniform(-1,1,(n,1))
        z = np.random.uniform(0,1,(n,1))
        X = np.hstack([u,v,w,x,y,z]).astype(np.float32)
        X_std = (X - ds.X_mean) / ds.X_std
        return torch.from_numpy(X_std).to(DEVICE)

    X_ood = sample_ood()
    _, std_ood = mc_predict(model, X_ood)
    plt.figure()
    sns.kdeplot(std_ood[:,0].cpu(), label="Presence STD")
    sns.kdeplot(std_ood[:,1].cpu(), label="Dissolution STD")
    sns.kdeplot(std_ood[:,2].cpu(), label="Transcendence STD")
    plt.title("OOD Uncertainty Distributions")
    plt.legend()
    plt.savefig("plots/ood_uncertainty.png", dpi=150)

    # 8.3 Reliability Diagram
    errors = (mean_pred - torch.from_numpy(ds.Y).to(DEVICE)).abs().cpu().numpy()
    stds   = std_pred.cpu().numpy()

    def reliability_plot(err, std, label):
        bins = np.linspace(std.min(), std.max(), 10)
        ids = np.digitize(std, bins) - 1
        avg_err, avg_std = [], []
        for i in range(len(bins)):
            mask = ids == i
            if mask.sum() > 0:
                avg_err.append(err[mask].mean())
                avg_std.append(std[mask].mean())
        plt.plot(avg_std, avg_err, '-o', label=label)

    plt.figure()
    reliability_plot(errors[:,0], stds[:,0], "Presence")
    reliability_plot(errors[:,1], stds[:,1], "Dissolution")
    reliability_plot(errors[:,2], stds[:,2], "Transcendence")
    plt.xlabel("Avg Pred STD"); plt.ylabel("Avg Abs Error")
    plt.title("Reliability Diagram")
    plt.legend()
    plt.savefig("plots/reliability.png", dpi=150)

    # 8.4 UMAP
    model.eval()
    with torch.no_grad():
        feats = model.net[0](X_all).cpu().numpy()
    emb = umap.UMAP(n_components=2, random_state=SEED).fit_transform(feats)
    raw = ds.X * ds.X_std + ds.X_mean
    u, v, w = raw[:,0], raw[:,1], raw[:,2]
    true_presence = np.sin(u)*v + np.cos(w)
    plt.figure(figsize=(6,5))
    sc = plt.scatter(emb[:,0], emb[:,1], c=true_presence, cmap="coolwarm", s=4)
    plt.colorbar(sc, label="True Presence")
    plt.title("UMAP of Hidden Features")
    plt.savefig("plots/umap_presence.png", dpi=150)

    # 8.5 Physics Residual Histogram
    with torch.no_grad():
        X_den = X_all * stats['X_std'] + stats['X_mean']
        u, v, w, x, y, z = X_den.T
        phys_pres = torch.sin(u)*v + torch.cos(w)
        phys_dis  = torch.exp(-x*y)
        phys_tr   = z*(phys_pres+phys_dis)
        Y_phys = torch.stack([phys_pres, phys_dis, phys_tr], dim=1).cpu().numpy()

    pred_den = mean_pred.cpu().numpy() * ds.Y_std + ds.Y_mean
    residuals = ((pred_den - Y_phys)**2).mean(axis=1)
    plt.figure()
    sns.histplot(residuals, bins=50, kde=True)
    plt.title("Physics Residual (MSE) Histogram")
    plt.xlabel("Residual"); plt.ylabel("Count")
    plt.savefig("plots/physics_residual_hist.png", dpi=150)

    print("Done. Plots saved in ./plots/")