# DBSCAN Implementation of the Sessa Empirical Estimator (SEE)

In this notebook, we implement the Sessa Empirical Estimator using **DBSCAN** clustering.

In [None]:
# Cell 1: Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.cluster import DBSCAN
from statsmodels.distributions.empirical_distribution import ECDF

import warnings
warnings.filterwarnings('ignore')


In [None]:
# Cell 2: Load and Prepare the Data
def load_data(filepath="../data/med_events.csv"):
    df = pd.read_csv(filepath)
    df.columns = ['PATIENT_ID', 'DATE', 'PERDAY', 'CATEGORY', 'DURATION']
    df['DATE'] = pd.to_datetime(df['DATE'], format='%m/%d/%Y')
    return df

med_events = load_data()
med_events.head()


In [None]:
# Cell 3: DBSCAN-based SEE Function
def see_dbscan(data, medication_code, eps=0.3, min_samples=5):
    """
    Python implementation of the Sessa Empirical Estimator using DBSCAN.
    Returns a DataFrame with cluster assignments.
    """
    drug_see_p0 = data[data['CATEGORY'] == medication_code].copy()
    drug_see_p1 = drug_see_p0.copy()
    
    # Sort, compute prev_date, remove NA
    drug_see_p1 = drug_see_p1.sort_values(by=['PATIENT_ID', 'DATE'])
    drug_see_p1['prev_date'] = drug_see_p1.groupby('PATIENT_ID')['DATE'].shift(1)
    drug_see_p1.dropna(subset=['prev_date'], inplace=True)
    
    # Randomly sample one record per patient
    drug_see_p1 = drug_see_p1.groupby('PATIENT_ID').apply(
        lambda x: x.sample(1, random_state=1234)
    ).reset_index(drop=True)
    
    # Calculate event_interval
    drug_see_p1['event_interval'] = (drug_see_p1['DATE'] - drug_see_p1['prev_date']).dt.days
    
    # Compute and plot ECDF
    intervals = np.sort(drug_see_p1['event_interval'].values)
    ecdf_vals = np.arange(1, len(intervals)+1) / len(intervals)
    df_ecdf = pd.DataFrame({'x': intervals, 'y': ecdf_vals})
    
    df_ecdf_80 = df_ecdf[df_ecdf['y'] <= 0.8]
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.scatter(df_ecdf_80['x'], df_ecdf_80['y'])
    plt.title('80% ECDF')
    plt.subplot(1,2,2)
    plt.scatter(df_ecdf['x'], df_ecdf['y'])
    plt.title('100% ECDF')
    plt.tight_layout()
    plt.show()
    
    # Filter intervals within 80%
    ni = df_ecdf_80['x'].max()
    drug_see_p2 = drug_see_p1[drug_see_p1['event_interval'] <= ni].copy()
    
    # Log-transform intervals
    log_intervals = np.log(drug_see_p2['event_interval'].astype(float) + 1e-8)
    
    # Apply DBSCAN
    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    clusters = dbscan.fit_predict(log_intervals.values.reshape(-1,1))
    df_ecdf['dbscan_cluster'] = clusters
    
    # Summarize cluster stats for non-noise points
    valid = df_ecdf[df_ecdf['dbscan_cluster'] != -1]
    cluster_stats = valid.groupby('dbscan_cluster')['x'].agg(['min','max','median']).reset_index()
    cluster_stats.columns = ['Cluster','Minimum','Maximum','Median']
    
    # Match intervals to clusters
    results = []
    for _, row in drug_see_p1.iterrows():
        interval = row['event_interval']
        matched = False
        for _, cs in cluster_stats.iterrows():
            if cs['Minimum'] <= interval <= cs['Maximum']:
                results.append({
                    'PATIENT_ID': row['PATIENT_ID'],
                    'Median': cs['Median'],
                    'Cluster': cs['Cluster']
                })
                matched = True
                break
        if not matched:
            results.append({
                'PATIENT_ID': row['PATIENT_ID'],
                'Median': np.nan,
                'Cluster': -1
            })
    
    results = pd.DataFrame(results)
    # Fill missing with the most common cluster's median
    if not results.empty and (results['Cluster'] != -1).any():
        most_common = results[results['Cluster'] != -1]['Cluster'].value_counts().idxmax()
        default_median = results[results['Cluster'] == most_common]['Median'].iloc[0]
    else:
        default_median = np.nan
    
    drug_see_p1 = pd.merge(drug_see_p1, results, on='PATIENT_ID', how='left')
    drug_see_p1['Median'] = drug_see_p1['Median'].fillna(default_median)
    drug_see_p1['Cluster'] = drug_see_p1['Cluster'].fillna(-1)
    drug_see_p1['test'] = (drug_see_p1['event_interval'] - drug_see_p1['Median']).round(1)
    
    final_df = pd.merge(drug_see_p0, drug_see_p1[['PATIENT_ID','Median','Cluster']], 
                        on='PATIENT_ID', how='left')
    final_df['Median'] = final_df['Median'].fillna(default_median)
    final_df['Cluster'] = final_df['Cluster'].fillna(-1)
    
    return final_df


In [None]:
# Cell 4: Assumption Checking
def see_assumption(data):
    data_sorted = data.sort_values(by=['PATIENT_ID','DATE']).copy()
    data_sorted['prev_date'] = data_sorted.groupby('PATIENT_ID')['DATE'].shift(1)
    data_sorted['p_number'] = data_sorted.groupby('PATIENT_ID').cumcount() + 1
    df_box = data_sorted[data_sorted['p_number'] >= 2].copy()
    df_box['Duration'] = (df_box['DATE'] - df_box['prev_date']).dt.days
    
    plt.figure(figsize=(10,6))
    sns.boxplot(x='p_number', y='Duration', data=df_box)
    medians = df_box.groupby('PATIENT_ID')['Duration'].median().median()
    plt.axhline(medians, color='red', linestyle='--', label=f'Median = {medians:.1f}')
    plt.title("Duration by Prescription Number")
    plt.xlabel("Prescription Number")
    plt.ylabel("Duration (days)")
    plt.legend()
    plt.show()

In [None]:
# Cell 5: Run DBSCAN-based SEE & Insights
result_dbscan_medA = see_dbscan(med_events, 'medA', eps=0.3, min_samples=5)
display(result_dbscan_medA.head())

print("\nChecking assumptions for medA (DBSCAN):")
see_assumption(result_dbscan_medA)

# Cell 6: Insights

**Observations:**