<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_stellar_engine_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_stellar_engine.py

Physics‐informed AI pipeline for Stellar Engine Optimization:

1. Synthetic dataset of 6 stellar engine parameters → 3 efficiency & stability metrics
2. Overflow‐safe normalization
3. Physics‐informed residual enforcing analytic extraction, stability, and stress laws
4. MLP with LayerNorm & Dropout for robustness
5. MC‐Dropout inference for uncertainty quantification
6. Training loop with AdamW, ReduceLROnPlateau, gradient clipping, early stopping
7. Visualizations: training curves, scatter plots, uncertainty heatmap
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Stellar Engine Dataset
# ------------------------------------------------------------------------------
class StellarEngineDataset(Dataset):
    def __init__(self, n_samples=8000, seed=0):
        np.random.seed(seed)
        # Inputs:
        # M_s ∈ [0.1,10] (solar masses),
        # L_rel ∈ [0.5,2] (luminosity),
        # A ∈ [1e3,1e5] (mirror area km²),
        # Rf ∈ [0.5,1.0] (reflectivity),
        # d ∈ [0.1,5.0] (AU),
        # wobble ∈ [0.0,0.1]
        M_s    = np.random.uniform(0.1,10.0,(n_samples,1)).astype(np.float64)
        L_rel  = np.random.uniform(0.5,2.0,  (n_samples,1)).astype(np.float64)
        A      = np.random.uniform(1e3,1e5,(n_samples,1)).astype(np.float64)
        Rf     = np.random.uniform(0.5,1.0, (n_samples,1)).astype(np.float64)
        d      = np.random.uniform(0.1,5.0, (n_samples,1)).astype(np.float64)
        wobble = np.random.uniform(0.0,0.1, (n_samples,1)).astype(np.float64)

        X_raw = np.hstack([M_s, L_rel, A, Rf, d, wobble])

        # Analytical targets:
        # 1. extraction_efficiency ≈ Rf * L_rel * A / (4π d²)
        ext = (Rf * L_rel * A / (4 * np.pi * d**2))
        # 2. stability_factor ≈ Rf * A / (M_s * d)
        stab = (Rf * A / (M_s * d))
        # 3. thermal_stress ≈ L_rel * (1 - Rf) / A
        stress = (L_rel * (1 - Rf) / A)

        Y_raw = np.hstack([ext, stab, stress]).astype(np.float64)
        Y_raw += 0.02 * np.std(Y_raw, axis=0) * np.random.randn(*Y_raw.shape)

        # Overflow‐safe normalization
        self.X_mean = X_raw.mean(0)
        with np.errstate(over='ignore'):
            self.X_std = X_raw.std(0) + 1e-6
        self.Y_mean = Y_raw.mean(0)
        with np.errstate(over='ignore'):
            self.Y_std = Y_raw.std(0) + 1e-6

        # Standardize and cast back to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------
class StellarEngineAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics‐Informed Residual
# ------------------------------------------------------------------------------
def physics_residual(pred, inp, stats):
    # Denormalize
    X_den = inp * stats['X_std'] + stats['X_mean']
    Y_den = pred * stats['Y_std'] + stats['Y_mean']
    ext_pred, stab_pred, stress_pred = Y_den.t()
    M_s, L_rel, A, Rf, d, wobble = X_den.t()

    ext_true = Rf * L_rel * A / (4 * torch.pi * d**2)
    stab_true = Rf * A / (M_s * d)
    stress_true = L_rel * (1 - Rf) / A

    loss_ext    = nn.MSELoss()(ext_pred,    ext_true)
    loss_stab   = nn.MSELoss()(stab_pred,   stab_true)
    loss_stress = nn.MSELoss()(stress_pred, stress_true)

    return loss_ext + loss_stab + loss_stress

def total_loss(pred, true, inp, stats, λ=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, inp, stats)
    return mse + λ * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC‐Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(x))
    arr  = torch.stack(preds, dim=0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_ld, val_ld, stats, device,
          lr=1e-3, wd=1e-5, λ=1.0, max_epochs=150, patience=10):
    model.to(device)
    opt   = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.5, patience=5)
    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, max_epochs+1):
        # Training
        model.train()
        tr_loss = 0.0
        for xb, yb in train_ld:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, λ)
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            tr_loss += loss.item() * xb.size(0)
        tr_loss /= len(train_ld.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_ld:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, λ)
                val_loss += loss.item() * xb.size(0)
        val_loss /= len(val_ld.dataset)

        sched.step(val_loss)
        history['train'].append(tr_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train {tr_loss:.4e} | Val {val_loss:.4e}")

        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_stellar_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    model.load_state_dict(torch.load("best_stellar_ai.pth"))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(history):
    plt.figure()
    plt.plot(history['train'], label='Train')
    plt.plot(history['val'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.legend(); plt.title("Training Curve"); plt.show()

def plot_scatter(y_true, y_pred, name):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m, M = y_true.min(), y_true.max()
    plt.plot([m,M],[m,M],'r--')
    plt.xlabel(f"True {name}"); plt.ylabel(f"Pred {name}")
    plt.title(name); plt.show()

def plot_uncertainty_heatmap(model, stats_np, device):
    # Vary mirror area vs distance, fix others at mean
    A_vals = np.linspace(1e3,1e5,100)
    d_vals = np.linspace(0.1,5.0,100)
    AA, dd = np.meshgrid(A_vals, d_vals)
    grid = np.zeros((AA.size,6), dtype=np.float32)
    grid[:,2] = AA.ravel()
    grid[:,4] = dd.ravel()
    for i in (0,1,3,5):
        grid[:,i] = stats_np['X_mean'][i].astype(np.float32)

    Xn = torch.from_numpy(grid).float().to(device)
    _, std = mc_dropout_predict(model, Xn, T=100)
    std_ext = std[:,0].cpu().numpy().reshape(AA.shape)

    plt.figure(figsize=(6,5))
    plt.pcolormesh(AA, dd, std_ext, cmap='viridis', shading='auto')
    plt.colorbar(label='Std(Extraction Efficiency)')
    plt.xlabel('Mirror Area (km²)'); plt.ylabel('Distance (AU)')
    plt.title('Uncertainty Heatmap: Efficiency'); plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ds = StellarEngineDataset(n_samples=8000)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, device=device),
        'X_std' : torch.tensor(ds.X_std,  device=device),
        'Y_mean': torch.tensor(ds.Y_mean, device=device),
        'Y_std' : torch.tensor(ds.Y_std,  device=device),
    }
    stats_np = {
        'X_mean': ds.X_mean, 'X_std': ds.X_std,
        'Y_mean': ds.Y_mean, 'Y_std': ds.Y_std,
    }

    n_val = int(0.2 * len(ds))
    tr_ds, va_ds = random_split(ds, [len(ds)-n_val, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=128, shuffle=True)
    va_ld = DataLoader(va_ds, batch_size=256)

    model = StellarEngineAI().to(device)
    history = train(model, tr_ld, va_ld, stats, device)

    plot_losses(history)

    X_all = torch.from_numpy(ds.X).float().to(device)
    with torch.no_grad():
        Yp = model(X_all).cpu().numpy()
    Yt = ds.Y
    names = ['Extraction Efficiency','Stability Factor','Thermal Stress']
    for i, nm in enumerate(names):
        plot_scatter(Yt[:,i], Yp[:,i], nm)

    plot_uncertainty_heatmap(model, stats_np, device)