# Optimization

## Import Libraries

In [1]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import importlib
import numpy as np
import pandas as pd

from pymoo.optimize import minimize
from pymoo.termination import get_termination
from pymoo.algorithms.soo.nonconvex.ga import GA
from pymoo.operators.sampling.rnd import IntegerRandomSampling
from pymoo.operators.crossover.sbx import SBX
from pymoo.operators.mutation.pm import PM

from src import config
import src.optimization_utils as ou
import src.solutions as s
import src.display as disp
from src.problem import SpineProblem

# Reload to pick up any changes
importlib.reload(ou)
importlib.reload(s)
importlib.reload(disp)

pd.set_option('display.max_columns', None)

## Load Models

In [2]:
# Load all models
model_configs = {
    "Mechanical Failure": config.MECH_FAIL_MODEL,
    "L4S1": config.L4S1_MODEL,
    "LL": config.LL_MODEL,
    "T4PA": config.T4PA_MODEL,
    "L1PA": config.L1PA_MODEL,
    "SVA": config.SVA_MODEL,
    "SS": config.SS_MODEL,
    "Global Tilt": config.GLOBAL_TILT_MODEL,
    "ODI": config.ODI_MODEL,
}

bundles = {name: ou.load_model_bundle(path) for name, path in model_configs.items()}

# Verify and display
print("Models loaded and verified:")
for name, bundle in bundles.items():
    status = "✓" if all(k in bundle for k in ["pipe", "features", "target"]) else "✗"
    print(f"  {status} {name}: {bundle.get('model_name', 'N/A')}")

# Extract individual bundles for use
mech_fail_bundle = bundles["Mechanical Failure"]
L4S1_bundle = bundles["L4S1"]
LL_bundle = bundles["LL"]
T4PA_bundle = bundles["T4PA"]
L1PA_bundle = bundles["L1PA"]
SVA_bundle = bundles["SVA"]
SS_bundle = bundles["SS"]
GT_bundle = bundles["Global Tilt"]
ODI_bundle = bundles["ODI"]

Models loaded and verified:
  ✓ Mechanical Failure: mech_fail_logreg
  ✓ L4S1: L4S1_ridge_reg
  ✓ LL: LL_ridge_reg
  ✓ T4PA: T4PA_ridge_reg
  ✓ L1PA: L1PA_rf_reg
  ✓ SVA: XGBRegressor_delta_SVA
  ✓ SS: XGBRegressor_delta_SS
  ✓ Global Tilt: XGBRegressor_delta_GlobalTilt
  ✓ ODI: ODI_ridge_reg


## Optimization Configuration

In [3]:
# =============================================================================
# COMPOSITE SCORE WEIGHTS - Adjust these to change optimization priorities
# =============================================================================
# All weights should sum to 1.0 for proper scaling
# Lower composite score = better outcome

WEIGHTS = {
    "w1": 1/6,  # GAP Score (normalized 0-100)
    "w2": 1/6,  # L1PA penalty (mismatch from ideal)
    "w3": 1/6,  # L4S1 penalty (ideal range 35-45)
    "w4": 1/6,  # T4L1PA penalty (T4PA - L1PA mismatch)
    "w5": 1/6,  # LL penalty (mismatch from ideal LL)
    "w6": 1/6,  # GAP category improvement penalty
}

WEIGHT_LABELS = {
    "w1": "GAP Score",
    "w2": "L1PA penalty",
    "w3": "L4S1 penalty",
    "w4": "T4L1PA penalty",
    "w5": "LL penalty",
    "w6": "GAP category improvement",
}

print("Composite Score Weights:")
for k, v in WEIGHTS.items():
    print(f"  {k}: {v:.4f}  ({WEIGHT_LABELS[k]})")
print(f"  Total: {sum(WEIGHTS.values()):.4f}")

Composite Score Weights:
  w1: 0.1667  (GAP Score)
  w2: 0.1667  (L1PA penalty)
  w3: 0.1667  (L4S1 penalty)
  w4: 0.1667  (T4L1PA penalty)
  w5: 0.1667  (LL penalty)
  w6: 0.1667  (GAP category improvement)
  Total: 1.0000


In [4]:
print("⚠️  MODEL FEATURE MISMATCH DETECTED")
print("="*60)
print("The delta models in artifacts/ have mismatched feature counts:")
print("  - Bundle stores 30 features")
print("  - Models expect 36-44 features (due to OneHotEncoding)")
print("\nFIX: Run notebook 04_delta_models_SVA_SS_GT.ipynb to retrain")
print("="*60)

# Temporarily using empty bundles - replace with trained models
delta_bundles = {
    # "L4S1": L4S1_bundle,
    # "LL": LL_bundle,
    # "T4PA": T4PA_bundle,
    # "L1PA": L1PA_bundle,
    # "SS": SS_bundle,
    # "GlobalTilt": GT_bundle,
    # "SVA": SVA_bundle,
}

