<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_pinn_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

# For reproducibility
torch.manual_seed(0)
np.random.seed(0)


class PhysicsDataset(Dataset):
    """
    Wraps input/output arrays into a torch Dataset with normalization
    and sanity checks against NaNs/Infs.
    """
    def __init__(self, X, Y):
        # Convert to float32
        self.X = X.astype(np.float32)
        self.Y = Y.astype(np.float32)

        # Compute normalization stats
        self.X_mean, self.X_std = self.X.mean(0), self.X.std(0) + 1e-8
        self.Y_mean, self.Y_std = self.Y.mean(0), self.Y.std(0) + 1e-8

        # Apply normalization
        self.X = (self.X - self.X_mean) / self.X_std
        self.Y = (self.Y - self.Y_mean) / self.Y_std

        # Sanity checks
        assert not np.isnan(self.X).any() and not np.isinf(self.X).any(), \
            "X contains NaN or Inf"
        assert not np.isnan(self.Y).any() and not np.isinf(self.Y).any(), \
            "Y contains NaN or Inf"

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


class MLP(nn.Module):
    """
    Simple fully connected network with Tanh activations.
    """
    def __init__(self, in_dim, out_dim, hidden=(64, 64)):
        super().__init__()
        layers = []
        dims = [in_dim] + list(hidden) + [out_dim]
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 1:
                layers.append(nn.Tanh())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def physics_residual(u_pred, x, stats):
    """
    Example PDE residual: u_xx - f(x) = 0
    Adjust this to your specific operator.
    """
    # Un-normalize u
    u = u_pred * stats['Y_std'] + stats['Y_mean']

    # First derivative
    grads = torch.autograd.grad(
        outputs=u.sum(),
        inputs=x,
        create_graph=True
    )[0]

    # Second derivative
    u_xx = torch.autograd.grad(
        outputs=grads.sum(),
        inputs=x,
        create_graph=True
    )[0]

    # Example source term
    f = torch.sin(x)

    # Residual and renormalize
    res = (u_xx - f) / stats['Y_std']
    return res


def train(model, train_loader, val_loader, stats,
          lr=1e-3, lam=1.0, max_epochs=100, patience=10):
    """
    Training loop with:
    - Adam optimizer
    - ReduceLROnPlateau scheduler
    - Gradient clipping
    - NaN loss check
    - Early stopping based on validation loss
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5, min_lr=1e-6)

    best_val, epochs_no_improve = float('inf'), 0
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(1, max_epochs + 1):
        # ---- Training ----
        model.train()
        train_loss = 0.0

        for xb, yb in train_loader:
            xb = xb.to(device).requires_grad_(True)
            yb = yb.to(device)

            optimizer.zero_grad()
            yp = model(xb)

            mse_loss = nn.functional.mse_loss(yp, yb)
            pres = physics_residual(yp, xb, stats)
            pres_loss = pres.clamp(-1.0, 1.0).pow(2).mean()

            loss = mse_loss + lam * pres_loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            if torch.isnan(loss):
                raise RuntimeError(f"NaN encountered in loss at epoch {epoch}")

            train_loss += loss.item() * xb.size(0)

        train_loss /= len(train_loader.dataset)
        history['train_loss'].append(train_loss)

        # ---- Validation ----
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                yp = model(xb)
                val_loss += nn.functional.mse_loss(yp, yb, reduction='sum').item()

        val_loss /= len(val_loader.dataset)
        history['val_loss'].append(val_loss)

        # Scheduler step
        scheduler.step(val_loss)

        print(f"Epoch {epoch:03d} | Train {train_loss:.6f} | Val {val_loss:.6f}")

        # Early stopping
        if val_loss + 1e-6 < best_val:
            best_val, epochs_no_improve = val_loss, 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    return history


def main():
    # -- Generate or load data --
    N = 2000
    X = np.random.uniform(-np.pi, np.pi, size=(N, 1))
    Y = np.sin(X)  # target function

    # Split into train/validation
    perm = np.random.permutation(N)
    train_idx, val_idx = perm[:1600], perm[1600:]
    ds_train = PhysicsDataset(X[train_idx], Y[train_idx])
    ds_val = PhysicsDataset(X[val_idx], Y[val_idx])

    # DataLoaders
    train_loader = DataLoader(ds_train, batch_size=64, shuffle=True)
    val_loader = DataLoader(ds_val, batch_size=64, shuffle=False)

    # Stats for residual normalization
    stats = {
        'X_mean': torch.tensor(ds_train.X_mean, dtype=torch.float32),
        'X_std':  torch.tensor(ds_train.X_std,  dtype=torch.float32),
        'Y_mean': torch.tensor(ds_train.Y_mean, dtype=torch.float32),
        'Y_std':  torch.tensor(ds_train.Y_std,  dtype=torch.float32),
    }

    # Model instantiation
    model = MLP(in_dim=1, out_dim=1, hidden=(64, 64))

    # Train
    history = train(
        model,
        train_loader,
        val_loader,
        stats,
        lr=1e-4,      # Lower LR for stability
        lam=0.1,      # Physics loss weight
        max_epochs=200,
        patience=20
    )


if __name__ == '__main__':
    main()