<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/space_time_lattice_pinn_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
space_time_lattice_pinn.py

A fully corrected, self-contained script for physics-informed SpaceTimeLatticeAI:

 1. Generates synthetic lattice data
 2. Defines PINN with a toy determinant–sum constraint
 3. Trains with AdamW, ReduceLROnPlateau, early stopping
 4. Quantifies uncertainty via MC-Dropout
 5. Plots training curves and true vs predicted scatter plots
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Dataset
# ------------------------------------------------------------------------------
def true_lattice_deformation(x: torch.Tensor) -> torch.Tensor:
    """Toy physics mapping: output_i = x_i * sum(x[:3])."""
    A = x[:, :3]
    S = x[:, :3].sum(dim=1, keepdim=True)
    return A * S

class SyntheticLatticeDataset(Dataset):
    def __init__(self, n_samples=5000, seed=42):
        super().__init__()
        torch.manual_seed(seed)
        X_raw = torch.rand(n_samples, 5) * 2 - 1  # uniform in [-1,1]
        Y_raw = true_lattice_deformation(X_raw) + 0.05 * torch.randn(n_samples, 3)

        # compute means and stds
        self.stats = {
            'X_mean': X_raw.mean(0), 'X_std': X_raw.std(0),
            'Y_mean': Y_raw.mean(0), 'Y_std': Y_raw.std(0)
        }

        # normalize
        self.X = (X_raw - self.stats['X_mean']) / self.stats['X_std']
        self.Y = (Y_raw - self.stats['Y_mean']) / self.stats['Y_std']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# ------------------------------------------------------------------------------
# 2. Model Definition
# ------------------------------------------------------------------------------
class SpaceTimeLatticeAI(nn.Module):
    def __init__(self, input_dim=5, hidden_dims=(64, 64), output_dim=3, p_drop=0.1):
        super().__init__()
        layers = []
        dim = input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(dim, h),
                nn.LayerNorm(h),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ]
            dim = h
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Loss
# ------------------------------------------------------------------------------
def physics_residual(pred: torch.Tensor,
                     inp: torch.Tensor,
                     stats: dict) -> torch.Tensor:
    """
    Enforce det(diag(pred_denorm)) ≈ (sum(inp_denorm[:3]))^2.
    """
    # denormalize
    Y_den = pred * stats['Y_std'] + stats['Y_mean']       # [batch,3]
    X_den = inp  * stats['X_std'] + stats['X_mean']       # [batch,5]

    # determinant of diagonal matrix = product of diagonal entries
    det_pred = Y_den[:,0] * Y_den[:,1] * Y_den[:,2]
    target   = (X_den[:,0] + X_den[:,1] + X_den[:,2]).pow(2)

    return nn.MSELoss()(det_pred, target)

def total_loss(pred: torch.Tensor,
               true: torch.Tensor,
               inp: torch.Tensor,
               stats: dict,
               lambda_phys: float = 0.5):
    mse   = nn.MSELoss()(pred, true)
    phys  = physics_residual(pred, inp, stats)
    return mse + lambda_phys * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Prediction
# ------------------------------------------------------------------------------
def mc_dropout_predict(model: nn.Module,
                       x: torch.Tensor,
                       n_samples: int = 30) -> (np.ndarray, np.ndarray):
    model.train()  # enable dropout
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds, axis=0)      # [samples, batch, output_dim]
    return arr.mean(axis=0), arr.std(axis=0)

# ------------------------------------------------------------------------------
# 5. Training Routine
# ------------------------------------------------------------------------------
def train(model: nn.Module,
          train_loader: DataLoader,
          val_loader: DataLoader,
          stats: dict,
          device: torch.device,
          lr: float = 1e-3,
          weight_decay: float = 1e-5,
          lambda_phys: float = 0.5,
          max_epochs: int = 100,
          patience: int = 10):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3
    )

    best_val = float('inf')
    wait = 0
    history = {'train': [], 'val': []}

    for epoch in range(1, max_epochs+1):
        # Training
        model.train()
        train_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred   = model(xb)
            loss, _, _ = total_loss(pred, yb, xb, stats, lambda_phys)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * xb.size(0)
        train_loss /= len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred   = model(xb)
                loss, _, _ = total_loss(pred, yb, xb, stats, lambda_phys)
                val_loss += loss.item() * xb.size(0)
        val_loss /= len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        print(f"Epoch {epoch:03d} | Train: {train_loss:.4e} | Val: {val_loss:.4e}")

        # Early stopping
        if val_loss < best_val - 1e-6:
            best_val = val_loss
            wait = 0
            torch.save(model.state_dict(), "best_lattice_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    # Load the best model
    model.load_state_dict(torch.load("best_lattice_ai.pth"))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_history(history: dict):
    plt.figure()
    plt.plot(history['train'], label='Train Loss')
    plt.plot(history['val'],   label='Val Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss')
    plt.legend(); plt.title('Training Curve')
    plt.tight_layout(); plt.show()

def plot_scatter(true_vals: np.ndarray, pred_vals: np.ndarray, name: str):
    plt.figure()
    plt.scatter(true_vals, pred_vals, s=5, alpha=0.5)
    m, M = true_vals.min(), true_vals.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.xlabel(f"True {name}"); plt.ylabel(f"Pred {name}")
    plt.title(f"{name}: True vs. Pred")
    plt.tight_layout(); plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    dataset = SyntheticLatticeDataset(n_samples=8000)
    stats = dataset.stats

    n_val = int(0.2 * len(dataset))
    n_trn = len(dataset) - n_val
    trn_ds, val_ds = random_split(dataset, [n_trn, n_val])
    trn_loader = DataLoader(trn_ds, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=256)

    # Build and train model
    model = SpaceTimeLatticeAI().to(device)
    history = train(
        model, trn_loader, val_loader, stats, device,
        lr=1e-3, weight_decay=1e-5, lambda_phys=0.5,
        max_epochs=100, patience=10
    )

    # Plot training history
    plot_history(history)

    # Full dataset scatter plots
    X_all = dataset.X.to(device)
    with torch.no_grad():
        Y_pred_norm = model(X_all).cpu()
    Y_true_norm = dataset.Y

    # Denormalize
    Y_pred = Y_pred_norm * stats['Y_std'] + stats['Y_mean']
    Y_true = Y_true_norm * stats['Y_std'] + stats['Y_mean']

    # Scatter for each output dim
    Y_pred_np = Y_pred.numpy()
    Y_true_np = Y_true.numpy()
    for i, name in enumerate(["Deformation₁", "Deformation₂", "Deformation₃"]):
        plot_scatter(Y_true_np[:, i], Y_pred_np[:, i], name)

    # MC-Dropout uncertainty example
    sample = torch.randn(100, 5).to(device)
    mean, std = mc_dropout_predict(model, sample, n_samples=50)
    print("MC-Dropout sample mean (first 3):", mean[:3])
    print("MC-Dropout sample std  (first 3):", std[:3])