<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_quantum_ai_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_quantum_ai.py

Physics-informed AI pipeline for Quantum Supercomputing Network Stability:

1. Synthetic dataset: (ent_strength, noise, decoherence, gate_error, temperature)
   → (fidelity, coherence_time, error_rate)
2. Physics-informed residuals enforcing fidelity decay and error scaling
3. MLP with LayerNorm & Dropout
4. MC-Dropout inference for uncertainty quantification
5. Training loop with AdamW (keywords), ReduceLROnPlateau, gradient clipping, early stopping
6. Visualizations: loss curves, scatter plots, uncertainty heatmap
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1. Synthetic Quantum Dataset
# ------------------------------------------------------------------------------
class QuantumDataset(Dataset):
    def __init__(self, n=6000, seed=0):
        np.random.seed(seed)
        ent   = np.random.uniform(0.5, 1.0, (n,1)).astype(np.float32)  # entanglement
        noise = np.random.uniform(0.0, 0.1, (n,1)).astype(np.float32)  # noise fraction
        deco  = np.random.uniform(1e-3, 1e-1, (n,1)).astype(np.float32) # decoherence rate
        gate  = np.random.uniform(1e-4, 1e-2, (n,1)).astype(np.float32) # gate error
        temp  = np.random.uniform(0.01, 1.0, (n,1)).astype(np.float32)  # temperature

        X = np.hstack([ent, noise, deco, gate, temp])

        # Physics-based targets
        fidelity      = ent * np.exp(-noise * 10 - deco * 5)
        coherence_t   = 1.0 / (deco + noise)
        error_rate    = gate * (1 + temp)

        Y = np.hstack([fidelity, coherence_t, error_rate])
        Y += 0.02 * Y.std(axis=0) * np.random.randn(*Y.shape)

        # normalization stats
        self.X_mean, self.X_std = X.mean(0), X.std(0) + 1e-6
        self.Y_mean, self.Y_std = Y.mean(0), Y.std(0) + 1e-6
        self.X = (X - self.X_mean) / self.X_std
        self.Y = (Y - self.Y_mean) / self.Y_std

    def __len__(self): return len(self.X)
    def __getitem__(self, i): return torch.from_numpy(self.X[i]), torch.from_numpy(self.Y[i])

# ------------------------------------------------------------------------------
# 2. Model
# ------------------------------------------------------------------------------
class QuantumAI(nn.Module):
    def __init__(self, in_dim=5, hid=64, out=3, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid), nn.LayerNorm(hid), nn.ReLU(), nn.Dropout(p),
            nn.Linear(hid, hid),    nn.LayerNorm(hid), nn.ReLU(), nn.Dropout(p),
            nn.Linear(hid, out)
        )
    def forward(self,x): return self.net(x)

# ------------------------------------------------------------------------------
# 3. Physics-Informed Residual
# ------------------------------------------------------------------------------
def physics_residual(pred, X, stats):
    X_den = X * stats['X_std'] + stats['X_mean']
    ent, noise, deco, gate, temp = X_den.t()
    Y_den = pred * stats['Y_std'] + stats['Y_mean']
    fid_pred, ct_pred, er_pred = Y_den.t()

    # True fidelity: ent * exp(-10*noise -5*deco)
    fid_true = ent * torch.exp(-10*noise - 5*deco)
    # True coherence time: 1/(deco+noise)
    ct_true  = 1.0 / (deco + noise)
    # True error rate: gate*(1+temp)
    er_true  = gate * (1 + temp)

    return (
        nn.MSELoss()(fid_pred, fid_true) +
        nn.MSELoss()(ct_pred,  ct_true ) +
        nn.MSELoss()(er_pred,  er_true )
    )

def total_loss(pred, truth, X, stats, lam=1.0):
    mse  = nn.MSELoss()(pred, truth)
    phys = physics_residual(pred, X, stats)
    return mse + lam * phys, mse, phys

# ------------------------------------------------------------------------------
# 4. MC-Dropout Inference
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, X, T=50):
    model.train()
    preds = []
    with torch.no_grad():
        for _ in range(T):
            preds.append(model(X))
    arr = torch.stack(preds, dim=0)
    return arr.mean(0), arr.std(0)

