<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_galactic_comm_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_galactic_comm.py

AI pipeline for Type-III civilization resource management:

1. Synthetic dataset:
   (energy, logistics, distribution, population, growth, colonization, tech)
   → (econ_balance, energy_stability, resource_efficiency)
2. Domain-informed loss enforcing econ_balance = energy*distribution/population
3. MLP with LayerNorm & Dropout
4. MC-Dropout inference for uncertainty
5. Training loop with AdamW, ReduceLROnPlateau, gradient clipping
6. Visualizations: loss curves, scatter plots, uncertainty map
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# 1. Synthetic Dataset
class GalacticDataset(Dataset):
    def __init__(self, n=8000, seed=42):
        np.random.seed(seed)
        # Features
        E = np.random.uniform(1e5, 1e8,   (n,1)).astype(np.float32)  # energy prod
        L = np.random.uniform(0.5, 1.0,   (n,1)).astype(np.float32)  # logistics eff
        R = np.random.uniform(0.1, 1.0,   (n,1)).astype(np.float32)  # distribution index
        P = np.random.uniform(1e9,1e12,   (n,1)).astype(np.float32)  # population
        G = np.random.uniform(0.001,0.02, (n,1)).astype(np.float32)  # growth rate
        C = np.random.uniform(0.0,0.1,    (n,1)).astype(np.float32)  # colonization rate
        T = np.random.uniform(1.0,10.0,   (n,1)).astype(np.float32)  # tech level

        X = np.hstack([E, L, R, P, G, C, T])

        # Targets via simple domain formulas
        econ_balance       = E * R / P
        energy_stability   = (E * L) / (P * G + 1e6)
        resource_efficiency= (R * L * T) / (C + 0.1)
        Y = np.hstack([econ_balance, energy_stability, resource_efficiency])

        # add noise
        Y += 0.01 * Y.std(axis=0) * np.random.randn(*Y.shape)

        # normalization
        self.X_mean, self.X_std = X.mean(0), X.std(0) + 1e-6
        self.Y_mean, self.Y_std = Y.mean(0), Y.std(0) + 1e-6
        self.X = (X - self.X_mean) / self.X_std
        self.Y = (Y - self.Y_mean) / self.Y_std

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (
            torch.from_numpy(self.X[idx]),
            torch.from_numpy(self.Y[idx])
        )

# 2. Model Definition
class GalacticAI(nn.Module):
    def __init__(self, in_dim=7, hid=64, out=3, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid),
            nn.LayerNorm(hid),
            nn.ReLU(),
            nn.Dropout(p),
            nn.Linear(hid, hid),
            nn.LayerNorm(hid),
            nn.ReLU(),
            nn.Dropout(p),
            nn.Linear(hid, out)
        )

    def forward(self, x):
        return self.net(x)

# 3. Domain-Informed Loss
def phys_loss(pred, X, stats):
    # Denormalize
    X_den = X * stats['X_std'] + stats['X_mean']
    E, L, R, P, G, C, T = X_den.t()
    Y_den = pred * stats['Y_std'] + stats['Y_mean']
    econ_pred = Y_den[:,0]
    econ_true = E * R / P
    return nn.MSELoss()(econ_pred, econ_true)

def total_loss(pred, truth, X, stats, lam=1.0):
    mse = nn.MSELoss()(pred, truth)
    pl  = phys_loss(pred, X, stats)
    return mse + lam*pl, mse, pl

