<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_physics_informed_pinn_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib scikit-learn

In [None]:
#!/usr/bin/env python3
"""
train_physics_informed_pinn.py

End-to-end training and OOD analysis for a physics-informed MLP with MC-Dropout.

Features:
- DataLoader tuned (num_workers, pin_memory)
- CuDNN autotuner enabled
- Batch-wise physics-consistent augmentation
- PINN-style MLP with Dropout
- Physics residual via ∂y_pred/∂x gradient norm
- Cosine LR schedule, weight decay
- MC-Dropout inference for uncertainty & OOD AUROC
- Reliability diagram & uncertainty histograms
- Loss curves saved to disk
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.backends import cudnn
import matplotlib.pyplot as plt

# fix: calibration_curve lives in sklearn.calibration
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve

# Enable CuDNN autotuner for fixed-size inputs
cudnn.benchmark = True

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------
# 1) Physics-consistent augmentation
# ----------------------------
def physics_augment(x, sigma=1e-2):
    return x + sigma * torch.randn_like(x)

# ----------------------------
# 2) Model definition
# ----------------------------
class PINN(nn.Module):
    def __init__(self, in_dim, hidden_dims=[128, 128], dropout_p=0.1):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden_dims:
            layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout_p)]
            last = h
        layers.append(nn.Linear(last, in_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# ----------------------------
# 3) Physics residual function
# ----------------------------
def physics_residual(x, y_pred):
    """
    Dummy placeholder: norm of gradient ∂y_pred/∂x.
    Replace with your actual ODE/PDE residual.
    """
    grads = torch.autograd.grad(y_pred.sum(), x, create_graph=True)[0]
    return grads.norm(dim=-1)

# ----------------------------
# 4) MC-Dropout inference
# ----------------------------
def mc_dropout_predict(model, x, n_samples=50):
    model.train()  # keep dropout active
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds, axis=0)        # [S, N, D]
    mean = arr.mean(axis=0)              # [N, D]
    var  = arr.var(axis=0).mean(axis=1)  # [N]
    return mean, var

# ----------------------------
# 5) Training & validation loops
# ----------------------------
mse_loss = nn.MSELoss()

def train_epoch(model, loader, optimizer, phys_coeff):
    model.train()
    total_loss = 0.0
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)

        # augment + require grad on inputs
        xb_aug = physics_augment(xb)
        xb_aug.requires_grad_(True)

        # forward
        y_pred = model(xb_aug)
        mse   = mse_loss(y_pred, yb)
        phys  = physics_residual(xb_aug, y_pred).mean()
        loss  = mse + phys_coeff * phys

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)

    return total_loss / len(loader.dataset)


def validate_epoch(model, loader, phys_coeff):
    model.eval()
    total_loss = 0.0
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)

        xb = xb.clone().detach().requires_grad_(True)
        y_pred = model(xb)
        mse   = mse_loss(y_pred, yb)
        phys  = physics_residual(xb, y_pred).mean()
        total_loss += (mse + phys_coeff * phys).item() * xb.size(0)

    return total_loss / len(loader.dataset)

# ----------------------------
# 6) Plotting utilities
# ----------------------------
def plot_losses(hist, save_dir="plots"):
    epochs = np.arange(1, len(hist['train']) + 1)
    plt.figure()
    plt.plot(epochs, hist['train'], label="Train Loss")
    plt.plot(epochs, hist['val'],   label="Val   Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Loss Curve")
    plt.savefig(f"{save_dir}/loss_curve.png")
    plt.close()

def plot_reliability(y_true, conf, save_dir="plots", n_bins=10):
    frac_true, frac_pred = calibration_curve(y_true, conf, n_bins=n_bins)
    plt.figure()
    plt.plot(frac_pred, frac_true, 'o-')
    plt.plot([0,1],[0,1],'--', color='gray')
    plt.xlabel("Predicted Confidence")
    plt.ylabel("True Frequency")
    plt.title("Reliability Diagram")
    plt.savefig(f"{save_dir}/reliability.png")
    plt.close()

def plot_uncertainty_hist(unc_dict, save_dir="plots"):
    plt.figure()
    for mode, uc in unc_dict.items():
        plt.hist(uc, bins=30, alpha=0.6, label=mode)
    plt.xlabel("Predictive Variance")
    plt.legend()
    plt.title("OOD Uncertainty Histogram")
    plt.savefig(f"{save_dir}/ood_uncertainty_hist.png")
    plt.close()

# ----------------------------
# 7) Main workflow
# ----------------------------
def main():
    os.makedirs("plots", exist_ok=True)

    # Dummy data (replace with your physics dataset)
    N_train, N_val = 5000, 1000
    D = 4
    X_train = torch.randn(N_train, D)
    Y_train = torch.randn(N_train, D)
    X_val   = torch.randn(N_val,   D)
    Y_val   = torch.randn(N_val,   D)

    # OOD sets
    X_ood = {
        'presence':      torch.randn(500, D) + 2.0,
        'dissolution':   torch.randn(500, D) - 2.0,
        'transcendence': torch.randn(500, D) * 3.0,
    }

    # DataLoaders
    train_loader = DataLoader(
        TensorDataset(X_train, Y_train),
        batch_size=128, shuffle=True,
        num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        TensorDataset(X_val, Y_val),
        batch_size=128,
        num_workers=2, pin_memory=True
    )

    # Model, optimizer, scheduler
    model = PINN(in_dim=D, hidden_dims=[128,128], dropout_p=0.1).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=100)

    # Training loop
    history = {'train': [], 'val': []}
    phys_coeff = 1e-3
    for epoch in range(1, 101):
        tr_loss = train_epoch(model, train_loader, optimizer, phys_coeff)
        vl_loss = validate_epoch(model, val_loader, phys_coeff)
        scheduler.step()
        history['train'].append(tr_loss)
        history['val'].append(vl_loss)
        print(f"Epoch {epoch:3d} | Train {tr_loss:.4f} | Val {vl_loss:.4f}")

    # Save curves and history
    plot_losses(history, "plots")
    np.savez("history.npz", train=history['train'], val=history['val'])

    # MC-Dropout ID & OOD evaluation
    X_val_dev = X_val.to(DEVICE)
    _, unc_id = mc_dropout_predict(model, X_val_dev, n_samples=50)
    y_id = np.zeros_like(unc_id)

    aurocs, unc_dict = {}, {}
    for mode, Xo in X_ood.items():
        Xo_dev = Xo.to(DEVICE)
        _, unc_ood = mc_dropout_predict(model, Xo_dev, n_samples=50)
        y_ood = np.ones_like(unc_ood)
        aurocs[mode] = roc_auc_score(
            np.concatenate([y_id, y_ood]),
            np.concatenate([unc_id, unc_ood])
        )
        unc_dict[mode] = unc_ood
        print(f"{mode:13s} AUROC: {aurocs[mode]:.4f}")

    plot_uncertainty_hist(unc_dict, "plots")

    # Reliability
    var_all  = np.concatenate([unc_id] + list(unc_dict.values()))
    conf_all = 1.0 - (var_all - var_all.min()) / (var_all.max() - var_all.min())
    y_all    = np.concatenate([y_id] + [np.ones_like(v) for v in unc_dict.values()])
    plot_reliability(y_all, conf_all, "plots")

    print("Done. Check the 'plots/' folder for figures.")

if __name__ == "__main__":
    main()