# Loading library and data

To use the VitalDB open dataset, the pandas library is required.

In [None]:
# Install numba for performance optimization (if not already installed)
try:
    import numba
    print(f"‚úÖ Numba {numba.__version__} is already installed")
except ImportError:
    print("üì¶ Installing numba for performance optimization...")
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "numba"])
    print("‚úÖ Numba installed successfully!")


At first, we need to load 3 endpoints of the VitalDB open dataset. 

In [None]:
import numpy as np
import pandas as pd
from numba import jit, prange
import warnings
warnings.filterwarnings('ignore')

df_cases = pd.read_csv("https://api.vitaldb.net/cases")  # clinical information
df_trks = pd.read_csv("https://api.vitaldb.net/trks")  # track list
df_labs = pd.read_csv('https://api.vitaldb.net/labs')  # laboratory results

## Using clinical information data
Let's visually check the cases and variables of the VitalDB dataset.

In [2]:
df_cases

Unnamed: 0,caseid,subjectid,casestart,caseend,anestart,aneend,opstart,opend,adm,dis,...,intraop_colloid,intraop_ppf,intraop_mdz,intraop_ftn,intraop_rocu,intraop_vecu,intraop_eph,intraop_phe,intraop_epi,intraop_ca
0,1,5955,0,11542,-552,10848.0,1668,10368,-236220,627780,...,0,120,0.0,100,70,0,10,0,0,0
1,2,2487,0,15741,-1039,14921.0,1721,14621,-221160,1506840,...,0,150,0.0,0,100,0,20,0,0,0
2,3,2861,0,4394,-590,4210.0,1090,3010,-218640,40560,...,0,0,0.0,0,50,0,0,0,0,0
3,4,1903,0,20990,-778,20222.0,2522,17822,-201120,576480,...,0,80,0.0,100,100,0,50,0,0,0
4,5,4416,0,21531,-1009,22391.0,2591,20291,-67560,3734040,...,0,0,0.0,0,160,0,10,900,0,2100
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6383,6384,5583,0,15248,-260,15640.0,2140,14140,-215340,648660,...,0,150,0.0,0,90,0,20,0,0,0
6384,6385,2278,0,20643,-544,20996.0,2396,19496,-225600,1675200,...,0,100,0.0,0,100,0,25,30,0,300
6385,6386,4045,0,19451,-667,19133.0,3533,18233,-200460,836340,...,0,70,0.0,0,130,0,10,0,0,0
6386,6387,5230,0,12025,-550,12830.0,1730,11030,-227760,377040,...,0,120,0.0,0,50,0,0,0,0,0


In [3]:
df_cases['optype'].value_counts()

optype
Colorectal          1350
Biliary/Pancreas     812
Others               799
Stomach              676
Major resection      584
Minor resection      553
Breast               434
Transplantation      403
Vascular             262
Hepatic              258
Thyroid              257
Name: count, dtype: int64

## Using track list data

In [4]:
df_trks

Unnamed: 0,caseid,tname,tid
0,1,BIS/BIS,fd869e25ba82a66cc95b38ed47110bf4f14bb368
1,1,BIS/EEG1_WAV,0aa685df768489a18a5e9f53af0d83bf60890c73
2,1,BIS/EEG2_WAV,ad13b2c39b19193c8ae4a2de4f8315f18d61a57e
3,1,BIS/EMG,2525603efe18d982764dbca457affe7a45e766a9
4,1,BIS/SEF,1c91aec859304840dec75acf4a35da78be0e8ef0
...,...,...,...
486444,6388,Solar8000/VENT_PIP,2d63adbc7e2653f14348e219816673cde3358cf6
486445,6388,Solar8000/VENT_PPLAT,6f6604255858ddc8f6a01b9f4774b0d43105f6da
486446,6388,Solar8000/VENT_RR,f34f3fae7fd963355c1c7060e1e876d20fa87536
486447,6388,Solar8000/VENT_SET_TV,4a4a55b8aebf9c76a4a76f62a7c1ec6fcb80e618


In [5]:
print('{} track types'.format(len(df_trks['tname'].unique())))
for tname in sorted(df_trks['tname'].unique()):
    print('{}\t{}'.format(tname, (df_trks['tname'] == tname).sum() / len(df_cases) * 100))


196 track types
BIS/BIS	91.84408265497808
BIS/EEG1_WAV	91.9067000626174
BIS/EEG2_WAV	91.9067000626174
BIS/EMG	87.30432060112712
BIS/SEF	87.17908578584847
BIS/SQI	91.84408265497808
BIS/SR	87.17908578584847
BIS/TOTPOW	86.8973074514715
CardioQ/ABP	0.4539762053850971
CardioQ/CI	0.4383218534752661
CardioQ/CO	0.4539762053850971
CardioQ/FLOW	0.4539762053850971
CardioQ/FTc	0.4539762053850971
CardioQ/FTp	0.4383218534752661
CardioQ/HR	0.4539762053850971
CardioQ/MA	0.4383218534752661
CardioQ/MD	0.4539762053850971
CardioQ/PV	0.4383218534752661
CardioQ/SD	0.4539762053850971
CardioQ/SV	0.4539762053850971
CardioQ/SVI	0.4383218534752661
EV1000/ART_MBP	9.267376330619912
EV1000/CI	9.658735128365684
EV1000/CO	9.658735128365684
EV1000/CVP	3.678772698810269
EV1000/SV	9.658735128365684
EV1000/SVI	9.658735128365684
EV1000/SVR	3.991859737006888
EV1000/SVRI	3.991859737006888
EV1000/SVV	9.658735128365684
FMS/FLOW_RATE	0.234815278647464
FMS/INPUT_AMB_TEMP	0.234815278647464
FMS/INPUT_TEMP	0.234815278647464
FMS/OU

## Using laboratory results data

In [6]:
print('{} lab types'.format(len(df_labs['name'].unique())))
df_labs

34 lab types


Unnamed: 0,caseid,dt,name,result
0,1,594470,alb,2.90
1,1,399575,alb,3.20
2,1,12614,alb,3.40
3,1,137855,alb,3.60
4,1,399575,alt,12.00
...,...,...,...,...
928443,6388,3503,sao2,100.00
928444,6388,408770,wbc,3.28
928445,6388,-32848,wbc,6.27
928446,6388,-249820,wbc,7.66


## Find a case that satisfies a specific condition

In [7]:
df_t = df_cases[df_cases['optype'] == "Transplantation"]
df_t

Unnamed: 0,caseid,subjectid,casestart,caseend,anestart,aneend,opstart,opend,adm,dis,...,intraop_colloid,intraop_ppf,intraop_mdz,intraop_ftn,intraop_rocu,intraop_vecu,intraop_eph,intraop_phe,intraop_epi,intraop_ca
11,12,491,0,31203,-220,31460.0,5360,30860,-208500,1519500,...,200,100,0.0,100,70,0,20,0,0,3300
28,29,3720,0,21394,-1176,21324.0,3324,20424,-114540,576660,...,0,0,0.0,0,130,0,0,0,0,0
51,52,1724,0,15590,-1453,15647.0,3647,14747,-220140,1075860,...,35,0,0.0,0,120,0,0,0,0,0
53,54,1517,0,15346,-939,15321.0,2421,14421,-132240,299760,...,35,0,0.0,0,90,0,10,0,0,0
54,55,5077,0,21734,-722,22498.0,3598,21151,-210900,1603500,...,100,0,0.0,0,50,0,20,0,0,1200
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6336,6337,5565,0,24223,-449,25651.0,4051,23551,-1286520,1391880,...,20,0,0.0,0,3,0,0,0,0,0
6343,6344,2628,0,12168,-863,11737.0,2737,10837,-114840,317160,...,0,0,0.0,0,70,0,0,0,0,0
6345,6346,5525,0,17606,-380,17560.0,2560,16660,-225420,984180,...,35,0,0.0,0,80,0,0,0,0,300
6362,6363,5396,0,26282,-552,25788.0,2988,23988,-559200,1341600,...,800,0,0.0,0,50,0,0,0,0,0


In [8]:
caseids = df_t['caseid']
caseids

11        12
28        29
51        52
53        54
54        55
        ... 
6336    6337
6343    6344
6345    6346
6362    6363
6382    6383
Name: caseid, Length: 403, dtype: int64

# Transplantation

In [9]:
transplantation_cases = df_cases[df_cases['optype'] == 'Transplantation']
transplantation_cases

