In [52]:

%matplotlib inline

# autoload
%load_ext autoreload
%autoreload 2

import sys
import pandas as pd
# append path
sys.path.append('../')
from scipy.special import logit

sys.path.append('../../')
from process.tables import *
from models.bayes import *
from models.predict import *
from process.config import *
from process.measurements import cutoff_measurements_df

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.optimize import minimize


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
measurements = pd.read_csv('../data/processed/all_lab_measurements.csv')
measurements['sex'] = measurements['gender_concept_id'].map({1: 'male', 0: 'female'})
processed_df = cutoff_measurements_df(measurements, percent=0.8, min_tests=5)
processed_df = processed_df.replace([np.inf, -np.inf], np.nan).dropna()

def get_priors_dict():
    priors = {
    (test, sex): {
        'mean': (low + high) / 2,
        'var': ((high - low) / 4)**2
    }
    for test, intervals in REFERENCE_INTERVALS.items()
        for sex, (low, high, _) in intervals.items()
    }
    return priors


Filtering Sequences:  96%|█████████▌| 146384/152269 [00:40<00:01, 3596.24it/s]


In [57]:
from models.ornstein_uhlenbeck import run_ou_with_prior
from models.ornstein_uhlenbeck import estimate_ou_parameters_with_prior

def run_ou_with_prior(processed_df):
    groups = processed_df.groupby(['subject_id', 'test_name'])
    priors = get_priors_dict()
    results = []
    
    for (subject, test_name), group in tqdm(groups):
        sample = group[['time', 'numeric_value', 'sex']]
        sex = sample['sex'].unique()[0]
        prior_mean, prior_var = priors[(test_name, sex)]['mean'], priors[(test_name, sex)]['var'] / 4
        prior_sigma_mean = np.sqrt(prior_var) 
        prior_sigma_var = 0.2   # log-variance of sigma (log-normal prior)
        
        df = sample.copy()
        df['time'] = pd.to_datetime(df['time'])
        df = df.sort_values('time')
        df['dt'] = df['time'].diff().dt.total_seconds().div(86400)
        df = df.dropna()
        
        S = df['numeric_value'].values
        dt = df['dt'].mean()
        
        theta_ml, mu_ml, sigma_ml = estimate_ou_parameters_with_prior(S, dt, prior_mean, prior_var, prior_sigma_mean, prior_sigma_var)
        
        results.append({
            'subject_id': subject,
            'test_name': test_name,
            'sex': sex,
            'ou_prior_mean': prior_mean,
            'ou_prior_var': prior_var,
            'ou_prior_sigma_mean': prior_sigma_mean,
            'ou_prior_sigma_var': prior_sigma_var,
            'ou_mean': theta_ml,
            'ou_speed': mu_ml,
            'ou_std': sigma_ml
        })
    
    results_df = pd.DataFrame(results)
    return results_df

sample_subjects = processed_df.subject_id.unique()[:10]
results_df = run_ou_with_prior(processed_df.query('subject_id in @sample_subjects'))


  0%|          | 0/153 [00:00<?, ?it/s]


KeyError: ('BUN', 'male')

In [None]:

subject = 115969474
test_name = 'HCT'
subjects = processed_df.query('test_name == @test_name')['subject_id'].unique()[10:15]
print(subjects)
# for subject in subjects:
#     sample = processed_df.query('subject_id == @subject and test_name == @test_name')[['time', 'numeric_value', 'sex']]

#     prior_mean = 42.0  # theta prior mean
#     prior_var = 2.5 # theta prior variance
#     prior_sigma_mean = 2.5
#     prior_sigma_var = 0.2  # log-variance of sigma (log-normal prior)

#     # can you ignore points not in the range?
#     # filtered_df = sample['numeric_value'].between(prior_mean - 3 * prior_sigma_mean, prior_mean + 3 * prior_sigma_mean)
#     # sample = sample[filtered_df]
#     df = sample.copy()
#     df['time'] = pd.to_datetime(df['time'])
#     df = df.sort_values('time')
#     df['dt'] = df['time'].diff().dt.total_seconds().div(86400)
#     df = df.dropna()

#     S = df['numeric_value'].values
#     dt = df['dt'].mean()
#     #dt = df['dt'].mean() / 365.25
    
