<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_terraforming_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_terraforming_ai.py

Physics‐informed AI pipeline for TerraformingAI:

1. Synthetic dataset: (CO2, radiation, pressure, temp, water_vapor, soil)
   → (O2_stability, greenhouse_balance, habitability_index)
2. Normalize features and targets
3. MLP with LayerNorm & Dropout
4. Physics‐informed residual enforcing O2 ≈ rad·(1−CO2)·0.21
5. MC‐Dropout inference for uncertainty quantification
6. Training loop with AdamW (keywords), ReduceLROnPlateau, gradient clipping, early stopping
7. Visualizations: loss curves, scatter plots, uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Terraforming Dataset
# ------------------------------------------------------------------------------
class TerraformingDataset(Dataset):
    def __init__(self, n_samples=6000, seed=0):
        np.random.seed(seed)
        CO2  = np.random.uniform(0.01, 0.1,   (n_samples,1)).astype(np.float32)
        rad  = np.random.uniform(0.7,  1.4,   (n_samples,1)).astype(np.float32)
        pres = np.random.uniform(0.3,  2.0,   (n_samples,1)).astype(np.float32)
        temp = np.random.uniform(200,  320,   (n_samples,1)).astype(np.float32)
        wv   = np.random.uniform(0.0,  0.05,  (n_samples,1)).astype(np.float32)
        soil = np.random.uniform(0.1,  0.6,   (n_samples,1)).astype(np.float32)

        X_raw = np.hstack([CO2, rad, pres, temp, wv, soil])

        # Targets
        O2_stab = rad * (1 - CO2) * 0.21
        gb      = CO2 * temp / pres
        hab     = (O2_stab * soil) / (temp + 1e-6)
        Y_raw   = np.hstack([O2_stab, gb, hab]).astype(np.float32)
        Y_raw  += 0.02 * Y_raw.std(axis=0) * np.random.randn(*Y_raw.shape).astype(np.float32)

        # 1D normalization stats
        self.X_mean = X_raw.mean(axis=0)
        self.X_std  = X_raw.std(axis=0) + 1e-6
        self.Y_mean = Y_raw.mean(axis=0)
        self.Y_std  = Y_raw.std(axis=0) + 1e-6

        # Standardize
        self.X = (X_raw - self.X_mean) / self.X_std
        self.Y = (Y_raw - self.Y_mean) / self.Y_std

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (
            torch.from_numpy(self.X[idx]),
            torch.from_numpy(self.Y[idx]),
        )

# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------
class TerraformingAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dims=(64,64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers, dim = [], input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop),
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics‐Informed Residual
# ------------------------------------------------------------------------------
def physics_residual(pred, inp, stats):
    # Denormalize inputs & predictions
    X_den = inp * stats['X_std'] + stats['X_mean']
    CO2, rad = X_den[:,0], X_den[:,1]
    Y_den = pred * stats['Y_std'] + stats['Y_mean']
    O2_pred = Y_den[:,0]

    # True O2 stability law
    O2_true = rad * (1 - CO2) * 0.21
    return torch.mean((O2_pred - O2_true)**2)

