<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_megastructure_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_megastructure_ai.py

Physics-informed AI pipeline for MegastructureAI with NaN safety and uncertainty.

1. Synthetic dataset: material strength, density, Young’s modulus, gravity, energy input, geometry
2. Float32 normalization and dtype consistency
3. MLP with LayerNorm, Dropout, ReLU (accepts int hidden_dims)
4. Physics-informed residual with denominator clamps
5. MC-Dropout for uncertainty quantification
6. Training loop: AdamW, ReduceLROnPlateau, gradient clipping, NaN checks, early stopping
7. Safe checkpoint load
8. Visualizations: loss history, scatter, uncertainty heatmap
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split

# ------------------------------------------------------------------------------
# 1. Synthetic Megastructure Dataset
# ------------------------------------------------------------------------------
class MegastructureDataset(Dataset):
    def __init__(self, n_samples=6000, seed=0):
        np.random.seed(seed)
        # Features:
        # MS (material strength)    ∈ [1e2, 1e4]
        # Den (density)             ∈ [1000, 8000]
        # YM (Young’s modulus)      ∈ [1e9, 3e11]
        # GP (gravity pull factor)  ∈ [0.01, 2.0]
        # EI (energy input rate)    ∈ [1e5, 1e8]
        # GC (geometry complexity)  ∈ [0.1, 1.0]
        MS  = np.random.uniform(1e2,   1e4,    (n_samples,1))
        Den = np.random.uniform(1e3,   8e3,    (n_samples,1))
        YM  = np.random.uniform(1e9,   3e11,   (n_samples,1))
        GP  = np.random.uniform(1e-2,  2.0,    (n_samples,1))
        EI  = np.random.uniform(1e5,   1e8,    (n_samples,1))
        GC  = np.random.uniform(0.1,   1.0,    (n_samples,1))

        X_raw = np.hstack([MS, Den, YM, GP, EI, GC]).astype(np.float64)

        # Physics targets:
        # Structural Integrity (SI)      = YM * MS / (GP + eps)
        # Gravitational Stability (GS)   = Den * MS / (YM + eps)
        # Power Output (PO)             = EI * GC / (Den + eps)
        eps = 1e-6
        SI = YM  * MS  / (GP  + eps)
        GS = Den * MS  / (YM  + eps)
        PO = EI  * GC  / (Den + eps)

        Y_raw = np.hstack([SI, GS, PO]).astype(np.float64)
        Y_raw += 0.01 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape)

        # Compute stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-8
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-8

        # Normalize to float32
        self.X = ((X_raw - self.X_mean) / self.X_std).astype(np.float32)
        self.Y = ((Y_raw - self.Y_mean) / self.Y_std).astype(np.float32)

        print(f"Dataset X range {self.X.min():.3e}–{self.X.max():.3e}")
        print(f"Dataset Y range {self.Y.min():.3e}–{self.Y.max():.3e}")

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), torch.from_numpy(self.Y[i])

# ------------------------------------------------------------------------------
# 2. Model Definition (accepts int hidden_dims)
# ------------------------------------------------------------------------------
class MegastructureAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        # allow integer for hidden_dims
        if isinstance(hidden_dims, int):
            hidden_dims = (hidden_dims,)

        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Loss with Clamps
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    X_den = X * stats['X_std'] + stats['X_mean']
    MS, Den, YM, GP, EI, GC = X_den.t()
    eps = 1e-4

    SI_t = YM  * MS  / torch.clamp(GP  + eps, min=eps)
    GS_t = Den * MS  / torch.clamp(YM  + eps, min=eps)
    PO_t = EI  * GC  / torch.clamp(Den + eps, min=eps)

    Yt = torch.stack([SI_t, GS_t, PO_t], dim=1)
    Yt_norm = (Yt - stats['Y_mean']) / stats['Y_std']
    return nn.MSELoss()(pred, Yt_norm)

def total_loss(pred, true, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Uncertainty Estimation
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    arr = torch.stack(preds, 0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop (NaN safety, checkpoints, early stopping)
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-4, wd=1e-5, lam=1.0, epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)

    best_val, wait = float('inf'), 0
    history = {'train': [], 'val': []}

    for epoch in range(1, epochs + 1):
        # training
        model.train()
        train_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, lam)
            if torch.isnan(loss):
                print("NaN! aborting.")
                return history

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item() * xb.size(0)
        train_loss /= len(train_loader.dataset)

        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lam)
                val_loss += loss.item() * xb.size(0)
        val_loss /= len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:03d} ┃ Train {train_loss:.4e} ┃ Val {val_loss:.4e}")

        # checkpoint & early stop
        if val_loss < best_val - 1e-8:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_megastruct.pth")
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    # load best
    if os.path.exists("best_megastruct.pth"):
        model.load_state_dict(torch.load("best_megastruct.pth", map_location=device))
    else:
        print("No checkpoint; using last model.")

    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_history(hist):
    plt.plot(hist['train'], label='train')
    plt.plot(hist['val'],   label='val')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

def plot_scatter(y_true, y_pred):
    plt.scatter(y_true, y_pred, s=5)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel("True")
    plt.ylabel("Pred")
    plt.show()

def plot_uncertainty(model, stats, device):
    # vary Material Strength (idx=0) and Gravity Pull (idx=3)
    MS = np.linspace(1e2, 1e4, 80)
    GP = np.linspace(0.01, 2.0, 80)
    MSm, GPm = np.meshgrid(MS, GP)
    pts = MSm.size

    grid = torch.zeros((pts, 6), device=device)
    grid[:, 0] = torch.from_numpy(MSm.ravel()).to(device)
    grid[:, 3] = torch.from_numpy(GPm.ravel()).to(device)
    for i in (1, 2, 4, 5):
        grid[:, i] = stats['X_mean'][i]

    Xn = (grid - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout_predict(model, Xn, T=100)
    U = std[:, 0].cpu().numpy().reshape(MSm.shape)

    plt.pcolormesh(MSm, GPm, U, cmap='magma')
    plt.colorbar(label="σ (uncertainty)")
    plt.xlabel("Material Strength")
    plt.ylabel("Gravity Pull Factor")
    plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare data
    dataset = MegastructureDataset(n_samples=6000)
    stats = {
        'X_mean': torch.tensor(dataset.X_mean, dtype=torch.float32, device=device),
        'X_std' : torch.tensor(dataset.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(dataset.Y_mean, dtype=torch.float32, device=device),
        'Y_std' : torch.tensor(dataset.Y_std,  dtype=torch.float32, device=device)
    }

    train_set, val_set = random_split(dataset, [len(dataset)-1200, 1200])
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_set,   batch_size=256, shuffle=False)

    # build & train
    model = MegastructureAI(input_dim=6, hidden_dims=32, output_dim=3).to(device)
    history = train(model, train_loader, val_loader, stats, device)

    # visualize
    plot_history(history)

    X_all = torch.from_numpy(dataset.X).to(device)
    with torch.no_grad():
        Y_pred = model(X_all).cpu().numpy()
    for i in range(3):
        plot_scatter(dataset.Y[:, i], Y_pred[:, i])

    plot_uncertainty(model, stats, device)