# ðŸ§  Parkinson's Brain: Inference & Visualization Demo

Welcome to the inference sandbox! This notebook is designed to show you exactly what's happening inside the Latent State Space Model (LSSM).

### What this model provides for your Interface:
1.  **Severity Score Over Time**: A curve showing how the disease progresses over months.
2.  **Latent Brain Fingerprint**: A 32-dimensional vector that represents the "biological state" of the brain. Great for heatmaps or radar charts.
3.  **Agent Influences**: Measure how much the Motor, Non-Motor, or Biological data is "pushing" the brain state.

## 1. Setup & Load Model

In [None]:
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Ensure we are in the project root
# Note: When running in Colab, the repository is usually cloned into /content/parkinson_official_project
if os.path.exists('parkinson_official_project'):
    os.chdir('parkinson_official_project')

sys.path.append(os.getcwd())

from config import config
from brain.lssm import BrainLSSM
from motor_agent.agents.version2.motor_agent import MotorAgent
from non_motor.agent.non_motor_agent import NonMotorAgent
from biomarker.agent.biological_agent import BiologicalAgent

print("Project Root:", os.getcwd())

# 1. Initialize Agents
motor = MotorAgent()
non_motor = NonMotorAgent()
bio = BiologicalAgent()

motor.load_data()
non_motor.load_data()
bio.load_data()

# 2. Initialize Brain and Load Weights
model = BrainLSSM(config).to(config.DEVICE)

checkpoint_path = config.CHECKPOINT_DIR / 'brain_final.pth' # Or specific epoch
if checkpoint_path.exists():
    model.load_state_dict(torch.load(checkpoint_path, map_location=config.DEVICE))
    model.eval()
    print(f"Loaded model from {checkpoint_path}")
else:
    print("No checkpoint found. Running with untrained (random) weights for demo.")
    model.eval()

## 2. Real Patient Simulation
Pick a random patient and project their future.

In [None]:
# Get a common patient
pats = set(motor.patient_data.keys()).intersection(non_motor.patient_data.keys())
if not pats:
    print("No common patients found. Using synthetic patient.")
    sample_pat = 0
else:
    sample_pat = list(pats)[0]
    print(f"Simulating Patient ID: {sample_pat}")

# 1. Encode their current state
u_m = motor.encode(sample_pat)
u_nm = non_motor.encode(sample_pat)
u_b = bio.encode(sample_pat)

# 2. Define Time-span (simulate next 24 months)
months = 24
t_span = torch.linspace(0, months, steps=months+1).to(config.DEVICE)

# 3. Initial Latent state (starts at 0)
h0 = torch.zeros(1, config.LATENT_DIM).to(config.DEVICE)

# 4. Run Model
with torch.no_grad():
    h_traj, y_pred = model(h0, t_span, u_m, u_nm, u_b)

# Convert to numpy for plotting
severity_scores = y_pred.squeeze().cpu().numpy()
brain_states = h_traj.squeeze().cpu().numpy()

print(f"Generated {len(severity_scores)} months of predictions.")

## 3. Visualization for Interface Design
This is the data you would feed into your fancy UI.

In [None]:
plt.figure(figsize=(15, 5))

# Plot 1: Progression Curve (The "Business" View)
plt.subplot(1, 2, 1)
plt.plot(range(months + 1), severity_scores, marker='o', color='#ff4b5c', linewidth=2)
plt.title("Predicted Disease Progression", fontsize=14)
plt.xlabel("Months from Baseline")
plt.ylabel("Severity Score (UPDRS)")
plt.grid(alpha=0.3)

# Plot 2: Brain Latent Heatmap (The "Scientific" View)
plt.subplot(1, 2, 2)
if len(brain_states.shape) > 1:
    sns.heatmap(brain_states.T[:, ::3], cmap="magma", cbar=True)
    plt.title("Latent Brain State Evolution (32-dim)", fontsize=14)
    plt.xlabel("Time (Sampled every 3 months)")
    plt.ylabel("Latent Dimension ID")
else:
    plt.text(0.5, 0.5, "Insufficient data for heatmap", ha='center')

plt.tight_layout()
plt.show()

## 4. "What-If" Stress Test
What if we artificially increase the motor symptoms? This shows the model's sensitivity.

In [None]:
def simulate_custom(motor_intensity, non_motor_intensity):
    # Create synthetic latents
    u_m_synth = torch.ones(1, config.INPUT_DIM_MOTOR).to(config.DEVICE) * motor_intensity
    u_nm_synth = torch.ones(1, config.INPUT_DIM_NON_MOTOR).to(config.DEVICE) * non_motor_intensity
    u_b_synth = torch.zeros(1, config.INPUT_DIM_BIOLOGICAL).to(config.DEVICE)
    
    h0 = torch.zeros(1, config.LATENT_DIM).to(config.DEVICE)
    t_span = torch.linspace(0, 48, 49).to(config.DEVICE) # 4 years
    
    with torch.no_grad():
        _, y_pred = model(h0, t_span, u_m_synth, u_nm_synth, u_b_synth)
    return y_pred.squeeze().cpu().numpy()

low_stress = simulate_custom(0.1, 0.1)
high_stress = simulate_custom(0.8, 0.8)

plt.figure(figsize=(10, 6))
plt.plot(low_stress, label="Baseline Symptoms", color='green')
plt.plot(high_stress, label="High Symptom Intensity", color='red', linestyle='--')
plt.fill_between(range(49), low_stress, high_stress, color='orange', alpha=0.1, label="Progression Risk Zone")

plt.title("Sensitivity Analysis: Progression Acceleration", fontsize=16)
plt.xlabel("Months")
plt.ylabel("Predicted Severity")
plt.legend()
plt.show()

## Summary for Your Interface
- **Backend API Input**: You send a Patient ID or a JSON of symptoms.
- **Model Integration**: The model runs the ODE integration.
- **Frontend Visuals**: 
    - Line charts for progression.
    - Animated heatmaps for the Latent State.
    - Comparison toggles for "What-If" scenarios.