# Optimization

## Import Libraries

In [1]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import importlib
import numpy as np
import pandas as pd

from pymoo.optimize import minimize
from pymoo.termination import get_termination
from pymoo.algorithms.soo.nonconvex.ga import GA
from pymoo.operators.sampling.rnd import IntegerRandomSampling
from pymoo.operators.crossover.sbx import SBX
from pymoo.operators.mutation.pm import PM

from src import config
import src.optimization_utils as ou
import src.solutions as s
import src.display as disp
from src.problem import SpineProblem

# Reload to pick up any changes
importlib.reload(ou)
importlib.reload(s)
importlib.reload(disp)

pd.set_option('display.max_columns', None)

## Load Models

In [2]:
# Mechanical failure model
mech_fail_bundle = ou.load_model_bundle(config.MECH_FAIL_MODEL)

# Delta models from notebook 03
L4S1_bundle = ou.load_model_bundle(config.L4S1_MODEL)
LL_bundle = ou.load_model_bundle(config.LL_MODEL)
T4PA_bundle = ou.load_model_bundle(config.T4PA_MODEL)
L1PA_bundle = ou.load_model_bundle(config.L1PA_MODEL)

# Delta models from notebook 04
SVA_bundle = ou.load_model_bundle(config.SVA_MODEL)
SS_bundle = ou.load_model_bundle(config.SS_MODEL)
GT_bundle = ou.load_model_bundle(config.GLOBAL_TILT_MODEL)

print("Loaded models:")
print(f"  - Mechanical failure: {mech_fail_bundle.get('model_name', 'N/A')}")
print(f"  - L4S1: {L4S1_bundle.get('model_name', 'N/A')}")
print(f"  - LL: {LL_bundle.get('model_name', 'N/A')}")
print(f"  - T4PA: {T4PA_bundle.get('model_name', 'N/A')}")
print(f"  - L1PA: {L1PA_bundle.get('model_name', 'N/A')}")
print(f"  - SVA: {SVA_bundle.get('model_name', 'N/A')}")
print(f"  - SS: {SS_bundle.get('model_name', 'N/A')}")
print(f"  - Global Tilt: {GT_bundle.get('model_name', 'N/A')}")

Loaded models:
  - Mechanical failure: mech_fail_logreg
  - L4S1: L4S1_ridge_reg
  - LL: LL_ridge_reg
  - T4PA: T4PA_ridge_reg
  - L1PA: L1PA_ridge_reg
  - SVA: XGBRegressor_delta_SVA
  - SS: XGBRegressor_delta_SS
  - Global Tilt: XGBRegressor_delta_GlobalTilt


## Optimization Configuration

In [3]:
# =============================================================================
# COMPOSITE SCORE WEIGHTS - Adjust these to change optimization priorities
# =============================================================================
# All weights should sum to 1.0 for proper scaling
# Lower composite score = better outcome

WEIGHTS = {
    "w1": 1/6,  # GAP Score (normalized 0-100)
    "w2": 1/6,  # L1PA penalty (mismatch from ideal)
    "w3": 1/6,  # L4S1 penalty (ideal range 35-45)
    "w4": 1/6,  # T4L1PA penalty (T4PA - L1PA mismatch)
    "w5": 1/6,  # LL penalty (mismatch from ideal LL)
    "w6": 1/6,  # GAP category improvement penalty
}

WEIGHT_LABELS = {
    "w1": "GAP Score",
    "w2": "L1PA penalty",
    "w3": "L4S1 penalty",
    "w4": "T4L1PA penalty",
    "w5": "LL penalty",
    "w6": "GAP category improvement",
}

print("Composite Score Weights:")
for k, v in WEIGHTS.items():
    print(f"  {k}: {v:.4f}  ({WEIGHT_LABELS[k]})")
print(f"  Total: {sum(WEIGHTS.values()):.4f}")

Composite Score Weights:
  w1: 0.1667  (GAP Score)
  w2: 0.1667  (L1PA penalty)
  w3: 0.1667  (L4S1 penalty)
  w4: 0.1667  (T4L1PA penalty)
  w5: 0.1667  (LL penalty)
  w6: 0.1667  (GAP category improvement)
  Total: 1.0000


In [4]:
delta_bundles = {
    "L4S1": L4S1_bundle,
    "LL": LL_bundle,
    "T4PA": T4PA_bundle,
    "L1PA": L1PA_bundle,
    "SS": SS_bundle,
    "GlobalTilt": GT_bundle,
    "SVA": SVA_bundle,
}

print("Delta model bundles loaded:", list(delta_bundles.keys()))

Delta model bundles loaded: ['L4S1', 'LL', 'T4PA', 'L1PA', 'SS', 'GlobalTilt', 'SVA']


