In [None]:
# @title üè• OpenCSU: Standalone Digital Twin
# @markdown **Complete Simulation Environment**
# @markdown <br>Run the full 6-Parameter Reaction-Diffusion Solver locally.
# @markdown No external servers or pre-trained weights required.

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from dataclasses import dataclass
from copy import deepcopy
from IPython.display import display, clear_output

# ==============================================================================
# 1. THE PHYSICS ENGINE (UnifiedSolver)
# ==============================================================================

@dataclass
class ModelParams:
    # --- SEIRIN-LEE (Histamine/TF) ---
    D_hist: float = 0.1     # u1 Diffusion
    D_coag: float = 0.05    # u4 Diffusion
    
    delta_M: float = 0.01   # Basal Mast
    delta_T: float = 0.01   # Basal TF
    delta_B: float = 0.01   # Basal Basophil
    
    mu_M: float = 0.5       # Histamine Decay
    mu_T: float = 0.8       # TF Decay (0.8=Annular, 0.2=Circular)
    mu_B: float = 0.5       # Basophil Decay
    mu_C: float = 0.5       # Coag Decay
    
    gamma_M: float = 1.5    # Feedback: Coag -> Mast
    gamma_T: float = 1.8    # Histamine -> TF
    gamma_B: float = 1.2    # TF -> Basophil
    gamma_C: float = 2.0    # TF -> Coag
    
    alpha: float = 5.0      # Adenosine Inhibition
    u200: float = 0.67      # Gap Threshold
    beta: float = 50.0      # Switch Steepness

    # --- BRADYKININ SYSTEM ---
    D_bk: float = 0.15      # u6 Diffusion (Fast)
    phi_tryptase: float = 0.5 # CROSS-TALK: Mast(u1) -> Kallikrein(u5)
    delta_Kal: float = 0.05   # Basal Kallikrein
    gamma_BK: float = 2.0     # Kal -> BK production
    mu_Kal: float = 0.8       # C1-INH clearance
    mu_BK: float = 2.0        # ACE clearance
    
    # --- PHARMA MODIFIERS ---
    gamma_T_mod: float = 1.0       # H1 Blockade
    gamma_M_mod: float = 1.0       # H2 Blockade
    mu_T_boost: float = 1.0        # Immunomodulation
    mu_BK_mod: float = 1.0         # ACE Inhibitor effect
    b2_block: float = 0.0          # Icatibant effect

class UnifiedSolver(torch.nn.Module):
    def __init__(self, params: ModelParams, grid_size=128, dt=0.01):
        super().__init__()
        self.p = params
        self.dt = dt
        self.grid_size = grid_size
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Initialize State (6 Channels)
        self.state = torch.zeros(6, grid_size, grid_size, device=self.device)
        self.initialize()

    def initialize(self):
        # Localized trigger in center
        y, x = torch.meshgrid(torch.arange(self.grid_size), torch.arange(self.grid_size), indexing='ij')
        center = self.grid_size // 2
        r2 = ((x - center)**2 + (y - center)**2).to(self.device).float()
        
        # Initial Activation
        self.state[2] = 2.0 * torch.exp(-r2 / 50.0) # Basophil spike
        self.state[1] += 0.01 * torch.rand_like(self.state[1]) # Noise
        self.state[4] += 0.01 * torch.rand_like(self.state[4])

    def reaction_dynamics(self, u):
        p = self.p
        u1, u2, u3, u4, u5, u6 = u[0], u[1], u[2], u[3], u[4], u[5]
        
        # Histamine / Coagulation Logic
        inh_M = 1.0 / (1.0 + p.alpha * u1**2 / (1.0 + u1**2))
        inh_T = 1.0 / (1.0 + p.alpha * (u1+u3)**2 / (1.0 + (u1+u3)**2))
        switch = 1.0 / (1.0 + torch.exp(-p.beta * (u2 - p.u200)))
        
        # Apply Pharma Mods
        g_T = p.gamma_T * p.gamma_T_mod
        g_M = p.gamma_M * p.gamma_M_mod
        m_T = p.mu_T * p.mu_T_boost
        
        # Derivatives
        du1 = p.delta_M + g_M * u4 * inh_M - p.mu_M * u1
        du2 = p.delta_T + g_T * ((u1+u3)/(1+(u1+u3))) * inh_T - m_T * u2
        du3 = p.delta_B + p.gamma_B * u2 - p.mu_B * u3
        du4 = p.gamma_C * switch - p.mu_C * u4
        
        # Bradykinin Logic
        kal_source = p.delta_Kal * (u4 + p.phi_tryptase * u1)
        du5 = kal_source - p.mu_Kal * u5
        du6 = p.gamma_BK * u5 - (p.mu_BK * p.mu_BK_mod) * u6
        
        return torch.stack([du1, du2, du3, du4, du5, du6])

    def step(self):
        # RK4 Integration
        k1 = self.reaction_dynamics(self.state)
        k2 = self.reaction_dynamics(self.state + 0.5*self.dt*k1)
        k3 = self.reaction_dynamics(self.state + 0.5*self.dt*k2)
        k4 = self.reaction_dynamics(self.state + self.dt*k3)
        self.state += (self.dt/6.0)*(k1 + 2*k2 + 2*k3 + k4)
        
        # Laplacian Diffusion
        for idx, D in [(0, self.p.D_hist), (3, self.p.D_coag), (5, self.p.D_bk)]:
            u = self.state[idx]
            lap = torch.roll(u,1,0) + torch.roll(u,-1,0) + torch.roll(u,1,1) + torch.roll(u,-1,1) - 4*u
            self.state[idx] += D * lap * self.dt
        
        self.state = torch.clamp(self.state, min=0.0)

    def get_visuals(self):
        # Wheal (Red)
        wheal = 1.0 / (1.0 + torch.exp(-self.p.beta * (self.state[1] - self.p.u200)))
        # Angioedema (Blue)
        angio = (self.state[5]**2 / (1.0 + self.state[5]**2)) * (1.0 - self.p.b2_block)
        return wheal.cpu().numpy(), angio.cpu().numpy()