Unnamed: 0,caseid,subjectid,casestart,caseend,anestart,aneend,opstart,opend,adm,dis,...,intraop_colloid,intraop_ppf,intraop_mdz,intraop_ftn,intraop_rocu,intraop_vecu,intraop_eph,intraop_phe,intraop_epi,intraop_ca
11,12,491,0,31203,-220,31460.0,5360,30860,-208500,1519500,...,200,100,0.0,100,70,0,20,0,0,3300
28,29,3720,0,21394,-1176,21324.0,3324,20424,-114540,576660,...,0,0,0.0,0,130,0,0,0,0,0
51,52,1724,0,15590,-1453,15647.0,3647,14747,-220140,1075860,...,35,0,0.0,0,120,0,0,0,0,0
53,54,1517,0,15346,-939,15321.0,2421,14421,-132240,299760,...,35,0,0.0,0,90,0,10,0,0,0
54,55,5077,0,21734,-722,22498.0,3598,21151,-210900,1603500,...,100,0,0.0,0,50,0,20,0,0,1200
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6336,6337,5565,0,24223,-449,25651.0,4051,23551,-1286520,1391880,...,20,0,0.0,0,3,0,0,0,0,0
6343,6344,2628,0,12168,-863,11737.0,2737,10837,-114840,317160,...,0,0,0.0,0,70,0,0,0,0,0
6345,6346,5525,0,17606,-380,17560.0,2560,16660,-225420,984180,...,35,0,0.0,0,80,0,0,0,0,300
6362,6363,5396,0,26282,-552,25788.0,2988,23988,-559200,1341600,...,800,0,0.0,0,50,0,0,0,0,0


In [10]:
transplantation_caseids = transplantation_cases['caseid'].tolist()
transplantation_caseids

