<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_planetary_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_planetary_ai.py

Full pipeline for PlanetaryAI:
1. Synthetic “planetary” dataset of 6 inputs → 3 outputs
2. Standardization, noise injection
3. MLP with residual head & dropout
4. Training loop with gradient clipping, LR scheduler, checkpointing
5. Validation and loss reporting
6. Optional scatter plots of true vs. predicted
"""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split

# ------------------------------------------------------------------------------
# 1. Synthetic Planetary Dataset
# ------------------------------------------------------------------------------
class PlanetaryDataset(Dataset):
    def __init__(self, n_samples=5000, seed=42):
        torch.manual_seed(seed)
        data = torch.rand(n_samples, 6)  # CNS, CER, ICL, PRA, ECI, CDF
        CNS, CER, ICL, PRA, ECI, CDF = data.t()
        ε = 1e-6

        # Toy “planetary laws”
        coherence  = (CNS * CER) / (ICL + ε)
        efficiency = PRA / (ICL + 1.0)
        retention  = CDF * torch.exp(-ECI)

        targets = torch.stack([coherence, efficiency, retention], dim=1)
        targets += 0.01 * targets.std(0) * torch.randn_like(targets)  # 1% noise

        # Standardize
        self.X_mean, self.X_std = data.mean(0), data.std(0) + ε
        self.Y_mean, self.Y_std = targets.mean(0), targets.std(0) + ε

        self.X = ((data - self.X_mean) / self.X_std).float()
        self.Y = ((targets - self.Y_mean) / self.Y_std).float()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# ------------------------------------------------------------------------------
# 2. PlanetaryAI Model with Residuals & Dropout
# ------------------------------------------------------------------------------
class PlanetaryAI(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=32, output_dim=3, p_drop=0.1):
        super().__init__()
        self.fc1      = nn.Linear(input_dim, hidden_dim)
        self.res_head = nn.Linear(input_dim, hidden_dim)
        self.fc2      = nn.Linear(hidden_dim, hidden_dim)
        self.fc3      = nn.Linear(hidden_dim, output_dim)
        self.relu     = nn.ReLU()
        self.drop     = nn.Dropout(p_drop)

    def forward(self, x):
        h1 = self.relu(self.fc1(x))
        r  = self.res_head(x)
        h2 = self.drop(self.relu(self.fc2(h1 + r)))
        out = self.fc3(h2)
        return out

# ------------------------------------------------------------------------------
# 3. Training & Validation Helpers
# ------------------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    criterion = nn.MSELoss()

    for Xb, Yb in loader:
        Xb, Yb = Xb.to(device), Yb.to(device)
        pred = model(Xb)
        loss = criterion(pred, Yb)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * Xb.size(0)

    return total_loss / len(loader.dataset)

def validate(model, loader, device):
    model.eval()
    total_loss = 0.0
    criterion = nn.MSELoss()

    with torch.no_grad():
        for Xb, Yb in loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            total_loss += criterion(pred, Yb).item() * Xb.size(0)

    return total_loss / len(loader.dataset)

# ------------------------------------------------------------------------------
# 4. Main Execution
# ------------------------------------------------------------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    dataset = PlanetaryDataset(n_samples=5000, seed=42)
    val_size = int(0.2 * len(dataset))
    train_ds, val_ds = random_split(dataset, [len(dataset)-val_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=256, shuffle=False)

    # Model, optimizer, scheduler
    model     = PlanetaryAI().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=0.5, patience=5)

    best_val = float('inf')
    os.makedirs("checkpoints", exist_ok=True)

    # Training loop
    for epoch in range(1, 51):
        tr_loss = train_one_epoch(model, train_loader, optimizer, device)
        va_loss = validate(model, val_loader, device)
        scheduler.step(va_loss)
        print(f"Epoch {epoch:02d} | Train Loss: {tr_loss:.4f} | Val Loss: {va_loss:.4f}")

        # Save best
        if va_loss < best_val - 1e-5:
            best_val = va_loss
            torch.save(model.state_dict(), "checkpoints/planetary_best.pth")

    # Load the best model
    model.load_state_dict(torch.load("checkpoints/planetary_best.pth", map_location=device))
    print("Best model loaded.")

    # Optional: scatter plots of True vs. Predicted
    X_all, Y_all = dataset.X.to(device), dataset.Y.to(device)
    with torch.no_grad():
        Y_pred = model(X_all)

    Y_true = (Y_all * dataset.Y_std.to(device)) + dataset.Y_mean.to(device)
    Y_est  = (Y_pred * dataset.Y_std.to(device)) + dataset.Y_mean.to(device)

    for i, name in enumerate(["Coherence", "Efficiency", "Retention"]):
        plt.figure(figsize=(4,4))
        plt.scatter(Y_true[:,i].cpu(), Y_est[:,i].cpu(), s=5, alpha=0.6)
        m, M = Y_true[:,i].min(), Y_true[:,i].max()
        plt.plot([m, M], [m, M], 'r--')
        plt.title(name)
        plt.xlabel("True")
        plt.ylabel("Predicted")
        plt.tight_layout()
        plt.show()

if __name__ == "__main__":
    main()