# ------------------------------------------------------------------------------
# 5. Training Loop
# ------------------------------------------------------------------------------
def train(model, train_loader, val_loader, stats, device,
          lr=1e-3, wd=1e-5, lam=1.0, epochs=100, patience=10):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=wd
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val, wait = float('inf'), 0
    history = {'train_loss': [], 'val_loss': []}

    for ep in range(1, epochs+1):
        model.train()
        train_loss = 0.0
        for Xb, Yb in train_loader:
            Xb, Yb = Xb.to(device), Yb.to(device)
            pred = model(Xb)
            loss, _, _ = total_loss(pred, Yb, Xb, stats, lam)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item() * Xb.size(0)
        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for Xb, Yb in val_loader:
                Xb, Yb = Xb.to(device), Yb.to(device)
                pred = model(Xb)
                loss, _, _ = total_loss(pred, Yb, Xb, stats, lam)
                val_loss += loss.item() * Xb.size(0)
        val_loss /= len(val_loader.dataset)

        scheduler.step(val_loss)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        print(f"Epoch {ep:03d} | Train {train_loss:.4e} | Val {val_loss:.4e}")

        if val_loss < best_val - 1e-6:
            best_val, wait = val_loss, 0
            torch.save(model.state_dict(), "best_quantum_ai.pth")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {ep}")
                break

    model.load_state_dict(torch.load("best_quantum_ai.pth"))
    return history

# ------------------------------------------------------------------------------
# 6. Visualization Helpers
# ------------------------------------------------------------------------------
def plot_losses(history):
    plt.figure()
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'],   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.show()

def plot_scatter(y_true, y_pred, name):
    plt.figure()
    plt.scatter(y_true, y_pred, s=5, alpha=0.6)
    m, M = y_true.min(), y_true.max()
    plt.plot([m, M], [m, M], 'r--')
    plt.title(name); plt.xlabel("True"); plt.ylabel("Pred"); plt.show()

def plot_uncertainty_heatmap(model, stats, device):
    # Sweep ent vs noise
    ents = torch.linspace(0.5,1.0,50,device=device)
    noises = torch.linspace(0.0,0.1,50,device=device)
    E_grid, N_grid = torch.meshgrid(ents, noises, indexing='xy')
    E_flat = E_grid.reshape(-1)
    N_flat = N_grid.reshape(-1)

    # Build input tensor
    means = [stats['X_mean'][i] for i in (2,3,4)]
    inputs = torch.stack([
        E_flat, N_flat,
        torch.full_like(E_flat, means[0]),
        torch.full_like(E_flat, means[1]),
        torch.full_like(E_flat, means[2])
    ], dim=1)
    Xn = (inputs - stats['X_mean']) / stats['X_std']

    _, std = mc_dropout_predict(model, Xn, T=100)
    std_map = std[:,0].reshape(E_grid.shape).cpu().numpy()

    plt.figure()
    plt.pcolormesh(
        E_grid.cpu().numpy(),
        N_grid.cpu().numpy(),
        std_map,
        cmap='viridis',
        shading='auto'
    )
    plt.colorbar(label="Std of Fidelity")
    plt.xlabel("Entanglement Strength")
    plt.ylabel("Noise Level")
    plt.title("Fidelity Uncertainty Heatmap")
    plt.show()

# ------------------------------------------------------------------------------
# 7. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = QuantumDataset()
    stats = {
        'X_mean': torch.tensor(dataset.X_mean, device=device),
        'X_std' : torch.tensor(dataset.X_std,  device=device),
        'Y_mean': torch.tensor(dataset.Y_mean, device=device),
        'Y_std' : torch.tensor(dataset.Y_std,  device=device),
    }

    train_ds, val_ds = random_split(dataset, [len(dataset)-1200, 1200])
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=256)

    model   = QuantumAI().to(device)
    history = train(
        model, train_loader, val_loader,
        stats, device,
        lr=1e-3, wd=1e-5, lam=1.0,
        epochs=100, patience=10
    )

    plot_losses(history)

    # Scatter True vs Pred
    with torch.no_grad():
        X_all   = torch.from_numpy(dataset.X).to(device)
        Y_pred  = model(X_all).cpu().numpy() * dataset.Y_std + dataset.Y_mean
    for i, nm in enumerate(["Fidelity","Coherence_Time","Error_Rate"]):
        plot_scatter(dataset.Y[:,i] * dataset.Y_std[i] + dataset.Y_mean[i], Y_pred[:,i], nm)

    plot_uncertainty_heatmap(model, stats, device)