# 4. MC-Dropout for Uncertainty Quantification
def mc_dropout(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    stacked = torch.stack(preds, dim=0)
    return stacked.mean(0), stacked.std(0)

# 5. Training Loop
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lam=1.0, max_epochs=100, patience=10):

    optimizer = optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=wd
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val = float('inf')
    wait = 0
    history = {'train_loss':[], 'val_loss':[]}

    for epoch in range(1, max_epochs+1):
        model.train()
        running_train = 0.0
        for Xb, Yb in train_loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            preds = model(Xb)
            loss, mse, pl = total_loss(preds, Yb, Xb, stats, lam)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            running_train += loss.item() * Xb.size(0)
        running_train /= len(train_loader.dataset)

        model.eval()
        running_val = 0.0
        with torch.no_grad():
            for Xb, Yb in val_loader:
                Xb, Yb = Xb.to(device), Yb.to(device)
                preds = model(Xb)
                loss, mse, pl = total_loss(preds, Yb, Xb, stats, lam)
                running_val += loss.item() * Xb.size(0)
        running_val /= len(val_loader.dataset)

        scheduler.step(running_val)
        history['train_loss'].append(running_train)
        history['val_loss'].append(running_val)
        print(f"Epoch {epoch:03d} – Train: {running_train:.4e}, Val: {running_val:.4e}")

        if running_val < best_val - 1e-6:
            best_val, wait = running_val, 0
            torch.save(model.state_dict(), 'best_galactic.pth')
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping triggered.")
                break

    model.load_state_dict(torch.load('best_galactic.pth'))
    return history

# 6. Visualization Helpers
def plot_history(hist):
    plt.plot(hist['train_loss'], label='Train')
    plt.plot(hist['val_loss'],   label='Val')
    plt.xlabel('Epoch'); plt.ylabel('Loss')
    plt.legend(); plt.show()

def scatter_compare(y_true, y_pred, name):
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m, M = y_true.min(), y_true.max()
    plt.plot([m,M],[m,M],'r--')
    plt.title(name)
    plt.xlabel('True'); plt.ylabel('Pred')
    plt.show()

def plot_uncertainty_map(model, stats, device):
    # Sweep Energy vs Population
    Es = torch.linspace(1e5,1e8,50,device=device)
    Ps = torch.linspace(1e9,1e12,50,device=device)
    Em, Pm = torch.meshgrid(Es, Ps, indexing='xy')

    # Build grid tensor
    grid = torch.zeros(Em.numel(), 7, device=device)
    grid[:,0] = Em.reshape(-1)
    grid[:,3] = Pm.reshape(-1)
    # fix others at their means
    for i in (1,2,4,5,6):
        grid[:,i] = stats['X_mean'][i]

    Xn = (grid - stats['X_mean']) / stats['X_std']
    _, std = mc_dropout(model, Xn, T=100)
    U = std[:,0].cpu().numpy().reshape(Em.shape)

    plt.pcolormesh(
        Em.cpu().numpy(),
        Pm.cpu().numpy(),
        U,
        cmap='inferno'
    )
    plt.colorbar(label='std of econ_balance')
    plt.xlabel('Energy Production'); plt.ylabel('Population')
    plt.show()

# 7. Main
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Prepare data & stats
    dataset = GalacticDataset()
    stats = {
        'X_mean': torch.tensor(dataset.X_mean, device=device),
        'X_std':  torch.tensor(dataset.X_std,  device=device),
        'Y_mean': torch.tensor(dataset.Y_mean, device=device),
        'Y_std':  torch.tensor(dataset.Y_std,  device=device),
    }

    # Split & loaders
    train_ds, val_ds = random_split(dataset, [len(dataset)-1600, 1600])
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=256)

    # Initialize & train
    model   = GalacticAI().to(device)
    history = train(
        model, train_loader, val_loader,
        stats, device,
        lr=1e-3, wd=1e-5, lam=1.0,
        max_epochs=100, patience=10
    )

    # Plot training history
    plot_history(history)

    # True vs. Predicted for each metric
    with torch.no_grad():
        X_all = torch.from_numpy(dataset.X).to(device)
        Y_pred = model(X_all).cpu().numpy()
    for i, name in enumerate(['econ_balance','energy_stability','resource_efficiency']):
        scatter_compare(dataset.Y[:,i], Y_pred[:,i], name)

    # Uncertainty heatmap
    plot_uncertainty_map(model, stats, device)