<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/train_tachyon_pinn_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
train_tachyon_pinn.py

Physics‐informed TachyonAI pipeline with:
  - Synthetic data generator for (E, p, f) → (f_next, rate)
  - PINN loss enforcing δf/δt = p − E·f
  - Residual‐skip MLP with LayerNorm & Dropout
  - Early stopping & ReduceLROnPlateau scheduler
  - MC‐Dropout uncertainty estimation
  - Field‐evolution animation saved as GIF
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# ------------------------------------------------------------------------------
# Hyperparameters
# ------------------------------------------------------------------------------
INPUT_DIM     = 3
HIDDEN_DIM    = 64
OUTPUT_DIM    = 2
DROPOUT_P     = 0.1
DT            = 0.1
LAMBDA_PHY    = 1.0
LR            = 1e-3
BATCH_SIZE    = 128
EPOCHS        = 200
PATIENCE      = 15
MC_SAMPLES    = 50
GIF_STEPS     = 80
GIF_CURVES    = 4
GIF_PATH      = "tachyon_evolution.gif"

# ------------------------------------------------------------------------------
# 1. Synthetic Data Generator
# ------------------------------------------------------------------------------
def generate_synthetic_tachyon_data(n_samples=12000, dt=DT):
    np.random.seed(42)
    E = np.random.uniform(-1,1,(n_samples,1)).astype(np.float32)
    p = np.random.uniform(-1,1,(n_samples,1)).astype(np.float32)
    f = np.random.uniform(-1,1,(n_samples,1)).astype(np.float32)
    X = np.hstack([E,p,f])
    f_next = f + dt*(p - E*f)
    rate   = p*f - E
    Y = np.hstack([f_next, rate]).astype(np.float32)
    return X, Y

# ------------------------------------------------------------------------------
# 2. Physics‐Informed Residual MLP with Residual Skip & Dropout
# ------------------------------------------------------------------------------
class TachyonPINN(nn.Module):
    def __init__(self, input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM,
                 output_dim=OUTPUT_DIM, dropout=DROPOUT_P):
        super().__init__()
        self.fc1   = nn.Linear(input_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.relu  = nn.ReLU()
        self.drop  = nn.Dropout(dropout)
        self.fc2   = nn.Linear(hidden_dim, output_dim)
        self.skip  = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        h   = self.fc1(x)
        h   = self.norm1(h)
        h   = self.relu(h)
        h   = self.drop(h)
        out = self.fc2(h)
        return out + self.skip(x)

# ------------------------------------------------------------------------------
# 3. PINN Loss: supervised MSE + physical constraint δf/δt = p − E·f
# ------------------------------------------------------------------------------
def pinn_loss(pred, true, inputs, dt=DT, lambda_phys=LAMBDA_PHY):
    mse_loss  = nn.MSELoss()(pred, true)
    f         = inputs[:,2]
    p         = inputs[:,1]
    E         = inputs[:,0]
    f_phys    = f + dt*(p - E*f)
    phys_loss = nn.MSELoss()(pred[:,0], f_phys)
    return mse_loss + lambda_phys*phys_loss, mse_loss, phys_loss

# ------------------------------------------------------------------------------
# 4. MC‐Dropout Uncertainty Estimation
# ------------------------------------------------------------------------------
def mc_dropout_predict(model, x, n_samples=MC_SAMPLES, device='cpu'):
    model.train()  # keep dropout active
    preds = []
    with torch.no_grad():
        for _ in range(n_samples):
            preds.append(model(x.to(device)).cpu().numpy())
    preds = np.stack(preds, axis=0)          # (n_samples, batch, 2)
    mean  = preds.mean(axis=0)               # (batch, 2)
    std   = preds.std(axis=0)                # (batch, 2)
    model.eval()
    return mean, std

# ------------------------------------------------------------------------------
# 5. DataLoaders
# ------------------------------------------------------------------------------
def prepare_dataloaders(batch_size=BATCH_SIZE, val_frac=0.2):
    X, Y        = generate_synthetic_tachyon_data()
    X_tensor    = torch.from_numpy(X)
    Y_tensor    = torch.from_numpy(Y)
    dataset     = TensorDataset(X_tensor, Y_tensor)
    n_val       = int(len(dataset) * val_frac)
    n_train     = len(dataset) - n_val
    train_ds, val_ds = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=batch_size)
    return train_loader, val_loader