#     print(dt)
#     # Estimate parameters
#     theta_ml, mu_ml, sigma_ml = estimate_ou_parameters_with_prior(
#         S, dt, prior_mean, prior_var, prior_sigma_mean, prior_sigma_var
#     )
priors = get_priors_dict()
for subject in subjects:
    sample = processed_df.query('subject_id == @subject and test_name == @test_name') #[['time', 'numeric_value', 'sex']]
    print(sample.head())
    sample = sample[['time', 'numeric_value', 'sex']]
    sex = sample['sex'].unique()[0]
    
    # only use points in the range
    
    # Get priors properly
    if (test_name, sex) in priors:
        prior_mean = priors[(test_name, sex)]['mean']
        prior_var = priors[(test_name, sex)]['var']
    else:
        # Fallback if test/sex combination not found
        prior_mean = sample['numeric_value'].mean()
        prior_var = sample['numeric_value'].var()

    
    # Set sigma prior based on data variability
    prior_sigma_mean = np.sqrt(priors[(test_name, sex)]['var'])
    prior_sigma_var = 0.1  # More reasonable log-variance
        

    df = sample.copy()
    df['time'] = pd.to_datetime(df['time'])
    df = df.sort_values('time')
    df['dt'] = df['time'].diff().dt.total_seconds().div(86400)
    
    df = df.dropna()
    
    if len(df) < 2:  # Need at least 2 points for OU
        continue
        
    S = df['numeric_value'].values
    dt = df['dt'].mean()
    #dt = df['dt'].mean() / 365.25
    
    theta_ml, mu_ml, sigma_ml = estimate_ou_parameters_with_prior(
        S, dt, prior_mean, prior_var, prior_sigma_mean, prior_sigma_var
    )
        
    # Plot results
    plt.figure(figsize=(12, 3))
    plt.scatter(sample['time'], sample['numeric_value'], color='blue', alpha=0.7, label='Observed HCT')

    std_dev = prior_sigma_mean
    confidence_interval = 2 * std_dev

    plt.axhline(prior_mean, color='red', linestyle='--', alpha=0.7, label=f'Population: {prior_mean:.1f}%')
    plt.fill_between(sample['time'], prior_mean - confidence_interval, prior_mean + confidence_interval,
                     color='red', alpha=0.2, label='Population 95% CI')

    ou_std_dev = sigma_ml / np.sqrt(2 * mu_ml)

    plt.axhline(theta_ml, color='green', linestyle='--', alpha=0.7, label=f'OU θ (estimated): {theta_ml:.2f}%')
    plt.fill_between(sample['time'], theta_ml - 2 * ou_std_dev, theta_ml + 2 * ou_std_dev,
                     color='green', alpha=0.2, label='OU 95% CI')

    plt.title(f'HCT OU Process Fit for Patient {subject}')
    plt.xlabel('Time')
    plt.ylabel('HCT (%)')
    plt.legend(loc='lower right', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    print("   'HCT': {'F': (37, 47, '%'), 'M': (42, 50, '%')},  ")
    print(f"\nComparison for subject {subject}:")
    print(f"Prior θ mean: {prior_mean:.1f}%")
    print(f"Prior σ mean: {prior_sigma_mean:.1f}%")
    print(f"Estimated θ (OU): {theta_ml:.2f}%")
    print(f"Estimated μ (OU): {mu_ml:.4f}")
    print(f"Estimated σ (OU): {sigma_ml:.2f}%")
    print(f"Difference from prior θ: {theta_ml - prior_mean:.2f}%")
    print(f"Difference from prior σ: {sigma_ml - prior_sigma_mean:.2f}%")


[115967111 115967113 115967114 115967115 115967116]
         subject_id                time           code  numeric_value  \
2114664   115967111 2014-10-25 20:55:00   LOINC/4544-3           38.7   
2114674   115967111 2014-10-25 20:58:00  LOINC/20570-8           38.0   
2114685   115967111 2014-10-26 02:30:00   LOINC/4544-3           28.0   
2114704   115967111 2014-10-26 10:00:00   LOINC/4544-3           28.5   
2114722   115967111 2014-10-26 15:25:00   LOINC/4544-3           31.0   

         care_site_id      clarity_table  end  note_id  provider_id  \
2114664           NaN  shc_order_results  NaN      NaN    6767342.0   
2114674           NaN  shc_order_results  NaN      NaN    6808394.0   
2114685           NaN  shc_order_results  NaN      NaN    6811155.0   
2114704           NaN  shc_order_results  NaN      NaN    6811155.0   
2114722           NaN  shc_order_results  NaN      NaN    6811155.0   

               table  text_value unit    visit_id test_name birth_DATETIME  \
2114

KeyError: ('HCT', 'female')

In [43]:
subject = processed_df.query('test_name == @test_name')['subject_id'].unique()[:1]
sample = processed_df.query('subject_id == @subject and test_name == @test_name')[['time', 'numeric_value']]
print(sample)

                       time  numeric_value
3635697 2008-10-08 01:11:00           28.7
3635716 2008-10-08 08:10:00           28.0
3635725 2008-10-08 18:45:00           32.3
3635735 2008-10-08 23:20:00           30.1
3635752 2008-10-09 05:05:00           29.5
3635761 2008-10-09 19:30:00           36.2
3635779 2008-10-10 06:05:00           35.5
3635796 2008-10-11 06:40:00           36.6
3635813 2008-10-12 06:20:00           33.0
3635830 2008-10-13 06:56:00           35.7