print("\nDelta model bundles: NONE (awaiting model retraining)")

⚠️  MODEL FEATURE MISMATCH DETECTED
The delta models in artifacts/ have mismatched feature counts:
  - Bundle stores 30 features
  - Models expect 36-44 features (due to OneHotEncoding)

FIX: Run notebook 04_delta_models_SVA_SS_GT.ipynb to retrain

Delta model bundles: NONE (awaiting model retraining)


In [5]:
UIV_CHOICES, xl, xu = ou.get_decision_config()

In [6]:
print("UIV_CHOICES:", UIV_CHOICES)
print("xl:", xl)
print("xu:", xu)

# Use column names from config
print(pd.DataFrame([xl, xu], index=["xl","xu"], columns=config.DECISION_VAR_NAMES))

UIV_CHOICES: ['Hook', 'PS', 'FS']
xl: [0 0 0 0 0 0 1 2 0]
xu: [2 1 5 1 1 1 6 4 1]
    uiv_code  num_levels_cat_code  num_interbody_fusion_levels  ALIF  XLIF  \
xl         0                    0                            0     0     0   
xu         2                    1                            5     1     1   

    TLIF  num_rods  num_pelvic_screws  osteotomy  
xl     0         1                  2          0  
xu     1         6                  4          1  


## Test Patient w fixed parameters

In [7]:
# Show holdout patients summary
holdout_df = pd.read_csv(config.DATA_HOLDOUT)

summary_cols = [
    "id", "description", "revision", "mech_fail_last",
    "gap_score_preop", "gap_category", "gap_score_postop", "gap_category_postop",
    "UIV_implant", "num_levels_cat", "num_interbody_fusion_levels",
    "ALIF", "XLIF", "TLIF", "num_rods", "num_pelvic_screws", "osteotomy",
]
display(holdout_df[summary_cols].set_index("id"))

# Select patient
PATIENT_ID = 6380632
patient_fixed = ou.load_patient_data(patient_id=PATIENT_ID, data_path=config.DATA_HOLDOUT)
print(f"\nSelected patient: {PATIENT_ID}")

Unnamed: 0_level_0,description,revision,mech_fail_last,gap_score_preop,gap_category,gap_score_postop,gap_category_postop,UIV_implant,num_levels_cat,num_interbody_fusion_levels,ALIF,XLIF,TLIF,num_rods,num_pelvic_screws,osteotomy
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
818588,mech failure - small GAP improvement;,1,1.0,12.0,SD,10.0,SD,PS,lower,1,0,0,0,3,2,0
1176294,mech failure - large GAP improvement;,0,1.0,10.0,SD,1.0,P,FS,lower,4,1,1,0,4,2,1
2964021,revision - GAP improvement 3 to 1,0,0.0,3.0,MD,1.0,P,PS,lower,2,1,0,0,3,3,0
6380632,test,0,0.0,12.0,SD,1.0,P,Hook,higher,2,1,0,0,3,2,0



Selected patient: 6380632


## Build optimization problem

**Objective:** Minimize composite score (lower = better patient outcomes)

**Constraints:** 
- If `num_interbody_fusion_levels > 0`, at least one fusion type (`ALIF`, `XLIF`, or `TLIF`) must be selected.
- If `ALIF=1` and `XLIF=0` and `TLIF=0`, then `num_interbody_fusion_levels` must be < 4.

In [8]:
problem = SpineProblem(
    patient_fixed=patient_fixed,
    delta_bundles=delta_bundles,
    xl=xl,
    xu=xu,
    weights=WEIGHTS
)

## Run GA and view results

In [9]:
algorithm = GA(
    pop_size=100,
    sampling=IntegerRandomSampling(),
    crossover=SBX(prob=0.9, eta=15),
    mutation=PM(eta=20),
    eliminate_duplicates=True,
)

res = minimize(
    problem,
    algorithm,
    get_termination("n_gen", 20),
    seed=42,
    verbose=True,
    save_history=True
)

n_gen  |  n_eval  |     cv_min    |     cv_avg    |     f_avg     |     f_min    
     1 |       99 |  0.000000E+00 |  0.2828282828 |  3.768436E+01 |  3.768436E+01
     2 |      199 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
     3 |      299 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
     4 |      399 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
     5 |      499 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
     6 |      599 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
     7 |      699 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
     8 |      799 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
     9 |      899 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
    10 |      999 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
    11 |     1099 |  0.000000E+00 |  0.000000E+00 |  3.768436E+01 |  3.768436E+01
    12 |     119

## Actual surgical plan & outcome for comparison

