<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/Enhanced_Wormhole_Stability_AI_with_Physics_Informed_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# ----------------------------------------
# 1. MODEL DEFINITION
# ----------------------------------------
class WormholeStabilityAI(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(WormholeStabilityAI, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)              # Outputs stability correction factors

# ----------------------------------------
# 2. PHYSICS-INFORMED LOSS
# ----------------------------------------
def curvature_residual(pred, params):
    """
    Example physics residual: R_predicted - R_true(params) = 0
    Here we fake R_true with a simple analytic function of params.
    Replace with your wormhole curvature eqns.
    """
    # Assume first output channel is an estimate of Ricci scalar
    ricci_est = pred[:, 0]
    # Fake ground-truth Ricci scalar from metric params
    ricci_true = params.pow(2).sum(dim=1) * 0.1
    return (ricci_est - ricci_true).pow(2).mean()

# ----------------------------------------
# 3. DATASET (DUMMY)
# ----------------------------------------
class DummyWormholeDataset(torch.utils.data.Dataset):
    def __init__(self, N=1000, input_dim=5):
        super().__init__()
        self.X = torch.randn(N, input_dim)
        # fake targets: two correction factors plus one curvature estimate
        self.Y = torch.randn(N, 3) * 0.5

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# ----------------------------------------
# 4. TRAINING SETUP
# ----------------------------------------
# Hyperparameters
input_dim, hidden_dim, output_dim = 5, 32, 3
lr, epochs, lambda_phys = 1e-3, 200, 0.5

# Model, optimizer, dataloader
model = WormholeStabilityAI(input_dim, hidden_dim, output_dim)
opt = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

dataset = DummyWormholeDataset(N=2000, input_dim=input_dim)
loader  = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

# ----------------------------------------
# 5. TRAINING LOOP
# ----------------------------------------
train_losses, phys_losses = [], []

for epoch in range(1, epochs+1):
    running_loss, running_phys = 0.0, 0.0

    for Xb, yb in loader:
        opt.zero_grad()

        pred = model(Xb)
        mse_loss   = criterion(pred, yb)
        phys_loss  = curvature_residual(pred, Xb)

        loss = mse_loss + lambda_phys * phys_loss
        loss.backward()
        opt.step()

        running_loss += mse_loss.item() * Xb.size(0)
        running_phys += phys_loss.item() * Xb.size(0)

    # logging
    avg_mse  = running_loss  / len(dataset)
    avg_phys = running_phys / len(dataset)
    train_losses.append(avg_mse)
    phys_losses .append(avg_phys)

    if epoch % 20 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d} | MSE: {avg_mse:.4f} | PhysRes: {avg_phys:.4f}")

# ----------------------------------------
# 6. VISUALIZATION
# ----------------------------------------
plt.figure(figsize=(8,4))
plt.plot(train_losses, label="MSE Loss")
plt.plot(phys_losses,  label="Physics Residual")
plt.xlabel("Epoch")
plt.legend()
plt.title("Training Curves")
plt.grid(True)
plt.tight_layout()
plt.show()

# Sample inference on a batch
X_sample, Y_sample = next(iter(loader))
with torch.no_grad():
    Y_pred = model(X_sample)

plt.figure(figsize=(6,6))
plt.scatter(Y_sample[:,0], Y_pred[:,0], s=20, alpha=0.5)
lims = [min(Y_sample[:,0].min(), Y_pred[:,0].min()),
        max(Y_sample[:,0].max(), Y_pred[:,0].max())]
plt.plot(lims, lims, 'k--')
plt.xlabel("True Curvature Correction")
plt.ylabel("Predicted")
plt.title("Channel 0: True vs Pred")
plt.grid(True)
plt.tight_layout()
plt.show()