# Data Combination & Sliding Window Creation

## Objective
1. **Step 1**: Convert wide format → long format (one row per patient-visit)
2. **Step 2**: Merge all data sources
3. **Step 3**: Create sliding windows for model training

## Input Files
- `0_OCT_paths_per_bscan.csv` - images paths
- `1_BCVA_processed.csv` - BCVA values (wide format)
- `3_CST_processed.csv` - CST values (wide format)
- `5_Injection_Data_processed.csv` - Injection data (wide format)
- `6_Leakage_Index_processed.csv` - Leakage Index (wide format)
- `4_Demographics_processed.csv` - Static patient features
- `2_biomarker_processed.csv` - Biomarkers (already long format, sparse)
- `7_fundus_path_extracted.csv` - Fundus image paths (long format)

## Output Files
- `8_longitudinal_data_long.csv` - All data in long format
- `9_sliding_window_dataset.csv` - Ready for model training


## merge logic
merged (base: BCVA)
- ↓ merge CST
- ↓ merge Injection
- ↓ merge Leakage
- ↓ merge Demographics
- ↓ merge Fundus          
- ↓ merge Biomarker       
- ↓ merge OCT_agg         
- Final merged dataframe

In [50]:
import pandas as pd
import numpy as np
from pathlib import Path

# Define data directory
DATA_DIR = Path('.')

---
# Part 1: Load All Data

In [51]:
# Load all processed files
oct_paths = pd.read_csv(DATA_DIR / '0_OCT_paths_per_bscan.csv')
bcva = pd.read_csv(DATA_DIR / '1_BCVA_processed.csv')
cst = pd.read_csv(DATA_DIR / '3_CST_processed.csv')
injection = pd.read_csv(DATA_DIR / '5_Injection_Data_processed.csv')
leakage = pd.read_csv(DATA_DIR / '6_Leakage_Index_processed.csv')
demographics = pd.read_csv(DATA_DIR / '4_Demographics_processed.csv')
biomarker = pd.read_csv(DATA_DIR / '2_biomarker_processed.csv')
fundus = pd.read_csv(DATA_DIR / '7_fundus_path_extracted.csv')


print("All files loaded successfully!")
print(f"OCT paths: {oct_paths.shape}")
print(f"BCVA: {bcva.shape}, CST: {cst.shape}, Injection: {injection.shape}")
print(f"Leakage: {leakage.shape}, Demographics: {demographics.shape}")
print(f"Biomarker: {biomarker.shape}, Fundus: {fundus.shape}")

All files loaded successfully!
OCT paths: (32337, 7)
BCVA: (40, 44), CST: (40, 44), Injection: (40, 24)
Leakage: (40, 24), Demographics: (40, 15)
Biomarker: (3920, 27), Fundus: (1284, 6)


In [52]:
# Define time points and their week values
TIME_POINTS = {
    'Screen': 0, 'Week4': 4, 'Week8': 8, 'Week12': 12, 'Week16': 16,
    'Week20': 20, 'Week24': 24, 'Week28': 28, 'Week32': 32, 'Week36': 36,
    'Week40': 40, 'Week44': 44, 'Week48': 48, 'Week52': 52, 'Week60': 60,
    'Week68': 68, 'Week76': 76, 'Week84': 84, 'Week92': 92, 'Week100': 100,
    'Week104': 104
}

WEEKS = list(TIME_POINTS.values())  # [0, 4, 8, 12, ..., 104]
print(f"Time points: {len(TIME_POINTS)}")
print(f"Weeks: {WEEKS}")

Time points: 21
Weeks: [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 60, 68, 76, 84, 92, 100, 104]


---
# Part 2: Convert Wide → Long Format
**Input (Wide format)** - 1 row per patient

**Output (Long format)** - 1 row per patient-visit

## 2.1 BCVA Wide → Long

In [53]:
bcva.head()  # Display the first few rows of the BCVA dataframe

Unnamed: 0,Patient_ID,Eye,Arm,Screen,Week4,Week4_Change,Week8,Week8_Change,Week12,Week12_Change,...,Week76,Week76_Change,Week84,Week84_Change,Week92,Week92_Change,Week100,Week100_Change,Week104,Week104_Change
0,01-001,OS,2,97,98,1,97.0,0.0,98.0,1.0,...,90.0,-7.0,95.0,-2.0,99.0,2.0,97.0,0.0,94.0,-3.0
1,01-002,OD,2,68,72,4,68.0,0.0,67.0,-1.0,...,,,,,,,,,,
2,01-013,OD,2,88,89,1,86.0,-2.0,89.0,1.0,...,89.0,1.0,89.0,1.0,89.0,1.0,82.0,-6.0,85.0,-3.0
3,01-014,OS,2,95,94,-1,94.0,-1.0,97.0,2.0,...,93.0,-2.0,98.0,3.0,97.0,2.0,98.0,3.0,97.0,2.0
4,01-023,OD,2,81,82,1,,,76.0,-5.0,...,,,81.0,0.0,84.0,3.0,,,,


In [54]:
TIME_POINTS

{'Screen': 0,
 'Week4': 4,
 'Week8': 8,
 'Week12': 12,
 'Week16': 16,
 'Week20': 20,
 'Week24': 24,
 'Week28': 28,
 'Week32': 32,
 'Week36': 36,
 'Week40': 40,
 'Week44': 44,
 'Week48': 48,
 'Week52': 52,
 'Week60': 60,
 'Week68': 68,
 'Week76': 76,
 'Week84': 84,
 'Week92': 92,
 'Week100': 100,
 'Week104': 104}

In [55]:
def wide_to_long_bcva(df):
    """Convert BCVA from wide to long format."""
    records = []
    
    for _, row in df.iterrows():
        patient_id = row['Patient_ID']
        eye = row['Eye']
        arm = row['Arm']
        
        for tp_name, week in TIME_POINTS.items():
            bcva_val = row.get(tp_name, np.nan)
            
            records.append({
                'Patient_ID': patient_id,
                'Eye': eye,
                'Arm': arm,
                'Week': week,
                'BCVA': bcva_val
            })
    
    return pd.DataFrame(records)

bcva_long = wide_to_long_bcva(bcva)
print(f"BCVA long format: {bcva_long.shape}")
bcva_long.head(10)

BCVA long format: (840, 5)


Unnamed: 0,Patient_ID,Eye,Arm,Week,BCVA
0,01-001,OS,2,0,97.0
1,01-001,OS,2,4,98.0
2,01-001,OS,2,8,97.0
3,01-001,OS,2,12,98.0
4,01-001,OS,2,16,97.0
5,01-001,OS,2,20,96.0
6,01-001,OS,2,24,98.0
7,01-001,OS,2,28,96.0
8,01-001,OS,2,32,97.0
9,01-001,OS,2,36,97.0


## 2.2 CST Wide → Long

In [56]:
cst.head()

Unnamed: 0,Patient_ID,Eye,Arm,Screen,Week4,Week4_Change,Week8,Week8_Change,Week12,Week12_Change,...,Week76,Week76_Change,Week84,Week84_Change,Week92,Week92_Change,Week100,Week100_Change,Week104,Week104_Change
0,01-001,OS,2,275,268,-7,268.0,-7.0,265.0,-10.0,...,272.0,-3.0,268.0,-7.0,269.0,-6.0,270.0,-5.0,276.0,1.0
1,01-002,OD,2,238,233,-5,222.0,-16.0,222.0,-16.0,...,,,,,,,,,,
2,01-013,OD,2,303,286,-17,280.0,-23.0,280.0,-23.0,...,291.0,-12.0,314.0,11.0,284.0,-19.0,378.0,75.0,322.0,19.0
3,01-014,OS,2,256,248,-8,246.0,-10.0,247.0,-9.0,...,255.0,-1.0,262.0,6.0,255.0,-1.0,265.0,9.0,292.0,36.0
4,01-023,OD,2,267,260,-7,,,252.0,-15.0,...,,,253.0,-14.0,242.0,-25.0,,,,


In [57]:
def wide_to_long_cst(df):
    """Convert CST from wide to long format."""
    records = []
    
    for _, row in df.iterrows(): #Loop through each patient (each row in wide format)
        patient_id = row['Patient_ID'] # Get patient ID for current row
        eye = row['Eye'] #Get eye (OS/OD) for current row
        
        for tp_name, week in TIME_POINTS.items(): #Loop through all time points: `{'Screen': 0, 'Week4': 4, ...}
            cst_val = row.get(tp_name, np.nan) # Get CST value for current time point (or NaN if missing)
            
            records.append({
                'Patient_ID': patient_id,
                'Eye': eye,
                'Week': week,
                'CST': cst_val
            })
    print(records)
    return pd.DataFrame(records)

cst_long = wide_to_long_cst(cst)
print(f"CST long format: {cst_long.shape}")
cst_long.head()