In [5]:
UIV_CHOICES, xl, xu = ou.get_decision_config()

In [6]:
print("UIV_CHOICES:", UIV_CHOICES)
print("xl:", xl)
print("xu:", xu)

# Use column names from config
print(pd.DataFrame([xl, xu], index=["xl","xu"], columns=config.DECISION_VAR_NAMES))

UIV_CHOICES: ['Hook', 'PS', 'FS']
xl: [0 0 0 0 0 0 1 2 0]
xu: [2 1 5 1 1 1 6 4 1]
    uiv_code  num_levels_cat_code  num_interbody_fusion_levels  ALIF  XLIF  \
xl         0                    0                            0     0     0   
xu         2                    1                            5     1     1   

    TLIF  num_rods  num_pelvic_screws  osteotomy  
xl     0         1                  2          0  
xu     1         6                  4          1  


## Test Patient w fixed parameters

In [7]:
# Load patient from dataset by ID
# PATIENT_ID = 1206016
PATIENT_ID = 817388
patient_fixed = ou.load_patient_data(patient_id=PATIENT_ID)

# Alternatively can load patient by index
# patient_fixed = ou.load_patient_data(index=0)

print(f"Loaded patient with ID {PATIENT_ID}")
print(f"Total patients in dataset: {ou.get_patient_count()}")
print("\nPatient fixed parameters:")
for k, v in patient_fixed.items():
    print(f"  {k}: {v}")

Loaded patient with ID 817388
Total patients in dataset: 277

Patient fixed parameters:
  age: 69
  sex: MALE
  bmi: 26.93
  C7CSVL_preop: 16.0
  SVA_preop: 96.8
  T4PA_preop: 26.8
  L1PA_preop: 22.3
  LL_preop: 8.3
  L4S1_preop: 15.8
  PT_preop: 21.0
  PI_preop: 47.4
  SS_preop: 26.4
  cobb_main_curve_preop: 36.4
  FC_preop: 17.5
  tscore_femneck_preop: -0.7
  HU_UIV_preop: 174.0
  HU_UIVplus1_preop: 157.0
  HU_UIVplus2_preop: 195.0
  gap_category: SD
  gap_score_preop: 12.0
  GlobalTilt_preop: 32.2


## Build optimization problem

**Objective:** Minimize composite score (lower = better patient outcomes)

**Constraints:** 
- If `num_interbody_fusion_levels > 0`, at least one fusion type (`ALIF`, `XLIF`, or `TLIF`) must be selected.

In [8]:
problem = SpineProblem(
    patient_fixed=patient_fixed,
    delta_bundles=delta_bundles,
    xl=xl,
    xu=xu,
    weights=WEIGHTS
)

## Run GA and view results

In [9]:
algorithm = GA(
    pop_size=100,
    sampling=IntegerRandomSampling(),
    crossover=SBX(prob=0.9, eta=15),
    mutation=PM(eta=20),
    eliminate_duplicates=True,
)

res = minimize(
    problem,
    algorithm,
    get_termination("n_gen", 20),
    seed=42,
    verbose=True,
    save_history=True
)

n_gen  |  n_eval  |     cv_min    |     cv_avg    |     f_avg     |     f_min    
     1 |       99 |  0.000000E+00 |  0.1919191919 |  1.310859E+01 |  1.9552369842
     2 |      199 |  0.000000E+00 |  0.000000E+00 |  1.052896E+01 |  1.9406913416
     3 |      299 |  0.000000E+00 |  0.000000E+00 |  8.0099945352 |  1.9406913416
     4 |      399 |  0.000000E+00 |  0.000000E+00 |  5.0989983921 |  1.7129311263
     5 |      499 |  0.000000E+00 |  0.000000E+00 |  2.0831616299 |  1.7129311263
     6 |      599 |  0.000000E+00 |  0.000000E+00 |  1.9852365969 |  1.7129311263
     7 |      699 |  0.000000E+00 |  0.000000E+00 |  1.9139651819 |  1.6388306241
     8 |      799 |  0.000000E+00 |  0.000000E+00 |  1.8551903058 |  1.6388306241
     9 |      899 |  0.000000E+00 |  0.000000E+00 |  1.7924485268 |  1.6388306241
    10 |      999 |  0.000000E+00 |  0.000000E+00 |  1.7710845297 |  1.6388306241
    11 |     1099 |  0.000000E+00 |  0.000000E+00 |  1.7206482221 |  1.6388306241
    12 |     119

## Actual surgical plan & outcome for comparison

In [10]:
df_actual = disp.display_actual_outcomes(PATIENT_ID, patient_fixed)
display(df_actual)