# ------------------------------------------------------------------------------
# 6. Training Loop with Early Stopping & Scheduler
# ------------------------------------------------------------------------------
def train_model(model, train_loader, val_loader,
                epochs=EPOCHS, lr=LR, weight_decay=1e-5,
                lambda_phys=LAMBDA_PHY, patience_max=PATIENCE,
                device='cpu'):

    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    best_val_loss = float('inf')
    patience_cnt  = 0
    history       = {'train': [], 'val': []}

    for epoch in range(1, epochs + 1):
        # -- Train --
        model.train()
        total_train_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            pred        = model(xb)
            loss, _, _ = pinn_loss(pred, yb, xb, lambda_phys=lambda_phys)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item() * xb.size(0)

        train_loss = total_train_loss / len(train_loader.dataset)

        # -- Validate --
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                pred       = model(xb)
                loss, _, _ = pinn_loss(pred, yb, xb, lambda_phys=lambda_phys)
                total_val_loss += loss.item() * xb.size(0)

        val_loss = total_val_loss / len(val_loader.dataset)
        scheduler.step(val_loss)

        history['train'].append(train_loss)
        history['val'].append(val_loss)

        # Check for best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_tachyon_pinn.pt")
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience_max:
                print(f"Early stopping at epoch {epoch}")
                break

        if epoch % 10 == 0 or epoch == 1:
            print(f"Epoch {epoch:03d} | Train Loss {train_loss:.4e} | Val Loss {val_loss:.4e}")

    # Load best
    model.load_state_dict(torch.load("best_tachyon_pinn.pt"))
    return model, history, best_val_loss

# ------------------------------------------------------------------------------
# 7. Field Evolution Animation
# ------------------------------------------------------------------------------
def animate_evolution(model, n_steps=GIF_STEPS, n_curves=GIF_CURVES,
                      dt=DT, save_path=GIF_PATH, device='cpu'):

    # Random initial conditions (E, p, f)
    X0 = np.random.uniform(-1,1,(n_curves,3)).astype(np.float32)
    x  = torch.from_numpy(X0).to(device)

    # Record f over time
    paths = np.zeros((n_curves, n_steps+1), dtype=np.float32)
    paths[:,0] = X0[:,2]

    model.eval()
    with torch.no_grad():
        for t in range(1, n_steps+1):
            pred   = model(x).cpu().numpy()
            f_next = pred[:,0]
            paths[:,t] = f_next
            # update input: keep E,p fixed, update f
            x = torch.stack([x[:,0], x[:,1], torch.from_numpy(f_next).to(device)], dim=1)

    # Plot & animate
    fig, ax = plt.subplots()
    lines = [ax.plot([], [], lw=2)[0] for _ in range(n_curves)]
    ax.set_xlim(0, n_steps*dt)
    ax.set_ylim(paths.min()-0.1, paths.max()+0.1)
    ax.set_xlabel('t'); ax.set_ylabel('f(t)')
    ax.set_title('Tachyonic Field Evolution')

    def init():
        for ln in lines:
            ln.set_data([], [])
        return lines

    def update(frame):
        t_vals = np.linspace(0, frame*dt, frame+1)
        for i, ln in enumerate(lines):
            ln.set_data(t_vals, paths[i,:frame+1])
        return lines

    anim = animation.FuncAnimation(fig, update, frames=n_steps+1,
                                   init_func=init, blit=True)
    anim.save(save_path, writer='pillow', fps=10)
    plt.close(fig)
    print(f"Animation saved to {save_path}")

# ------------------------------------------------------------------------------
# 8. Main Execution
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    train_loader, val_loader = prepare_dataloaders()

    # Build and train model
    model = TachyonPINN().to(device)
    model, history, best_val = train_model(model, train_loader, val_loader, device=device)
    print(f"\nTraining complete. Best Val Loss: {best_val:.4e}")

    # Plot training history
    plt.figure()
    plt.plot(history['train'], label="Train Loss")
    plt.plot(history['val'],   label="Val Loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.legend(); plt.show()

    # MC‐Dropout example on validation batch
    xb, yb = next(iter(val_loader))
    mean_p, std_p = mc_dropout_predict(model, xb, device=device)
    print("\nMC‐Dropout sample predictions (mean ± std):")
    for i in range(min(5, len(xb))):
        print(xb[i].numpy(), "→", mean_p[i], "±", std_p[i])

    # Animate field evolution and save GIF
    animate_evolution(model, device=device)