[12,
 29,
 52,
 54,
 55,
 58,
 60,
 81,
 83,
 97,
 111,
 146,
 164,
 177,
 195,
 202,
 236,
 237,
 251,
 264,
 280,
 284,
 290,
 304,
 345,
 349,
 363,
 378,
 391,
 397,
 401,
 406,
 431,
 448,
 457,
 464,
 470,
 507,
 524,
 553,
 626,
 631,
 638,
 675,
 690,
 691,
 706,
 733,
 734,
 741,
 775,
 783,
 785,
 814,
 818,
 822,
 847,
 870,
 872,
 902,
 945,
 964,
 985,
 986,
 1018,
 1029,
 1042,
 1045,
 1056,
 1061,
 1083,
 1095,
 1157,
 1166,
 1168,
 1173,
 1191,
 1218,
 1229,
 1231,
 1284,
 1292,
 1325,
 1327,
 1350,
 1410,
 1454,
 1482,
 1512,
 1539,
 1545,
 1548,
 1564,
 1586,
 1590,
 1656,
 1683,
 1710,
 1716,
 1720,
 1724,
 1725,
 1730,
 1738,
 1762,
 1785,
 1807,
 1809,
 1820,
 1831,
 1835,
 1858,
 1885,
 1896,
 1900,
 1959,
 1964,
 1976,
 1983,
 1995,
 2014,
 2016,
 2047,
 2069,
 2096,
 2106,
 2130,
 2137,
 2160,
 2168,
 2185,
 2192,
 2202,
 2238,
 2245,
 2252,
 2267,
 2272,
 2273,
 2304,
 2320,
 2325,
 2326,
 2327,
 2331,
 2332,
 2337,
 2360,
 2375,
 2402,
 2453,
 2461,
 2489,
 24

In [11]:
# Load track data for transplantation cases
print("üìà Loading track data for transplantation cases...")

# Get tracks available for transplantation cases
transplantation_tracks = df_trks[df_trks['caseid'].isin(transplantation_caseids)]
print(f"Found {len(transplantation_tracks)} track records for {len(transplantation_caseids)} transplantation cases")
print(f"Unique track types available: {transplantation_tracks['tname'].nunique()}")

# Show track categories for transplantation cases
print("\nüîç Track categories for transplantation cases:")
track_categories = {}
for track in transplantation_tracks['tname'].unique():
    device = track.split('/')[0] if '/' in track else 'Unknown'
    if device not in track_categories:
        track_categories[device] = []
    track_categories[device].append(track)

for device, tracks in sorted(track_categories.items()):
    print(f"\n{device} ({len(tracks)} tracks):")
    for track in sorted(tracks):
        print(f"  - {track}")

# Show most common tracks for transplantation cases
print(f"\nüìä Most common tracks for transplantation cases:")
track_counts = transplantation_tracks['tname'].value_counts().head(15)
for track, count in track_counts.items():
    percentage = (count / len(transplantation_caseids)) * 100
    print(f"  - {track}: {count} cases ({percentage:.1f}%)")


üìà Loading track data for transplantation cases...
Found 35698 track records for 403 transplantation cases
Unique track types available: 183

üîç Track categories for transplantation cases:

BIS (8 tracks):
  - BIS/BIS
  - BIS/EEG1_WAV
  - BIS/EEG2_WAV
  - BIS/EMG
  - BIS/SEF
  - BIS/SQI
  - BIS/SR
  - BIS/TOTPOW

CardioQ (13 tracks):
  - CardioQ/ABP
  - CardioQ/CI
  - CardioQ/CO
  - CardioQ/FLOW
  - CardioQ/FTc
  - CardioQ/FTp
  - CardioQ/HR
  - CardioQ/MA
  - CardioQ/MD
  - CardioQ/PV
  - CardioQ/SD
  - CardioQ/SV
  - CardioQ/SVI

EV1000 (9 tracks):
  - EV1000/ART_MBP
  - EV1000/CI
  - EV1000/CO
  - EV1000/CVP
  - EV1000/SV
  - EV1000/SVI
  - EV1000/SVR
  - EV1000/SVRI
  - EV1000/SVV

FMS (7 tracks):
  - FMS/FLOW_RATE
  - FMS/INPUT_AMB_TEMP
  - FMS/INPUT_TEMP
  - FMS/OUTPUT_AMB_TEMP
  - FMS/OUTPUT_TEMP
  - FMS/PRESSURE
  - FMS/TOTAL_VOL

Invos (2 tracks):
  - Invos/SCO2_L
  - Invos/SCO2_R

Orchestra (39 tracks):
  - Orchestra/DEX4_RATE
  - Orchestra/DEX4_VOL
  - Orchestra/DOBU_RAT

In [12]:
# Load lab data for transplantation cases
print("üß™ Loading lab data for transplantation cases...")

# Get lab results for transplantation cases
transplantation_labs = df_labs[df_labs['caseid'].isin(transplantation_caseids)]
print(f"Found {len(transplantation_labs)} lab records for transplantation cases")
print(f"Unique lab types available: {transplantation_labs['name'].nunique()}")

# Show lab types and their frequency for transplantation cases
print(f"\nüî¨ Lab types for transplantation cases:")
lab_counts = transplantation_labs['name'].value_counts()
for lab, count in lab_counts.items():
    percentage = (count / len(transplantation_caseids)) * 100
    print(f"  - {lab}: {count} records ({percentage:.1f}% of cases)")

# Show key organ function tests
print(f"\nü´Ä Key organ function tests for transplantation cases:")
organ_function_tests = ['ast', 'alt', 'tbil', 'alb', 'tprot', 'ptsec', 'ptinr', 'aptt', 'fib']
for test in organ_function_tests:
    if test in transplantation_labs['name'].values:
        count = (transplantation_labs['name'] == test).sum()
        percentage = (count / len(transplantation_caseids)) * 100
        print(f"  - {test}: {count} records ({percentage:.1f}% of cases)")
    else:
        print(f"  - {test}: Not available")

# Show basic metabolic panel
print(f"\nüß™ Basic metabolic panel for transplantation cases:")
metabolic_tests = ['na', 'k', 'cl', 'cr', 'bun', 'gfr', 'gluc']
for test in metabolic_tests:
    if test in transplantation_labs['name'].values:
        count = (transplantation_labs['name'] == test).sum()
        percentage = (count / len(transplantation_caseids)) * 100
        print(f"  - {test}: {count} records ({percentage:.1f}% of cases)")
    else:
        print(f"  - {test}: Not available")


üß™ Loading lab data for transplantation cases...
Found 193144 lab records for transplantation cases
Unique lab types available: 34

üî¨ Lab types for transplantation cases:
  - hct: 11573 records (2871.7% of cases)
  - k: 10706 records (2656.6% of cases)
  - na: 10703 records (2655.8% of cases)
  - hb: 8769 records (2175.9% of cases)
  - wbc: 8753 records (2172.0% of cases)
  - plt: 8665 records (2150.1% of cases)
  - gluc: 8398 records (2083.9% of cases)
  - cl: 7872 records (1953.3% of cases)
  - alb: 7841 records (1945.7% of cases)
  - cr: 7679 records (1905.5% of cases)
  - bun: 7661 records (1901.0% of cases)
  - tprot: 7654 records (1899.3% of cases)
  - tbil: 7559 records (1875.7% of cases)
  - ast: 7523 records (1866.7% of cases)
  - alt: 7521 records (1866.3% of cases)
  - gfr: 7369 records (1828.5% of cases)
  - ptinr: 5580 records (1384.6% of cases)
  - ptsec: 5580 records (1384.6% of cases)
  - pt%: 5579 records (1384.4% of cases)
  - aptt: 5173 records (1283.6% of cases

In [14]:
!pip install vitaldb

Collecting vitaldb
  Downloading vitaldb-1.5.8-py3-none-any.whl.metadata (314 bytes)
Collecting requests (from vitaldb)
  Using cached requests-2.32.5-py3-none-any.whl.metadata (4.9 kB)
Collecting wfdb (from vitaldb)
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Collecting charset_normalizer<4,>=2 (from requests->vitaldb)
  Using cached charset_normalizer-3.4.4-cp314-cp314-win_amd64.whl.metadata (38 kB)
Collecting idna<4,>=2.5 (from requests->vitaldb)
  Using cached idna-3.11-py3-none-any.whl.metadata (8.4 kB)
Collecting urllib3<3,>=1.21.1 (from requests->vitaldb)
  Downloading urllib3-2.6.2-py3-none-any.whl.metadata (6.6 kB)
Collecting certifi>=2017.4.17 (from requests->vitaldb)
  Using cached certifi-2025.11.12-py3-none-any.whl.metadata (2.5 kB)
Collecting aiohttp>=3.10.11 (from wfdb->vitaldb)
  Using cached aiohttp-3.13.2-cp314-cp314-win_amd64.whl.metadata (8.4 kB)
Collecting fsspec>=2023.10.0 (from wfdb->vitaldb)
  Downloading fsspec-2025.12.0-py3-none-any.whl.metadat

In [None]:
# Load waveform data for transplantation cases using vitaldb
import vitaldb

print("üîÑ Loading waveform data for transplantation cases...")

# Get available track names for transplantation cases
available_tracks = transplantation_tracks['tname'].unique().tolist()
print(f"Available tracks: {len(available_tracks)} types")

# Select key tracks for transplantation surgery monitoring
key_tracks = [
    'Solar8000/HR',           # Heart rate
    'Solar8000/ART_SBP',      # Systolic blood pressure  
    'Solar8000/ART_DBP',      # Diastolic blood pressure
    'Solar8000/ART_MBP',      # Mean blood pressure
    'Solar8000/PLETH_SPO2',   # Oxygen saturation
    'SNUADC/ART',             # Arterial pressure waveform
    'SNUADC/ECG_II',          # ECG Lead II
    'Primus/CO2',             # CO2 monitoring
    'Primus/ETCO2',           # End-tidal CO2
    'BIS/BIS'                 # Bispectral index
]

# Filter to only tracks that are available for transplantation cases
available_key_tracks = [track for track in key_tracks if track in available_tracks]
print(f"Key tracks available: {len(available_key_tracks)}")

if len(available_key_tracks) > 0:
    print("Available key tracks:")
    for track in available_key_tracks:
        print(f"  - {track}")
        
    # Load waveform data for first transplantation case as example
    if len(transplantation_caseids) > 0:
        example_case_id = transplantation_caseids[0]
        print(f"\nüìä Loading waveform data for transplantation case {example_case_id}...")
        
        try:
            # Load waveform data
            waveform_data = vitaldb.load_case(
                caseid=example_case_id,
                track_names=available_key_tracks,
                interval=1.0  # 1Hz sampling rate (consistent with training data)
            )
            
            print(f"‚úÖ Successfully loaded waveform data!")
            print(f"  - Data shape: {waveform_data.shape}")
            print(f"  - Time points: {waveform_data.shape[0]}")
            print(f"  - Tracks: {waveform_data.shape[1]}")
            
            # Show data availability for each track
            print("\nüìà Data availability for each track:")
            for i, track_name in enumerate(available_key_tracks):
                track_data = waveform_data[:, i]
                non_nan_count = np.sum(~np.isnan(track_data))
                total_count = len(track_data)
                availability = (non_nan_count / total_count) * 100
                print(f"  - {track_name}: {non_nan_count}/{total_count} ({availability:.1f}%)")
                
        except Exception as e:
            print(f"‚ùå Error loading waveform data: {e}")
else:
    print("‚ùå No key tracks available for transplantation cases")


üîÑ Loading waveform data for transplantation cases...
Available tracks: 183 types
Key tracks available: 10
Available key tracks:
  - Solar8000/HR
  - Solar8000/ART_SBP
  - Solar8000/ART_DBP
  - Solar8000/ART_MBP
  - Solar8000/PLETH_SPO2
  - SNUADC/ART
  - SNUADC/ECG_II
  - Primus/CO2
  - Primus/ETCO2
  - BIS/BIS

üìä Loading waveform data for transplantation case 12...
‚úÖ Successfully loaded waveform data!
  - Data shape: (3120291, 10)
  - Time points: 3120291
  - Tracks: 10

üìà Data availability for each track:
  - Solar8000/HR: 15482/3120291 (0.5%)
  - Solar8000/ART_SBP: 13858/3120291 (0.4%)
  - Solar8000/ART_DBP: 13861/3120291 (0.4%)
  - Solar8000/ART_MBP: 14346/3120291 (0.5%)
  - Solar8000/PLETH_SPO2: 15473/3120291 (0.5%)
  - SNUADC/ART: 2255264/3120291 (72.3%)
  - SNUADC/ECG_II: 2255264/3120291 (72.3%)
  - Primus/CO2: 3117940/3120291 (99.9%)
  - Primus/ETCO2: 4313/3120291 (0.1%)
  - BIS/BIS: 31184/3120291 (1.0%)


In [16]:
# Create a comprehensive data summary for transplantation cases
print("üìã TRANSPLANTATION CASES DATA SUMMARY")
print("=" * 50)

print(f"\nüè• CASES:")
print(f"  - Total transplantation cases: {len(transplantation_cases)}")
print(f"  - Case IDs: {transplantation_caseids[:10]}..." if len(transplantation_caseids) > 10 else f"  - Case IDs: {transplantation_caseids}")

print(f"\nüìà WAVEFORMS:")
print(f"  - Total track records: {len(transplantation_tracks)}")
print(f"  - Unique track types: {transplantation_tracks['tname'].nunique()}")
print(f"  - Track categories: {len(track_categories)}")

print(f"\nüß™ LAB RESULTS:")
print(f"  - Total lab records: {len(transplantation_labs)}")
print(f"  - Unique lab types: {transplantation_labs['name'].nunique()}")

print(f"\nüéØ KEY INSIGHTS FOR TRANSPLANTATION SURGERY:")
print("  - Cardiovascular monitoring: Heart rate, blood pressure, ECG")
print("  - Respiratory monitoring: CO2, oxygen saturation, ventilation")
print("  - Neurological monitoring: BIS index, EEG")
print("  - Organ function tests: AST, ALT, bilirubin, albumin, PT/INR")
print("  - Metabolic panel: Electrolytes, creatinine, glucose")
print("  - Coagulation studies: PT, PTT, fibrinogen")

print(f"\n‚úÖ Ready for real-time transplantation vital forecasting and anomaly detection!")
print(f"   - {len(transplantation_caseids)} cases available")
print(f"   - {transplantation_tracks['tname'].nunique()} waveform types")
print(f"   - {transplantation_labs['name'].nunique()} lab result types")


üìã TRANSPLANTATION CASES DATA SUMMARY

üè• CASES:
  - Total transplantation cases: 403
  - Case IDs: [12, 29, 52, 54, 55, 58, 60, 81, 83, 97]...

üìà WAVEFORMS:
  - Total track records: 35698
  - Unique track types: 183
  - Track categories: 11

üß™ LAB RESULTS:
  - Total lab records: 193144
  - Unique lab types: 34

üéØ KEY INSIGHTS FOR TRANSPLANTATION SURGERY:
  - Cardiovascular monitoring: Heart rate, blood pressure, ECG
  - Respiratory monitoring: CO2, oxygen saturation, ventilation
  - Neurological monitoring: BIS index, EEG
  - Organ function tests: AST, ALT, bilirubin, albumin, PT/INR
  - Metabolic panel: Electrolytes, creatinine, glucose
  - Coagulation studies: PT, PTT, fibrinogen

‚úÖ Ready for real-time transplantation vital forecasting and anomaly detection!
   - 403 cases available
   - 183 waveform types
   - 34 lab result types


## Data Quality Analysis for BP Forecasting

Analyze track availability and data completeness to identify transplantation cases suitable for blood pressure forecasting.


In [17]:
# Data Quality Analysis for BP Forecasting
print("üîç Analyzing data quality for BP forecasting...")

# Define BP-related tracks and feature tracks
bp_tracks = {
    'target': ['Solar8000/ART_MBP'],  # Mean Blood Pressure (target)
    'bp_related': ['Solar8000/ART_SBP', 'Solar8000/ART_DBP', 'SNUADC/ART'],
    'features': ['Solar8000/HR', 'Solar8000/PLETH_SPO2', 'Primus/ETCO2', 'BIS/BIS']
}

# Analyze track availability for each transplantation case
case_quality = []

for caseid in transplantation_caseids:
    case_tracks = transplantation_tracks[transplantation_tracks['caseid'] == caseid]['tname'].tolist()
    
    quality_metrics = {
        'caseid': caseid,
        'has_mbp': 'Solar8000/ART_MBP' in case_tracks,
        'has_sbp': 'Solar8000/ART_SBP' in case_tracks,
        'has_dbp': 'Solar8000/ART_DBP' in case_tracks,
        'has_hr': 'Solar8000/HR' in case_tracks,
        'has_spo2': 'Solar8000/PLETH_SPO2' in case_tracks,
        'has_etco2': 'Primus/ETCO2' in case_tracks,
        'has_bis': 'BIS/BIS' in case_tracks,
        'available_tracks': len(case_tracks)
    }
    
    # Count available feature tracks
    feature_count = sum([
        quality_metrics['has_hr'],
        quality_metrics['has_spo2'],
        quality_metrics['has_etco2'],
        quality_metrics['has_bis']
    ])
    quality_metrics['feature_count'] = feature_count
    
    case_quality.append(quality_metrics)

quality_df = pd.DataFrame(case_quality)

print(f"\nüìä Track Availability Summary:")
print(f"  - Cases with MBP track: {quality_df['has_mbp'].sum()} / {len(quality_df)} ({quality_df['has_mbp'].sum()/len(quality_df)*100:.1f}%)")
print(f"  - Cases with HR track: {quality_df['has_hr'].sum()} / {len(quality_df)} ({quality_df['has_hr'].sum()/len(quality_df)*100:.1f}%)")
print(f"  - Cases with SpO2 track: {quality_df['has_spo2'].sum()} / {len(quality_df)} ({quality_df['has_spo2'].sum()/len(quality_df)*100:.1f}%)")
print(f"  - Cases with ETCO2 track: {quality_df['has_etco2'].sum()} / {len(quality_df)} ({quality_df['has_etco2'].sum()/len(quality_df)*100:.1f}%)")
print(f"  - Cases with BIS track: {quality_df['has_bis'].sum()} / {len(quality_df)} ({quality_df['has_bis'].sum()/len(quality_df)*100:.1f}%)")

print(f"\nüìà Feature Availability:")
print(f"  - Cases with 0 features: {(quality_df['feature_count'] == 0).sum()}")
print(f"  - Cases with 1-2 features: {((quality_df['feature_count'] >= 1) & (quality_df['feature_count'] <= 2)).sum()}")
print(f"  - Cases with 3-4 features: {((quality_df['feature_count'] >= 3) & (quality_df['feature_count'] <= 4)).sum()}")

quality_df.head(10)


üîç Analyzing data quality for BP forecasting...

üìä Track Availability Summary:
  - Cases with MBP track: 316 / 403 (78.4%)
  - Cases with HR track: 403 / 403 (100.0%)
  - Cases with SpO2 track: 403 / 403 (100.0%)
  - Cases with ETCO2 track: 402 / 403 (99.8%)
  - Cases with BIS track: 387 / 403 (96.0%)

üìà Feature Availability:
  - Cases with 0 features: 0
  - Cases with 1-2 features: 0
  - Cases with 3-4 features: 403


Unnamed: 0,caseid,has_mbp,has_sbp,has_dbp,has_hr,has_spo2,has_etco2,has_bis,available_tracks,feature_count
0,12,True,True,True,True,True,True,True,96,4
1,29,True,True,True,True,True,True,True,89,4
2,52,True,True,True,True,True,True,True,94,4
3,54,False,False,False,True,True,True,True,76,4
4,55,True,True,True,True,True,True,True,113,4
5,58,True,True,True,True,True,True,True,84,4
6,60,True,True,True,True,True,True,True,98,4
7,81,False,False,False,True,True,True,True,78,4
8,83,True,True,True,True,True,True,True,93,4
9,97,True,True,True,True,True,True,True,96,4


In [None]:
# Numba-optimized function for calculating data completeness
@jit(nopython=True)
def calculate_completeness_numba(data_array):
    """
    Calculate completeness percentage for a data array.
    Optimized with numba for faster computation.
    """
    total = len(data_array)
    non_nan = 0
    for i in prange(total):
        if not np.isnan(data_array[i]):
            non_nan += 1
    return (non_nan / total * 100.0) if total > 0 else 0.0

# Load sample data to check data completeness
print("üìä Checking data completeness for cases with MBP track...")

mbp_cases = quality_df[quality_df['has_mbp'] == True]['caseid'].tolist()
print(f"Found {len(mbp_cases)} cases with MBP track")

# Check data completeness for a sample of cases
completeness_results = []

# Sample first 20 cases to check completeness (to avoid long runtime)
sample_cases = mbp_cases[:20] if len(mbp_cases) > 20 else mbp_cases

for caseid in sample_cases:
    try:
        # Load MBP data at 1Hz (1 second intervals)
        mbp_data = vitaldb.load_case(
            caseid=caseid,
            track_names=['Solar8000/ART_MBP'],
            interval=1.0  # 1Hz sampling
        )
        
        if mbp_data is not None and mbp_data.shape[0] > 0:
            mbp_values = mbp_data[:, 0].astype(np.float64)
            # Use numba-optimized function
            completeness = calculate_completeness_numba(mbp_values)
            total_count = len(mbp_values)
            non_nan_count = int(total_count * completeness / 100.0)
            duration_minutes = total_count / 60.0  # Convert seconds to minutes
            
            completeness_results.append({
                'caseid': caseid,
                'completeness': completeness,
                'duration_minutes': duration_minutes,
                'non_nan_count': non_nan_count,
                'total_count': total_count
            })
    except Exception as e:
        print(f"  ‚ö†Ô∏è Error loading case {caseid}: {e}")
        continue

if completeness_results:
    completeness_df = pd.DataFrame(completeness_results)
    print(f"\n‚úÖ Data completeness analysis for {len(completeness_df)} cases:")
    print(f"  - Mean completeness: {completeness_df['completeness'].mean():.1f}%")
    print(f"  - Mean duration: {completeness_df['duration_minutes'].mean():.1f} minutes")
    print(f"  - Cases with >20% completeness: {(completeness_df['completeness'] > 20).sum()}")
    print(f"  - Cases with >2 minutes: {(completeness_df['duration_minutes'] > 2).sum()}")
    
    # Filter cases with good data quality (relaxed criteria)
    # Lower threshold: >20% completeness and >2 minutes duration (sufficient for 90s minimum)
    good_cases = completeness_df[
        (completeness_df['completeness'] > 20) & 
        (completeness_df['duration_minutes'] > 2)
    ]['caseid'].tolist()
    
    print(f"\n‚úÖ Cases suitable for forecasting: {len(good_cases)}")
    completeness_df.head(10)
else:
    print("‚ùå No completeness data available")
    good_cases = []


üìä Checking data completeness for cases with MBP track...
Found 316 cases with MBP track

‚úÖ Data completeness analysis for 20 cases:
  - Mean completeness: 49.6%
  - Mean duration: 356.1 minutes
  - Cases with >80% completeness: 0
  - Cases with >30 minutes: 20

‚úÖ Cases suitable for forecasting: 0


In [19]:
# Filter transplantation cases for forecasting
# Use cases with MBP track and at least 2 feature tracks
filtered_cases = quality_df[
    (quality_df['has_mbp'] == True) & 
    (quality_df['feature_count'] >= 2)
]['caseid'].tolist()

print(f"üìã Filtered cases for BP forecasting:")
print(f"  - Total transplantation cases: {len(transplantation_caseids)}")
print(f"  - Cases with MBP track: {quality_df['has_mbp'].sum()}")
print(f"  - Cases with MBP + 2+ features: {len(filtered_cases)}")
print(f"  - Selected cases: {len(filtered_cases)}")

# If we have completeness data, further filter
if 'good_cases' in locals() and len(good_cases) > 0:
    # Intersect with good quality cases
    filtered_cases = [c for c in filtered_cases if c in good_cases]
    print(f"  - After quality filtering: {len(filtered_cases)} cases")

print(f"\n‚úÖ Final filtered case IDs: {filtered_cases[:10]}..." if len(filtered_cases) > 10 else f"\n‚úÖ Final filtered case IDs: {filtered_cases}")


üìã Filtered cases for BP forecasting:
  - Total transplantation cases: 403
  - Cases with MBP track: 316
  - Cases with MBP + 2+ features: 316
  - Selected cases: 316

‚úÖ Final filtered case IDs: [12, 29, 52, 55, 58, 60, 83, 97, 111, 146]...


## Load Waveform Data and Export to CSV

Load waveform data for filtered transplantation cases and export to CSV for BP forecasting.


In [20]:
# Batch load waveform data for filtered cases
print("üì• Loading waveform data for filtered transplantation cases...")
print(f"Processing {len(filtered_cases)} cases...")

# Define tracks to load (MBP as target, others as features)
tracks_to_load = {
    'mbp': 'Solar8000/ART_MBP',
    'hr': 'Solar8000/HR',
    'spo2': 'Solar8000/PLETH_SPO2',
    'etco2': 'Primus/ETCO2',
    'bis': 'BIS/BIS'
}

# Get available tracks for each case
all_data = []

for idx, caseid in enumerate(filtered_cases):
    if (idx + 1) % 10 == 0:
        print(f"  Processing case {idx + 1}/{len(filtered_cases)}...")
    
    try:
        # Get available tracks for this case
        case_tracks = transplantation_tracks[transplantation_tracks['caseid'] == caseid]['tname'].tolist()
        
        # Build list of tracks to load (only if available)
        tracks_list = []
        track_mapping = {}  # Map track name to column index
        
        for key, track_name in tracks_to_load.items():
            if track_name in case_tracks:
                tracks_list.append(track_name)
                track_mapping[track_name] = key
        
        if len(tracks_list) == 0:
            continue
        
        # Load waveform data at 1Hz (1 second intervals)
        waveform_data = vitaldb.load_case(
            caseid=caseid,
            track_names=tracks_list,
            interval=1.0  # 1Hz sampling rate
        )
        
        if waveform_data is None or waveform_data.shape[0] == 0:
            continue
        
        # Create DataFrame for this case
        case_df = pd.DataFrame()
        case_df['caseid'] = caseid
        case_df['timestamp'] = np.arange(len(waveform_data))  # Time in seconds
        
        # Add each track as a column
        for i, track_name in enumerate(tracks_list):
            column_name = track_mapping[track_name]
            case_df[column_name] = waveform_data[:, i]
        
        all_data.append(case_df)
        
    except Exception as e:
        print(f"  ‚ö†Ô∏è Error loading case {caseid}: {e}")
        continue

print(f"\n‚úÖ Successfully loaded data from {len(all_data)} cases")

# Combine all cases into single DataFrame
if len(all_data) > 0:
    combined_data = pd.concat(all_data, ignore_index=True)
    print(f"\nüìä Combined dataset:")
    print(f"  - Total rows: {len(combined_data)}")
    print(f"  - Columns: {list(combined_data.columns)}")
    print(f"  - Memory usage: {combined_data.memory_usage(deep=True).sum() / 1024**2:.2f} MB")
    
    # Show data availability
    print(f"\nüìà Data availability:")
    for col in ['mbp', 'hr', 'spo2', 'etco2', 'bis']:
        if col in combined_data.columns:
            non_nan = combined_data[col].notna().sum()
            total = len(combined_data)
            print(f"  - {col.upper()}: {non_nan}/{total} ({non_nan/total*100:.1f}%)")
    
    combined_data.head()
else:
    print("‚ùå No data loaded")
    combined_data = pd.DataFrame()


üì• Loading waveform data for filtered transplantation cases...
Processing 316 cases...
  Processing case 10/316...
  Processing case 20/316...
  Processing case 30/316...
  Processing case 40/316...
  Processing case 50/316...
  Processing case 60/316...
  Processing case 70/316...
  Processing case 80/316...
  Processing case 90/316...
  Processing case 100/316...
  Processing case 110/316...
  Processing case 120/316...


KeyboardInterrupt: 

In [None]:
# Numba-optimized forward/backward fill
@jit(nopython=True)
def forward_backward_fill_numba(data_array):
    """
    Forward fill then backward fill for a 1D array.
    Optimized with numba for faster computation.
    """
    n = len(data_array)
    result = data_array.copy()
    
    # Forward fill
    last_valid = np.nan
    for i in range(n):
        if not np.isnan(result[i]):
            last_valid = result[i]
        else:
            result[i] = last_valid
    
    # Backward fill
    last_valid = np.nan
    for i in range(n - 1, -1, -1):
        if not np.isnan(result[i]):
            last_valid = result[i]
        else:
            result[i] = last_valid
    
    return result

# Handle missing data with forward fill and interpolation
print("üîß Processing missing data...")

if len(combined_data) > 0:
    # Forward fill missing values (within each case)
    processed_data = combined_data.copy()
    
    # Group by caseid and forward fill
    for caseid in processed_data['caseid'].unique():
        case_mask = processed_data['caseid'] == caseid
        case_data = processed_data.loc[case_mask].copy()
        
        # Forward fill, then backward fill for each column using numba
        for col in ['mbp', 'hr', 'spo2', 'etco2', 'bis']:
            if col in case_data.columns:
                # Convert to numpy array for numba processing
                col_array = case_data[col].values.astype(np.float64)
                filled_array = forward_backward_fill_numba(col_array)
                case_data[col] = filled_array
                processed_data.loc[case_mask, col] = filled_array
    
    # Remove rows where MBP is still NaN (critical for target variable)
    initial_rows = len(processed_data)
    processed_data = processed_data[processed_data['mbp'].notna()]
    removed_rows = initial_rows - len(processed_data)
    
    print(f"  - Removed {removed_rows} rows with missing MBP")
    print(f"  - Remaining rows: {len(processed_data)}")
    
    # Show final data availability
    print(f"\nüìà Final data availability after processing:")
    for col in ['mbp', 'hr', 'spo2', 'etco2', 'bis']:
        if col in processed_data.columns:
            non_nan = processed_data[col].notna().sum()
            total = len(processed_data)
            print(f"  - {col.upper()}: {non_nan}/{total} ({non_nan/total*100:.1f}%)")
    
    processed_data.head(10)
else:
    processed_data = pd.DataFrame()


In [None]:
# Export to CSV
import os

if len(processed_data) > 0:
    output_file = 'transplantation_bp_data.csv'
    processed_data.to_csv(output_file, index=False)
    print(f"‚úÖ Data exported to {output_file}")
    print(f"  - Total rows: {len(processed_data)}")
    print(f"  - Unique cases: {processed_data['caseid'].nunique()}")
    if os.path.exists(output_file):
        print(f"  - File size: {os.path.getsize(output_file) / 1024**2:.2f} MB")
    
    # Create metadata summary
    metadata = {
        'total_cases': processed_data['caseid'].nunique(),
        'total_rows': len(processed_data),
        'sampling_rate_hz': 1.0,
        'tracks_included': list(processed_data.columns),
        'case_ids': processed_data['caseid'].unique().tolist(),
        'date_created': pd.Timestamp.now().isoformat()
    }
    
    print(f"\nüìã Metadata:")
    print(f"  - Cases: {metadata['total_cases']}")
    print(f"  - Sampling rate: {metadata['sampling_rate_hz']} Hz")
    print(f"  - Columns: {', '.join(metadata['tracks_included'])}")
else:
    print("‚ùå No data to export")


## Time Series Preprocessing for LSTM

Prepare data for LSTM forecasting: normalization, sliding windows, and train/val/test split.


In [None]:
# Numba-optimized normalization helper
@jit(nopython=True)
def normalize_array_numba(data_array):
    """
    Normalize array using z-score (mean=0, std=1).
    Returns normalized array, mean, and std.
    """
    n = len(data_array)
    mean = 0.0
    std = 0.0
    
    # Calculate mean
    valid_count = 0
    for i in range(n):
        if not np.isnan(data_array[i]):
            mean += data_array[i]
            valid_count += 1
    if valid_count > 0:
        mean = mean / valid_count
    
    # Calculate std
    for i in range(n):
        if not np.isnan(data_array[i]):
            diff = data_array[i] - mean
            std += diff * diff
    if valid_count > 0:
        std = np.sqrt(std / valid_count)
    
    # Normalize
    normalized = np.zeros(n, dtype=np.float64)
    for i in range(n):
        if not np.isnan(data_array[i]) and std > 1e-8:
            normalized[i] = (data_array[i] - mean) / std
        else:
            normalized[i] = data_array[i]
    
    return normalized, mean, std

# Load the exported data
import os
from sklearn.preprocessing import StandardScaler, MinMaxScaler

if os.path.exists('transplantation_bp_data.csv'):
    data = pd.read_csv('transplantation_bp_data.csv')
    print(f"‚úÖ Loaded data: {len(data)} rows, {data['caseid'].nunique()} cases")
else:
    print("‚ùå CSV file not found. Please run previous cells first.")
    data = pd.DataFrame()

if len(data) > 0:
    print(f"\nüìä Data summary:")
    print(f"  - Cases: {data['caseid'].nunique()}")
    print(f"  - Columns: {list(data.columns)}")
    print(f"  - Date range per case: {data.groupby('caseid')['timestamp'].agg(['min', 'max', 'count']).head()}")


In [None]:
# Preprocessing parameters
SEQ_LENGTH = 60  # 1 minute at 1Hz (input sequence) - 60 past values
FORECAST_HORIZON = 30  # 30 seconds at 1Hz (forecast horizon) - forecast 30s ahead
FEATURE_COLS = ['mbp', 'hr', 'spo2', 'etco2', 'bis']  # Features to use

print(f"‚öôÔ∏è Preprocessing configuration:")
print(f"  - Input sequence length: {SEQ_LENGTH} seconds ({SEQ_LENGTH/60:.1f} minutes)")
print(f"  - Forecast horizon: {FORECAST_HORIZON} seconds ({FORECAST_HORIZON/60:.1f} minutes)")
print(f"  - Features: {FEATURE_COLS}")

# Filter to only cases with sufficient data
if len(data) > 0:
    # Calculate minimum required length per case
    min_length = SEQ_LENGTH + FORECAST_HORIZON
    
    case_lengths = data.groupby('caseid').size()
    valid_cases = case_lengths[case_lengths >= min_length].index.tolist()
    
    print(f"\nüìã Case filtering:")
    print(f"  - Total cases: {data['caseid'].nunique()}")
    print(f"  - Cases with sufficient data (>{min_length} seconds): {len(valid_cases)}")
    
    # Filter data to valid cases
    data_filtered = data[data['caseid'].isin(valid_cases)].copy()
    print(f"  - Filtered data rows: {len(data_filtered)}")
    
    # Separate features and target
    available_features = [col for col in FEATURE_COLS if col in data_filtered.columns]
    print(f"  - Available features: {available_features}")
    
    # Normalize features (per case to avoid data leakage)
    # Use numba-optimized normalization for better performance
    scalers = {}
    normalized_data = data_filtered.copy()
    
    for caseid in valid_cases:
        case_mask = normalized_data['caseid'] == caseid
        case_data = normalized_data.loc[case_mask, available_features].copy()
        
        # Use numba-optimized normalization for each feature
        scaled_values = case_data.values.copy().astype(np.float64)
        n_features = scaled_values.shape[1]
        
        for feat_idx in range(n_features):
            feat_array = scaled_values[:, feat_idx]
            normalized_feat, mean_val, std_val = normalize_array_numba(feat_array)
            scaled_values[:, feat_idx] = normalized_feat
        
        normalized_data.loc[case_mask, available_features] = scaled_values
        
        # Store scaler info (for denormalization later)
        scaler = StandardScaler()
        scaler.fit(case_data.values)
        scalers[caseid] = scaler
    
    print(f"‚úÖ Normalized data for {len(valid_cases)} cases")
    normalized_data.head()
else:
    normalized_data = pd.DataFrame()
    valid_cases = []


In [None]:
# Numba-optimized sequence creation helper
@jit(nopython=True)
def create_sequences_numba(feature_data, target_data, seq_length, forecast_horizon):
    """
    Create sequences from numpy arrays (numba-optimized).
    feature_data: (n_samples, n_features) array
    target_data: (n_samples,) array
    """
    n_samples = len(feature_data)
    n_features = feature_data.shape[1]
    min_length = seq_length + forecast_horizon
    
    if n_samples < min_length:
        return None, None
    
    max_sequences = n_samples - min_length + 1
    sequences = np.zeros((max_sequences, seq_length, n_features), dtype=np.float64)
    targets = np.zeros(max_sequences, dtype=np.float64)
    valid_count = 0
    
    for i in range(n_samples - min_length + 1):
        # Extract sequence
        for j in range(seq_length):
            for k in range(n_features):
                sequences[valid_count, j, k] = feature_data[i + j, k]
        
        # Extract target
        target_idx = i + seq_length + forecast_horizon - 1
        targets[valid_count] = target_data[target_idx]
        valid_count += 1
    
    return sequences[:valid_count], targets[:valid_count]

# Create sliding windows for LSTM (optimized version)
def create_sequences(data, caseid_col, feature_cols, target_col, seq_length, forecast_horizon):
    """
    Create input sequences and target values for time series forecasting.
    Uses numba-optimized helper for faster processing.
    """
    sequences = []
    targets = []
    case_ids = []
    
    for caseid in data[caseid_col].unique():
        case_data = data[data[caseid_col] == caseid].sort_values('timestamp').reset_index(drop=True)
        
        # Need at least seq_length + forecast_horizon data points
        if len(case_data) < seq_length + forecast_horizon:
            continue
        
        # Convert to numpy arrays for numba processing
        feature_array = case_data[feature_cols].values.astype(np.float64)
        target_array = case_data[target_col].values.astype(np.float64)
        
        # Use numba-optimized function
        case_sequences, case_targets = create_sequences_numba(
            feature_array, target_array, seq_length, forecast_horizon
        )
        
        if case_sequences is not None and len(case_sequences) > 0:
            sequences.append(case_sequences)
            targets.append(case_targets)
            case_ids.extend([caseid] * len(case_targets))
    
    if len(sequences) > 0:
        return np.vstack(sequences), np.concatenate(targets), np.array(case_ids)
    else:
        return np.array([]), np.array([]), np.array([])

if len(normalized_data) > 0:
    print("üîÑ Creating sequences...")
    
    # Create sequences
    X, y, case_ids = create_sequences(
        normalized_data,
        caseid_col='caseid',
        feature_cols=available_features,
        target_col='mbp',
        seq_length=SEQ_LENGTH,
        forecast_horizon=FORECAST_HORIZON
    )
    
    print(f"‚úÖ Created sequences:")
    print(f"  - Input shape: {X.shape}")
    print(f"  - Target shape: {y.shape}")
    print(f"  - Unique cases: {len(np.unique(case_ids))}")
    
    # Split by case ID to avoid data leakage
    unique_cases = np.unique(case_ids)
    np.random.seed(42)
    np.random.shuffle(unique_cases)
    
    train_split = int(0.7 * len(unique_cases))
    val_split = int(0.85 * len(unique_cases))
    
    train_cases = unique_cases[:train_split]
    val_cases = unique_cases[train_split:val_split]
    test_cases = unique_cases[val_split:]
    
    # Create masks
    train_mask = np.isin(case_ids, train_cases)
    val_mask = np.isin(case_ids, val_cases)
    test_mask = np.isin(case_ids, test_cases)
    
    X_train, y_train = X[train_mask], y[train_mask]
    X_val, y_val = X[val_mask], y[val_mask]
    X_test, y_test = X[test_mask], y[test_mask]
    
    print(f"\nüìä Train/Val/Test split:")
    print(f"  - Train: {len(X_train)} sequences from {len(train_cases)} cases")
    print(f"  - Val: {len(X_val)} sequences from {len(val_cases)} cases")
    print(f"  - Test: {len(X_test)} sequences from {len(test_cases)} cases")
    
    print(f"\n‚úÖ Data ready for LSTM training!")
else:
    print("‚ùå No data available for sequence creation")


## Build LSTM Model for BP Forecasting

Create a multi-layer LSTM model to forecast Mean Blood Pressure 30 seconds ahead.


In [None]:
# Build LSTM model
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

print("üèóÔ∏è Building LSTM model...")

if len(X_train) > 0:
    # Model parameters
    lstm_units = 128
    dropout_rate = 0.2
    learning_rate = 0.001
    
    # Get input dimensions
    n_features = X_train.shape[2]
    seq_length = X_train.shape[1]
    
    print(f"  - Input shape: (batch, {seq_length}, {n_features})")
    print(f"  - LSTM units: {lstm_units}")
    print(f"  - Dropout: {dropout_rate}")
    print(f"  - Learning rate: {learning_rate}")
    
    # Build model
    model = keras.Sequential([
        # First LSTM layer
        layers.LSTM(lstm_units, return_sequences=True, input_shape=(seq_length, n_features)),
        layers.Dropout(dropout_rate),
        
        # Second LSTM layer
        layers.LSTM(lstm_units, return_sequences=True),
        layers.Dropout(dropout_rate),
        
        # Third LSTM layer
        layers.LSTM(lstm_units // 2, return_sequences=False),
        layers.Dropout(dropout_rate),
        
        # Dense output layer
        layers.Dense(64, activation='relu'),
        layers.Dropout(dropout_rate),
        layers.Dense(1)  # Single output: MBP forecast
    ])
    
    # Compile model
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss='mean_absolute_error',  # MAE for BP forecasting
        metrics=['mean_squared_error', 'mean_absolute_error']
    )
    
    print(f"\n‚úÖ Model built successfully!")
    print(f"\nüìã Model architecture:")
    model.summary()
else:
    print("‚ùå No training data available")
    model = None


## Train LSTM Model

Train the model with early stopping and checkpointing.


In [None]:
# Train the model
if model is not None and len(X_train) > 0:
    print("üöÄ Training LSTM model...")
    
    # Callbacks
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    )
    
    checkpoint = ModelCheckpoint(
        'lstm_bp_model.h5',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    )
    
    # Training parameters
    batch_size = 32
    epochs = 50
    
    print(f"  - Batch size: {batch_size}")
    print(f"  - Max epochs: {epochs}")
    print(f"  - Early stopping patience: 10")
    
    # Train model
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        batch_size=batch_size,
        epochs=epochs,
        callbacks=[early_stopping, checkpoint],
        verbose=1
    )
    
    print(f"\n‚úÖ Training completed!")
    print(f"  - Best validation loss: {min(history.history['val_loss']):.4f}")
    print(f"  - Best validation MAE: {min(history.history['val_mean_absolute_error']):.4f}")
    
    # Plot training history
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss plot
    axes[0].plot(history.history['loss'], label='Train Loss')
    axes[0].plot(history.history['val_loss'], label='Val Loss')
    axes[0].set_title('Model Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss (MAE)')
    axes[0].legend()
    axes[0].grid(True)
    
    # MAE plot
    axes[1].plot(history.history['mean_absolute_error'], label='Train MAE')
    axes[1].plot(history.history['val_mean_absolute_error'], label='Val MAE')
    axes[1].set_title('Mean Absolute Error')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('MAE')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nüìä Training history saved to training_history.png")
