<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_pinn_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib scikit-learn

In [None]:
#!/usr/bin/env python3
# train_pinn.py

"""
End-to-end PINN pipeline with MC-Dropout uncertainty and OOD evaluation.

Features:
- Simple sine-wave dataset (replace with your physics data loader)
- MLP with Dropout (PINN style)
- Physics residual via ∂²y/∂x² + y = 0 (harmonic oscillator placeholder)
- Training loop (data + physics loss)
- MC-Dropout inference for uncertainty
- OOD AUROC computation
- Reliability diagram (using sklearn.calibration)
- Ensemble baseline comparison
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve

# Configuration
BATCH_SIZE = 128
NUM_WORKERS = 2      # avoid excessive workers
EPOCHS = 100
LR = 1e-3
PHYS_COEFF = 0.1
MC_SAMPLES = 50
ENSEMBLE_SIZE = 3
PLOTS_DIR = "plots"
os.makedirs(PLOTS_DIR, exist_ok=True)

# -----------------------------------------------------------------------------
# 1. Dataset (Sine wave) - replace with your loader
# -----------------------------------------------------------------------------
def load_datasets():
    X = np.linspace(0, 2*np.pi, 10000)[:, None].astype(np.float32)
    y = np.sin(X).astype(np.float32)
    idx = np.random.permutation(len(X))
    train_idx, val_idx = idx[:8000], idx[8000:]
    train_ds = TensorDataset(torch.from_numpy(X[train_idx]),
                             torch.from_numpy(y[train_idx]))
    val_ds   = TensorDataset(torch.from_numpy(X[val_idx]),
                             torch.from_numpy(y[val_idx]))
    return train_ds, val_ds

train_ds, val_ds = load_datasets()
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

# -----------------------------------------------------------------------------
# 2. PINN Model with Dropout
# -----------------------------------------------------------------------------
class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        return self.net(x)

# -----------------------------------------------------------------------------
# 3. Physics Residual (pure Python)
# d²y/dx² + y = 0 (harmonic oscillator placeholder)
# -----------------------------------------------------------------------------
def physics_residual(x, y_pred):
    # first derivative
    dy = torch.autograd.grad(
        y_pred.sum(), x, create_graph=True
    )[0]
    # second derivative
    d2y = torch.autograd.grad(
        dy.sum(), x, create_graph=True
    )[0]
    return d2y + y_pred

# -----------------------------------------------------------------------------
# 4. Training & Validation
# -----------------------------------------------------------------------------
mse_loss = nn.MSELoss()

def train_epoch(model, loader, optimizer):
    model.train()
    total = 0.0
    for xb, yb in loader:
        xb = xb.to(device).requires_grad_(True)
        yb = yb.to(device)

        # forward
        y_pred = model(xb)
        data_l = mse_loss(y_pred, yb)
        phys_r = physics_residual(xb, y_pred)
        phys_l = phys_r.pow(2).mean()
        loss = data_l + PHYS_COEFF * phys_l

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total += loss.item() * xb.size(0)
    return total / len(loader.dataset)

def validate_epoch(model, loader):
    model.eval()
    total = 0.0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            y_pred = model(xb)
            total += mse_loss(y_pred, yb).item() * xb.size(0)
    return total / len(loader.dataset)

# -----------------------------------------------------------------------------
# 5. MC-Dropout Inference & OOD AUROC
# -----------------------------------------------------------------------------
def mc_dropout_predict(model, x, samples=MC_SAMPLES):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(samples):
            preds.append(model(x).cpu().numpy())
    arr = np.stack(preds)  # [S, N, 1]
    mean = arr.mean(axis=0).squeeze()
    std  = arr.std(axis=0).squeeze()
    return mean, std

def compute_auroc_in_ood(model, loader):
    # in-domain uncertainty
    in_unc = []
    for xb, _ in loader:
        xb = xb.to(device)
        _, std = mc_dropout_predict(model, xb)
        in_unc.extend(std.tolist())

    # synthetic OOD: uniform noise over wider range
    ood_X = torch.rand(len(train_ds), 1).to(device) * 4*np.pi - 1.5*np.pi
    ood_unc = []
    for i in range(0, len(ood_X), BATCH_SIZE):
        batch = ood_X[i:i+BATCH_SIZE]
        _, std = mc_dropout_predict(model, batch)
        ood_unc.extend(std.tolist())

    labels = np.array([0]*len(in_unc) + [1]*len(ood_unc))
    scores = np.array(in_unc + ood_unc)
    return roc_auc_score(labels, scores)

# -----------------------------------------------------------------------------
# 6. Reliability Diagram
# -----------------------------------------------------------------------------
def plot_reliability(y_true, y_prob, bins=10, fname="reliability.png"):
    frac_true, frac_pred = calibration_curve(y_true, y_prob, n_bins=bins)
    plt.figure(figsize=(5,5))
    plt.plot(frac_pred, frac_true, 'o-', label="Model")
    plt.plot([0,1],[0,1],'--', color='gray', label="Perfect")
    plt.xlabel("Predicted prob")
    plt.ylabel("Empirical freq")
    plt.legend()
    plt.title("Reliability Diagram")
    plt.savefig(os.path.join(PLOTS_DIR, fname))
    plt.close()

# -----------------------------------------------------------------------------
# 7. Ensemble Baseline
# -----------------------------------------------------------------------------
def train_ensemble():
    ensemble = []
    for i in range(ENSEMBLE_SIZE):
        m = PINN().to(device)
        opt = optim.Adam(m.parameters(), lr=LR)
        for _ in range(EPOCHS):
            train_epoch(m, train_loader, opt)
        ensemble.append(m)
    return ensemble

def ensemble_predict(models, x):
    preds = np.stack([m(x.to(device)).detach().cpu().numpy() for m in models], axis=0)
    return preds.mean(axis=0).squeeze(), preds.std(axis=0).squeeze()

# -----------------------------------------------------------------------------
# 8. Main
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Single-model training
    model = PINN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    history = {'train':[], 'val':[]}

    for ep in range(1, EPOCHS+1):
        tr = train_epoch(model, train_loader, optimizer)
        vl = validate_epoch(model, val_loader)
        history['train'].append(tr)
        history['val'].append(vl)
        print(f"Epoch {ep:3d} | Train {tr:.4f} | Val {vl:.4f}")

    # Save loss curve
    plt.figure()
    plt.plot(history['train'], label="Train")
    plt.plot(history['val'],   label="Val")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()
    plt.title("Loss Curve")
    plt.savefig(os.path.join(PLOTS_DIR, "loss_curve.png"))
    plt.close()

    # Compute OOD AUROC
    auroc = compute_auroc_in_ood(model, val_loader)
    print(f"\nMC-Dropout OOD AUROC: {auroc:.4f}")

    # Reliability on validation set
    X_val = torch.from_numpy(np.linspace(0,2*np.pi,len(val_ds))[:,None]).float().to(device)
    _, val_std = mc_dropout_predict(model, X_val)
    # Treat normalized std as pseudo-probability
    y_true = (np.sin(np.linspace(0,2*np.pi,len(val_ds))) > 0).astype(int)
    plot_reliability(y_true, val_std, fname="reliability_mc.png")
    print(f"Reliability plot saved to {PLOTS_DIR}/reliability_mc.png")

    # Ensemble baseline
    ensemble = train_ensemble()
    en_auroc = roc_auc_score(
        np.concatenate([[0]*len(val_ds), [1]*len(val_ds)]),
        np.concatenate([
            ensemble_predict(ensemble, X_val)[1],
            ensemble_predict(ensemble, torch.rand(len(val_ds),1).to(device)*4*np.pi - 1.5*np.pi)[1]
        ])
    )
    print(f"Ensemble OOD AUROC: {en_auroc:.4f}")