<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/Implementation_AI_Stabilized_Wormhole_Geometry.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# ------------------------------------------------------------------------------
# 1. Define your WormholeAI (as given)
# ------------------------------------------------------------------------------
class WormholeAI(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1   = nn.Linear(input_dim, hidden_dim)
        self.relu  = nn.ReLU()
        self.fc2   = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

# ------------------------------------------------------------------------------
# 2. Synthetic data generator
#    Replace this with your real curvature → stability mapping
# ------------------------------------------------------------------------------
def generate_synthetic_data(n_samples, input_dim, output_dim):
    X = torch.randn(n_samples, input_dim)
    # toy target: nonlinear function of inputs + noise
    y_base = torch.sin(X.sum(dim=1, keepdim=True))
    y = y_base.repeat(1, output_dim) + 0.1 * torch.randn(n_samples, output_dim)
    return X, y

# ------------------------------------------------------------------------------
# 3. Hyperparameters & DataLoader
# ------------------------------------------------------------------------------
input_dim, hidden_dim, output_dim = 5, 16, 3
batch_size, lr, epochs = 64, 1e-3, 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# generate train/val splits
X_train, y_train = generate_synthetic_data(8000, input_dim, output_dim)
X_val,   y_val   = generate_synthetic_data(2000, input_dim, output_dim)

train_ds = TensorDataset(X_train, y_train)
val_ds   = TensorDataset(X_val,   y_val)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size)

# ------------------------------------------------------------------------------
# 4. Model, Loss, Optimizer, Scheduler
# ------------------------------------------------------------------------------
model     = WormholeAI(input_dim, hidden_dim, output_dim).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min",
                                                 factor=0.5, patience=10)

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
best_val = float("inf")

for epoch in range(1, epochs + 1):
    # -- Train --
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        loss = criterion(model(xb), yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)

    # -- Validate --
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            val_loss += criterion(model(xb), yb).item() * xb.size(0)
    val_loss /= len(val_loader.dataset)

    scheduler.step(val_loss)
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), "best_wormhole_ai.pt")

    print(f"Epoch {epoch:03d} | Train MSE: {train_loss:.4f} | Val MSE: {val_loss:.4f}")

print("\nTraining complete. Best Val MSE:", best_val)

# ------------------------------------------------------------------------------
# 6. Loading & Inference Example
# ------------------------------------------------------------------------------
model.load_state_dict(torch.load("best_wormhole_ai.pt"))
model.eval()
# Option 1: detach before moving to CPU
sample = torch.randn(1, input_dim).to(device)
output = model(sample).detach().cpu().numpy()
print("Stability corrections:", output)

# — OR —

# Option 2: wrap in torch.no_grad() and then detach
with torch.no_grad():
    sample = torch.randn(1, input_dim).to(device)
    output = model(sample).cpu().detach().numpy()
    print("Stability corrections:", output)