else:
    print("‚ùå Model not available for training")


## Generate Forecasts and Evaluate

Generate 30-second ahead MBP forecasts for test cases and calculate evaluation metrics.


In [None]:
# Numba-optimized metric calculations
@jit(nopython=True)
def calculate_mae_numba(y_true, y_pred):
    """Calculate Mean Absolute Error (numba-optimized)."""
    n = len(y_true)
    mae = 0.0
    for i in prange(n):
        mae += np.abs(y_true[i] - y_pred[i])
    return mae / n

@jit(nopython=True)
def calculate_rmse_numba(y_true, y_pred):
    """Calculate Root Mean Squared Error (numba-optimized)."""
    n = len(y_true)
    mse = 0.0
    for i in prange(n):
        diff = y_true[i] - y_pred[i]
        mse += diff * diff
    return np.sqrt(mse / n)

@jit(nopython=True)
def calculate_mape_numba(y_true, y_pred):
    """Calculate Mean Absolute Percentage Error (numba-optimized)."""
    n = len(y_true)
    mape = 0.0
    for i in prange(n):
        if np.abs(y_true[i]) > 1e-8:
            mape += np.abs((y_true[i] - y_pred[i]) / y_true[i])
    return (mape / n) * 100.0

# Generate forecasts on test set
if model is not None and len(X_test) > 0:
    print("üîÆ Generating forecasts...")
    
    # Predict on test set
    y_pred = model.predict(X_test, verbose=0)
    y_pred = y_pred.flatten()
    
    # Convert to numpy arrays for numba
    y_test_array = y_test.astype(np.float64)
    y_pred_array = y_pred.astype(np.float64)
    
    # Calculate metrics using numba-optimized functions
    mae = calculate_mae_numba(y_test_array, y_pred_array)
    rmse = calculate_rmse_numba(y_test_array, y_pred_array)
    mape = calculate_mape_numba(y_test_array, y_pred_array)
    
    print(f"\nüìä Test Set Metrics:")
    print(f"  - MAE: {mae:.4f}")
    print(f"  - RMSE: {rmse:.4f}")
    print(f"  - MAPE: {mape:.2f}%")
    
    # Denormalize predictions (need to use scalers from training)
    # For simplicity, we'll use the mean and std from test set
    # In production, you'd use the scalers saved during training
    test_mbp_mean = normalized_data[normalized_data['caseid'].isin(test_cases)]['mbp'].mean()
    test_mbp_std = normalized_data[normalized_data['caseid'].isin(test_cases)]['mbp'].std()
    
    # Get original MBP values from data
    test_original_mbp = []
    test_predicted_mbp = []
    
    # Reconstruct original values per case
    for caseid in test_cases:
        case_data = normalized_data[normalized_data['caseid'] == caseid].sort_values('timestamp')
        if len(case_data) >= SEQ_LENGTH + FORECAST_HORIZON:
            # Get the actual MBP values at forecast points
            case_indices = np.where(case_ids[test_mask] == caseid)[0]
            if len(case_indices) > 0:
                # Get original MBP from data
                forecast_points = [SEQ_LENGTH + FORECAST_HORIZON - 1 + i * FORECAST_HORIZON 
                                  for i in range(len(case_indices))]
                forecast_points = [p for p in forecast_points if p < len(case_data)]
                
                if len(forecast_points) > 0:
                    actual_mbp = case_data.iloc[forecast_points[:len(case_indices)]]['mbp'].values
                    pred_mbp = y_pred[case_indices[:len(actual_mbp)]]
                    
                    test_original_mbp.extend(actual_mbp)
                    test_predicted_mbp.extend(pred_mbp)
    
    if len(test_original_mbp) > 0:
        test_original_mbp = np.array(test_original_mbp)
        test_predicted_mbp = np.array(test_predicted_mbp)
        
        # Denormalize (approximate - in production use saved scalers)
        # We'll use the overall statistics as approximation
        original_mean = data[data['caseid'].isin(test_cases)]['mbp'].mean()
        original_std = data[data['caseid'].isin(test_cases)]['mbp'].std()
        
        denorm_actual = test_original_mbp * original_std + original_mean
        denorm_pred = test_predicted_mbp * original_std + original_mean
        
        # Calculate denormalized metrics using numba-optimized functions
        denorm_actual_array = denorm_actual.astype(np.float64)
        denorm_pred_array = denorm_pred.astype(np.float64)
        denorm_mae = calculate_mae_numba(denorm_actual_array, denorm_pred_array)
        denorm_rmse = calculate_rmse_numba(denorm_actual_array, denorm_pred_array)
        denorm_mape = calculate_mape_numba(denorm_actual_array, denorm_pred_array)
        
        print(f"\nüìä Denormalized Metrics (mmHg):")
        print(f"  - MAE: {denorm_mae:.2f} mmHg")
        print(f"  - RMSE: {denorm_rmse:.2f} mmHg")
        print(f"  - MAPE: {denorm_mape:.2f}%")
        
        # Store for visualization
        forecast_results = {
            'actual': denorm_actual,
            'predicted': denorm_pred
        }
    else:
        forecast_results = None
        print("‚ö†Ô∏è Could not denormalize predictions")
