<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_null_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_null_ai.py

End-to-end pipeline for NullAI:
1. Synthetic dataset (6 inputs → 3 targets)
2. Float32 normalization
3. MLP with dropout for uncertainty
4. Physics-informed residual enforcing toy nullification laws
5. Combined loss (MSE + residual)
6. Training loop: AdamW, lr scheduler, grad clipping, early stopping
7. Checkpointing & reload best weights
8. Visualizations: loss plot, parity plot, uncertainty histogram
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# -----------------------------------------------------------------------------
# 1. Synthetic Dataset
# -----------------------------------------------------------------------------
class NullDataset(Dataset):
    def __init__(self, n_samples=5000, seed=2025):
        np.random.seed(seed)
        # features: U, N, O, P, Q, R
        U = np.random.uniform(0.0, 1.0, (n_samples,1))
        N = np.random.uniform(0.0, 5.0, (n_samples,1))
        O = np.random.uniform(-2.0, 2.0,(n_samples,1))
        P = np.random.uniform(0.1, 10.0,(n_samples,1))
        Q = np.random.uniform(-1.0,1.0,(n_samples,1))
        R = np.random.uniform(0.0,2.0,(n_samples,1))
        X_raw = np.hstack([U, N, O, P, Q, R]).astype(np.float32)

        # toy ground truth:
        eps = 1e-6
        C = U * N / (np.abs(O) + eps)
        E = P * np.sin(Q)
        T = R * (C + E)
        Y_raw = np.hstack([C, E, T]).astype(np.float32)
        Y_raw += 0.01 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape).astype(np.float32)

        # compute stats in float32
        self.X_mean = X_raw.mean(0).astype(np.float32)
        self.X_std  = X_raw.std(0).astype(np.float32) + eps
        self.Y_mean = Y_raw.mean(0).astype(np.float32)
        self.Y_std  = Y_raw.std(0).astype(np.float32) + eps

        # normalize to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return (
            torch.from_numpy(self.X[i]),
            torch.from_numpy(self.Y[i])
        )

# -----------------------------------------------------------------------------
# 2. Model with Dropout
# -----------------------------------------------------------------------------
class NullAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=32, output_dim=3, p_drop=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p_drop),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# -----------------------------------------------------------------------------
# 3. Physics-Informed Residual
# -----------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    # denormalize inputs
    X_den = X * stats['X_std'] + stats['X_mean']
    U, N, O, P, Q, R = X_den.t()
    eps = 1e-6

    C_t = U * N / (O.abs() + eps)
    E_t = P * torch.sin(Q)
    T_t = R * (C_t + E_t)

    Y_true = torch.stack([C_t, E_t, T_t], dim=1)
    # normalize back
    Y_norm = (Y_true - stats['Y_mean']) / stats['Y_std']

    return nn.MSELoss()(pred, Y_norm)

# -----------------------------------------------------------------------------
# 4. Combined Loss
# -----------------------------------------------------------------------------
def total_loss(pred, true, X, stats, lam=1.0):
    mse = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# -----------------------------------------------------------------------------
# 5. MC-Dropout Uncertainty
# -----------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()  # keep dropout on
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    stacked = torch.stack(preds)
    return stacked.mean(0), stacked.std(0)

# -----------------------------------------------------------------------------
# 6. Training Loop
# -----------------------------------------------------------------------------
def train(model, tr_dl, va_dl, stats, device,
          lr=1e-3, weight_decay=1e-5, lam=1.0,
          epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)

    best_val = float('inf')
    wait = 0
    history = {'train': [], 'val': []}

    for epoch in range(1, epochs + 1):
        # Training
        model.train()
        running_train = 0.0
        for Xb, Yb in tr_dl:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            loss, _, _ = total_loss(pred, Yb, Xb, stats, lam)

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_train += loss.item() * Xb.size(0)
        running_train /= len(tr_dl.dataset)

        # Validation
        model.eval()
        running_val = 0.0
        with torch.no_grad():
            for Xv, Yv in va_dl:
                Xv, Yv = Xv.to(device), Yv.to(device)
                pred = model(Xv)
                loss, _, _ = total_loss(pred, Yv, Xv, stats, lam)
                running_val += loss.item() * Xv.size(0)
        running_val /= len(va_dl.dataset)

        scheduler.step(running_val)
        history['train'].append(running_train)
        history['val'].append(running_val)
        print(f"Epoch {epoch:03d} | Train {running_train:.4e} | Val {running_val:.4e}")

        # Checkpoint
        if running_val < best_val - 1e-6:
            best_val = running_val
            wait = 0
            torch.save(model.state_dict(), "best_null_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    # Load best
    model.load_state_dict(torch.load("best_null_ai.pth", map_location=device))
    return history

# -----------------------------------------------------------------------------
# 7. Visualizations
# -----------------------------------------------------------------------------
def plot_history(history):
    plt.figure()
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'],   label='Val')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def plot_parity(model, dataset, stats, device):
    model.eval()
    with torch.no_grad():
        X = torch.from_numpy(dataset.X).to(device)
        P = model(X).cpu().numpy()
    Y_true = (dataset.Y * dataset.Y_std + dataset.Y_mean)
    P_denorm = (P * dataset.Y_std + dataset.Y_mean)

    plt.figure(figsize=(6,6))
    plt.scatter(Y_true.ravel(), P_denorm.ravel(), s=4, alpha=0.5)
    m, M = Y_true.min(), Y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel("True")
    plt.ylabel("Pred")
    plt.show()

def plot_uncertainty(model, dataset, stats, device):
    X = torch.from_numpy(dataset.X).to(device)
    _, std = mc_dropout_predict(model, X, T=50)
    u = std[:,0].cpu().numpy()
    plt.hist(u, bins=30, color='teal')
    plt.xlabel('Prediction Stddev')
    plt.show()

# -----------------------------------------------------------------------------
# 8. Main Execution
# -----------------------------------------------------------------------------
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare dataset + stats
    ds = NullDataset(n_samples=5000)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std':  torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  dtype=torch.float32, device=device)
    }

    # Split & loaders
    val_count = int(len(ds) * 0.2)
    tr_ds, va_ds = random_split(ds, [len(ds) - val_count, val_count])
    tr_dl = DataLoader(tr_ds, batch_size=128, shuffle=True)
    va_dl = DataLoader(va_ds, batch_size=256, shuffle=False)

    # Model & training
    model = NullAI()
    hist = train(model, tr_dl, va_dl, stats, device)

    # Plots
    plot_history(hist)
    plot_parity(model, ds, stats, device)
    plot_uncertainty(model, ds, stats, device)