ACTUAL SURGICAL PLAN (WHAT WAS PERFORMED)
  UIV_implant: Hook
  num_levels_cat: higher
  num_interbody_fusion_levels: 3.0
  ALIF: 0
  XLIF: 0
  TLIF: 0
  num_rods: 3.0
  num_pelvic_screws: 3.0
  osteotomy: 0.0

ALIGNMENT PARAMETERS: PREOP → POSTOP (ACTUAL)


Unnamed: 0,Parameter,Preop,Delta (actual),Postop (actual)
0,LL,8.3,28.4,36.7
1,SS,26.4,5.0,31.4
2,L4S1,15.8,12.0,27.8
3,GlobalTilt,32.2,-15.8,16.4
4,T4PA,26.8,-15.9,10.9
5,L1PA,22.3,-12.6,9.7
6,PI,47.4,-3.0,44.4
7,PT,21.0,-8.0,13.0
8,SVA,96.8,-44.6,52.2
9,PI-LL,39.1 ⚠,-31.4,7.7 ✓


## Show best solution

In [11]:
best_x = np.asarray(res.X).astype(int)
result = ou.evaluate_solution(
    best_x, 
    patient_fixed, 
    delta_bundles, 
    mech_fail_bundle,
    weights=WEIGHTS
)

df_comparison = disp.display_optimized_solution(result, patient_fixed)
display(df_comparison)

BEST SOLUTION SUMMARY (OPTIMIZED)

Composite Score: 1.6388 (lower is better)
Mechanical Failure Probability: 7.4%

Surgical Plan:
  UIV_implant: Hook
  num_levels_cat: lower
  num_interbody_fusion_levels: 5
  ALIF: 1
  XLIF: 0
  TLIF: 0
  num_rods: 5
  num_pelvic_screws: 4
  osteotomy: 1

ALIGNMENT PARAMETERS: PREOP → POSTOP (PREDICTED)


Unnamed: 0,Parameter,Preop,Delta (pred),Postop (pred)
0,LL,8.3,43.2,51.5
1,SS,26.4,5.6,32.0
2,L4S1,15.8,25.3,41.1
3,GlobalTilt,32.2,-19.6,12.6
4,T4PA,26.8,-19.8,7.0
5,L1PA,22.3,-13.1,9.2
6,SVA,96.8,-60.1,36.7
7,PI,47.4,-,47.4
8,PT,21.0,-5.6,15.4
9,PI-LL,39.1 ⚠,-43.2,-4.1 ⚠


## Getting multiple solutions

Extracts diverse surgical plans from the GA optimization history:
1. Collects top candidates across all generations
2. Filters to solutions within `score_tolerance` of best score
3. Buckets by `(UIV_implant, ALIF, XLIF, TLIF)` to ensure variety in implant/fusion types
4. Returns 'top_n` unique plans with full evaluation (postop values, GAP score, mech fail prob)

In [12]:
top12_solutions = s.get_diverse_solutions(
    res=res,
    top_n=12,
    top_per_gen=50,
    score_tolerance=2,  # Include solutions within 2 points of best
    bucket_cols=("UIV_implant", "ALIF", "XLIF", "TLIF"),
    n_per_bucket=1,
    patient_fixed=patient_fixed,
    delta_bundles=delta_bundles,
    mech_fail_bundle=mech_fail_bundle,
    weights=WEIGHTS,
)

_ = disp.display_multiple_solutions(top12_solutions, patient_fixed)
# top12_solutions

SOLUTIONS COMPARISON


Unnamed: 0,Parameter,Sol 1,Sol 2,Sol 3,Sol 4,Sol 5,Sol 6,Sol 7,Sol 8,Sol 9,Sol 10,Sol 11,Sol 12
0,Composite Score,1.64,1.71,1.85,1.91,1.96,1.97,2.02,2.05,2.07,2.12,2.13,2.38
1,Mech Fail Prob,7.4%,14.3%,22.1%,4.6%,26.5%,37.1%,16.1%,11.1%,30.6%,16.5%,37.7%,47.0%
2,GAP Score,12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P),12.0 (SD) → 1 (P)
3,────────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────,──────────
4,SURGICAL PLAN,,,,,,,,,,,,
5,UIV_implant,Hook,FS,Hook,Hook,Hook,FS,FS,FS,PS,Hook,FS,FS
6,num_levels_cat,lower,lower,lower,lower,lower,lower,higher,higher,lower,higher,lower,higher
7,num_interbody_fusion_levels,5,5,5,5,5,5,5,5,5,5,5,4
8,ALIF,1,1,1,0,0,1,0,0,1,0,0,0
9,XLIF,0,0,1,0,0,1,1,0,1,1,0,1