else:
    print("‚ùå Model or test data not available")
    forecast_results = None


In [None]:
# Generate forecasts for specific test cases for visualization
if model is not None and len(test_cases) > 0:
    print("üìà Preparing forecasts for visualization...")
    
    # Select a few test cases for detailed visualization
    viz_cases = test_cases[:5] if len(test_cases) >= 5 else test_cases
    
    case_forecasts = {}
    
    for caseid in viz_cases:
        try:
            # Get case data
            case_data = normalized_data[normalized_data['caseid'] == caseid].sort_values('timestamp').reset_index(drop=True)
            
            if len(case_data) < SEQ_LENGTH + FORECAST_HORIZON:
                continue
            
            # Create sequences for this case
            case_sequences = []
            case_actuals = []
            case_timestamps = []
            
            # Generate multiple forecasts along the case timeline
            step_size = FORECAST_HORIZON  # Non-overlapping forecasts
            for i in range(0, len(case_data) - SEQ_LENGTH - FORECAST_HORIZON + 1, step_size):
                seq = case_data.iloc[i:i+SEQ_LENGTH][available_features].values
                target_idx = i + SEQ_LENGTH + FORECAST_HORIZON - 1
                
                if target_idx < len(case_data):
                    case_sequences.append(seq)
                    case_actuals.append(case_data.iloc[target_idx]['mbp'])
                    case_timestamps.append(case_data.iloc[target_idx]['timestamp'])
            
            if len(case_sequences) > 0:
                # Predict
                case_X = np.array(case_sequences)
                case_pred = model.predict(case_X, verbose=0).flatten()
                
                # Get original MBP values for denormalization
                case_original = data[data['caseid'] == caseid].sort_values('timestamp').reset_index(drop=True)
                case_mbp_mean = case_original['mbp'].mean()
                case_mbp_std = case_original['mbp'].std()
                
                # Denormalize
                denorm_actual = np.array(case_actuals) * case_mbp_std + case_mbp_mean
                denorm_pred = case_pred * case_mbp_std + case_mbp_mean
                
                # Get full timeline
                full_timestamps = case_original['timestamp'].values
                full_actual_mbp = case_original['mbp'].values
                
                case_forecasts[caseid] = {
                    'timestamps': full_timestamps,
                    'actual_mbp': full_actual_mbp,
                    'forecast_timestamps': [case_original.iloc[int(ts)]['timestamp'] if int(ts) < len(case_original) else None 
                                          for ts in case_timestamps],
                    'forecast_actual': denorm_actual,
                    'forecast_predicted': denorm_pred
                }
                
        except Exception as e:
            print(f"  ‚ö†Ô∏è Error processing case {caseid}: {e}")
            continue
    
    print(f"‚úÖ Prepared forecasts for {len(case_forecasts)} cases")
    print(f"  - Cases: {list(case_forecasts.keys())}")