[{'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 0, 'CST': 275}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 4, 'CST': 268}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 8, 'CST': 268.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 12, 'CST': 265.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 16, 'CST': 267.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 20, 'CST': 270.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 24, 'CST': 268.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 28, 'CST': 273.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 32, 'CST': '269'}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 36, 'CST': 270.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 40, 'CST': 266.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 44, 'CST': 268.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 48, 'CST': 264.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 52, 'CST': 265.0}, {'Patient_ID': '01-001', 'Eye': 'OS', 'Week': 60, 'CST': nan}, {'Patient_ID': '01-001', 'Eye': '

Unnamed: 0,Patient_ID,Eye,Week,CST
0,01-001,OS,0,275.0
1,01-001,OS,4,268.0
2,01-001,OS,8,268.0
3,01-001,OS,12,265.0
4,01-001,OS,16,267.0


## 2.3 Injection Wide → Long

In [58]:
def wide_to_long_injection(df):
    """Convert Injection from wide to long format."""
    records = []
    
    for _, row in df.iterrows():
        patient_id = row['Patient_ID']
        eye = row['Eye']
        
        for tp_name, week in TIME_POINTS.items():
            inj_val = row.get(tp_name, np.nan)
            
            records.append({
                'Patient_ID': patient_id,
                'Eye': eye,
                'Week': week,
                'Injection': inj_val
            })
    
    return pd.DataFrame(records)

injection_long = wide_to_long_injection(injection)
print(f"Injection long format: {injection_long.shape}")
injection_long.head()

Injection long format: (840, 4)


Unnamed: 0,Patient_ID,Eye,Week,Injection
0,01-001,OS,0,1.0
1,01-001,OS,4,1.0
2,01-001,OS,8,1.0
3,01-001,OS,12,1.0
4,01-001,OS,16,0.0


## 2.4 Leakage Index Wide → Long

In [59]:
def wide_to_long_leakage(df):
    """Convert Leakage Index from wide to long format."""
    records = []
    
    for _, row in df.iterrows():
        patient_id = row['Patient_ID']
        eye = row['Eye']
        
        for tp_name, week in TIME_POINTS.items():
            # Column name format: Leakage_Screen, Leakage_Week4, etc.
            col_name = f'Leakage_{tp_name}'
            leakage_val = row.get(col_name, np.nan)
            
            records.append({
                'Patient_ID': patient_id,
                'Eye': eye,
                'Week': week,
                'Leakage_Index': leakage_val
            })
    
    return pd.DataFrame(records)

leakage_long = wide_to_long_leakage(leakage)
print(f"Leakage long format: {leakage_long.shape}")
leakage_long.head()

Leakage long format: (840, 4)


Unnamed: 0,Patient_ID,Eye,Week,Leakage_Index
0,01-001,OS,0,1.59
1,01-001,OS,4,0.28
2,01-001,OS,8,0.21
3,01-001,OS,12,0.15
4,01-001,OS,16,0.15


---
# Part 3: Merge All Data Sources

In [60]:
# Start with BCVA as base (has Patient_ID, Eye, Arm, Week)
merged = bcva_long.copy()
print(f"Base (BCVA): {merged.shape}")

# Merge CST
merged = merged.merge(
    cst_long[['Patient_ID', 'Eye', 'Week', 'CST']],
    on=['Patient_ID', 'Eye', 'Week'],
    how='left'
)
print(f"After CST merge: {merged.shape}")

# Merge Injection
merged = merged.merge(
    injection_long[['Patient_ID', 'Eye', 'Week', 'Injection']],
    on=['Patient_ID', 'Eye', 'Week'],
    how='left'
)
print(f"After Injection merge: {merged.shape}")

# Merge Leakage Index
merged = merged.merge(
    leakage_long[['Patient_ID', 'Eye', 'Week', 'Leakage_Index']],
    on=['Patient_ID', 'Eye', 'Week'],
    how='left'
)
print(f"After Leakage merge: {merged.shape}")

Base (BCVA): (840, 5)
After CST merge: (840, 6)
After Injection merge: (840, 7)
After Leakage merge: (840, 8)


In [61]:
merged.head()

Unnamed: 0,Patient_ID,Eye,Arm,Week,BCVA,CST,Injection,Leakage_Index
0,01-001,OS,2,0,97.0,275.0,1.0,1.59
1,01-001,OS,2,4,98.0,268.0,1.0,0.28
2,01-001,OS,2,8,97.0,268.0,1.0,0.21
3,01-001,OS,2,12,98.0,265.0,1.0,0.15
4,01-001,OS,2,16,97.0,267.0,0.0,0.15


In [62]:
demographics.head()

Unnamed: 0,Patient_ID,Eye,Treatment_Arm,Age,Gender,Ethnicity,Race,Diabetes_Type,Diabetes_Years,Baseline_HbA1c,BMI,W24_HbA1c,W52_HbA1c,W76_HbA1c,W104_HbA1c
0,01-001,OS,2,44,M,N H/L,White,2,20,7.1,,8.7,8.4,9.1,8.4
1,01-002,OD,2,56,F,N H/L,White,2,25,11.3,34.484657,9.1,,,
2,01-013,OD,2,38,M,H/L,White,1,13,8.0,25.997929,9.5,9.6,11.7,9.8
3,01-014,OS,2,55,M,N H/L,White,2,12,10.1,31.871377,7.2,7.4,7.4,7.4
4,01-023,OD,2,56,M,H/L,White,2,22,7.1,35.669938,6.7,5.9,6.9,


In [63]:
# Merge Demographics (static features - merge on Patient_ID, Eye only)
demo_cols = ['Patient_ID', 'Eye', 'Age', 'Gender', 'Ethnicity', 'Race',
             'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']

merged = merged.merge(
    demographics[demo_cols],
    on=['Patient_ID', 'Eye'],
    how='left'
)
print(f"After Demographics merge: {merged.shape}")
print(merged.columns.tolist())
merged.head()

After Demographics merge: (840, 16)
['Patient_ID', 'Eye', 'Arm', 'Week', 'BCVA', 'CST', 'Injection', 'Leakage_Index', 'Age', 'Gender', 'Ethnicity', 'Race', 'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']


Unnamed: 0,Patient_ID,Eye,Arm,Week,BCVA,CST,Injection,Leakage_Index,Age,Gender,Ethnicity,Race,Diabetes_Type,Diabetes_Years,Baseline_HbA1c,BMI
0,01-001,OS,2,0,97.0,275.0,1.0,1.59,44,M,N H/L,White,2,20,7.1,
1,01-001,OS,2,4,98.0,268.0,1.0,0.28,44,M,N H/L,White,2,20,7.1,
2,01-001,OS,2,8,97.0,268.0,1.0,0.21,44,M,N H/L,White,2,20,7.1,
3,01-001,OS,2,12,98.0,265.0,1.0,0.15,44,M,N H/L,White,2,20,7.1,
4,01-001,OS,2,16,97.0,267.0,0.0,0.15,44,M,N H/L,White,2,20,7.1,


In [64]:
fundus.head()

Unnamed: 0,Patient_ID_String,Week,Eye,File_Path,Filename,Extension
0,01-001,0,OD,Prime_FULL/01-001/W0/OD/fundus_W0.png,fundus_W0.png,png
1,01-001,0,OS,Prime_FULL/01-001/W0/OS/fundus_W0.png,fundus_W0.png,png
2,01-001,100,OD,Prime_FULL/01-001/W100/OD/fundus_W100.tif,fundus_W100.tif,tif
3,01-001,100,OS,Prime_FULL/01-001/W100/OS/fundus_W100.tif,fundus_W100.tif,tif
4,01-001,104,OD,Prime_FULL/01-001/W104/OD/fundus_W104.tif,fundus_W104.tif,tif


In [65]:
# Merge Fundus paths
fundus_renamed = fundus.rename(columns={'Patient_ID_String': 'Patient_ID'})
fundus_renamed = fundus_renamed[['Patient_ID', 'Eye', 'Week', 'File_Path']].rename(
    columns={'File_Path': 'Fundus_Path'}
)

merged = merged.merge(
    fundus_renamed,
    on=['Patient_ID', 'Eye', 'Week'],
    how='left'
)
print(f"After Fundus merge: {merged.shape}")

After Fundus merge: (840, 17)


In [66]:
print(merged.columns.tolist())
merged.head()

['Patient_ID', 'Eye', 'Arm', 'Week', 'BCVA', 'CST', 'Injection', 'Leakage_Index', 'Age', 'Gender', 'Ethnicity', 'Race', 'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI', 'Fundus_Path']


Unnamed: 0,Patient_ID,Eye,Arm,Week,BCVA,CST,Injection,Leakage_Index,Age,Gender,Ethnicity,Race,Diabetes_Type,Diabetes_Years,Baseline_HbA1c,BMI,Fundus_Path
0,01-001,OS,2,0,97.0,275.0,1.0,1.59,44,M,N H/L,White,2,20,7.1,,Prime_FULL/01-001/W0/OS/fundus_W0.png
1,01-001,OS,2,4,98.0,268.0,1.0,0.28,44,M,N H/L,White,2,20,7.1,,Prime_FULL/01-001/W4/OS/fundus_W4.png
2,01-001,OS,2,8,97.0,268.0,1.0,0.21,44,M,N H/L,White,2,20,7.1,,Prime_FULL/01-001/W8/OS/fundus_W8.png
3,01-001,OS,2,12,98.0,265.0,1.0,0.15,44,M,N H/L,White,2,20,7.1,,Prime_FULL/01-001/W12/OS/fundus_W12.png
4,01-001,OS,2,16,97.0,267.0,0.0,0.15,44,M,N H/L,White,2,20,7.1,,Prime_FULL/01-001/W16/OS/fundus_W16.png


In [67]:
print(biomarker.columns.tolist())
biomarker.head()

['Path (Trial/Arm/Folder/Visit/Eye/Image Name)', 'Scan (n/49)', 'Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL', 'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face', 'Fully attached vitreous face', 'Preretinal tissue/hemorrhage', 'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)', 'Disruption of RPE', 'PED (serous)', 'SHRM', 'Eye_ID', 'BCVA', 'CST', 'Patient_ID', 'Patient_ID_String', 'Week', 'Eye', 'OCT_Filename', 'B_scan_number']


Unnamed: 0,Path (Trial/Arm/Folder/Visit/Eye/Image Name),Scan (n/49),Atrophy / thinning of retinal layers,Disruption of EZ,DRIL,IR hemorrhages,IR HRF,Partially attached vitreous face,Fully attached vitreous face,Preretinal tissue/hemorrhage,...,SHRM,Eye_ID,BCVA,CST,Patient_ID,Patient_ID_String,Week,Eye,OCT_Filename,B_scan_number
0,/Prime_FULL/02-010/W0/OD/0.tif,1,0,0,0,0,1,0.0,0.0,0.0,...,0,57,88,307,57,02-010,0.0,OD,0.tif,0
1,/Prime_FULL/02-010/W0/OD/1.tif,2,0,0,0,0,1,0.0,0.0,0.0,...,0,57,88,307,57,02-010,0.0,OD,1.tif,1
2,/Prime_FULL/02-010/W0/OD/2.tif,3,0,0,0,0,1,0.0,0.0,0.0,...,0,57,88,307,57,02-010,0.0,OD,2.tif,2
3,/Prime_FULL/02-010/W0/OD/3.tif,4,0,0,0,0,1,0.0,0.0,0.0,...,0,57,88,307,57,02-010,0.0,OD,3.tif,3
4,/Prime_FULL/02-010/W0/OD/4.tif,5,0,0,0,0,1,0.0,0.0,0.0,...,0,57,88,307,57,02-010,0.0,OD,4.tif,4


## About Groupby Use
-  to combine into the 1 csv, Do Groupby for OCT and Biomarker -- compress 49 Bscans into 1, easy for reference
-  in modeling, use detailed paths 

In [68]:
# Merge Biomarkers (aggregate per visit - take mean across B-scans)
biomarker_cols = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                  'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                  'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                  'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                  'Disruption of RPE', 'PED (serous)', 'SHRM']

# Aggregate biomarkers per patient-eye-week (mean across 49 B-scans)
biomarker_agg = biomarker.groupby(['Patient_ID_String', 'Eye', 'Week'])[biomarker_cols].mean().reset_index()
biomarker_agg = biomarker_agg.rename(columns={'Patient_ID_String': 'Patient_ID'})

print(f"Biomarker aggregated: {biomarker_agg.shape}")
print(f"Weeks with biomarkers: {sorted(biomarker_agg['Week'].unique())}")

Biomarker aggregated: (80, 19)
Weeks with biomarkers: [np.float64(0.0), np.float64(12.0), np.float64(24.0), np.float64(28.0), np.float64(36.0), np.float64(48.0), np.float64(52.0), np.float64(84.0), np.float64(92.0), np.float64(100.0), np.float64(104.0)]


In [69]:
# Merge aggregated biomarkers
merged = merged.merge(
    biomarker_agg,
    on=['Patient_ID', 'Eye', 'Week'],
    how='left'
)
print(f"After Biomarker merge: {merged.shape}")
print(merged.columns.tolist())
merged.head()

After Biomarker merge: (840, 33)
['Patient_ID', 'Eye', 'Arm', 'Week', 'BCVA', 'CST', 'Injection', 'Leakage_Index', 'Age', 'Gender', 'Ethnicity', 'Race', 'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI', 'Fundus_Path', 'Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL', 'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face', 'Fully attached vitreous face', 'Preretinal tissue/hemorrhage', 'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)', 'Disruption of RPE', 'PED (serous)', 'SHRM']


Unnamed: 0,Patient_ID,Eye,Arm,Week,BCVA,CST,Injection,Leakage_Index,Age,Gender,...,Fully attached vitreous face,Preretinal tissue/hemorrhage,Vitreous debris,VMT,DRT/ME,Fluid (IRF),Fluid (SRF),Disruption of RPE,PED (serous),SHRM
0,01-001,OS,2,0,97.0,275.0,1.0,1.59,44,M,...,0.897959,0.0,0.632653,0.0,0.0,0.265306,0.0,0.0,0.0,0.0
1,01-001,OS,2,4,98.0,268.0,1.0,0.28,44,M,...,,,,,,,,,,
2,01-001,OS,2,8,97.0,268.0,1.0,0.21,44,M,...,,,,,,,,,,
3,01-001,OS,2,12,98.0,265.0,1.0,0.15,44,M,...,,,,,,,,,,
4,01-001,OS,2,16,97.0,267.0,0.0,0.15,44,M,...,,,,,,,,,,


In [70]:
# Aggregate OCT paths per visit (join all B-scan paths with delimiter)
oct_agg = oct_paths.groupby(['Patient_ID', 'Eye', 'Week']).agg({
    'OCT_Path': lambda x: '|'.join(x.astype(str)),  # Join all paths
    'B_scan_num': 'count'  # Count B-scans
}).reset_index()

oct_agg = oct_agg.rename(columns={
    'OCT_Path': 'OCT_Paths',
    'B_scan_num': 'Num_B_scans'
})

# Merge OCT paths
merged = merged.merge(
    oct_agg[['Patient_ID', 'Eye', 'Week', 'OCT_Paths', 'Num_B_scans']],
    on=['Patient_ID', 'Eye', 'Week'],
    how='left'
)

# Sort by Patient_ID, Eye, Week
merged = merged.sort_values(['Patient_ID', 'Eye', 'Week']).reset_index(drop=True)

In [71]:
print(f"\nFinal merged long format: {merged.shape}")
print(f"Columns: {merged.columns.tolist()}")
merged


Final merged long format: (840, 35)
Columns: ['Patient_ID', 'Eye', 'Arm', 'Week', 'BCVA', 'CST', 'Injection', 'Leakage_Index', 'Age', 'Gender', 'Ethnicity', 'Race', 'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI', 'Fundus_Path', 'Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL', 'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face', 'Fully attached vitreous face', 'Preretinal tissue/hemorrhage', 'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)', 'Disruption of RPE', 'PED (serous)', 'SHRM', 'OCT_Paths', 'Num_B_scans']


Unnamed: 0,Patient_ID,Eye,Arm,Week,BCVA,CST,Injection,Leakage_Index,Age,Gender,...,Vitreous debris,VMT,DRT/ME,Fluid (IRF),Fluid (SRF),Disruption of RPE,PED (serous),SHRM,OCT_Paths,Num_B_scans
0,01-001,OS,2,0,97.0,275,1.0,1.59,44,M,...,0.632653,0.0,0.0,0.265306,0.0,0.0,0.0,0.0,/Prime_FULL/01-001/W0/OS/0.png|/Prime_FULL/01-...,49.0
1,01-001,OS,2,4,98.0,268,1.0,0.28,44,M,...,,,,,,,,,/Prime_FULL/01-001/W4/OS/0.png|/Prime_FULL/01-...,49.0
2,01-001,OS,2,8,97.0,268.0,1.0,0.21,44,M,...,,,,,,,,,/Prime_FULL/01-001/W8/OS/0.png|/Prime_FULL/01-...,49.0
3,01-001,OS,2,12,98.0,265.0,1.0,0.15,44,M,...,,,,,,,,,/Prime_FULL/01-001/W12/OS/0.png|/Prime_FULL/01...,49.0
4,01-001,OS,2,16,97.0,267.0,0.0,0.15,44,M,...,,,,,,,,,/Prime_FULL/01-001/W16/OS/0.png|/Prime_FULL/01...,49.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
835,02-046,OD,1,76,,,,,48,M,...,,,,,,,,,,
836,02-046,OD,1,84,,,,,48,M,...,,,,,,,,,,
837,02-046,OD,1,92,,,,,48,M,...,,,,,,,,,,
838,02-046,OD,1,100,84.0,257.0,1.0,6.22,48,M,...,,,,,,,,,/Prime_FULL/02-046/W100/OD/0.tif|/Prime_FULL/0...,49.0


In [72]:
# Check missing values
print("=== Missing Values ===")
missing = merged.isna().sum()
missing_pct = 100 * missing / len(merged)
missing_df = pd.DataFrame({'Missing': missing, 'Pct': missing_pct}).sort_values('Pct', ascending=False)
print(missing_df[missing_df['Missing'] > 0])

=== Missing Values ===
                                      Missing        Pct
SHRM                                      762  90.714286
Atrophy / thinning of retinal layers      762  90.714286
Disruption of EZ                          762  90.714286
VMT                                       762  90.714286
Vitreous debris                           762  90.714286
Preretinal tissue/hemorrhage              762  90.714286
PED (serous)                              762  90.714286
Disruption of RPE                         762  90.714286
Fluid (SRF)                               762  90.714286
Fluid (IRF)                               762  90.714286
DRT/ME                                    762  90.714286
Partially attached vitreous face          762  90.714286
IR HRF                                    762  90.714286
IR hemorrhages                            762  90.714286
Fully attached vitreous face              762  90.714286
DRIL                                      762  90.714286
Leakage_

# ANALYZE missing value
- Biomarker 90% missing: Only labeled at 11 weeks by design [0, 12, 24, 28, 36, 48, 52, 84, 92, 100, 104]✅ OK, handle as sparse feature
( Biomarkers only labeled at 11 out of 21 weeks. So:
Total visits: 840 (40 eyes × 21 weeks),
Visits with biomarkers: ~78 (40 eyes × ~11 weeks, minus missing)
Missing: ~762 → 90.7% ✓ This is correct!)
- OCT, Fundus, BCVA, CST, Injection (~21% missing) : Real missing visits, mostly after W52 ✅ OK, you're using max_week=52
- Leakage slightly higher than others: DRSS.csv had some "Missed"/"Dropped" values ✅ Already converted to NaN

In [73]:
# Save long format data
merged.to_csv('8_combined_data_long.csv', index=False)
print(f"✓ Saved: 8_combined_data_long.csv ({merged.shape})")

✓ Saved: 8_combined_data_long.csv ((840, 35))


---
# Part 4: Create Sliding Windows

## Task Definition
- **Input**: Data from K=3 consecutive visits (spanning 12 weeks)
- **Target**: BCVA change at next visit: `BCVA(t+12) - BCVA(t+8)`
- **Example**: [W0, W4, W8] → predict ΔBCVA at W12

In [74]:
# Sliding window parameters
K = 3  # Number of visits in input window
VISIT_INTERVAL = 4  # Weeks between visits
PREDICTION_HORIZON = 4  # Predict 4 weeks ahead (next visit)

print(f"Window size: K={K} visits")
print(f"Visit interval: {VISIT_INTERVAL} weeks")
print(f"Prediction horizon: {PREDICTION_HORIZON} weeks")

Window size: K=3 visits
Visit interval: 4 weeks
Prediction horizon: 4 weeks


In [75]:
def create_sliding_windows(df, k=3, visit_interval=4, max_week=52):
    """
    Create sliding window samples from longitudinal data.
    
    Args:
        df: Long format dataframe with Patient_ID, Eye, Week, BCVA, etc.
        k: Number of visits in input window
        visit_interval: Weeks between consecutive visits
        max_week: Maximum week to use for training (W52 recommended)
    
    Returns:
        DataFrame with sliding window samples
    """
    windows = []
    
    # Group by patient-eye
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        # weeks_available = group['Week'].values
        # 改进: 只保留 BCVA 非空的 weeks
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            target_week = window_weeks[-1] + visit_interval
            
            if target_week > max_week:
                continue
            
            # 改进: 宽松模式（只要求 t2 和 target 的 BCVA 有值，允许 t0 和 t1 缺失）
            t2_week = window_weeks[-1]
            if t2_week not in valid_weeks or target_week not in valid_weeks:
                continue
            
            # Get data for each visit in window
            window_data = {
                'Patient_ID': patient_id,
                'Eye': eye,
                'Window_Start_Week': start_week,
                'Target_Week': target_week,
            }
            
            # Add static features (from first visit)
            first_visit = group[group['Week'] == window_weeks[0]].iloc[0]
            static_cols = ['Arm', 'Age', 'Gender', 'Ethnicity', 'Race',
                          'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']
            for col in static_cols:
                if col in first_visit:
                    window_data[col] = first_visit[col]
            
            # Add temporal features for each visit in window
            temporal_cols = ['BCVA', 'CST', 'Injection', 'Leakage_Index', 'Fundus_Path', 'OCT_Paths', 'Num_B_scans']
            biomarker_cols = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                             'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                             'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                             'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                             'Disruption of RPE', 'PED (serous)', 'SHRM']
            
            all_temporal = temporal_cols + biomarker_cols
            
            for i, week in enumerate(window_weeks):
                visit_data = group[group['Week'] == week].iloc[0]
                for col in all_temporal:
                    if col in visit_data:
                        window_data[f'{col}_t{i}'] = visit_data[col]
                window_data[f'Week_t{i}'] = week
            
            # Target
            target_visit = group[group['Week'] == target_week].iloc[0]
            bcva_target = target_visit['BCVA']
            bcva_last = window_data['BCVA_t2']
            
            window_data['BCVA_Target'] = bcva_target
            window_data['BCVA_Change'] = bcva_target - bcva_last if pd.notna(bcva_target) and pd.notna(bcva_last) else np.nan
            
            # Missing flags (标记 t0, t1 是否缺失)
            for i, week in enumerate(window_weeks):
                window_data[f'Missing_t{i}'] = 0 if week in valid_weeks else 1
            
            windows.append(window_data)
    
    return pd.DataFrame(windows)

print("Function defined.")

Function defined.


In [76]:
# Create sliding windows (using W0-W52 data)
sliding_windows = create_sliding_windows(merged, k=3, visit_interval=4, max_week=52)

print(f"Created {len(sliding_windows)} sliding window samples")
print(f"Shape: {sliding_windows.shape}")

print(f"Totaled {len(sliding_windows.columns)} cols: {sliding_windows.columns.tolist()}")
sliding_windows

Created 369 sliding window samples
Shape: (369, 90)
Totaled 90 cols: ['Patient_ID', 'Eye', 'Window_Start_Week', 'Target_Week', 'Arm', 'Age', 'Gender', 'Ethnicity', 'Race', 'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI', 'BCVA_t0', 'CST_t0', 'Injection_t0', 'Leakage_Index_t0', 'Fundus_Path_t0', 'OCT_Paths_t0', 'Num_B_scans_t0', 'Atrophy / thinning of retinal layers_t0', 'Disruption of EZ_t0', 'DRIL_t0', 'IR hemorrhages_t0', 'IR HRF_t0', 'Partially attached vitreous face_t0', 'Fully attached vitreous face_t0', 'Preretinal tissue/hemorrhage_t0', 'Vitreous debris_t0', 'VMT_t0', 'DRT/ME_t0', 'Fluid (IRF)_t0', 'Fluid (SRF)_t0', 'Disruption of RPE_t0', 'PED (serous)_t0', 'SHRM_t0', 'Week_t0', 'BCVA_t1', 'CST_t1', 'Injection_t1', 'Leakage_Index_t1', 'Fundus_Path_t1', 'OCT_Paths_t1', 'Num_B_scans_t1', 'Atrophy / thinning of retinal layers_t1', 'Disruption of EZ_t1', 'DRIL_t1', 'IR hemorrhages_t1', 'IR HRF_t1', 'Partially attached vitreous face_t1', 'Fully attached vitreous face_t1',

Unnamed: 0,Patient_ID,Eye,Window_Start_Week,Target_Week,Arm,Age,Gender,Ethnicity,Race,Diabetes_Type,...,Fluid (SRF)_t2,Disruption of RPE_t2,PED (serous)_t2,SHRM_t2,Week_t2,BCVA_Target,BCVA_Change,Missing_t0,Missing_t1,Missing_t2
0,01-001,OS,0,12,2,44,M,N H/L,White,2,...,,,,,8,98.0,1.0,0,0,0
1,01-001,OS,4,16,2,44,M,N H/L,White,2,...,,,,,12,97.0,-1.0,0,0,0
2,01-001,OS,8,20,2,44,M,N H/L,White,2,...,,,,,16,96.0,-1.0,0,0,0
3,01-001,OS,12,24,2,44,M,N H/L,White,2,...,,,,,20,98.0,2.0,0,0,0
4,01-001,OS,16,28,2,44,M,N H/L,White,2,...,,,,,24,96.0,-2.0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
364,02-046,OD,12,24,1,48,M,N H/L,White,1,...,,,,,20,89.0,-3.0,0,0,0
365,02-046,OD,16,28,1,48,M,N H/L,White,1,...,,,,,24,92.0,3.0,0,0,0
366,02-046,OD,32,44,1,48,M,N H/L,White,1,...,,,,,40,91.0,4.0,1,1,0
367,02-046,OD,36,48,1,48,M,N H/L,White,1,...,,,,,44,92.0,1.0,1,0,0


In [77]:
# Check sample distribution
print("=== Sliding Window Statistics ===")
print(f"Total samples: {len(sliding_windows)}")
print(f"Unique patients: {sliding_windows['Patient_ID'].nunique()}")
print(f"Samples per patient: {len(sliding_windows) / sliding_windows['Patient_ID'].nunique():.1f}")

print(f"\nWindow start weeks:")
print(sliding_windows['Window_Start_Week'].value_counts().sort_index())

=== Sliding Window Statistics ===
Total samples: 369
Unique patients: 39
Samples per patient: 9.5

Window start weeks:
Window_Start_Week
0     36
4     38
8     38
12    35
16    33
20    33
24    32
28    31
32    31
36    32
40    30
Name: count, dtype: int64


- Window Distribution:
W0  → 36 samples
W4  → 38 samples
...
W40 → 30 samples
最后一个窗口: [W40, W44, W48] → predict W52


In [78]:
# Check target distribution
print("=== Target Distribution (BCVA_Change) ===")
print(sliding_windows['BCVA_Change'].describe())

# Clinical significance thresholds
print(f"\nClinical significance:")
print(f"  Improved ≥5 letters: {(sliding_windows['BCVA_Change'] >= 5).sum()} ({100*(sliding_windows['BCVA_Change'] >= 5).mean():.1f}%)")
print(f"  Worsened ≤-5 letters: {(sliding_windows['BCVA_Change'] <= -5).sum()} ({100*(sliding_windows['BCVA_Change'] <= -5).mean():.1f}%)")
print(f"  Stable (-5 to 5): {((sliding_windows['BCVA_Change'] > -5) & (sliding_windows['BCVA_Change'] < 5)).sum()}")

=== Target Distribution (BCVA_Change) ===
count    369.000000
mean      -0.016260
std        3.581316
min      -31.000000
25%       -2.000000
50%        0.000000
75%        2.000000
max       11.000000
Name: BCVA_Change, dtype: float64

Clinical significance:
  Improved ≥5 letters: 29 (7.9%)
  Worsened ≤-5 letters: 29 (7.9%)
  Stable (-5 to 5): 311


In [79]:
# 目前的sliding window结果 （15.7% 绝对改变） vs 之前的估算方式 （26.4%） 对比
# 绝对变化 ≥5 字母的比例
significant_change = (sliding_windows['BCVA_Change'].abs() >= 5).sum()
pct = 100 * significant_change / len(sliding_windows)
print(f"|BCVA_Change| ≥ 5: {significant_change} ({pct:.1f}%)")
# 之前的估算方式 (可能是这样)
changes_12week = []
for (patient_id, eye), group in merged.groupby(['Patient_ID', 'Eye']):
    group = group.sort_values('Week')
    for i, row in group.iterrows():
        week_now = row['Week']
        week_future = week_now + 12
        future_row = group[group['Week'] == week_future]
        if len(future_row) > 0:
            bcva_now = row['BCVA']
            bcva_future = future_row.iloc[0]['BCVA']
            if pd.notna(bcva_now) and pd.notna(bcva_future):
                changes_12week.append(bcva_future - bcva_now)

changes_12week = pd.Series(changes_12week)
print(f"样本量: {len(changes_12week)}")
print(f"|ΔBCVA| ≥ 5: {(changes_12week.abs() >= 5).sum()} ({100*(changes_12week.abs() >= 5).mean():.1f}%)")

'''之前：可能包含了所有有效的 12 周间隔配对（W0-W12, W4-W16, W12-W24, ... 包括 W52 之后的数据）
现在：
限制了 max_week=52（target 必须 ≤ W52）
排除了 t2 或 target BCVA 缺失的样本
只用了宽松模式筛选后的 369 个样本'''

|BCVA_Change| ≥ 5: 58 (15.7%)
样本量: 409
|ΔBCVA| ≥ 5: 108 (26.4%)


'之前：可能包含了所有有效的 12 周间隔配对（W0-W12, W4-W16, W12-W24, ... 包括 W52 之后的数据）\n现在：\n限制了 max_week=52（target 必须 ≤ W52）\n排除了 t2 或 target BCVA 缺失的样本\n只用了宽松模式筛选后的 369 个样本'

## decide the best window size by trying all possibles

In [80]:
def analyze_window_size(df, k, visit_interval=4, max_week=52):
    """分析不同窗口大小下的 BCVA Change 分布"""
    changes = []
    
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            target_week = window_weeks[-1] + visit_interval
            
            if target_week > max_week:
                continue
            
            t_last = window_weeks[-1]  # 窗口最后一个时间点
            
            # 宽松模式：只要求 t_last 和 target 有 BCVA
            if t_last not in valid_weeks or target_week not in valid_weeks:
                continue
            
            bcva_last = group[group['Week'] == t_last].iloc[0]['BCVA']
            bcva_target = group[group['Week'] == target_week].iloc[0]['BCVA']
            
            change = bcva_target - bcva_last
            changes.append({
                'Patient_ID': patient_id,
                'Eye': eye,
                'Start_Week': start_week,
                'Target_Week': target_week,
                'BCVA_Change': change
            })
    
    return pd.DataFrame(changes)

# 测试不同的 K 值
results = []
for k in [2, 3, 4, 6, 8]:
    window_span = (k - 1) * 4 + 4  # 窗口跨度 + 预测间隔
    df_windows = analyze_window_size(merged, k=k, visit_interval=4, max_week=52)
    
    if len(df_windows) == 0:
        continue
    
    n_samples = len(df_windows)
    mean_change = df_windows['BCVA_Change'].mean()
    std_change = df_windows['BCVA_Change'].std()
    mean_abs_change = df_windows['BCVA_Change'].abs().mean()
    pct_ge5 = 100 * (df_windows['BCVA_Change'].abs() >= 5).mean()
    pct_ge10 = 100 * (df_windows['BCVA_Change'].abs() >= 10).mean()
    
    results.append({
        'K': k,
        'Window_Span': f"{(k-1)*4}周 → +4周",
        'Total_Weeks': window_span,
        'Samples': n_samples,
        'Mean': round(mean_change, 2),
        'Std': round(std_change, 2),
        'Mean_Abs': round(mean_abs_change, 2),
        '≥5_letters%': round(pct_ge5, 1),
        '≥10_letters%': round(pct_ge10, 1)
    })

results_df = pd.DataFrame(results)
print(results_df.to_string(index=False))


 K Window_Span  Total_Weeks  Samples  Mean  Std  Mean_Abs  ≥5_letters%  ≥10_letters%
 2    4周 → +4周            8      405  0.05 3.54      2.44         15.6           1.5
 3    8周 → +4周           12      369 -0.02 3.58      2.47         15.7           1.4
 4   12周 → +4周           16      333 -0.08 3.61      2.45         15.6           1.5
 6   20周 → +4周           24      257 -0.10 3.80      2.61         17.1           1.2
 8   28周 → +4周           32      189  0.04 3.75      2.45         14.8           1.1


- 当前实验说明：窗口大小 K 对显著变化比例影响不大，因为预测的都是 4 周后的短期变化。


In [81]:
# 固定 K=3，改变 prediction_horizon
def analyze_prediction_horizon(df, k=3, visit_interval=4, prediction_weeks=4, max_week=52):
    """分析不同预测间隔下的 BCVA Change"""
    changes = []
    
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            t_last = window_weeks[-1]
            target_week = t_last + prediction_weeks  # 可变的预测间隔
            
            if target_week > max_week:
                continue
            
            if t_last not in valid_weeks or target_week not in valid_weeks:
                continue
            
            bcva_last = group[group['Week'] == t_last].iloc[0]['BCVA']
            bcva_target = group[group['Week'] == target_week].iloc[0]['BCVA']
            
            change = bcva_target - bcva_last
            changes.append({'BCVA_Change': change})
    
    return pd.DataFrame(changes)

# 测试不同预测间隔
results = []
for pred_weeks in [4, 8, 12, 16, 24]:
    df_windows = analyze_prediction_horizon(merged, k=3, prediction_weeks=pred_weeks, max_week=104)
    
    if len(df_windows) == 0:
        continue
    
    pct_ge5 = 100 * (df_windows['BCVA_Change'].abs() >= 5).mean()
    
    results.append({
        'Prediction_Horizon': f"{pred_weeks}周",
        'Samples': len(df_windows),
        '≥5_letters%': round(pct_ge5, 1)
    })

print(pd.DataFrame(results).to_string(index=False))

Prediction_Horizon  Samples  ≥5_letters%
                4周      393         15.8
                8周      447         19.9
               12周      331         23.6
               16周      375         22.1
               24周      308         24.0


- 预测间隔越长，BCVA 变化越显著，这符合临床直觉。
- 12周是一个合理的选择: 临床意义：12周（3个月）是常见的随访间隔; ≥5 letters% = 23.6%：接近之前估算的 26.4%

--> 推荐配置：K=3, prediction_horizon=12周
即：用 [t, t+4, t+8] 三次 visit 的数据，预测 12 周后（t+20）相对于 t+8 的 BCVA 变化。



In [82]:
# update the sliding window function
def create_sliding_windows(df, k=3, visit_interval=4, prediction_horizon=12, max_week=52):
    """
    Create sliding window samples from longitudinal data.
    
    Args:
        df: Long format dataframe with Patient_ID, Eye, Week, BCVA, etc.
        k: Number of visits in input window
        visit_interval: Weeks between consecutive visits (default: 4)
        prediction_horizon: Weeks ahead to predict (default: 12)
        max_week: Maximum week for target (default: 52)
    
    Returns:
        DataFrame with sliding window samples
    
    Example with k=3, prediction_horizon=12:
        Input: [W0, W4, W8] → Predict: BCVA(W20) - BCVA(W8)
    """
    windows = []
    
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            t_last = window_weeks[-1]
            target_week = t_last + prediction_horizon  # 改为 prediction_horizon
            
            if target_week > max_week:
                continue
            
            # 宽松模式: 只检查 t_last 和 target
            if t_last not in valid_weeks or target_week not in valid_weeks:
                continue
            
            window_data = {
                'Patient_ID': patient_id,
                'Eye': eye,
                'Window_Start_Week': start_week,
                'Window_End_Week': t_last,
                'Target_Week': target_week,
                'Prediction_Horizon': prediction_horizon,
            }
            
            # Static features
            first_visit = group[group['Week'] == window_weeks[0]].iloc[0]
            static_cols = ['Arm', 'Age', 'Gender', 'Ethnicity', 'Race',
                          'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']
            for col in static_cols:
                if col in first_visit:
                    window_data[col] = first_visit[col]
            
            # Temporal features
            temporal_cols = ['BCVA', 'CST', 'Injection', 'Leakage_Index', 'Fundus_Path', 'OCT_Paths', 'Num_B_scans']
            biomarker_cols = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                             'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                             'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                             'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                             'Disruption of RPE', 'PED (serous)', 'SHRM']
            all_temporal = temporal_cols + biomarker_cols
            
            for i, week in enumerate(window_weeks):
                visit_data = group[group['Week'] == week].iloc[0]
                for col in all_temporal:
                    if col in visit_data:
                        window_data[f'{col}_t{i}'] = visit_data[col]
                window_data[f'Week_t{i}'] = week
            
            # Target: BCVA change (target - t_last)
            target_visit = group[group['Week'] == target_week].iloc[0]
            bcva_target = target_visit['BCVA']
            bcva_last = window_data[f'BCVA_t{k-1}']  # 最后一个时间点
            
            window_data['BCVA_Target'] = bcva_target
            window_data['BCVA_Change'] = bcva_target - bcva_last if pd.notna(bcva_target) and pd.notna(bcva_last) else np.nan
            
            # Missing flags
            for i, week in enumerate(window_weeks):
                window_data[f'Missing_t{i}'] = 0 if week in valid_weeks else 1
            
            windows.append(window_data)
    
    return pd.DataFrame(windows)

In [83]:
# 参数设置
K = 3                    # 窗口包含 3 次 visit
VISIT_INTERVAL = 4       # 每次 visit 间隔 4 周
PREDICTION_HORIZON = 12  # 预测 12 周后的变化
MAX_WEEK = 52            # target 不超过 W52

# 创建 sliding windows
sliding_windows = create_sliding_windows(
    merged, 
    k=K, 
    visit_interval=VISIT_INTERVAL, 
    prediction_horizon=PREDICTION_HORIZON, 
    max_week=MAX_WEEK
)

print(f"Total samples: {len(sliding_windows)}")
print(f"≥5 letters%: {100 * (sliding_windows['BCVA_Change'].abs() >= 5).mean():.1f}%")


Total samples: 293
≥5 letters%: 23.2%


#### Try multi-output: 样本: [W0, W4, W8] → BCVA_Change_4w, BCVA_Change_8w, BCVA_Change_12w

建模方式

1. Multi-output regression：一个模型同时预测 3 个 target
2. Shared encoder：共享特征提取层，不同 horizon 用不同的 prediction head
3. 损失函数：可以加权组合，比如更重视长期预测

In [84]:
def create_sliding_windows_multi_horizon(df, k=3, visit_interval=4, prediction_horizons=[4, 8, 12], max_week=52):
    """
    Create sliding window samples with multiple prediction horizons.
    
    Args:
        df: Long format dataframe
        k: Number of visits in input window
        visit_interval: Weeks between consecutive visits
        prediction_horizons: List of prediction horizons (weeks)
        max_week: Maximum week for the furthest target
    
    Returns:
        DataFrame with multiple BCVA_Change columns
    
    Example with k=3, prediction_horizons=[4, 8, 12]:
        Input: [W0, W4, W8]
        Targets: BCVA_Change_4w  = BCVA(W12) - BCVA(W8)
                 BCVA_Change_8w  = BCVA(W16) - BCVA(W8)
                 BCVA_Change_12w = BCVA(W20) - BCVA(W8)
    """
    windows = []
    
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            t_last = window_weeks[-1]
            
            # 计算所有 target weeks
            target_weeks = {h: t_last + h for h in prediction_horizons}
            max_target_week = max(target_weeks.values())
            
            if max_target_week > max_week:
                continue
            
            # 宽松模式: 要求 t_last 有值，至少一个 target 有值
            if t_last not in valid_weeks:
                continue
            
            # 检查至少有一个 target 可用
            any_target_valid = any(tw in valid_weeks for tw in target_weeks.values())
            if not any_target_valid:
                continue
            
            window_data = {
                'Patient_ID': patient_id,
                'Eye': eye,
                'Window_Start_Week': start_week,
                'Window_End_Week': t_last,
            }
            
            # Static features
            first_visit = group[group['Week'] == window_weeks[0]].iloc[0]
            static_cols = ['Arm', 'Age', 'Gender', 'Ethnicity', 'Race',
                          'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']
            for col in static_cols:
                if col in first_visit:
                    window_data[col] = first_visit[col]
            
            # Temporal features
            temporal_cols = ['BCVA', 'CST', 'Injection', 'Leakage_Index', 'Fundus_Path', 'OCT_Paths', 'Num_B_scans']
            biomarker_cols = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                             'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                             'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                             'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                             'Disruption of RPE', 'PED (serous)', 'SHRM']
            all_temporal = temporal_cols + biomarker_cols
            
            for i, week in enumerate(window_weeks):
                visit_data = group[group['Week'] == week].iloc[0]
                for col in all_temporal:
                    if col in visit_data:
                        window_data[f'{col}_t{i}'] = visit_data[col]
                window_data[f'Week_t{i}'] = week
            
            # Multiple targets
            bcva_last = window_data[f'BCVA_t{k-1}']
            
            for horizon in prediction_horizons:
                target_week = target_weeks[horizon]
                window_data[f'Target_Week_{horizon}w'] = target_week
                
                if target_week in valid_weeks:
                    bcva_target = group[group['Week'] == target_week].iloc[0]['BCVA']
                    window_data[f'BCVA_Target_{horizon}w'] = bcva_target
                    window_data[f'BCVA_Change_{horizon}w'] = bcva_target - bcva_last if pd.notna(bcva_last) else np.nan
                else:
                    window_data[f'BCVA_Target_{horizon}w'] = np.nan
                    window_data[f'BCVA_Change_{horizon}w'] = np.nan
            
            # Missing flags
            for i, week in enumerate(window_weeks):
                window_data[f'Missing_t{i}'] = 0 if week in valid_weeks else 1
            
            windows.append(window_data)
    
    return pd.DataFrame(windows)
# 创建多 horizon 的 sliding windows
sliding_windows = create_sliding_windows_multi_horizon(
    merged,
    k=3,
    visit_interval=4,
    prediction_horizons=[4, 8, 12],
    max_week=52
)

print(f"Total samples: {len(sliding_windows)}")
print(f"\nTarget columns:")
for h in [4, 8, 12]:
    valid = sliding_windows[f'BCVA_Change_{h}w'].notna().sum()
    pct_ge5 = 100 * (sliding_windows[f'BCVA_Change_{h}w'].abs() >= 5).mean()
    print(f"  {h}w: {valid} valid, ≥5 letters = {pct_ge5:.1f}%")

Total samples: 316

Target columns:
  4w: 307 valid, ≥5 letters = 14.2%
  8w: 300 valid, ≥5 letters = 19.0%
  12w: 293 valid, ≥5 letters = 21.5%


### 做GRID SEARCH.  测试所有窗口大小, 所有Prediction Horizon, 找出最SIGNIFICANT BCVA CHANGE

In [85]:

def grid_search_window_params(df, k_values=[2, 3, 4, 6], horizon_values=[4, 8, 12, 16, 24], visit_interval=4, max_week=52):
    """
    Grid search to find optimal window size and prediction horizon.
    """
    results = []
    
    for k in k_values:
        for horizon in horizon_values:
            changes = []
            
            for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
                group = group.sort_values('Week')
                valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
                
                for start_week in group['Week'].values:
                    window_weeks = [start_week + i * visit_interval for i in range(k)]
                    t_last = window_weeks[-1]
                    target_week = t_last + horizon
                    
                    if target_week > max_week:
                        continue
                    
                    if t_last not in valid_weeks or target_week not in valid_weeks:
                        continue
                    
                    bcva_last = group[group['Week'] == t_last].iloc[0]['BCVA']
                    bcva_target = group[group['Week'] == target_week].iloc[0]['BCVA']
                    
                    if pd.notna(bcva_last) and pd.notna(bcva_target):
                        changes.append(bcva_target - bcva_last)
            
            if len(changes) == 0:
                continue
            
            changes = pd.Series(changes)
            
            results.append({
                'K': k,
                'Window_Span': (k - 1) * visit_interval,
                'Horizon': horizon,
                'Total_Weeks': (k - 1) * visit_interval + horizon,
                'Samples': len(changes),
                'Mean_Abs': round(changes.abs().mean(), 2),
                'Std': round(changes.std(), 2),
                '≥5_pct': round(100 * (changes.abs() >= 5).mean(), 1),
                '≥10_pct': round(100 * (changes.abs() >= 10).mean(), 1),
            })
    
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('≥5_pct', ascending=False)
    
    return results_df

# 运行
results = grid_search_window_params(
    merged,
    k_values=[2, 3, 4],
    horizon_values=[4, 8, 12, 16,20, 24],
    visit_interval=4,
    max_week=52
)

print("=== Grid Search Results (sorted by ≥5_pct) ===")
print(results.to_string(index=False))



=== Grid Search Results (sorted by ≥5_pct) ===
 K  Window_Span  Horizon  Total_Weeks  Samples  Mean_Abs  Std  ≥5_pct  ≥10_pct
 2            4       12           16      332      2.86 4.09    24.1      2.7
 4           12       12           24      258      2.85 4.15    23.6      2.7
 3            8       12           20      293      2.79 4.05    23.2      2.4
 2            4       20           24      258      2.88 4.11    22.5      1.9
 4           12       20           32      188      2.93 4.34    22.3      2.7
 3            8       20           28      222      2.91 4.20    22.1      2.3
 3            8       16           24      256      2.86 4.17    21.5      2.3
 2            4       16           20      294      2.89 4.17    21.4      2.7
 2            4        8           12      368      2.76 3.86    21.2      1.9
 2            4       24           28      225      2.84 4.05    20.9      2.7
 3            8        8           16      329      2.68 3.81    20.4      1.5
 4   

- K=2, 追求最高显著变化比例，简单模型 (输入：[t, t+4]（两次 visit）
预测：BCVA(t+16) - BCVA(t+4))
- K=3, multi-horizon, 追求预测稳定性

In [86]:
# k=2, multiple horizons
def create_sliding_windows_multi_horizon(df, k, visit_interval=4, prediction_horizons=[4, 8, 12, 16, 20, 24], max_week=52):
    """
    Create sliding window samples with multiple prediction horizons.
    
    Args:
        df: Long format dataframe
        k: Number of visits in input window
        visit_interval: Weeks between consecutive visits
        prediction_horizons: List of prediction horizons (weeks)
        max_week: Maximum week for the furthest target
    
    Returns:
        DataFrame with multiple BCVA_Change columns
    
    Example with k=3, prediction_horizons=[4, 8, 12]:
        Input: [W0, W4, W8]
        Targets: BCVA_Change_4w  = BCVA(W12) - BCVA(W8)
                 BCVA_Change_8w  = BCVA(W16) - BCVA(W8)
                 BCVA_Change_12w = BCVA(W20) - BCVA(W8)
    """
    windows = []
    
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            t_last = window_weeks[-1]
            
            # 计算所有 target weeks
            target_weeks = {h: t_last + h for h in prediction_horizons}
            max_target_week = max(target_weeks.values())
            
            if max_target_week > max_week:
                continue
            
            # 宽松模式: 要求 t_last 有值，至少一个 target 有值
            if t_last not in valid_weeks:
                continue
            
            # 检查至少有一个 target 可用
            any_target_valid = any(tw in valid_weeks for tw in target_weeks.values())
            if not any_target_valid:
                continue
            
            window_data = {
                'Patient_ID': patient_id,
                'Eye': eye,
                'Window_Start_Week': start_week,
                'Window_End_Week': t_last,
            }
            
            # Static features
            first_visit = group[group['Week'] == window_weeks[0]].iloc[0]
            static_cols = ['Arm', 'Age', 'Gender', 'Ethnicity', 'Race',
                          'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']
            for col in static_cols:
                if col in first_visit:
                    window_data[col] = first_visit[col]
            
            # Temporal features
            temporal_cols = ['BCVA', 'CST', 'Injection', 'Leakage_Index', 'Fundus_Path', 'OCT_Paths', 'Num_B_scans']
            biomarker_cols = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                             'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                             'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                             'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                             'Disruption of RPE', 'PED (serous)', 'SHRM']
            all_temporal = temporal_cols + biomarker_cols
            
            for i, week in enumerate(window_weeks):
                visit_data = group[group['Week'] == week].iloc[0]
                for col in all_temporal:
                    if col in visit_data:
                        window_data[f'{col}_t{i}'] = visit_data[col]
                window_data[f'Week_t{i}'] = week
            
            # Multiple targets
            bcva_last = window_data[f'BCVA_t{k-1}']
            
            for horizon in prediction_horizons:
                target_week = target_weeks[horizon]
                window_data[f'Target_Week_{horizon}w'] = target_week
                
                if target_week in valid_weeks:
                    bcva_target = group[group['Week'] == target_week].iloc[0]['BCVA']
                    window_data[f'BCVA_Target_{horizon}w'] = bcva_target
                    window_data[f'BCVA_Change_{horizon}w'] = bcva_target - bcva_last if pd.notna(bcva_last) else np.nan
                else:
                    window_data[f'BCVA_Target_{horizon}w'] = np.nan
                    window_data[f'BCVA_Change_{horizon}w'] = np.nan
            
            # Missing flags
            for i, week in enumerate(window_weeks):
                window_data[f'Missing_t{i}'] = 0 if week in valid_weeks else 1
            
            windows.append(window_data)
    
    return pd.DataFrame(windows)
# 创建多 horizon 的 sliding windows
sliding_windows = create_sliding_windows_multi_horizon(
    merged,
    k=2,
    visit_interval=4,
    prediction_horizons=[4, 8, 12, 16, 20, 24],
    max_week=52
)

print(f"When k=2, Total samples: {len(sliding_windows)}")
print(f"\nTarget columns:")
for h in [4, 8, 12, 16, 20, 24]:
    valid = sliding_windows[f'BCVA_Change_{h}w'].notna().sum()
    pct_ge5 = 100 * (sliding_windows[f'BCVA_Change_{h}w'].abs() >= 5).mean()
    print(f"  {h}w: {valid} valid, ≥5 letters = {pct_ge5:.1f}%")

When k=2, Total samples: 260

Target columns:
  4w: 249 valid, ≥5 letters = 15.0%
  8w: 247 valid, ≥5 letters = 21.2%
  12w: 241 valid, ≥5 letters = 24.6%
  16w: 234 valid, ≥5 letters = 21.2%
  20w: 230 valid, ≥5 letters = 18.8%
  24w: 225 valid, ≥5 letters = 18.1%


## when k=3, handling missing data strategy 3: Only require t2 and target to have BCVA, allow t0/t1 missing (369 samples, model learns to handle missing visits) 

In [87]:
# k=3, multiple horizons

def create_sliding_windows_multi_horizon(df, k, visit_interval=4, prediction_horizons=[4, 8, 12, 16, 20, 24], max_week=52):
    """
    Create sliding window samples with multiple prediction horizons.
    
    Args:
        df: Long format dataframe
        k: Number of visits in input window
        visit_interval: Weeks between consecutive visits
        prediction_horizons: List of prediction horizons (weeks)
        max_week: Maximum week for the furthest target
    
    Returns:
        DataFrame with multiple BCVA_Change columns
    
    Example with k=3, prediction_horizons=[4, 8, 12]:
        Input: [W0, W4, W8]
        Targets: BCVA_Change_4w  = BCVA(W12) - BCVA(W8)
                 BCVA_Change_8w  = BCVA(W16) - BCVA(W8)
                 BCVA_Change_12w = BCVA(W20) - BCVA(W8)
    """
    windows = []
    
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            t_last = window_weeks[-1]
            
            # 计算所有 target weeks
            target_weeks = {h: t_last + h for h in prediction_horizons}
            max_target_week = max(target_weeks.values())
            
            if max_target_week > max_week:
                continue
            
            # 宽松模式: 要求 t_last 有值，至少一个 target 有值
            if t_last not in valid_weeks:
                continue
            
            # 检查至少有一个 target 可用
            any_target_valid = any(tw in valid_weeks for tw in target_weeks.values())
            if not any_target_valid:
                continue
            
            window_data = {
                'Patient_ID': patient_id,
                'Eye': eye,
                'Window_Start_Week': start_week,
                'Window_End_Week': t_last,
            }
            
            # Static features
            first_visit = group[group['Week'] == window_weeks[0]].iloc[0]
            static_cols = ['Arm', 'Age', 'Gender', 'Ethnicity', 'Race',
                          'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']
            for col in static_cols:
                if col in first_visit:
                    window_data[col] = first_visit[col]
            
            # Temporal features
            temporal_cols = ['BCVA', 'CST', 'Injection', 'Leakage_Index', 'Fundus_Path', 'OCT_Paths', 'Num_B_scans']
            biomarker_cols = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                             'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                             'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                             'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                             'Disruption of RPE', 'PED (serous)', 'SHRM']
            all_temporal = temporal_cols + biomarker_cols
            
            for i, week in enumerate(window_weeks):
                visit_data = group[group['Week'] == week].iloc[0]
                for col in all_temporal:
                    if col in visit_data:
                        window_data[f'{col}_t{i}'] = visit_data[col]
                window_data[f'Week_t{i}'] = week
            
            # Multiple targets
            bcva_last = window_data[f'BCVA_t{k-1}']
            
            for horizon in prediction_horizons:
                target_week = target_weeks[horizon]
                window_data[f'Target_Week_{horizon}w'] = target_week
                
                if target_week in valid_weeks:
                    bcva_target = group[group['Week'] == target_week].iloc[0]['BCVA']
                    window_data[f'BCVA_Target_{horizon}w'] = bcva_target
                    window_data[f'BCVA_Change_{horizon}w'] = bcva_target - bcva_last if pd.notna(bcva_last) else np.nan
                else:
                    window_data[f'BCVA_Target_{horizon}w'] = np.nan
                    window_data[f'BCVA_Change_{horizon}w'] = np.nan
            
            # Missing flags
            for i, week in enumerate(window_weeks):
                window_data[f'Missing_t{i}'] = 0 if week in valid_weeks else 1
            
            windows.append(window_data)
    
    return pd.DataFrame(windows)
# 创建多 horizon 的 sliding windows
sliding_windows = create_sliding_windows_multi_horizon(
    merged,
    k=3,
    visit_interval=4,
    prediction_horizons=[4, 8, 12, 16, 20, 24],
    max_week=52
)

print(f"When k=3, Total samples: {len(sliding_windows)}")
print(f"\nTarget columns:")
for h in [4, 8, 12, 16, 20, 24]:
    valid = sliding_windows[f'BCVA_Change_{h}w'].notna().sum()
    pct_ge5 = 100 * (sliding_windows[f'BCVA_Change_{h}w'].abs() >= 5).mean()
    print(f"  {h}w: {valid} valid, ≥5 letters = {pct_ge5:.1f}%")

When k=3, Total samples: 220

Target columns:
  4w: 213 valid, ≥5 letters = 15.5%
  8w: 208 valid, ≥5 letters = 20.0%
  12w: 202 valid, ≥5 letters = 23.6%
  16w: 196 valid, ≥5 letters = 21.4%
  20w: 194 valid, ≥5 letters = 18.2%
  24w: 189 valid, ≥5 letters = 15.5%


## when k=3, handling missing data strategy 4 “最近 3 次随访（至少 2 次有数据）→ 预测下次随访”

条件3 (新增): {t0, t1, t2} 中至少 2 个有 BCVA（保证有纵向信息）

In [88]:
def create_sliding_windows_multi_horizon(df, k=3, visit_interval=4, prediction_horizons=[4, 8, 12], max_week=52):
    """
    Create sliding window samples with multiple prediction horizons.
    
    Strategy 4 (Lenient with minimum information requirement):
    - BCVA_t2 and at least one BCVA_target must be observed
    - At least 2 of {t0, t1, t2} must have observed clinical features (BCVA not missing)
    - Windows with only 1 observed visit are discarded (no longitudinal information)
    
    Args:
        df: Long format dataframe
        k: Number of visits in input window
        visit_interval: Weeks between consecutive visits
        prediction_horizons: List of prediction horizons (weeks)
        max_week: Maximum week for the furthest target
    
    Returns:
        DataFrame with multiple BCVA_Change columns
    """
    windows = []
    
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            t_last = window_weeks[-1]
            
            # 计算所有 target weeks
            target_weeks = {h: t_last + h for h in prediction_horizons}
            max_target_week = max(target_weeks.values())
            
            if max_target_week > max_week:
                continue
            
            # ========== Strategy 4 条件 ==========
            # 条件1: t_last (t2) 必须有 BCVA
            if t_last not in valid_weeks:
                continue
            
            # 条件2: 至少一个 target 有 BCVA
            any_target_valid = any(tw in valid_weeks for tw in target_weeks.values())
            if not any_target_valid:
                continue
            
            # 条件3 (新增): {t0, t1, t2} 中至少 2 个有 BCVA（保证有纵向信息）
            observed_count = sum(1 for w in window_weeks if w in valid_weeks)
            if observed_count < 2:
                continue  # 只有1个observed，丢弃
            # =====================================
            
            window_data = {
                'Patient_ID': patient_id,
                'Eye': eye,
                'Window_Start_Week': start_week,
                'Window_End_Week': t_last,
            }
            
            # Static features
            first_visit = group[group['Week'] == window_weeks[0]].iloc[0]
            static_cols = ['Arm', 'Age', 'Gender', 'Ethnicity', 'Race',
                          'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']
            for col in static_cols:
                if col in first_visit:
                    window_data[col] = first_visit[col]
            
            # Temporal features
            temporal_cols = ['BCVA', 'CST', 'Injection', 'Leakage_Index', 'Fundus_Path', 'OCT_Paths', 'Num_B_scans']
            biomarker_cols = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                             'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                             'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                             'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                             'Disruption of RPE', 'PED (serous)', 'SHRM']
            all_temporal = temporal_cols + biomarker_cols
            
            for i, week in enumerate(window_weeks):
                visit_data = group[group['Week'] == week].iloc[0]
                for col in all_temporal:
                    if col in visit_data:
                        window_data[f'{col}_t{i}'] = visit_data[col]
                window_data[f'Week_t{i}'] = week
                # 添加 missing flag
                window_data[f'BCVA_missing_t{i}'] = 0 if week in valid_weeks else 1
            
            # Target BCVA values and changes
            t_last_bcva = group[group['Week'] == t_last].iloc[0]['BCVA']
            
            for h in prediction_horizons:
                tw = target_weeks[h]
                window_data[f'Target_Week_{h}w'] = tw
                if tw in valid_weeks:
                    target_bcva = group[group['Week'] == tw].iloc[0]['BCVA']
                    window_data[f'BCVA_Target_{h}w'] = target_bcva
                    window_data[f'BCVA_Change_{h}w'] = target_bcva - t_last_bcva
                else:
                    window_data[f'BCVA_Target_{h}w'] = np.nan
                    window_data[f'BCVA_Change_{h}w'] = np.nan
            
            windows.append(window_data)
    
    return pd.DataFrame(windows)
# 创建多 horizon 的 sliding windows
sliding_windows = create_sliding_windows_multi_horizon(
    merged,
    k=3,
    visit_interval=4,
    prediction_horizons=[4, 8, 12, 16, 20, 24],
    max_week=52
)

print(f"When k=3, Total samples: {len(sliding_windows)}")
print(f"\nTarget columns:")
for h in [4, 8, 12, 16, 20, 24]:
    valid = sliding_windows[f'BCVA_Change_{h}w'].notna().sum()
    pct_ge5 = 100 * (sliding_windows[f'BCVA_Change_{h}w'].abs() >= 5).mean()
    print(f"  {h}w: {valid} valid, ≥5 letters = {pct_ge5:.1f}%")

When k=3, Total samples: 219

Target columns:
  4w: 212 valid, ≥5 letters = 15.5%
  8w: 208 valid, ≥5 letters = 20.1%
  12w: 201 valid, ≥5 letters = 23.7%
  16w: 196 valid, ≥5 letters = 21.5%
  20w: 194 valid, ≥5 letters = 18.3%
  24w: 189 valid, ≥5 letters = 15.5%


## 最终配置：
✅ Window Size: K=3 (8周观察期)
✅ Primary Horizon: 12w (23.7% ≥5 letters，最高)
✅ Multi-Horizon: [4w, 8w, 12w] 都保留

In [89]:
def create_sliding_windows_multi_horizon(df, k=3, visit_interval=4, prediction_horizons=[4, 8, 12], max_week=52):
    """
    Create sliding window samples with multiple prediction horizons.
    
    Strategy 4 (Lenient with minimum information requirement):
    - BCVA at t_last (t2) must be observed (required for computing ΔBCVA)
    - At least one BCVA_target must be observed
    - At least 2 of {t0, t1, ..., t_{k-1}} must have observed BCVA (ensures longitudinal info)
    - Windows with only 1 observed visit are discarded
    
    Args:
        df: Long format dataframe with Patient_ID, Eye, Week, BCVA, etc.
        k: Number of visits in input window (default: 3)
        visit_interval: Weeks between consecutive visits (default: 4)
        prediction_horizons: List of prediction horizons in weeks (default: [4, 8, 12])
        max_week: Maximum week for the furthest target (default: 52)
    
    Returns:
        DataFrame with sliding window samples and multiple BCVA_Change columns
    
    Example with k=3, prediction_horizons=[4, 8, 12]:
        Input: [W0, W4, W8]
        Targets: BCVA_Change_4w  = BCVA(W12) - BCVA(W8)
                 BCVA_Change_8w  = BCVA(W16) - BCVA(W8)
                 BCVA_Change_12w = BCVA(W20) - BCVA(W8)
    """
    windows = []
    
    for (patient_id, eye), group in df.groupby(['Patient_ID', 'Eye']):
        group = group.sort_values('Week')
        valid_weeks = set(group[group['BCVA'].notna()]['Week'].values)
        
        for start_week in group['Week'].values:
            window_weeks = [start_week + i * visit_interval for i in range(k)]
            t_last = window_weeks[-1]
            
            # 计算所有 target weeks
            target_weeks = {h: t_last + h for h in prediction_horizons}
            max_target_week = max(target_weeks.values())
            
            if max_target_week > max_week:
                continue
            
            # ========== Strategy 4 条件 ==========
            # 条件1: t_last 必须有 BCVA（用于计算 ΔBCVA）
            if t_last not in valid_weeks:
                continue
            
            # 条件2: 至少一个 target 有 BCVA
            any_target_valid = any(tw in valid_weeks for tw in target_weeks.values())
            if not any_target_valid:
                continue
            
            # 条件3: window 中至少 2 个时间点有 BCVA（保证有纵向信息）
            observed_count = sum(1 for w in window_weeks if w in valid_weeks)
            if observed_count < 2:
                continue  # 只有1个observed，丢弃（无纵向信息）
            # =====================================
            
            window_data = {
                'Patient_ID': patient_id,
                'Eye': eye,
                'Window_Start_Week': start_week,
                'Window_End_Week': t_last,
                'Num_Observed_Visits': observed_count,  # 记录有多少个有效观测点
            }
            
            # Static features (从第一个有效 visit 获取，避免缺失问题)
            # 找到 window 中第一个有数据的 visit
            for w in window_weeks:
                if w in valid_weeks:
                    first_valid_visit = group[group['Week'] == w].iloc[0]
                    break
            else:
                first_valid_visit = group[group['Week'] == window_weeks[0]].iloc[0]
            
            static_cols = ['Arm', 'Age', 'Gender', 'Ethnicity', 'Race',
                          'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']
            for col in static_cols:
                if col in first_valid_visit:
                    window_data[col] = first_valid_visit[col]
            
            # Temporal features for each visit in window
            temporal_cols = ['BCVA', 'CST', 'Injection', 'Leakage_Index', 
                            'Fundus_Path', 'OCT_Paths', 'Num_B_scans']
            biomarker_cols = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                             'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                             'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                             'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                             'Disruption of RPE', 'PED (serous)', 'SHRM']
            all_temporal = temporal_cols + biomarker_cols
            
            for i, week in enumerate(window_weeks):
                window_data[f'Week_t{i}'] = week
                # Missing flag for BCVA (核心临床指标)
                is_missing = week not in valid_weeks
                window_data[f'BCVA_missing_t{i}'] = 1 if is_missing else 0
                
                # 获取该时间点的数据
                visit_rows = group[group['Week'] == week]
                if len(visit_rows) > 0:
                    visit_data = visit_rows.iloc[0]
                    for col in all_temporal:
                        if col in visit_data:
                            window_data[f'{col}_t{i}'] = visit_data[col]
                else:
                    # 如果该 week 完全没有记录，填充 NaN
                    for col in all_temporal:
                        window_data[f'{col}_t{i}'] = np.nan
            
            # Target BCVA values and changes
            t_last_bcva = group[group['Week'] == t_last].iloc[0]['BCVA']
            
            for h in prediction_horizons:
                tw = target_weeks[h]
                window_data[f'Target_Week_{h}w'] = tw
                if tw in valid_weeks:
                    target_bcva = group[group['Week'] == tw].iloc[0]['BCVA']
                    window_data[f'BCVA_Target_{h}w'] = target_bcva
                    window_data[f'BCVA_Change_{h}w'] = target_bcva - t_last_bcva
                else:
                    window_data[f'BCVA_Target_{h}w'] = np.nan
                    window_data[f'BCVA_Change_{h}w'] = np.nan
            
            windows.append(window_data)
    
    result_df = pd.DataFrame(windows)
    
    # 打印统计信息
    print(f"=== Sliding Window Creation Summary ===")
    print(f"Total windows created: {len(result_df)}")
    print(f"\nObserved visits distribution:")
    print(result_df['Num_Observed_Visits'].value_counts().sort_index())
    print(f"\nMissing pattern (BCVA):")
    for i in range(k):
        missing_pct = 100 * result_df[f'BCVA_missing_t{i}'].mean()
        print(f"  t{i}: {missing_pct:.1f}% missing")
    
    return result_df


# ============ 使用函数 ============
sliding_windows = create_sliding_windows_multi_horizon(
    merged,
    k=3,
    visit_interval=4,
    prediction_horizons=[4, 8, 12],  # 推荐的 horizons
    max_week=52
)

print(f"\n=== Target Statistics ===")
for h in [4, 8, 12]:
    valid = sliding_windows[f'BCVA_Change_{h}w'].notna().sum()
    pct_ge5 = 100 * (sliding_windows[f'BCVA_Change_{h}w'].abs() >= 5).mean()
    print(f"  {h}w: {valid} valid samples, ≥5 letters change = {pct_ge5:.1f}%")

=== Sliding Window Creation Summary ===
Total windows created: 313

Observed visits distribution:
Num_Observed_Visits
2     12
3    301
Name: count, dtype: int64

Missing pattern (BCVA):
  t0: 2.2% missing
  t1: 1.6% missing
  t2: 0.0% missing

=== Target Statistics ===
  4w: 304 valid samples, ≥5 letters change = 14.4%
  8w: 298 valid samples, ≥5 letters change = 18.8%
  12w: 290 valid samples, ≥5 letters change = 21.1%


- 由于预测窗口是3个（而不是之前的5个），样本数量从219增加到313
- 数据质量很高： 301/313 (96.2%) 的窗口有完整的3个观测点， 只有 12 个窗口是 2 个观测点
- ≥5 letters 比例下降了，从 23.7% 降到 21.1%。这可能是因为：样本量增加后，新增的样本变化较小

In [90]:
sliding_windows

Unnamed: 0,Patient_ID,Eye,Window_Start_Week,Window_End_Week,Num_Observed_Visits,Arm,Age,Gender,Ethnicity,Race,...,SHRM_t2,Target_Week_4w,BCVA_Target_4w,BCVA_Change_4w,Target_Week_8w,BCVA_Target_8w,BCVA_Change_8w,Target_Week_12w,BCVA_Target_12w,BCVA_Change_12w
0,01-001,OS,0,8,3,2,44,M,N H/L,White,...,,12,98.0,1.0,16,97.0,0.0,20,96.0,-1.0
1,01-001,OS,4,12,3,2,44,M,N H/L,White,...,,16,97.0,-1.0,20,96.0,-2.0,24,98.0,0.0
2,01-001,OS,8,16,3,2,44,M,N H/L,White,...,,20,96.0,-1.0,24,98.0,1.0,28,96.0,-1.0
3,01-001,OS,12,20,3,2,44,M,N H/L,White,...,,24,98.0,2.0,28,96.0,0.0,32,97.0,1.0
4,01-001,OS,16,24,3,2,44,M,N H/L,White,...,,28,96.0,-2.0,32,97.0,-1.0,36,97.0,-1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
308,02-046,OD,4,12,3,1,48,M,N H/L,White,...,,16,94.0,4.0,20,92.0,2.0,24,89.0,-1.0
309,02-046,OD,8,16,3,1,48,M,N H/L,White,...,,20,92.0,-2.0,24,89.0,-5.0,28,92.0,-2.0
310,02-046,OD,12,20,3,1,48,M,N H/L,White,...,,24,89.0,-3.0,28,92.0,0.0,32,,
311,02-046,OD,16,24,3,1,48,M,N H/L,White,...,,28,92.0,3.0,32,,,36,,


In [97]:
# Check missing values in sliding windows
print("=== Missing Values in Sliding Windows ===")
missing = sliding_windows.isna().sum()
missing_pct = 100 * missing / len(sliding_windows)
missing_df = pd.DataFrame({'Missing': missing, 'Pct': missing_pct.round(1)})
missing_df = missing_df[missing_df['Missing'] > 0].sort_values('Pct', ascending=False)
print(missing_df)

=== Missing Values in Sliding Windows ===
                                     Missing    Pct
Vitreous debris_t1                       313  100.0
Preretinal tissue/hemorrhage_t1          313  100.0
Fully attached vitreous face_t1          313  100.0
Partially attached vitreous face_t1      313  100.0
IR HRF_t1                                313  100.0
...                                      ...    ...
BCVA_t1                                    5    1.6
OCT_Paths_t1                               5    1.6
BCVA_Delta_Recent                          5    1.6
CST_Delta_Recent                           5    1.6
Leakage_Index_t2                           3    1.0

[72 rows x 2 columns]


In [None]:

# ============================================================
# Part 1: 分离不同类型的列
# ============================================================

# Biomarker 列（暂不用于 baseline，留给 SupCon 预训练）
biomarker_names = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                   'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                   'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                   'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                   'Disruption of RPE', 'PED (serous)', 'SHRM']

biomarker_cols = [col for col in sliding_windows.columns 
                  if any(bm in col for bm in biomarker_names)]

# 标识符列
id_cols = ['Patient_ID', 'Eye', 'Window_Start_Week', 'Window_End_Week', 'Num_Observed_Visits']

# 静态特征
static_cols = ['Arm', 'Age', 'Gender', 'Ethnicity', 'Race', 
               'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']

# Target 列
target_cols = [col for col in sliding_windows.columns 
               if 'Target_Week' in col or 'BCVA_Target' in col or 'BCVA_Change' in col]

# 时序临床特征（用于 baseline）
temporal_clinical_cols = []
for i in range(3):
    temporal_clinical_cols.extend([
        f'Week_t{i}',
        f'BCVA_t{i}', f'BCVA_missing_t{i}',
        f'CST_t{i}',
        f'Injection_t{i}',
        f'Leakage_Index_t{i}',
    ])

# 图像路径列（用于后续多模态模型）
image_path_cols = [col for col in sliding_windows.columns 
                   if 'Path' in col or 'OCT_Paths' in col or 'Num_B_scans' in col]

print("=== Column Categories ===")
print(f"ID columns: {len(id_cols)}")
print(f"Static features: {len(static_cols)}")
print(f"Temporal clinical features: {len(temporal_clinical_cols)}")
print(f"Target columns: {len(target_cols)}")
print(f"Image path columns: {len(image_path_cols)}")
print(f"Biomarker columns (not used in baseline): {len(biomarker_cols)}")

# ============================================================
# Part 2: 检查临床特征的缺失情况
# ============================================================

print("\n=== Clinical Feature Missing Pattern ===")
clinical_check_cols = ['BCVA', 'CST', 'Injection', 'Leakage_Index']
for col_base in clinical_check_cols:
    missing_str = ""
    for i in range(3):
        col = f'{col_base}_t{i}'
        if col in sliding_windows.columns:
            missing_pct = 100 * sliding_windows[col].isna().mean()
            missing_str += f"t{i}: {missing_pct:.1f}%  "
    print(f"  {col_base}: {missing_str}")

# ============================================================
# Part 3: 创建 Baseline 特征 DataFrame
# ============================================================

# Baseline 模型使用的列
baseline_feature_cols = static_cols + temporal_clinical_cols
baseline_feature_cols = [col for col in baseline_feature_cols if col in sliding_windows.columns]

# 创建 baseline 数据集
baseline_df = sliding_windows[id_cols + baseline_feature_cols + target_cols].copy()

print(f"\n=== Baseline Dataset ===")
print(f"Shape: {baseline_df.shape}")
print(f"Features: {len(baseline_feature_cols)}")

# 检查 baseline 数据集的缺失
baseline_missing = baseline_df[baseline_feature_cols].isna().sum()
baseline_missing = baseline_missing[baseline_missing > 0]
if len(baseline_missing) > 0:
    print(f"\nMissing in baseline features:")
    print(baseline_missing)
else:
    print("\nNo missing values in baseline features!")

=== Missing Values in Sliding Windows ===
                                         Missing    Pct
IR HRF_t1                                    313  100.0
Fully attached vitreous face_t1              313  100.0
Disruption of EZ_t1                          313  100.0
DRIL_t1                                      313  100.0
Atrophy / thinning of retinal layers_t1      313  100.0
...                                          ...    ...
BCVA_t1                                        5    1.6
CST_t1                                         5    1.6
OCT_Paths_t1                                   5    1.6
Fundus_Path_t1                                 5    1.6
Leakage_Index_t2                               3    1.0

[70 rows x 2 columns]
=== Column Categories ===
ID columns: 5
Static features: 9
Temporal clinical features: 18
Target columns: 9
Image path columns: 9
Biomarker columns (not used in baseline): 48

=== Clinical Feature Missing Pattern ===
  BCVA: t0: 2.2%  t1: 1.6%  t2: 0.0%  
  CST: 

## 缺失值处理见下

- Biomarkers only available at FIRST and LAST visit per patient!
- Biomarker 不应该作为时序输入特征，而是：
用于 OCT Encoder 预训练（SupCon）; 
或作为辅助任务的标签（multi-task learning）
- 暂时不处理其他缺失值， BMI只有一个病人有缺失，之后模型如xgboost可以处理缺失值


## 特征工程（添加衍生特征）

In [93]:
# ============================================================
# Part 4: 特征工程
# ============================================================

# 先检查并修复 CST 列的数据类型
print("=== Checking data types ===")
for i in range(3):
    col = f'CST_t{i}'
    if col in sliding_windows.columns:
        print(f"{col}: {sliding_windows[col].dtype}")
        # 转换为数值型
        sliding_windows[col] = pd.to_numeric(sliding_windows[col], errors='coerce')
        print(f"  → converted to: {sliding_windows[col].dtype}")

def add_engineered_features(df, k=3):
    """添加衍生特征，正确处理缺失"""
    
    # 1. BCVA/CST 斜率
    def calc_slope_safe(row, feature, k):
        points = []
        for i in range(k):
            week = row.get(f'Week_t{i}')
            val = row.get(f'{feature}_t{i}')
            # 确保都是数值
            if pd.notna(week) and pd.notna(val):
                try:
                    week = float(week)
                    val = float(val)
                    points.append((week, val))
                except (ValueError, TypeError):
                    continue
        if len(points) >= 2:
            denom = points[-1][0] - points[0][0]
            if denom != 0:
                return (points[-1][1] - points[0][1]) / denom
        return np.nan
    
    df['BCVA_Slope'] = df.apply(lambda r: calc_slope_safe(r, 'BCVA', k), axis=1)
    df['CST_Slope'] = df.apply(lambda r: calc_slope_safe(r, 'CST', k), axis=1)
    
    # 2. 当前值（最重要的特征）
    df['BCVA_Current'] = df['BCVA_t2']
    df['CST_Current'] = df['CST_t2']
    
    # 3. 短期变化 (t1→t2)
    df['BCVA_Delta_Recent'] = df['BCVA_t2'] - df['BCVA_t1']
    df['CST_Delta_Recent'] = df['CST_t2'] - df['CST_t1']
    
    # 4. 总体变化 (t0→t2)
    df['BCVA_Delta_Total'] = df['BCVA_t2'] - df['BCVA_t0']
    df['CST_Delta_Total'] = df['CST_t2'] - df['CST_t0']
    
    # 5. 窗口内注射次数
    injection_cols = [f'Injection_t{i}' for i in range(k) if f'Injection_t{i}' in df.columns]
    df['Injection_Count'] = df[injection_cols].sum(axis=1, skipna=True)
    
    # 6. 最近是否注射
    df['Injection_Recent'] = df['Injection_t2']
    
    return df

sliding_windows = add_engineered_features(sliding_windows, k=3)

# 打印结果
print("\n=== Engineered Features Added ===")
eng_cols = ['BCVA_Slope', 'CST_Slope', 
            'BCVA_Current', 'CST_Current',
            'BCVA_Delta_Recent', 'CST_Delta_Recent',
            'BCVA_Delta_Total', 'CST_Delta_Total',
            'Injection_Count', 'Injection_Recent']
print(sliding_windows[eng_cols].describe().round(2))

=== Checking data types ===
CST_t0: object
  → converted to: float64
CST_t1: object
  → converted to: float64
CST_t2: object
  → converted to: float64

=== Engineered Features Added ===
       BCVA_Slope  CST_Slope  BCVA_Current  CST_Current  BCVA_Delta_Recent  \
count      313.00     313.00        313.00       313.00             308.00   
mean         0.06      -0.31         85.60       256.24               0.19   
std          0.50       1.39          7.64        20.28               3.21   
min         -1.25      -6.62         60.00       200.00             -10.00   
25%         -0.25      -0.88         82.00       246.00              -1.00   
50%          0.00      -0.25         87.00       260.00               0.00   
75%          0.38       0.25         91.00       269.00               2.00   
max          2.00      10.00         99.00       308.00              11.00   

       CST_Delta_Recent  BCVA_Delta_Total  CST_Delta_Total  Injection_Count  \
count            308.00         

In [94]:

sliding_windows.columns.tolist()

['Patient_ID',
 'Eye',
 'Window_Start_Week',
 'Window_End_Week',
 'Num_Observed_Visits',
 'Arm',
 'Age',
 'Gender',
 'Ethnicity',
 'Race',
 'Diabetes_Type',
 'Diabetes_Years',
 'Baseline_HbA1c',
 'BMI',
 'Week_t0',
 'BCVA_missing_t0',
 'BCVA_t0',
 'CST_t0',
 'Injection_t0',
 'Leakage_Index_t0',
 'Fundus_Path_t0',
 'OCT_Paths_t0',
 'Num_B_scans_t0',
 'Atrophy / thinning of retinal layers_t0',
 'Disruption of EZ_t0',
 'DRIL_t0',
 'IR hemorrhages_t0',
 'IR HRF_t0',
 'Partially attached vitreous face_t0',
 'Fully attached vitreous face_t0',
 'Preretinal tissue/hemorrhage_t0',
 'Vitreous debris_t0',
 'VMT_t0',
 'DRT/ME_t0',
 'Fluid (IRF)_t0',
 'Fluid (SRF)_t0',
 'Disruption of RPE_t0',
 'PED (serous)_t0',
 'SHRM_t0',
 'Week_t1',
 'BCVA_missing_t1',
 'BCVA_t1',
 'CST_t1',
 'Injection_t1',
 'Leakage_Index_t1',
 'Fundus_Path_t1',
 'OCT_Paths_t1',
 'Num_B_scans_t1',
 'Atrophy / thinning of retinal layers_t1',
 'Disruption of EZ_t1',
 'DRIL_t1',
 'IR hemorrhages_t1',
 'IR HRF_t1',
 'Partially at

## remove redundant cols

In [95]:
# ============================================================
# Part 4.2: 删除冗余列
# ============================================================

# 1. 删除冗余的工程特征（完全重复）
redundant_engineered = ['BCVA_Current', 'CST_Current', 'Injection_Recent']

# 2. 删除高度相关的特征（保留 Slope，删除 Delta_Total）
redundant_correlated = ['BCVA_Delta_Total', 'CST_Delta_Total']

# 3. 删除可能有问题的静态特征
redundant_static = ['Arm', 'Ethnicity', 'Race']

# 合并所有要删除的列
cols_to_remove = redundant_engineered + redundant_correlated + redundant_static
cols_to_remove = [c for c in cols_to_remove if c in sliding_windows.columns]

# 执行删除
sliding_windows.drop(columns=cols_to_remove, inplace=True)

print("=== Removed Redundant Columns ===")
print(f"Redundant engineered: {redundant_engineered}")
print(f"Highly correlated: {redundant_correlated}")
print(f"Problematic static: {redundant_static}")
print(f"\nTotal removed: {len(cols_to_remove)}")
print(f"New shape: {sliding_windows.shape}")

# ============================================================
# 更新特征定义
# ============================================================

# 静态特征（简化版）
STATIC_COLS = ['Age', 'Gender', 'Diabetes_Type', 'Diabetes_Years', 'Baseline_HbA1c', 'BMI']

# 时序临床特征
TEMPORAL_COLS = []
for i in range(3):
    TEMPORAL_COLS.extend([
        f'BCVA_t{i}', f'CST_t{i}', f'Injection_t{i}', f'Leakage_Index_t{i}',
        f'BCVA_missing_t{i}'
    ])

# 工程特征（简化版）
ENGINEERED_COLS = ['BCVA_Slope', 'CST_Slope', 'BCVA_Delta_Recent', 'CST_Delta_Recent', 'Injection_Count']

# Baseline 特征集
BASELINE_FEATURES = STATIC_COLS + TEMPORAL_COLS + ENGINEERED_COLS
BASELINE_FEATURES = [f for f in BASELINE_FEATURES if f in sliding_windows.columns]

print(f"\n=== Final Baseline Features ({len(BASELINE_FEATURES)}) ===")
print(f"Static: {len([f for f in STATIC_COLS if f in sliding_windows.columns])}")
print(f"Temporal: {len([f for f in TEMPORAL_COLS if f in sliding_windows.columns])}")
print(f"Engineered: {len([f for f in ENGINEERED_COLS if f in sliding_windows.columns])}")


=== Removed Redundant Columns ===
Redundant engineered: ['BCVA_Current', 'CST_Current', 'Injection_Recent']
Highly correlated: ['BCVA_Delta_Total', 'CST_Delta_Total']
Problematic static: ['Arm', 'Ethnicity', 'Race']

Total removed: 8
New shape: (313, 100)

=== Final Baseline Features (26) ===
Static: 6
Temporal: 15
Engineered: 5


In [98]:
# ============================================================
# Part 4.3: 处理 Biomarker 列
# ============================================================

# Biomarker 列名
BIOMARKER_NAMES = ['Atrophy / thinning of retinal layers', 'Disruption of EZ', 'DRIL',
                   'IR hemorrhages', 'IR HRF', 'Partially attached vitreous face',
                   'Fully attached vitreous face', 'Preretinal tissue/hemorrhage',
                   'Vitreous debris', 'VMT', 'DRT/ME', 'Fluid (IRF)', 'Fluid (SRF)',
                   'Disruption of RPE', 'PED (serous)', 'SHRM']

# 找到所有 biomarker 相关列
biomarker_cols = [col for col in sliding_windows.columns 
                  if any(bm in col for bm in BIOMARKER_NAMES)]

print(f"=== Biomarker Columns ({len(biomarker_cols)}) ===")
print(f"These columns are 98-100% missing and will be removed from baseline model")
print(f"They will be saved separately for SupCon pretraining\n")

# 保存 biomarker 数据（用于后续 SupCon）
biomarker_df = sliding_windows[['Patient_ID', 'Eye', 'Window_Start_Week'] + biomarker_cols].copy()
biomarker_df.to_csv('9a_biomarker_for_supcon.csv', index=False)
print(f"✓ Saved biomarker data: 9a_biomarker_for_supcon.csv")

# 从主数据集删除 biomarker 列
sliding_windows.drop(columns=biomarker_cols, inplace=True)
print(f"✓ Removed {len(biomarker_cols)} biomarker columns from main dataset")
print(f"New shape: {sliding_windows.shape}")

# ============================================================
# 重新检查缺失值
# ============================================================

print("\n=== Remaining Missing Values ===")
missing = sliding_windows.isna().sum()
missing = missing[missing > 0].sort_values(ascending=False)
print(missing)

print("\n=== Summary ===")
print(f"Columns with missing values: {len(missing)}")
print(f"Max missing rate: {100 * missing.max() / len(sliding_windows):.1f}%")
print("\nNote: XGBoost can handle these missing values natively (no imputation needed)")


=== Biomarker Columns (48) ===
These columns are 98-100% missing and will be removed from baseline model
They will be saved separately for SupCon pretraining

✓ Saved biomarker data: 9a_biomarker_for_supcon.csv
✓ Removed 48 biomarker columns from main dataset
New shape: (313, 52)

=== Remaining Missing Values ===
BCVA_Change_12w      23
BCVA_Target_12w      23
BCVA_Change_8w       15
BCVA_Target_8w       15
Leakage_Index_t0     12
Leakage_Index_t1     10
BCVA_Target_4w        9
BMI                   9
BCVA_Change_4w        9
OCT_Paths_t0          7
Num_B_scans_t0        7
Injection_t0          7
BCVA_t0               7
CST_t0                7
Fundus_Path_t0        7
BCVA_t1               5
Injection_t1          5
Num_B_scans_t1        5
OCT_Paths_t1          5
Fundus_Path_t1        5
BCVA_Delta_Recent     5
CST_t1                5
CST_Delta_Recent      5
Leakage_Index_t2      3
dtype: int64

=== Summary ===
Columns with missing values: 24
Max missing rate: 7.3%

Note: XGBoost can handl

In [99]:
sliding_windows.columns.tolist()

['Patient_ID',
 'Eye',
 'Window_Start_Week',
 'Window_End_Week',
 'Num_Observed_Visits',
 'Age',
 'Gender',
 'Diabetes_Type',
 'Diabetes_Years',
 'Baseline_HbA1c',
 'BMI',
 'Week_t0',
 'BCVA_missing_t0',
 'BCVA_t0',
 'CST_t0',
 'Injection_t0',
 'Leakage_Index_t0',
 'Fundus_Path_t0',
 'OCT_Paths_t0',
 'Num_B_scans_t0',
 'Week_t1',
 'BCVA_missing_t1',
 'BCVA_t1',
 'CST_t1',
 'Injection_t1',
 'Leakage_Index_t1',
 'Fundus_Path_t1',
 'OCT_Paths_t1',
 'Num_B_scans_t1',
 'Week_t2',
 'BCVA_missing_t2',
 'BCVA_t2',
 'CST_t2',
 'Injection_t2',
 'Leakage_Index_t2',
 'Fundus_Path_t2',
 'OCT_Paths_t2',
 'Num_B_scans_t2',
 'Target_Week_4w',
 'BCVA_Target_4w',
 'BCVA_Change_4w',
 'Target_Week_8w',
 'BCVA_Target_8w',
 'BCVA_Change_8w',
 'Target_Week_12w',
 'BCVA_Target_12w',
 'BCVA_Change_12w',
 'BCVA_Slope',
 'CST_Slope',
 'BCVA_Delta_Recent',
 'CST_Delta_Recent',
 'Injection_Count']

## 4.1 Save Sliding Window Dataset

In [100]:
# Save sliding window dataset
sliding_windows.to_csv('10_sliding_window_dataset.csv', index=False)
print(f"✓ Saved: 10_sliding_window_dataset.csv")
print(f"  Shape: {sliding_windows.shape}")
print(f"  Samples: {len(sliding_windows)}")

✓ Saved: 10_sliding_window_dataset.csv
  Shape: (313, 52)
  Samples: 313


---
# Part 5: Create Train/Val/Test Split (Patient-Level)

In [None]:
from sklearn.model_selection import train_test_split

# Get unique patients
unique_patients = sliding_windows['Patient_ID'].unique()
print(f"Total unique patients: {len(unique_patients)}")

# Split patients (not samples!) - 70% train, 15% val, 15% test
train_patients, temp_patients = train_test_split(unique_patients, test_size=0.3, random_state=42)
val_patients, test_patients = train_test_split(temp_patients, test_size=0.5, random_state=42)

print(f"Train patients: {len(train_patients)}")
print(f"Val patients: {len(val_patients)}")
print(f"Test patients: {len(test_patients)}")

In [None]:
# Create split column
def assign_split(patient_id):
    if patient_id in train_patients:
        return 'train'
    elif patient_id in val_patients:
        return 'val'
    else:
        return 'test'

sliding_windows['Split'] = sliding_windows['Patient_ID'].apply(assign_split)

print("=== Sample Distribution by Split ===")
print(sliding_windows['Split'].value_counts())

In [None]:
# Save final dataset with split
sliding_windows.to_csv('9_sliding_window_dataset.csv', index=False)
print(f"✓ Updated: 9_sliding_window_dataset.csv (with Split column)")

---
# Summary

## Output Files
1. `8_longitudinal_data_long.csv` - All data in long format (one row per patient-visit)
2. `9_sliding_window_dataset.csv` - Sliding window samples ready for modeling

## Sliding Window Dataset Structure
- **Identifiers**: Patient_ID, Eye, Window_Start_Week, Target_Week
- **Static features**: Age, Gender, Ethnicity, Race, Diabetes_Type, Diabetes_Years, Baseline_HbA1c, BMI
- **Temporal features (t0, t1, t2)**: BCVA, CST, Injection, Leakage_Index, Biomarkers
- **Target**: BCVA_Change = BCVA(target_week) - BCVA(t2)
- **Split**: train/val/test (patient-level split)