<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/warp_drive_pinn_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
warp_drive_pinn.py

Physics‐informed net for warp‐drive curvature & energy field optimization.
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# Hyperparameters
# ------------------------------------------------------------------------------
INPUT_DIM    = 4      # [rho, bubble_radius, velocity, init_curvature]
HIDDEN_DIM   = 32
OUTPUT_DIM   = 2      # [opt_energy_field, opt_curvature_metric]
DROPOUT_P    = 0.1
LR           = 1e-3
WEIGHT_DECAY = 1e-5
BATCH_SIZE   = 64
EPOCHS       = 200
PATIENCE     = 15
LAMBDA_PHY   = 0.5    # physics loss weight

# ------------------------------------------------------------------------------
# 1. Synthetic Data Generator (Toy Relativity)
# ------------------------------------------------------------------------------
def generate_synthetic_warp_data(n_samples=5000):
    np.random.seed(0)
    rho    = np.random.uniform(0.1, 10.0, n_samples)      # energy density
    R      = np.random.uniform(1.0, 5.0,  n_samples)      # bubble radius
    v      = np.random.uniform(0.0, 0.9,  n_samples)      # velocity fraction c
    k_init = np.random.uniform(0.0, 2.0,  n_samples)      # init curvature
    X = np.vstack([rho, R, v, k_init]).T.astype(np.float32)

    # Toy “true” targets: known analytic surrogates
    energy_field  = rho * R * (1 - v**2)
    curvature_opt = k_init + 0.1 * rho / (R + 1e-3)
    Y = np.vstack([energy_field, curvature_opt]).T.astype(np.float32)
    return X, Y

# ------------------------------------------------------------------------------
# 2. Model: PhysWarpDrive with Residual Skip & Dropout
# ------------------------------------------------------------------------------
class PhysWarpDrive(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(INPUT_DIM, HIDDEN_DIM)
        self.norm = nn.LayerNorm(HIDDEN_DIM)
        self.fc2  = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
        self.out  = nn.Linear(HIDDEN_DIM, OUTPUT_DIM)
        self.skip = nn.Linear(INPUT_DIM, HIDDEN_DIM)
        self.drop = nn.Dropout(DROPOUT_P)

    def forward(self, x):
        h = F.relu(self.norm(self.fc1(x)))
        h = self.drop(h)
        h = F.relu(self.fc2(h) + self.skip(x))
        return self.out(h)

# ------------------------------------------------------------------------------
# 3. Physics Components & Loss
# ------------------------------------------------------------------------------
def compute_stress_energy(inputs):
    rho = inputs[:, 0]
    v   = inputs[:, 2]
    return rho * v**2  # toy T—energy flux

def compute_einstein_tensor(curv):
    return curv       # toy G—identity

def physics_loss(pred, inputs):
    # Enforce G(curvature) ≈ T(inputs)
    T = compute_stress_energy(inputs)
    G = compute_einstein_tensor(pred[:,1])
    return F.mse_loss(G, T)

def total_loss(pred, true, inputs, λ=LAMBDA_PHY):
    mse   = F.mse_loss(pred, true)
    phys  = physics_loss(pred, inputs)
    return mse + λ * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. Data Preparation
# ------------------------------------------------------------------------------
def get_dataloaders(val_frac=0.2):
    X, Y = generate_synthetic_warp_data()
    Xt   = torch.from_numpy(X)
    Yt   = torch.from_numpy(Y)
    ds   = TensorDataset(Xt, Yt)
    n_val = int(len(ds)*val_frac)
    n_tr  = len(ds) - n_val
    tr_ds, val_ds = random_split(ds, [n_tr, n_val])
    tr_ld = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True)
    va_ld = DataLoader(val_ds, batch_size=BATCH_SIZE)
    return tr_ld, va_ld

# ------------------------------------------------------------------------------
# 5. Training Loop with Scheduler & Early Stopping
# ------------------------------------------------------------------------------
def train(model, tr_ld, va_ld, device):
    opt       = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sched     = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.5, patience=5)
    best_val  = float('inf')
    patience  = 0
    history   = {'train_mse':[], 'val_mse':[], 'val_phys':[]}

    for ep in range(1, EPOCHS+1):
        # Training
        model.train()
        tot_tr = 0.0
        for xb, yb in tr_ld:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            pred       = model(xb)
            loss, mse, _ = total_loss(pred, yb, xb)
            loss.backward()
            opt.step()
            tot_tr += mse.item() * xb.size(0)
        tr_mse = tot_tr / len(tr_ld.dataset)

        # Validation
        model.eval()
        tot_mse = tot_phy = 0.0
        with torch.no_grad():
            for xb, yb in va_ld:
                xb, yb = xb.to(device), yb.to(device)
                pred      = model(xb)
                _, mse_v, phy_v = total_loss(pred, yb, xb)
                tot_mse += mse_v.item() * xb.size(0)
                tot_phy += phy_v.item() * xb.size(0)
        val_mse = tot_mse / len(va_ld.dataset)
        val_phy = tot_phy / len(va_ld.dataset)

        sched.step(val_mse)
        history['train_mse'].append(tr_mse)
        history['val_mse'].append(val_mse)
        history['val_phys'].append(val_phy)

        if val_mse < best_val:
            best_val = val_mse
            torch.save(model.state_dict(), 'best_warp_pinn.pt')
            patience = 0
        else:
            patience += 1
            if patience >= PATIENCE:
                print(f"Early stopping at epoch {ep}")
                break

        if ep % 10 == 0 or ep == 1:
            print(f"Epoch {ep:03d} | Train MSE {tr_mse:.4f} | Val MSE {val_mse:.4f} | Val Phys {val_phy:.4f}")

    model.load_state_dict(torch.load('best_warp_pinn.pt'))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization
# ------------------------------------------------------------------------------
def plot_history(hist):
    epochs = len(hist['train_mse'])
    plt.figure(figsize=(8,4))
    plt.plot(range(1, epochs+1), hist['train_mse'], label='Train MSE')
    plt.plot(range(1, epochs+1), hist['val_mse'],  label='Val MSE')
    plt.plot(range(1, epochs+1), hist['val_phys'], label='Val Physics Residual')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
    plt.tight_layout(); plt.show()

# ------------------------------------------------------------------------------
# 7. Main
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tr_loader, va_loader = get_dataloaders()
    model      = PhysWarpDrive().to(device)
    history    = train(model, tr_loader, va_loader, device)
    print("Training complete. Best validation saved to best_warp_pinn.pt")
    plot_history(history)