<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_final_field_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# train_final_field_ai.py

"""
End-to-end pipeline for FinalFieldAI:
1. Synthetic “undefined intelligence field” dataset (6 inputs → 3 targets)
2. float32 normalization & dtype consistency
3. MLP with Dropout for MC‐Dropout uncertainty
4. Physics‐informed residual enforcing toy field law
5. Combined loss (MSE + physics residual)
6. Training loop: AdamW, ReduceLROnPlateau, gradient clipping, NaN checks, early stopping
7. Checkpointing & reload best model
8. Visualizations: loss curves, parity plot, uncertainty histogram
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# -----------------------------------------------------------------------------
# 1. Synthetic Undefined Intelligence Field Dataset
# -----------------------------------------------------------------------------
class FinalFieldDataset(Dataset):
    def __init__(self, n_samples=5000, seed=123):
        np.random.seed(seed)
        # Inputs:
        # x0 ∈ [−1,1], x1 ∈ [0,2], x2 ∈ [−2,2], x3 ∈ [0,5], x4 ∈ [−1,1], x5 ∈ [0,1]
        x0 = np.random.uniform(-1, 1, (n_samples,1))
        x1 = np.random.uniform( 0, 2, (n_samples,1))
        x2 = np.random.uniform(-2, 2, (n_samples,1))
        x3 = np.random.uniform( 0, 5, (n_samples,1))
        x4 = np.random.uniform(-1, 1, (n_samples,1))
        x5 = np.random.uniform( 0, 1, (n_samples,1))
        X_raw = np.hstack([x0, x1, x2, x3, x4, x5]).astype(np.float64)

        # Targets (toy “field law”):
        # UIF = sin(x0) * x1
        # UIC = tanh(x2) + 0.1 * x3
        # UFS = UIF * UIC + x4 * x5
        UIF = np.sin(x0) * x1
        UIC = np.tanh(x2) + 0.1 * x3
        UFS = UIF * UIC + x4 * x5
        Y_raw = np.hstack([UIF, UIC, UFS]).astype(np.float64)

        # Add small noise
        Y_raw += 0.01 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # Compute normalization stats (float64)
        self.X_mean, self.X_std = X_raw.mean(0), X_raw.std(0) + 1e-8
        self.Y_mean, self.Y_std = Y_raw.mean(0), Y_raw.std(0) + 1e-8

        # Standardize to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])


# -----------------------------------------------------------------------------
# 2. FinalFieldAI Model with Dropout
# -----------------------------------------------------------------------------
class FinalFieldAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=32, output_dim=3, p_drop=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)


# -----------------------------------------------------------------------------
# 3. Physics-Informed Residual Loss
# -----------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    # Denormalize features
    X_den = X * stats['X_std'] + stats['X_mean']
    x0, x1, x2, x3, x4, x5 = X_den.t()

    UIF = torch.sin(x0) * x1
    UIC = torch.tanh(x2) + 0.1 * x3
    UFS = UIF * UIC + x4 * x5

    Yt = torch.stack([UIF, UIC, UFS], dim=1)
    Yt_n = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_n)


# -----------------------------------------------------------------------------
# 4. Combined Loss Function
# -----------------------------------------------------------------------------
def total_loss(pred, true, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys


# -----------------------------------------------------------------------------
# 5. MC-Dropout Uncertainty Estimation
# -----------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()  # keep dropout active
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X).unsqueeze(0))
    stack = torch.cat(preds, dim=0)  # [T, batch, out]
    return stack.mean(dim=0), stack.std(dim=0)


# -----------------------------------------------------------------------------
# 6. Training Loop with Checkpointing & Early Stopping
# -----------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lam=1.0, epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, epochs+1):
        # Training phase
        model.train()
        run_tr = 0.0
        for Xb, Yb in train_loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            loss, _, _ = total_loss(pred, Yb, Xb, stats, lam)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            run_tr += loss.item() * Xb.size(0)
        tr_loss = run_tr / len(train_loader.dataset)

        # Validation phase
        model.eval()
        run_val = 0.0
        with torch.no_grad():
            for Xv, Yv in val_loader:
                Xv, Yv = Xv.to(device), Yv.to(device)
                pred = model(Xv)
                l, _, _ = total_loss(pred, Yv, Xv, stats, lam)
                run_val += l.item() * Xv.size(0)
        val_loss = run_val / len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(tr_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {tr_loss:.4e} | Val {val_loss:.4e}")

        # Checkpoint & early stopping
        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_final_field_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    # Load best
    model.load_state_dict(torch.load("best_final_field_ai.pth",
                                     map_location=device))
    return history


# -----------------------------------------------------------------------------
# 7. Visualization Helpers
# -----------------------------------------------------------------------------
def plot_history(history):
    plt.figure()
    plt.plot(history['train'], label='Train Loss')
    plt.plot(history['val'],   label='Val Loss')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.show()

def plot_parity(model, dataset, stats, device):
    X_all = torch.from_numpy(dataset.X).to(device)
    with torch.no_grad():
        pred_n = model(X_all).cpu().numpy()
    Y_true = dataset.Y * dataset.Y_std + dataset.Y_mean
    Y_pred = pred_n * dataset.Y_std + dataset.Y_mean

    plt.figure(figsize=(6,6))
    plt.scatter(Y_true.ravel(), Y_pred.ravel(), s=4, alpha=0.5)
    mn, mx = Y_true.min(), Y_true.max()
    plt.plot([mn, mx], [mn, mx], 'r--')
    plt.xlabel("True"); plt.ylabel("Predicted"); plt.title("Parity Plot")
    plt.show()

def plot_uncertainty_hist(model, dataset, stats, device):
    X_all = torch.from_numpy(dataset.X).to(device)
    mean, std = mc_dropout_predict(model, X_all, T=50)
    u = std[:,0].detach().cpu().numpy()
    plt.figure()
    plt.hist(u, bins=30, color='skyblue')
    plt.xlabel("Std Dev of UIF"); plt.title("Uncertainty Distribution (UIF)")
    plt.show()


# -----------------------------------------------------------------------------
# 8. Main Execution
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare dataset & stats
    ds = FinalFieldDataset(n_samples=5000, seed=123)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std' : torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std' : torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    # Split & loaders
    n_val = int(0.2 * len(ds))
    train_ds, val_ds = random_split(ds, [len(ds)-n_val, n_val])
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False)

    # Build, train, evaluate
    model   = FinalFieldAI().to(device)
    history = train(model, train_loader, val_loader, stats, device)

    # Visualize
    plot_history(history)
    plot_parity(model, ds, stats, device)
    plot_uncertainty_hist(model, ds, stats, device)