<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/wormhole_ai_full_pipeline_fixed_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch umap-learn matplotlib scikit-learn

In [None]:
#!/usr/bin/env python3
"""
wormhole_ai_full_pipeline_fixed.py

Complete pipeline for wormhole stability AI, with fixes:
  - Monte Carlo Dropout UQ (no_grad + detach)
  - Physics‐informed residual loss placeholder
  - UMAP of hidden activations (cleared to avoid mismatch)
  - Batch‐sweep surface plotting
"""

import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import umap
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Reproducibility & Device
# ------------------------------------------------------------------------------
SEED    = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE  = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------------------------------------------------------------
# 2. Monte Carlo Dropout Model Definition
# ------------------------------------------------------------------------------
class WormholeAI_UQ(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, drop_p=0.1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.do1 = nn.Dropout(drop_p)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.do2 = nn.Dropout(drop_p)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.do1(x)
        x = torch.relu(self.fc2(x))
        x = self.do2(x)
        return self.fc3(x)

# ------------------------------------------------------------------------------
# 3. Synthetic Data Generator (replace with real data)
# ------------------------------------------------------------------------------
def generate_synthetic_data(n_samples, input_dim, output_dim):
    X = torch.randn(n_samples, input_dim)
    y_base = torch.sin(X.sum(dim=1, keepdim=True))
    y = y_base.repeat(1, output_dim) + 0.1 * torch.randn(n_samples, output_dim)
    return X, y

# ------------------------------------------------------------------------------
# 4. Physics‐informed Residual (placeholder)
# ------------------------------------------------------------------------------
def compute_einstein_residual(x, u):
    # x: [B, input_dim], u: [B, output_dim]
    # TODO: Replace with actual PDE residual using autograd
    return torch.zeros_like(u)

# ------------------------------------------------------------------------------
# 5. Hyperparameters & Data Loaders
# ------------------------------------------------------------------------------
INPUT_DIM   = 5
HIDDEN_DIM  = 32
OUTPUT_DIM  = 3
BATCH_SIZE  = 64
LR          = 1e-3
MAX_EPOCHS  = 100
LAMBDA_PDE  = 0.1

# Generate & split data
X, y = generate_synthetic_data(10000, INPUT_DIM, OUTPUT_DIM)
X_tr, X_val, y_tr, y_val = train_test_split(
    X.numpy(), y.numpy(), test_size=0.2, random_state=SEED
)

train_ds      = TensorDataset(torch.from_numpy(X_tr).float(),
                              torch.from_numpy(y_tr).float())
val_ds        = TensorDataset(torch.from_numpy(X_val).float(),
                              torch.from_numpy(y_val).float())

train_loader  = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader    = DataLoader(val_ds,   batch_size=BATCH_SIZE)

# ------------------------------------------------------------------------------
# 6. Instantiate Model, Loss, Optimizer, Scheduler
# ------------------------------------------------------------------------------
model     = WormholeAI_UQ(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)
mse_loss  = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=10
)

# ------------------------------------------------------------------------------
# 7. Hook for UMAP: capture hidden activations from fc2
# ------------------------------------------------------------------------------
activations = []
def hook_fn(module, inp, out):
    activations.append(out.detach().cpu().numpy())

model.fc2.register_forward_hook(hook_fn)

# ------------------------------------------------------------------------------
# 8. Training Loop with Physics Residual
# ------------------------------------------------------------------------------
best_val = float("inf")

for epoch in range(1, MAX_EPOCHS + 1):
    model.train()
    train_loss = 0.0

    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        pred   = model(xb)

        pde_res = compute_einstein_residual(xb, pred)
        loss    = mse_loss(pred, yb) + LAMBDA_PDE * mse_loss(pde_res, torch.zeros_like(pde_res))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            pred   = model(xb)
            pde_res= compute_einstein_residual(xb, pred)
            loss   = mse_loss(pred, yb) + LAMBDA_PDE * mse_loss(pde_res, torch.zeros_like(pde_res))
            val_loss += loss.item() * xb.size(0)

    val_loss /= len(val_loader.dataset)
    scheduler.step(val_loss)

    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), "best_wormhole_uq.pt")

    print(f"Epoch {epoch:03d} | Train MSE: {train_loss:.4f} | Val MSE: {val_loss:.4f}")

print(f"\nTraining complete. Best Val MSE: {best_val:.4f}")

# ------------------------------------------------------------------------------
# 9. Monte Carlo Dropout UQ on a sample input (fixed)
# ------------------------------------------------------------------------------
model.load_state_dict(torch.load("best_wormhole_uq.pt"))
model.train()  # retain dropout layers

sample = torch.randn(1, INPUT_DIM).to(DEVICE)
with torch.no_grad():
    mc_runs = torch.stack([model(sample) for _ in range(100)], dim=0)

uq_mean = mc_runs.mean(0).cpu().numpy()
uq_var  = mc_runs.var(0).cpu().numpy()

print("UQ Mean Corrections:", uq_mean)
print("UQ Variance Corrections:", uq_var)

# ------------------------------------------------------------------------------
# 10. UMAP Visualization of Hidden Activations (fixed)
# ------------------------------------------------------------------------------
# Clear activations from training/validation logging
activations.clear()

# Re‐run model on validation loader to populate activations once
with torch.no_grad():
    for xb, _ in val_loader:
        model(xb.to(DEVICE))

# Stack all hidden vectors and gather labels
hid    = np.vstack(activations)   # shape (N_val, hidden_dim)
labels = y_val[:, 0]              # shape (N_val,)

# UMAP projection
um = umap.UMAP(n_components=2, random_state=SEED)
proj = um.fit_transform(hid)

# Scatter plot
plt.figure(figsize=(6, 5))
plt.scatter(proj[:, 0], proj[:, 1],
            c=labels, cmap="Spectral", s=5)
plt.colorbar(label="True First Correction")
plt.title("UMAP of Hidden Representations")
plt.tight_layout()
plt.savefig("umap_hidden.png")
print("Saved UMAP plot → umap_hidden.png")

# ------------------------------------------------------------------------------
# 11. Batch‐Sweep Surface Plot for First Output Dimension
# ------------------------------------------------------------------------------
grid_n = 100
dim0   = np.linspace(-2, 2, grid_n)
dim1   = np.linspace(-2, 2, grid_n)
G0, G1 = np.meshgrid(dim0, dim1)

# Build input grid
grid      = np.zeros((grid_n * grid_n, INPUT_DIM), dtype=np.float32)
grid[:, 0] = G0.ravel()
grid[:, 1] = G1.ravel()

# Predict corrections
with torch.no_grad():
    preds = model(torch.from_numpy(grid).to(DEVICE)).cpu().numpy()[:, 0]

Z = preds.reshape(grid_n, grid_n)

# Contour plot
plt.figure(figsize=(6, 5))
cs = plt.contourf(G0, G1, Z, levels=50, cmap="viridis")
plt.colorbar(cs, label="Correction[0]")
plt.xlabel("Input Dim 0")
plt.ylabel("Input Dim 1")
plt.title("Stability Correction Surface (dim0 vs dim1)")
plt.tight_layout()
plt.savefig("surface_plot.png")
print("Saved surface plot → surface_plot.png")