In [None]:
# @title üöÄ Initialize OpenCSU Neural Surrogate
# @markdown Click **Play** to load the AI model. This connects to the pre-trained physics engine.

import torch
import torch.nn as nn
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import numpy as np
import os

# --- 1. DEFINE THE NEURAL ARCHITECTURE ---
# This matches the DGX training script exactly.
class CSUSurrogate(nn.Module):
    def __init__(self):
        super().__init__()
        # Input: 3 parameters [mu_T, gamma_M, drug_eff]
        # Output: 64x64 Image
        
        self.fc = nn.Sequential(
            nn.Linear(3, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 4 * 4 * 128), # Reshape base
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            # 4x4 -> 8x8
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # 8x8 -> 16x16
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # 16x16 -> 32x32
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            # 32x32 -> 64x64
            nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid() # Output 0-1 (Image brightness)
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 128, 4, 4)
        x = self.decoder(x)
        return x.squeeze(1)

# --- 2. DOWNLOAD & LOAD WEIGHTS ---
# Replace this URL with the raw link to your GitHub weights file
WEIGHTS_URL = "https://github.com/YOUR_USERNAME/OpenCSU/raw/main/csu_surrogate_weights.pt"
WEIGHTS_FILE = "csu_surrogate_weights.pt"

if not os.path.exists(WEIGHTS_FILE):
    print(f"‚¨áÔ∏è Downloading Neural Weights from {WEIGHTS_URL}...")
    !wget -q $WEIGHTS_URL -O $WEIGHTS_FILE

print("üß† Loading Model...")
model = CSUSurrogate()
try:
    # Load weights (map_location handles loading GPU weights on a CPU machine)
    model.load_state_dict(torch.load(WEIGHTS_FILE, map_location=torch.device('cpu')))
    model.eval()
    print("‚úÖ Model Ready. Neural Surrogate Active.")
except Exception as e:
    print(f"‚ùå Error loading weights: {e}")
    print("Ensure you have uploaded 'csu_surrogate_weights.pt' to your repo.")

# --- 3. THE REAL-TIME DASHBOARD ---

# UI Elements
style = {'description_width': 'initial'}
layout = widgets.Layout(width='95%')

w_mu_t = widgets.FloatSlider(
    value=0.2, min=0.1, max=1.2, step=0.01, 
    description='Clearance Rate (mu_T)', 
    style=style, layout=layout,
    continuous_update=True # Updates while dragging!
)

w_gamma_m = widgets.FloatSlider(
    value=1.5, min=0.5, max=2.5, step=0.01, 
    description='Feedback Strength (gamma_M)', 
    style=style, layout=layout,
    continuous_update=True
)

w_drug = widgets.FloatSlider(
    value=0.0, min=0.0, max=1.0, step=0.01, 
    description='Antihistamine Efficacy', 
    style=style, layout=layout,
    continuous_update=True
)

out_viz = widgets.Output()

def update_view(change=None):
    # Get slider values
    mu_val = w_mu_t.value
    gamma_val = w_gamma_m.value
    drug_val = w_drug.value
    
    # Inference (No Grad needed for speed)
    with torch.no_grad():
        inputs = torch.tensor([[mu_val, gamma_val, drug_val]])
        prediction = model(inputs).numpy()[0]
    
    # Render
    out_viz.clear_output(wait=True)
    with out_viz:
        fig, ax = plt.subplots(figsize=(6, 6))
        # Use a consistent heatmap style
        im = ax.imshow(prediction, cmap='RdPu', vmin=0, vmax=1)
        ax.set_title(f"Predicted Morphology\n(Neural Inference <10ms)")
        ax.axis('off')
        
        # Add pattern labels dynamically based on mu_T
        label = "Unknown"
        if mu_val < 0.3: label = "Circular (Filled)"
        elif mu_val < 0.6: label = "Geographic (Merged)"
        elif mu_val < 0.9: label = "Annular (Ring)"
        elif mu_val < 1.1: label = "Broken Annular"
        else: label = "Dot (Transient)"
            
        ax.text(2, 60, f"Pattern: {label}", color='black', fontsize=12, 
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
        
        plt.show()

# Link sliders to function
w_mu_t.observe(update_view, names='value')
w_gamma_m.observe(update_view, names='value')
w_drug.observe(update_view, names='value')

# Initial Call
update_view()

# Layout
ui = widgets.VBox([
    widgets.HTML("<h2>OpenCSU: Neural Surrogate</h2>"),
    widgets.HTML("<i>Move sliders to see instant morphological predictions.</i>"),
    widgets.HBox([
        widgets.VBox([w_mu_t, w_gamma_m, w_drug], layout=widgets.Layout(width='40%')),
        out_viz
    ])
])

display(ui)