else:
    case_forecasts = {}
    print("‚ùå Could not prepare case forecasts")


## Visualize Forecasts

Create visualizations showing actual vs forecasted MBP for multiple case IDs.


In [None]:
# Create visualization plots
if len(case_forecasts) > 0:
    print("üìä Creating visualization plots...")
    
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    from datetime import datetime, timedelta
    from sklearn.metrics import mean_absolute_error, mean_squared_error
    
    # Create subplots
    n_cases = len(case_forecasts)
    fig, axes = plt.subplots(n_cases, 1, figsize=(14, 4 * n_cases))
    
    if n_cases == 1:
        axes = [axes]
    
    for idx, (caseid, forecast_data) in enumerate(case_forecasts.items()):
        ax = axes[idx]
        
        # Plot actual MBP over time
        timestamps = forecast_data['timestamps']
        actual_mbp = forecast_data['actual_mbp']
        
        # Convert timestamps to minutes for readability
        time_minutes = timestamps / 60
        
        ax.plot(time_minutes, actual_mbp, 'b-', alpha=0.6, label='Actual MBP', linewidth=1.5)
        
        # Plot forecast points
        if len(forecast_data['forecast_timestamps']) > 0:
            forecast_ts = [t/60 if t is not None else None for t in forecast_data['forecast_timestamps']]
            forecast_ts = [t for t in forecast_ts if t is not None]
            
            if len(forecast_ts) == len(forecast_data['forecast_predicted']):
                ax.scatter(forecast_ts, forecast_data['forecast_actual'], 
                          color='green', marker='o', s=50, label='Actual (forecast points)', zorder=5)
                ax.scatter(forecast_ts, forecast_data['forecast_predicted'], 
                          color='red', marker='x', s=100, linewidths=2, label='Forecasted MBP', zorder=5)
        
        # Add forecast horizon indicator
        if len(forecast_ts) > 0:
            # Show forecast horizon on first forecast point
            first_forecast_time = forecast_ts[0]
            ax.axvspan(first_forecast_time, first_forecast_time + FORECAST_HORIZON/60, 
                      alpha=0.2, color='yellow', label=f'{FORECAST_HORIZON}-sec forecast horizon')
        
        ax.set_xlabel('Time (minutes)', fontsize=11)
        ax.set_ylabel('Mean Blood Pressure (mmHg)', fontsize=11)
        ax.set_title(f'Case ID {caseid}: Actual vs Forecasted MBP', fontsize=12, fontweight='bold')
        ax.legend(loc='best', fontsize=9)
        ax.grid(True, alpha=0.3)
        
        # Calculate and display metrics for this case using numba-optimized functions
        if len(forecast_data['forecast_predicted']) > 0:
            case_actual_array = np.array(forecast_data['forecast_actual'], dtype=np.float64)
            case_pred_array = np.array(forecast_data['forecast_predicted'], dtype=np.float64)
            case_mae = calculate_mae_numba(case_actual_array, case_pred_array)
            case_rmse = calculate_rmse_numba(case_actual_array, case_pred_array)
            
            ax.text(0.02, 0.98, f'MAE: {case_mae:.2f} mmHg | RMSE: {case_rmse:.2f} mmHg',
                   transform=ax.transAxes, fontsize=9, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig('transplantation_bp_forecasts.png', dpi=150, bbox_inches='tight')
    print(f"‚úÖ Visualization saved to transplantation_bp_forecasts.png")
    plt.show()
    
    # Create a summary comparison plot
    if len(case_forecasts) > 1:
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        
        all_actual = []
        all_predicted = []
        
        for caseid, forecast_data in case_forecasts.items():
            all_actual.extend(forecast_data['forecast_actual'])
            all_predicted.extend(forecast_data['forecast_predicted'])
        
        all_actual = np.array(all_actual)
        all_predicted = np.array(all_predicted)
        
        # Scatter plot
        ax.scatter(all_actual, all_predicted, alpha=0.6, s=50)
        
        # Perfect prediction line
        min_val = min(min(all_actual), min(all_predicted))
        max_val = max(max(all_actual), max(all_predicted))
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction')
        
        ax.set_xlabel('Actual MBP (mmHg)', fontsize=12)
        ax.set_ylabel('Forecasted MBP (mmHg)', fontsize=12)
        ax.set_title('Actual vs Forecasted MBP - All Test Cases', fontsize=13, fontweight='bold')
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)
        
        # Add metrics text using numba-optimized functions
        overall_actual_array = all_actual.astype(np.float64)
        overall_pred_array = all_predicted.astype(np.float64)
        overall_mae = calculate_mae_numba(overall_actual_array, overall_pred_array)
        overall_rmse = calculate_rmse_numba(overall_actual_array, overall_pred_array)
        overall_mape = calculate_mape_numba(overall_actual_array, overall_pred_array)
        
        textstr = f'MAE: {overall_mae:.2f} mmHg\nRMSE: {overall_rmse:.2f} mmHg\nMAPE: {overall_mape:.2f}%'
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=11,
               verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        plt.tight_layout()
        plt.savefig('transplantation_bp_scatter.png', dpi=150, bbox_inches='tight')
        print(f"‚úÖ Scatter plot saved to transplantation_bp_scatter.png")
        plt.show()
else:
    print("‚ùå No forecast data available for visualization")