# ==============================================================================
# 2. PATIENT & DRUG LOGIC
# ==============================================================================

@dataclass
class Patient:
    age: float
    weight: float
    liver_func: float = 1.0
    
    def get_vd(self):
        # Simple Vd calculation
        factor = 0.7 if self.age < 12 else 0.6
        return self.weight * factor

def get_params_for_patient(patient, dose_h1, dose_h2, dose_ace):
    base = ModelParams()
    vd = patient.get_vd()
    
    # Calculate concentrations (Simplified)
    conc_h1 = dose_h1 / vd
    conc_h2 = dose_h2 / vd
    
    # Binding (Hill Equation)
    h1_bind = (conc_h1) / (1.0 + conc_h1) # Assumes Kd=1.0
    h2_bind = (conc_h2) / (3.0 + conc_h2) # Assumes Kd=3.0
    
    # Apply to Physics
    base.gamma_T_mod = 1.0 - (0.95 * h1_bind) # Block Histamine
    base.gamma_M_mod = 1.0 - (0.95 * h2_bind) # Block Feedback
    
    if dose_ace > 0:
        base.mu_BK_mod = 0.1 # ACE Inhibition slows breakdown 10x
        
    return base

# ==============================================================================
# 3. INTERACTIVE DASHBOARD
# ==============================================================================

# UI Setup
style = {'description_width': 'initial'}
layout = widgets.Layout(width='95%')

w_age = widgets.IntSlider(value=12, min=2, max=100, description='Patient Age', style=style)
w_weight = widgets.FloatSlider(value=45, min=10, max=120, description='Weight (kg)', style=style)
w_h1 = widgets.FloatSlider(value=0, min=0, max=40, step=5, description='Antihistamine (mg)', style=style)
w_h2 = widgets.FloatSlider(value=0, min=0, max=80, step=10, description='H2 Blocker (mg)', style=style)
w_ace = widgets.Checkbox(value=False, description='On ACE Inhibitor?')

btn_run = widgets.Button(description='Run Simulation', button_style='primary', icon='play')
out_viz = widgets.Output()

def run_sim(b):
    out_viz.clear_output()
    with out_viz:
        print("‚è≥ Initializing Physics Engine (CUDA if available)...")
        
        # 1. Setup Patient
        pat = Patient(w_age.value, w_weight.value)
        ace_dose = 10.0 if w_ace.value else 0.0
        
        # 2. Get Physics
        params = get_params_for_patient(pat, w_h1.value, w_h2.value, ace_dose)
        
        # 3. Run Solver
        solver = UnifiedSolver(params, grid_size=100) # 100x100 is good balance
        
        print("üîÑ Solving PDEs (400 steps)...")
        # Visual loading bar
        prog = widgets.IntProgress(value=0, min=0, max=400)
        display(prog)
        
        for i in range(400):
            solver.step()
            if i % 50 == 0: prog.value = i
            
        prog.value = 400
        
        # 4. Visualize
        wheal, angio = solver.get_visuals()
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Histamine (Red)
        im1 = ax1.imshow(wheal, cmap='Reds', vmin=0, vmax=1)
        ax1.set_title(f"Surface Hives (H1 Activity)\nBlockade: {(1-params.gamma_T_mod)*100:.0f}%")
        plt.colorbar(im1, ax=ax1)
        
        # Bradykinin (Blue)
        im2 = ax2.imshow(angio, cmap='Blues', vmin=0, vmax=1)
        ax2.set_title(f"Deep Swelling (Bradykinin)\nDecay Rate: {params.mu_BK_mod*100:.0f}%")
        plt.colorbar(im2, ax=ax2)
        
        plt.show()

btn_run.on_click(run_sim)

# Layout
ui = widgets.VBox([
    widgets.HTML("<h2>OpenCSU: Standalone Physics Core</h2>"),
    widgets.HBox([
        widgets.VBox([widgets.HTML("<b>Patient</b>"), w_age, w_weight]),
        widgets.VBox([widgets.HTML("<b>Treatment</b>"), w_h1, w_h2, w_ace])
    ]),
    widgets.HTML("<hr>"),
    btn_run,
    out_viz
])

display(ui)