<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_gravwave_comm_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_gravwave_comm.py

Physics‐informed AI pipeline for gravitational wave signal encoding:

1. Synthetic dataset: (mass, freq, amp, spin, distance) → (clarity, efficiency, bandwidth)
2. Physics‐informed residual enforcing SNR scaling laws
3. MLP with LayerNorm & Dropout
4. MC‐Dropout inference for uncertainty
5. Training loop with AdamW, ReduceLROnPlateau, gradient clipping
6. Visualizations: losses, scatter, uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# 1. Synthetic Dataset
class GravWaveDataset(Dataset):
    def __init__(self, n=6000, seed=0):
        np.random.seed(seed)
        M = np.random.uniform(5, 50,   (n, 1)).astype(np.float32)
        f = np.random.uniform(10, 1000,(n, 1)).astype(np.float32)
        A = np.random.uniform(1e-22,1e-20,(n, 1)).astype(np.float32)
        a = np.random.uniform(0, 0.99, (n, 1)).astype(np.float32)
        d = np.random.uniform(1e2, 1e4,(n, 1)).astype(np.float32)
        X = np.hstack([M, f, A, a, d])

        # Targets via simple physics
        clarity    = M**2 * A / d
        efficiency = clarity / f
        bandwidth  = f * A
        Y = np.hstack([clarity, efficiency, bandwidth])
        Y += 0.02 * np.std(Y, 0) * np.random.randn(*Y.shape)

        # normalization stats
        self.X_mean, self.X_std = X.mean(0), X.std(0) + 1e-6
        self.Y_mean, self.Y_std = Y.mean(0), Y.std(0) + 1e-6
        self.X = (X - self.X_mean) / self.X_std
        self.Y = (Y - self.Y_mean) / self.Y_std

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), torch.from_numpy(self.Y[i])

# 2. Model
class GravWaveCommAI(nn.Module):
    def __init__(self, in_dim=5, hid=64, out=3, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid), nn.LayerNorm(hid), nn.ReLU(), nn.Dropout(p),
            nn.Linear(hid, hid),   nn.LayerNorm(hid), nn.ReLU(), nn.Dropout(p),
            nn.Linear(hid, out)
        )

    def forward(self, x):
        return self.net(x)

# 3. Physics Loss (SNR scaling)
def phys_loss(pred, X, stats):
    X_den = X * stats['X_std'] + stats['X_mean']
    M, f, A, a, d = X_den.t()
    clarity_pred, _, _ = (pred * stats['Y_std'] + stats['Y_mean']).t()
    snr_true = M**2 * A / d
    return nn.MSELoss()(clarity_pred, snr_true)

def total_loss(pred, truth, X, stats, lam=1.0):
    mse = nn.MSELoss()(pred, truth)
    pl  = phys_loss(pred, X, stats)
    return mse + lam * pl, mse, pl

# 4. MC‐Dropout for Uncertainty
def mc_dropout(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    stacked = torch.stack(preds, 0)
    return stacked.mean(0), stacked.std(0)

# 5. Training Loop
def train(model, tr_loader, val_loader, stats, dev,
          lr=1e-3, wd=1e-5, lam=1.0, max_epochs=100, patience=10):

    optimizer = optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=wd
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', factor=0.5, patience=5
    )
    best_val, wait = float('inf'), 0
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(1, max_epochs+1):
        model.train()
        running_train = 0.0
        for Xb, Yb in tr_loader:
            Xb, Yb = Xb.to(dev), Yb.to(dev)
            preds = model(Xb)
            loss, _, _ = total_loss(preds, Yb, Xb, stats, lam)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            running_train += loss.item() * Xb.size(0)
        running_train /= len(tr_loader.dataset)

        model.eval()
        running_val = 0.0
        with torch.no_grad():
            for Xb, Yb in val_loader:
                Xb, Yb = Xb.to(dev), Yb.to(dev)
                preds = model(Xb)
                loss, _, _ = total_loss(preds, Yb, Xb, stats, lam)
                running_val += loss.item() * Xb.size(0)
        running_val /= len(val_loader.dataset)

        scheduler.step(running_val)
        history['train_loss'].append(running_train)
        history['val_loss'].append(running_val)
        print(f"Epoch {epoch:03d} – Train: {running_train:.4e}, Val: {running_val:.4e}")

        if running_val < best_val - 1e-6:
            best_val, wait = running_val, 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping triggered.")
                break

    model.load_state_dict(torch.load('best_model.pth'))
    return history

# 6. Visualization
def plot_history(history):
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'],   label='Val')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def scatter_true_vs_pred(Y_true, Y_pred, name):
    plt.scatter(Y_true, Y_pred, s=5, alpha=0.6)
    m, M = Y_true.min(), Y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.title(name)
    plt.xlabel('True'); plt.ylabel('Pred')
    plt.show()

def plot_uncertainty_map(model, stats, dev):
    # Create grid in torch
    Ms = torch.linspace(5, 50, 50, device=dev)
    ds = torch.linspace(1e2, 1e4, 50, device=dev)
    Mg, Dg = torch.meshgrid(Ms, ds, indexing='xy')

    # Build feature matrix
    Xg = torch.zeros(Mg.numel(), 5, device=dev)
    Xg[:, 0] = Mg.reshape(-1)
    Xg[:, 4] = Dg.reshape(-1)

    # Fill other columns with mean (tensor)
    for i in (1, 2, 3):
        Xg[:, i] = stats['X_mean'][i]

    # Normalize
    Xn = (Xg - stats['X_mean']) / stats['X_std']

    # MC-Dropout
    _, std = mc_dropout(model, Xn, T=100)

    # Reshape and plot
    S = std[:, 0].cpu().numpy().reshape(Mg.shape)
    plt.pcolormesh(
        Mg.cpu().numpy(),
        Dg.cpu().numpy(),
        S,
        cmap='viridis'
    )
    plt.colorbar(label='std of clarity')
    plt.xlabel('Mass'); plt.ylabel('Distance')
    plt.show()

# 7. Main
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = GravWaveDataset()
    stats = {
        'X_mean': torch.tensor(dataset.X_mean, device=device),
        'X_std':  torch.tensor(dataset.X_std,  device=device),
        'Y_mean': torch.tensor(dataset.Y_mean, device=device),
        'Y_std':  torch.tensor(dataset.Y_std,  device=device),
    }

    train_ds, val_ds = random_split(dataset, [len(dataset)-1200, 1200])
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=256)

    model   = GravWaveCommAI().to(device)
    history = train(
        model, train_loader, val_loader,
        stats, device,
        lr=1e-3, wd=1e-5, lam=1.0,
        max_epochs=100, patience=10
    )

    plot_history(history)

    # Scatter plots
    with torch.no_grad():
        X_all = torch.from_numpy(dataset.X).to(device)
        Y_pred_norm = model(X_all).cpu().numpy()
    for i, name in enumerate(['clarity','efficiency','bandwidth']):
        scatter_true_vs_pred(dataset.Y[:, i], Y_pred_norm[:, i], name)

    # Uncertainty heatmap
    plot_uncertainty_map(model, stats, device)