<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/warp_pinn_analysis_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy matplotlib plotly dash

In [None]:
# warp_pinn_analysis.py

import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

# Dash imports
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objs as go

# -----------------------------
# 1. CONFIGURATION & UTILITIES
# -----------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "best_warp_pinn.pt"

def load_model(path: str):
    """Load the trained PINN model."""
    checkpoint = torch.load(path, map_location=DEVICE)
    model = checkpoint['model']  # assumes checkpoint['model'] is the nn.Module
    model.to(DEVICE)
    model.eval()
    return model

def mc_dropout_predict(model, x, repeats=50):
    """
    Perform MC-Dropout at inference.
    Returns mean and std over T stochastic forward passes.
    """
    model.train()  # enable dropout
    preds = []
    with torch.no_grad():
        for _ in range(repeats):
            preds.append(model(x).cpu().numpy())
    preds = np.stack(preds, axis=0)
    mean = preds.mean(axis=0)
    std = preds.std(axis=0)
    model.eval()
    return mean, std

# -----------------------------
# 2. LOAD DATA & MODEL
# -----------------------------
# Replace these with your real data loaders or arrays
data = np.load("test_data.npz")
X_test = torch.tensor(data["X"], dtype=torch.float32, device=DEVICE)
y_true = data["Y"]  # shape (N,2) for [energy_field, curvature_opt]

model = load_model(MODEL_PATH)

# MC-Dropout predictions
y_pred_mean, y_pred_std = mc_dropout_predict(model, X_test, repeats=100)

# -----------------------------
# 3. PREDICTED VS TRUE PLOTS
# -----------------------------
def plot_pred_vs_true(y_true, y_pred, field_name):
    fig, ax = plt.subplots(figsize=(6,6))
    ax.scatter(y_true, y_pred, s=10, alpha=0.6)
    lims = [min(y_true.min(), y_pred.min()), max(y_true.max(), y_pred.max())]
    ax.plot(lims, lims, 'k--', lw=1)
    ax.set_xlabel(f"True {field_name}")
    ax.set_ylabel(f"Predicted {field_name}")
    ax.set_title(f"{field_name}: True vs Predicted")
    ax.grid(True)
    plt.tight_layout()
    return fig

fig1 = plot_pred_vs_true(y_true[:,0], y_pred_mean[:,0], "Energy Field")
fig2 = plot_pred_vs_true(y_true[:,1], y_pred_mean[:,1], "Curvature Opt")

fig1.savefig("energy_true_vs_pred.png", dpi=150)
fig2.savefig("curv_true_vs_pred.png", dpi=150)

# -----------------------------
# 4. HIGH-DIMENSIONAL SURFACE PLOTS
# -----------------------------
# Assuming X_test columns: [rho, R]
rho = data["X"][:,0]
R   = data["X"][:,1]
Z1  = y_pred_mean[:,0]
Z2  = y_pred_mean[:,1]

def make_surface(x, y, z, title, fname):
    fig = plt.figure(figsize=(7,5))
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_trisurf(x, y, z, cmap=cm.viridis, linewidth=0.2)
    ax.set_xlabel("rho")
    ax.set_ylabel("R")
    ax.set_zlabel(title)
    fig.colorbar(surf, shrink=0.5, aspect=12)
    ax.set_title(title)
    plt.tight_layout()
    fig.savefig(fname, dpi=150)
    return fig

surf1 = make_surface(rho, R, Z1, "Predicted Energy Field Surface", "surf_energy.png")
surf2 = make_surface(rho, R, Z2, "Predicted Curvature Surface", "surf_curv.png")

# -----------------------------
# 5. DASHBOARD WITH PLOTLY DASH
# -----------------------------
app = dash.Dash(__name__)
server = app.server  # for deployment

app.layout = html.Div([
    html.H1("WarpDriveAI Diagnostics Dashboard"),
    html.Div([
        html.Div([
            dcc.Graph(
                id='scatter-energy',
                figure={
                    'data': [go.Scatter(
                        x=y_true[:,0], y=y_pred_mean[:,0],
                        mode='markers', marker={'size':5, 'opacity':0.7}
                    )],
                    'layout': go.Layout(
                        title='Energy Field: True vs Predicted',
                        xaxis={'title':'True Energy'}, yaxis={'title':'Predicted Energy'},
                        shapes=[{
                            'type':'line', 'x0':y_true.min(), 'y0':y_true.min(),
                            'x1':y_true.max(), 'y1':y_true.max(),
                            'line':{'dash':'dash'}
                        }]
                    )
                }
            )
        ], style={'width':'48%', 'display':'inline-block'}),
        html.Div([
            dcc.Graph(
                id='scatter-curv',
                figure={
                    'data': [go.Scatter(
                        x=y_true[:,1], y=y_pred_mean[:,1],
                        mode='markers', marker={'size':5, 'opacity':0.7, 'color':'crimson'}
                    )],
                    'layout': go.Layout(
                        title='Curvature Opt: True vs Predicted',
                        xaxis={'title':'True Curvature'}, yaxis={'title':'Predicted Curvature'},
                        shapes=[{
                            'type':'line', 'x0':y_true[:,1].min(), 'y0':y_true[:,1].min(),
                            'x1':y_true[:,1].max(), 'y1':y_true[:,1].max(),
                            'line':{'dash':'dash'}
                        }]
                    )
                }
            )
        ], style={'width':'48%', 'display':'inline-block', 'float':'right'}),
    ]),
    html.Hr(),
    html.H3("Energy Field Surface"),
    dcc.Graph(
        id='surface-energy',
        figure=go.Figure(data=[
            go.Mesh3d(
                x=rho, y=R, z=Z1,
                intensity=Z1, colorscale='Viridis', opacity=0.8
            )
        ], layout=go.Layout(
            scene={'xaxis_title':'rho','yaxis_title':'R','zaxis_title':'Energy'},
            margin={'l':0,'r':0,'b':0,'t':30}
        ))
    ),
    html.H3("Curvature Surface"),
    dcc.Graph(
        id='surface-curv',
        figure=go.Figure(data=[
            go.Mesh3d(
                x=rho, y=R, z=Z2,
                intensity=Z2, colorscale='Cividis', opacity=0.8
            )
        ], layout=go.Layout(
            scene={'xaxis_title':'rho','yaxis_title':'R','zaxis_title':'Curvature'},
            margin={'l':0,'r':0,'b':0,'t':30}
        ))
    ),
])

if __name__ == '__main__':
    app.run_server(debug=True)