<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_wormhole_stability_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_wormhole_stability.py

End-to-end pipeline for WormholeStabilityAI:
  • Defines an MLP with Dropout for MC‐Dropout UQ
  • Synthetic dummy dataset of wormhole metric coefficients → stability corrections
  • Physics‐informed loss: MSE + curvature residual
  • Training loop with AdamW, ReduceLROnPlateau, early stopping
  • MC‐Dropout inference for uncertainty estimates
  • Scatter‐and‐residual plots of predictions
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Model Definition with Dropout
# ------------------------------------------------------------------------------
class WormholeStabilityAI(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=32, output_dim=3, dropout_p=0.1):
        super().__init__()
        self.fc1   = nn.Linear(input_dim, hidden_dim)
        self.relu  = nn.ReLU()
        self.drop1 = nn.Dropout(dropout_p)
        self.fc2   = nn.Linear(hidden_dim, hidden_dim)
        self.drop2 = nn.Dropout(dropout_p)
        self.fc3   = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h = self.relu(self.fc1(x))
        h = self.drop1(h)
        h = self.relu(self.fc2(h))
        h = self.drop2(h)
        return self.fc3(h)

# ------------------------------------------------------------------------------
# 2. Physics‐Informed Loss Components
# ------------------------------------------------------------------------------
def curvature_residual(pred, inputs):
    """
    Physics residual: pred[:,0] estimates Ricci scalar R
    Dummy true Ricci scalar ~ 0.1 * sum(inputs_i^2)
    """
    R_pred = pred[:, 0]
    R_true = 0.1 * (inputs**2).sum(dim=1)
    return nn.functional.mse_loss(R_pred, R_true)

def total_loss(pred, target, inputs, lambda_phys=0.5):
    """
    Combined loss: MSE(targets) + λ * physics residual
    """
    mse = nn.functional.mse_loss(pred, target)
    phys = curvature_residual(pred, inputs)
    return mse + lambda_phys * phys, mse, phys

# ------------------------------------------------------------------------------
# 3. Dummy Dataset
# ------------------------------------------------------------------------------
class DummyWormholeDataset(Dataset):
    def __init__(self, N=5000, input_dim=5, output_dim=3):
        super().__init__()
        np.random.seed(1)
        # Random metric coefficients in [-1,1]
        self.X = torch.from_numpy(
            np.random.uniform(-1, 1, size=(N, input_dim)).astype(np.float32)
        )
        # Dummy stability corrections and Ricci estimate
        self.Y = torch.from_numpy(
            np.random.uniform(-0.5, 0.5, size=(N, output_dim)).astype(np.float32)
        )

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# ------------------------------------------------------------------------------
# 4. MC‐Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, n_samples=100, device='cpu'):
    """
    Returns mean and std of model(x) under MC‐Dropout.
    """
    model.train()  # keep dropout on
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            out = model(x.to(device)).cpu().numpy()
            preds.append(out)
    preds = np.stack(preds, axis=0)      # (n_samples, batch, output_dim)
    mean  = preds.mean(axis=0)           # (batch, output_dim)
    std   = preds.std(axis=0)            # (batch, output_dim)
    model.eval()
    return mean, std

# ------------------------------------------------------------------------------
# 5. Data Loaders
# ------------------------------------------------------------------------------
def get_loaders(batch_size=64, val_frac=0.2):
    ds = DummyWormholeDataset(N=5000)
    n_val = int(len(ds) * val_frac)
    n_trn = len(ds) - n_val
    trn, val = random_split(ds, [n_trn, n_val])
    return (
        DataLoader(trn, batch_size=batch_size, shuffle=True),
        DataLoader(val, batch_size=batch_size)
    )

# ------------------------------------------------------------------------------
# 6. Training Loop with Early Stopping & Scheduler
# ------------------------------------------------------------------------------
def train_model(
    model, train_loader, val_loader,
    lr=1e-3, weight_decay=1e-5, lambda_phys=0.5,
    max_epochs=200, patience_max=20, device='cpu'
):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    best_val = float('inf')
    patience = 0
    history = {'train_loss': [], 'val_loss': [], 'val_phys': []}

    for epoch in range(1, max_epochs + 1):
        # Training
        model.train()
        running_tr_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            out = model(xb)
            loss, mse_l, phys_l = total_loss(out, yb, xb, lambda_phys)
            loss.backward()
            optimizer.step()
            running_tr_loss += loss.item() * xb.size(0)

        train_loss = running_tr_loss / len(train_loader.dataset)

        # Validation
        model.eval()
        running_val_loss = 0.0
        running_phys    = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                out = model(xb)
                loss, mse_l, phys_l = total_loss(out, yb, xb, lambda_phys)
                running_val_loss += loss.item() * xb.size(0)
                running_phys    += phys_l.item() * xb.size(0)

        val_loss = running_val_loss / len(val_loader.dataset)
        val_phys = running_phys / len(val_loader.dataset)
        scheduler.step(val_loss)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_phys'].append(val_phys)

        # Early stopping
        if val_loss < best_val - 1e-6:
            best_val = val_loss
            torch.save(model.state_dict(), "best_wormhole_stab.pt")
            patience = 0
        else:
            patience += 1
            if patience >= patience_max:
                print(f"Early stopping at epoch {epoch}")
                break

        if epoch % 20 == 0 or epoch == 1:
            print(f"Epoch {epoch:03d} | Train {train_loss:.4f} | Val {val_loss:.4f} | PhysRes {val_phys:.4f}")

    # Load best model
    model.load_state_dict(torch.load("best_wormhole_stab.pt"))
    return model, history

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    tr_loader, va_loader = get_loaders(batch_size=64, val_frac=0.2)

    # Build and train
    model = WormholeStabilityAI(input_dim=5, hidden_dim=32, output_dim=3, dropout_p=0.1)
    model, hist = train_model(
        model, tr_loader, va_loader,
        lr=1e-3, weight_decay=1e-5, lambda_phys=0.5,
        max_epochs=200, patience_max=20, device=device
    )

    # Plot training history
    epochs = len(hist['train_loss'])
    plt.figure(figsize=(10,4))
    plt.plot(range(1, epochs+1), hist['train_loss'], label="Train Loss")
    plt.plot(range(1, epochs+1), hist['val_loss'],   label="Val Loss")
    plt.plot(range(1, epochs+1), hist['val_phys'],   label="Val PhysRes")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()
    plt.title("WormholeStabilityAI Training History")
    plt.tight_layout()
    plt.show()

    # MC‐Dropout on a mini‐batch
    xb, yb = next(iter(va_loader))
    mean_p, std_p = mc_dropout_predict(model, xb, n_samples=100, device=device)
    print("\nMC‐Dropout predictions (first 5 samples):")
    for i in range(5):
        print(f"Input: {xb[i].cpu().numpy()}")
        print(f"  Pred Mean: {mean_p[i]}  Std: {std_p[i]}")

    # Scatter plot: channel 0 (Ricci) true vs. predicted
    plt.figure(figsize=(6,6))
    plt.scatter(
        0.1 * (xb**2).sum(dim=1).cpu().numpy(),  # pseudo‐true Ricci
        mean_p[:,0], s=20, alpha=0.5
    )
    m = min(mean_p[:,0].min(), 0.1*(xb**2).sum(dim=1).min().item())
    M = max(mean_p[:,0].max(), 0.1*(xb**2).sum(dim=1).max().item())
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel("True Ricci Scalar")
    plt.ylabel("Predicted Ricci")
    plt.title("Ricci: True vs Predicted")
    plt.grid(True)
    plt.tight_layout()
    plt.show()