<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_wormhole_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# train_wormhole_ai.py

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# ------------------------------------------------------------------------------
# 1. Define the WormholeAI model
# ------------------------------------------------------------------------------
class WormholeAI(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1  = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2  = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

# ------------------------------------------------------------------------------
# 2. Synthetic data generator
#    Replace this with your real curvature→stability dataset
# ------------------------------------------------------------------------------
def generate_synthetic_data(n_samples, input_dim, output_dim):
    X = torch.randn(n_samples, input_dim)
    # Toy target: sin(sum(inputs)) + Gaussian noise
    y_base = torch.sin(X.sum(dim=1, keepdim=True))
    y = y_base.repeat(1, output_dim) + 0.1 * torch.randn(n_samples, output_dim)
    return X, y

# ------------------------------------------------------------------------------
# 3. Hyperparameters & device
# ------------------------------------------------------------------------------
INPUT_DIM   = 5    # number of wormhole metric coefficients
HIDDEN_DIM  = 16
OUTPUT_DIM  = 3    # number of stability corrections
BATCH_SIZE  = 64
LR          = 1e-3
MAX_EPOCHS  = 100
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT  = "best_wormhole_ai.pt"

# ------------------------------------------------------------------------------
# 4. Prepare data loaders
# ------------------------------------------------------------------------------
# Train / validation split
X_train, y_train = generate_synthetic_data(8000, INPUT_DIM, OUTPUT_DIM)
X_val,   y_val   = generate_synthetic_data(2000, INPUT_DIM, OUTPUT_DIM)

train_ds = TensorDataset(X_train, y_train)
val_ds   = TensorDataset(X_val,   y_val)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE)

# ------------------------------------------------------------------------------
# 5. Instantiate model, loss, optimizer, scheduler
# ------------------------------------------------------------------------------
model     = WormholeAI(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=10
)

# ------------------------------------------------------------------------------
# 6. Inference before training (random init)
# ------------------------------------------------------------------------------
sample = torch.randn(1, INPUT_DIM).to(DEVICE)
model.eval()
with torch.no_grad():
    rand_out = model(sample).cpu().detach().numpy()
print("Random-init corrections:", rand_out)

# ------------------------------------------------------------------------------
# 7. Training loop with best-model checkpointing
# ------------------------------------------------------------------------------
best_val_mse = float("inf")

for epoch in range(1, MAX_EPOCHS + 1):
    # — Train —
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        loss = criterion(model(xb), yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)

    # — Validate —
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            val_loss += criterion(model(xb), yb).item() * xb.size(0)
    val_loss /= len(val_loader.dataset)

    # LR scheduler step
    scheduler.step(val_loss)

    # Checkpoint best model
    if val_loss < best_val_mse:
        best_val_mse = val_loss
        torch.save(model.state_dict(), CHECKPOINT)

    print(
        f"Epoch {epoch:03d} | "
        f"Train MSE: {train_loss:.4f} | "
        f"Val MSE: {val_loss:.4f}"
    )

print(f"\nTraining complete. Best Val MSE: {best_val_mse:.6f}")

# ------------------------------------------------------------------------------
# 8. Inference after training (loaded checkpoint)
# ------------------------------------------------------------------------------
model.load_state_dict(torch.load(CHECKPOINT))
model.eval()
with torch.no_grad():
    best_out = model(sample).cpu().detach().numpy()
print("Post-training corrections:", best_out)

# ------------------------------------------------------------------------------
# 9. (Optional) Batch inference example
# ------------------------------------------------------------------------------
# B = 10
# batch_samples = torch.randn(B, INPUT_DIM).to(DEVICE)
# with torch.no_grad():
#     batch_out = model(batch_samples).cpu().numpy()
# print("Batch corrections:", batch_out)

# ------------------------------------------------------------------------------
# 10. Next steps suggestions
# ------------------------------------------------------------------------------
# • Swap in real wormhole curvature→stability data
# • Add physics-informed residual losses (Einstein equations, energy conditions)
# • Quantify uncertainty via Monte Carlo Dropout or Gaussian NLL head
# • Visualize hidden representations (UMAP, t-SNE)
# ------------------------------------------------------------------------------