def total_loss(pred, true, inp, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, true)
    phys = physics_residual(pred, inp, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC‐Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(x))
    arr = torch.stack(preds, dim=0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lam=1.0, epochs=100, patience=10):
    model.to(device)
    optimizer = optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=wd
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train_loss': [], 'val_loss': []}

    for ep in range(1, epochs+1):
        # Train
        model.train()
        t_loss = 0.0
        for Xb, Yb in train_loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            loss, _, _ = total_loss(pred, Yb, Xb, stats, lam)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            t_loss += loss.item() * Xb.size(0)
        t_loss /= len(train_loader.dataset)

        # Validate
        model.eval()
        v_loss = 0.0
        with torch.no_grad():
            for Xb, Yb in val_loader:
                Xb, Yb = Xb.to(device), Yb.to(device)
                pred = model(Xb)
                loss, _, _ = total_loss(pred, Yb, Xb, stats, lam)
                v_loss += loss.item() * Xb.size(0)
        v_loss /= len(val_loader.dataset)

        scheduler.step(v_loss)
        history['train_loss'].append(t_loss)
        history['val_loss'].append(v_loss)
        print(f"Epoch {ep:03d} | Train {t_loss:.4e} | Val {v_loss:.4e}")

        # Early stopping
        if v_loss < best_val - 1e-6:
            best_val, wait = v_loss, 0
            torch.save(model.state_dict(), "best_terraform_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {ep}")
                break

    # Load best
    model.load_state_dict(torch.load("best_terraform_ai.pth"))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(hist):
    plt.figure()
    plt.plot(hist['train_loss'], label='Train')
    plt.plot(hist['val_loss'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.title("Loss Curves")
    plt.show()

def plot_scatter(y_true, y_pred, name):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel("True"); plt.ylabel("Pred"); plt.title(name)
    plt.show()

def plot_uncertainty(model, stats, device):
    co2_vals = torch.linspace(0.01, 0.1, 100, device=device)
    rad_vals = torch.linspace(0.7,  1.4, 100, device=device)
    CO2, RAD = torch.meshgrid(co2_vals, rad_vals, indexing='xy')

    CO2f = CO2.reshape(-1)
    RADf = RAD.reshape(-1)
    # Build grid with means for other features
    means = [stats['X_mean'][i] for i in (2,3,4,5)]
    others = torch.stack([torch.full_like(CO2f, m) for m in means], dim=1)
    grid = torch.cat([CO2f.unsqueeze(1), RADf.unsqueeze(1), others], dim=1)

    # Normalize
    Xn = (grid - stats['X_mean']) / stats['X_std']

    # MC‐Dropout
    _, std = mc_dropout_predict(model, Xn, T=100)
    std_map = std[:,0].reshape(CO2.shape).cpu().numpy()

    plt.figure()
    plt.pcolormesh(
        CO2.cpu().numpy(),
        RAD.cpu().numpy(),
        std_map,
        cmap='viridis',
        shading='auto'
    )
    plt.colorbar(label="Std of O2 Stability")
    plt.xlabel("CO2 Fraction")
    plt.ylabel("Solar Radiation Factor")
    plt.title("O2 Stability Uncertainty")
    plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare dataset and stats
    ds = TerraformingDataset(n_samples=6000)
    stats = {
        'X_mean': torch.tensor(ds.X_mean, dtype=torch.float32, device=device),
        'X_std':  torch.tensor(ds.X_std,  dtype=torch.float32, device=device),
        'Y_mean': torch.tensor(ds.Y_mean, dtype=torch.float32, device=device),
        'Y_std':  torch.tensor(ds.Y_std,  dtype=torch.float32, device=device),
    }

    # Split and loaders
    n_val = int(0.2 * len(ds))
    train_ds, val_ds = random_split(ds, [len(ds)-n_val, n_val])
    train_loader = DataLoader(train_ds, batch_size=64,  shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=128, shuffle=False)

    # Build, train, and visualize
    model   = TerraformingAI(input_dim=6, hidden_dims=(64,64), output_dim=3).to(device)
    history = train(model, train_loader, val_loader, stats, device,
                    lr=1e-3, wd=1e-5, lam=1.0, epochs=100, patience=10)

    plot_losses(history)

    # Scatter True vs Predicted (denormalized)
    with torch.no_grad():
        X_all  = torch.from_numpy(ds.X).to(device)
        Y_pred = model(X_all).cpu().numpy() * ds.Y_std + ds.Y_mean
    Y_true = ds.Y * ds.Y_std + ds.Y_mean
    names = ["O2 Stability", "Greenhouse Balance", "Habitability Index"]
    for i, nm in enumerate(names):
        plot_scatter(Y_true[:,i], Y_pred[:,i], nm)

    # Uncertainty heatmap for O2 stability
    plot_uncertainty(model, stats, device)