In [10]:
df_actual = disp.display_actual_outcomes(PATIENT_ID, patient_fixed, data_path=config.DATA_HOLDOUT)
display(df_actual)

ACTUAL SURGICAL PLAN (WHAT WAS PERFORMED)
  UIV_implant: Hook
  num_interbody_fusion_levels: 2
  ALIF: 1
  XLIF: 0
  TLIF: 0
  num_rods: 3
  num_pelvic_screws: 2
  osteotomy: 0

ALIGNMENT PARAMETERS: PREOP → POSTOP (ACTUAL)


Unnamed: 0,Parameter,Preop,Delta (actual),Postop (actual)
0,LL,30.4,32.5,62.9
1,SS,33.8,11.5,45.3
2,L4S1,34.5,7.0,41.5
3,GlobalTilt,42.2,-23.9,18.3
4,T4PA,24.8,-15.2,9.6
5,L1PA,12.2,-4.8,7.4
6,PI,58.6,2.2,60.8
7,PT,24.8,-9.3,15.5
8,SVA,154.6,-124.8,29.8
9,PI-LL,28.2 ⚠,-30.3,-2.1 ⚠


## Show best solution

In [11]:
best_x = np.asarray(res.X).astype(int)
result = ou.evaluate_solution(
    best_x, 
    patient_fixed, 
    delta_bundles, 
    mech_fail_bundle,
    weights=WEIGHTS
)

df_comparison = disp.display_optimized_solution(result, patient_fixed)
display(df_comparison)

BEST SOLUTION SUMMARY (OPTIMIZED)

Composite Score: 37.6844 (lower is better)
Mechanical Failure Probability: 51.9%

Surgical Plan:
  UIV_implant: Hook
  num_levels_cat: higher
  num_interbody_fusion_levels: 2
  ALIF: 0
  XLIF: 1
  TLIF: 0
  num_rods: 4
  num_pelvic_screws: 3
  osteotomy: 0

ALIGNMENT PARAMETERS: PREOP → POSTOP (PREDICTED)


Unnamed: 0,Parameter,Preop,Delta (pred),Postop (pred)
0,LL,30.4,-,30.4
1,SS,33.8,-,33.8
2,L4S1,34.5,-,34.5
3,GlobalTilt,42.2,-,42.2
4,T4PA,24.8,-,24.8
5,L1PA,12.2,-,12.2
6,SVA,154.6,-,-
7,PI,58.6,-,58.6
8,PT,24.8,0.0,24.8
9,PI-LL,28.2 ⚠,0.0,28.2 ⚠


## Getting multiple solutions

Extracts diverse surgical plans from the GA optimization history:
1. Collects top candidates across all generations
2. Filters to solutions within `score_tolerance` of best score
3. Buckets by `(UIV_implant, ALIF, XLIF, TLIF)` to ensure variety in implant/fusion types
4. Returns 'top_n` unique plans with full evaluation (postop values, GAP score, mech fail prob)

In [12]:
top12_solutions = s.get_diverse_solutions(
    res=res,
    top_n=12,
    top_per_gen=50,
    score_tolerance=2,  # Include solutions within 2 points of best
    bucket_cols=("UIV_implant", "ALIF", "XLIF", "TLIF"),
    n_per_bucket=1,
    patient_fixed=patient_fixed,
    delta_bundles=delta_bundles,
    mech_fail_bundle=mech_fail_bundle,
    weights=WEIGHTS,
)

_ = disp.display_multiple_solutions(top12_solutions, patient_fixed)
# top12_solutions

SOLUTIONS COMPARISON


Unnamed: 0,Parameter,Sol 1,Sol 2,Sol 3,Sol 4,Sol 5,Sol 6,Sol 7,Sol 8,Sol 9,Sol 10,Sol 11,Sol 12
0,Composite Score,37.68,37.68,37.68,37.68,37.68,37.68,37.68,37.68,37.68,37.68,37.68,37.68
1,Mech Fail Prob,51.9%,55.2%,69.7%,85.2%,94.9%,26.7%,18.5%,18.9%,26.5%,33.2%,91.2%,90.1%
2,GAP Score,12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD),12.0 (SD) → 12 (SD)
3,────────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────
4,SURGICAL PLAN,,,,,,,,,,,,
5,UIV_implant,Hook,PS,Hook,FS,PS,Hook,PS,FS,PS,FS,FS,FS
6,num_levels_cat,higher,higher,lower,lower,lower,lower,lower,lower,lower,lower,higher,higher
7,num_interbody_fusion_levels,2,0,0,1,3,3,2,2,5,5,4,3
8,ALIF,0,0,1,0,1,0,0,0,0,1,1,1
9,XLIF,1,0,1,1,1,0,1,0,0,0,1,0
