# Propensity Score and Other Matching Methods

In [None]:
import duckdb
import pandas as pd
import numpy as np
import json

import ipywidgets as widgets
from IPython.display import display, Markdown

# timeoutput
import datetime

# regex
import re

# plots
import matplotlib.pyplot as plt
import seaborn as sns

# sklearn
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.neighbors import NearestNeighbors
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.covariance import EmpiricalCovariance
from scipy.spatial.distance import cdist

In [None]:
#! change the base_path to the IC data location in Wynton


# Functions for easy pulling of CDW data

def file_path_parquet(filename, datatype):
    base_path = f"path/to/ic/data/{datatype}/"
    parquet_wild = "/*.parquet"
    return f"{base_path}{filename}{parquet_wild}"

def rtime():
    # Get the current datetime
    current_datetime = datetime.datetime.now()
    # Define a mapping of days of the week to colors
    day_color_mapping = {
        0: 'red',       # Monday
        1: 'orange',    # Tuesday
        2: 'green',     # Wednesday
        3: 'blue',      # Thursday
        4: 'purple',    # Friday
        5: 'brown',     # Saturday
        6: 'gray',      # Sunday
    }

    # Get the day of the week (0=Monday, 1=Tuesday, ..., 6=Sunday)
    day_of_week = current_datetime.weekday()
    # Get the color based on the day of the week
    text_color = day_color_mapping.get(day_of_week, 'black')  # Default to black if the day is not found in the mapping
    # Format the current datetime
    formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
    # Generate the formatted output with the corresponding color
    formatted_output = f"\n<b><span style='color:{text_color}'>Ran: {formatted_datetime}</span></b>\n"
    # Display the formatted output using Markdown
    display(Markdown(formatted_output))
    
rtime()

In [None]:
#! change the path to scratch and the username


# wynton_username with your actual Wynton username
username = 'name'

# Spill data that doesn't fit into memory into Wynton Scratch storage (BeeGFS)
# Increase up to 12 threads and 150 GB of memory to not overwhelm the system
# Recommendation: ~12 GB of memory for each thread
# reduce if there are other system limitations in place
config_query = f"""
    SET temp_directory = 'path/to/scratch/{username}/duckdb_dir';
    SET preserve_insertion_order = false;
    SET memory_limit = '150GB';
    SET threads TO 12;
"""

# Create a connection with configurations
con = duckdb.connect()
con_info = con.execute(config_query)  # Apply configuration settings

display(con_info)
rtime()

# Data

In [None]:
ms_list = [374919, 4178929, 4145049, 4137855, 37110514]

bad_pats = ['-1', '*Unspecified']

ms_pats = pd.read_csv("data/ms_cohort_250318.csv")
ctl_pats = con.read_parquet("data/control_pats_250318.parquet").df()

rtime()

### OMOP Data

In [None]:
# condition_occurrence
condition_occurrence_ucsf = con.read_parquet(file_path_parquet('condition_occurrence', 'DEID_OMOP'))

# person demographics
person_ucsf = con.read_parquet(file_path_parquet('person', 'DEID_OMOP'))

# person linkage OMOP - CDW
person_extension_ucsf = con.read_parquet(file_path_parquet('person_extension', 'DEID_OMOP'))

# visit_occurrence
visit_occurrence_ucsf = con.read_parquet(file_path_parquet('visit_occurrence', 'DEID_OMOP'))

# condition occurrence to link to CDW
condition_occurrence_extension_ucsf = con.read_parquet(file_path_parquet('condition_occurrence_extension', 'DEID_OMOP'))


rtime()

### CDW Data

In [None]:
# deid_note_key and negation terms
note_concepts = con.read_parquet(file_path_parquet('note_concepts', 'DEID_CDW'))

# linker to patientdurablekey, encoutnerkey, and deid_note_key
note_metadata = con.read_parquet(file_path_parquet('note_metadata', 'DEID_CDW'))

# note text - only deid_note_key and note_text
note_text = con.read_parquet(file_path_parquet('note_text', 'DEID_CDW'))

# diagnosis event fact
diag_fact = con.read_parquet(file_path_parquet('diagnosiseventfact', 'DEID_CDW'))

# patdurabledim
patdurabledim = con.read_parquet(file_path_parquet('patdurabledim', 'DEID_CDW'))


rtime()

# Demographic Info

In [None]:
query_ms_demo = f"""
SELECT ms.patientepicid,
    ms.note_fdate,
    ms.preex_years AS follow_up,
    ms.preex_note_count AS n_notes,
    CAST(DATEDIFF('day', 
        CAST(prsn.birth_datetime AS DATE), 
        CAST(ms.note_fdate AS DATE)
    ) / 365.25 AS INT) AS age_at_first_visit,
    gender_concept_id,
    race_concept_id,
    1 AS is_ms
FROM ms_pats ms
JOIN (
    SELECT person_id,
        birth_datetime,
        gender_concept_id,
        race_concept_id
    FROM person_ucsf
) prsn ON ms.person_id = prsn.person_id
WHERE age_at_first_visit >= 0
    AND follow_up >= 1
"""

# optional keep note_fdate
ms_demo = con.query(query_ms_demo).df().drop('note_fdate', axis=1)

In [None]:
query_ctl_demo = f"""
SELECT con.patientepicid,
    con.note_fdate,
    con.note_years AS follow_up,
    con.note_count AS n_notes,
    CAST(DATEDIFF('day', 
        CAST(prsn.birth_datetime AS DATE), 
        CAST(con.note_fdate AS DATE)
    ) / 365.25 AS INT) AS age_at_first_visit,
    gender_concept_id,
    race_concept_id,
    0 AS is_ms
FROM ctl_pats con
JOIN (
    SELECT person_id,
        birth_datetime,
        gender_concept_id,
        race_concept_id
    FROM person_ucsf
) prsn ON con.person_id = prsn.person_id
WHERE age_at_first_visit >= 0
    AND follow_up >= 1
"""

# optional keep note_fdate
ctl_demo = con.query(query_ctl_demo).df().drop('note_fdate', axis=1)

**race_concept_id**
<br>8515: Asian
<br>8516, Black or African American 
<br>8522, Other
<br>8527: White
<br>8552: None, Declined, Unknown, Unknown/Declined
<br>8557: Native Hawaiian, Native Hawaiian or Other Pacific Islander, Other Pacific Islander
<br>8657: Native American or Alaska Native


**gender_concept_id**
<br>8507: Male
<br>8521: Nonbinary
<br>8532: Female
<br>8551: Unknown

# Propensity and Other Methods

In [None]:
concat_demo = pd.concat([ctl_demo, ms_demo]).reset_index()

In [None]:
# # optional to clear up memory
# del ctl_demo, ms_demo

In [None]:
concat_demo

In [None]:
# make a copy to modify this version
psm_data = concat_demo.copy()

## sklearn Approach

In [None]:
def calculate_propensity_scores(df, treatment_col='is_ms'):
    numeric_features = ['follow_up', 'age_at_first_visit', 'n_notes']
    categorical_features = ['gender_concept_id', 'race_concept_id']

    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numeric_features),
            ('cat', OneHotEncoder(drop='first'), categorical_features)
        ])
    
    X = df[numeric_features + categorical_features]
    y = df[treatment_col]
    X_transformed = preprocessor.fit_transform(X)

    logistic = LogisticRegression(random_state=42, class_weight='balanced')
    logistic.fit(X_transformed, y)
    propensity_scores = logistic.predict_proba(X_transformed)[:, 1]
    
    return propensity_scores, X_transformed

In [None]:
# very simple matching
def perform_matching(df, treatment_col, propensity_scores, n_neighbors=1):
    treated = df[df[treatment_col] == 1]
    control = df[df[treatment_col] == 0]
    
    control_scores = propensity_scores[control.index]
    treated_scores = propensity_scores[treated.index]

    nbrs = NearestNeighbors(n_neighbors=n_neighbors)
    nbrs.fit(control_scores.reshape(-1, 1))
    distances, indices = nbrs.kneighbors(treated_scores.reshape(-1, 1))

    matched_control_indices = control.iloc[indices.flatten()].index
    matched_treated = treated
    matched_control = df.loc[matched_control_indices]

    matched_data = pd.concat([matched_treated, matched_control])
    
    return matched_data


# Simple caliper matching across some columns.
# No additional coding/handling of vars, so not for multivariate use
def perform_caliper_matching(df, treatment_col, propensity_scores, caliper=0.05, n_neighbors=1):
    treated = df[df[treatment_col] == 1]
    control = df[df[treatment_col] == 0]
    treated_scores = propensity_scores[treated.index]
    control_scores = propensity_scores[control.index]

    nbrs = NearestNeighbors(radius=caliper, n_neighbors=n_neighbors)
    nbrs.fit(control_scores.reshape(-1, 1))
    
    distances, indices = nbrs.radius_neighbors(treated_scores.reshape(-1, 1))
    
    matched_control_indices = []
    for ind in indices:
        if len(ind) > 0:
            # select the closest n_neighbors matches (or fewer if less are found)
            matched_control_indices.extend(control.iloc[ind[:n_neighbors]].index.tolist())
    
    matched_control = df.loc[matched_control_indices]
    matched_treated = treated.loc[np.repeat(treated.index, n_neighbors)[:len(matched_control_indices)]]
    
    matched_data = pd.concat([matched_treated, matched_control])
    return matched_data



# This method works very well
def multi_covariate_adjusted_matching(df, treatment_col, propensity_scores, caliper=0.05, n_neighbors=1, numeric_method='scaler', bin_features=None):
    numeric_features = ['follow_up', 'age_at_first_visit', 'n_notes']
    categorical_features = ['gender_concept_id', 'race_concept_id']
    
    # Handle binning
    if numeric_method in ['binned', 'binned_scaler'] and bin_features:
        def bin_age_to_decade(age):
            return (age - 1) // 10 + 1
        for feature in bin_features:
            df[feature] = df[feature].apply(bin_age_to_decade)
    
    # One-hot encode categorical variables before scaling to avoid misinterpretation of encoded data
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', 'passthrough', numeric_features),
            ('cat', OneHotEncoder(drop='first'), categorical_features)
        ]
    )
    
    df_transformed = preprocessor.fit_transform(df)

    # Handle scaling after encoding
    if numeric_method in ['scaler', 'binned_scaler']:
        scaler = StandardScaler()
        # Assuming numerical data are the first len(numeric_features) columns in df_transformed
        df_transformed[:, :len(numeric_features)] = scaler.fit_transform(df_transformed[:, :len(numeric_features)])
    else:
        raise ValueError("Numeric method not valid/complete")
    
    # Combine transformed DataFrame with propensity scores
    propensity_scores = propensity_scores.reshape(-1, 1)
    df_transformed = np.hstack([propensity_scores, df_transformed])
    
    treated_indices = df.index[df[treatment_col] == 1].tolist()
    control_indices = df.index[df[treatment_col] == 0].tolist()
    
    treated_transformed = df_transformed[treated_indices]
    control_transformed = df_transformed[control_indices]
    
    # NearestNeighbors
    nbrs = NearestNeighbors(radius=caliper, n_neighbors=n_neighbors, metric='euclidean')
    nbrs.fit(control_transformed)
    
    distances, indices = nbrs.radius_neighbors(treated_transformed)
    
    matched_control_indices = []
    matched_treated_indices = []
    used_control_indices = set()  # keeps track of used controls
    
    for i, ind in enumerate(indices):
        if len(ind) > 0:
            # filter out already used controls
            available_controls = [idx for idx in ind if control_indices[idx] not in used_control_indices]
            
            if available_controls:
                # closest n_neighbors matches from available controls
                n_to_match = min(n_neighbors, len(available_controls))
                closest_n = available_controls[:n_to_match]
                
                new_control_indices = [control_indices[j] for j in closest_n]
                matched_control_indices.extend(new_control_indices)
                matched_treated_indices.extend([treated_indices[i]])
                
                # Mark these controls as used
                used_control_indices.update(new_control_indices)
    
    matched_control = df.loc[matched_control_indices].copy()
    matched_treated = df.loc[matched_treated_indices].copy()
    
    # Add a matching group identifier
    matched_treated['match_group'] = range(len(matched_treated))
    matched_control['match_group'] = np.repeat(range(len(matched_treated)), 
                                             n_neighbors)[:len(matched_control)]
    
    # Combine matched treated and control groups
    matched_data = pd.concat([matched_treated, matched_control])
    
    return matched_data



# This method works and is optional to use.
# You might get better matches depending on the selection criteria
def mahalanobis_matching(df, treatment_col, caliper=0.05, n_neighbors=1, numeric_method='scaler', bin_features=None):
    numeric_features = ['follow_up', 'age_at_first_visit', 'n_notes']
    categorical_features = ['gender_concept_id', 'race_concept_id']
    
    # Create a copy and handle binning
    df = df.copy()
    if numeric_method in ['binned', 'binned_scaler'] and bin_features:
        def bin_age_to_decade(age):
            return (age - 1) // 10 + 1
        for feature in bin_features:
            df[feature] = df[feature].apply(bin_age_to_decade)

    # One-hot encode categorical variables
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', 'passthrough', numeric_features),
            ('cat', OneHotEncoder(drop='first'), categorical_features)
        ]
    )
    
    # Transform data
    df_transformed = preprocessor.fit_transform(df)
    
    # Handle scaling
    if numeric_method in ['scaler', 'binned_scaler']:
        scaler = StandardScaler()
        df_transformed[:, :len(numeric_features)] = scaler.fit_transform(df_transformed[:, :len(numeric_features)])
    else:
        raise ValueError("Numeric method not valid/complete")
    
    # Split into treated and control
    treated_mask = df[treatment_col] == 1
    control_mask = ~treated_mask
    
    treated_indices = df.index[treated_mask].tolist()
    control_indices = df.index[control_mask].tolist()
    
    treated_transformed = df_transformed[treated_mask]
    control_transformed = df_transformed[control_mask]
    
    # Calculate Mahalanobis distances using vectorized operations
    cov = EmpiricalCovariance().fit(df_transformed)
    distances = cdist(treated_transformed, control_transformed, 
                     metric='mahalanobis', VI=cov.precision_)
    
    # Find matches within caliper
    matched_pairs = []  # Store (treated_idx, control_idx) pairs
    used_control_indices = set()
    
    for i in range(len(treated_indices)):
        valid_matches = np.where(distances[i] <= caliper)[0]
        available_matches = [j for j in valid_matches if control_indices[j] not in used_control_indices]
        
        if available_matches:
            n_to_match = min(n_neighbors, len(available_matches))
            closest = np.argsort(distances[i][available_matches])[:n_to_match]
            selected_matches = [available_matches[j] for j in closest]
            
            # Store the pairs of indices
            for control_idx in [control_indices[j] for j in selected_matches]:
                matched_pairs.append((treated_indices[i], control_idx))
            used_control_indices.update(control_indices[j] for j in selected_matches)
    
    # Separate treated and control indices from pairs
    matched_treated_indices = [pair[0] for pair in matched_pairs]
    matched_control_indices = [pair[1] for pair in matched_pairs]
    
    # Get unique treated indices while preserving order
    unique_treated_indices = list(dict.fromkeys(matched_treated_indices))
    
    # Create matched dataset
    matched_treated = df.loc[unique_treated_indices].copy()
    matched_control = df.loc[matched_control_indices].copy()
    
    # Add matching group identifier
    n_groups = len(unique_treated_indices)
    group_map = {idx: i for i, idx in enumerate(unique_treated_indices)}
    
    matched_treated['match_group'] = range(n_groups)
    matched_control['match_group'] = [group_map[treated_idx] for treated_idx in matched_treated_indices]
    
    # Combine matched pairs
    matched_data = pd.concat([matched_treated, matched_control])
    
    return matched_data

### Propensity Score

In [None]:
treatment_col = 'is_ms'

# Calculate propensity scores
propensity_scores, X_transformed = calculate_propensity_scores(psm_data, treatment_col)

In [None]:
# Plot the propensity score support
psm_data['propensity_score'] = propensity_scores

sns.kdeplot(data=psm_data, x='propensity_score', hue='is_ms')
plt.title('Density Plot of Propensity Scores by Treatment Status')
plt.xlabel('Propensity Score')
plt.ylabel('Density')
plt.show()

# Cumulative Distribution Function (CDF)
plt.figure(figsize=(12, 6))
sns.ecdfplot(data=psm_data, x='propensity_score', hue='is_ms')
plt.title('CDF of Propensity Scores by Treatment Status')
plt.xlabel('Propensity Score')
plt.ylabel('Cumulative Probability')
plt.show()

### Matching

In [None]:
# reset this for safety
psm_data = concat_demo.copy()

In [None]:
# Perform matching
matched_data = multi_covariate_adjusted_matching(psm_data, treatment_col, 
                                                 propensity_scores, caliper=0.1, 
                                                 n_neighbors=25, numeric_method='scaler', 
                                                 bin_features=['age_at_first_visit'])

rtime()

In [None]:
matched_data

In [None]:
# After matching, you can verify the results:
def verify_matching_results(matched_data, treatment_col):
    print("Matching Summary:")
    print("-" * 50)
    
    # Count unique patients in each group
    n_treated = matched_data[matched_data[treatment_col] == 1]['patientepicid'].nunique()
    n_control = matched_data[matched_data[treatment_col] == 0]['patientepicid'].nunique()
    
    print(f"Unique treated patients: {n_treated}")
    print(f"Unique control patients: {n_control}")
    print(f"Matching ratio (control:treated): {n_control/n_treated:.2f}:1")
    
    # Check for duplicates
    treated_dups = matched_data[matched_data[treatment_col] == 1]['patientepicid'].duplicated().sum()
    control_dups = matched_data[matched_data[treatment_col] == 0]['patientepicid'].duplicated().sum()
    
    print(f"\nDuplicate treated patients: {treated_dups}")
    print(f"Duplicate control patients: {control_dups}")
    
    return matched_data.drop_duplicates(subset=['patientepicid', treatment_col])

clean_matched_data = verify_matching_results(matched_data, 'is_ms')

In [None]:
# Summary table
def summary_stats_table(data, covariates, categorical, treatment_col):
    
    def round4(num):
        round(num, 3)
    
    rows = []
    
    # numeric covariates
    for covariate in covariates:
        mean_treated = data[data[treatment_col] == 1][covariate].mean()
        std_treated = data[data[treatment_col] == 1][covariate].std()
        mean_control = data[data[treatment_col] == 0][covariate].mean()
        std_control = data[data[treatment_col] == 0][covariate].std()

        smd = calculate_smd(data[data[treatment_col] == 1], 
                            data[data[treatment_col] == 0], 
                            [covariate])[covariate]

        rows.append({
            'Covariate': covariate,
            'Mean_Treated': mean_treated,
            'Std_Treated': std_treated,
            'Mean_Control': mean_control,
            'Std_Control': std_control,
            'SMD': smd
        })
    
    # categorical covariates
    for covariate in categorical:
        prop_treated = data[data[treatment_col] == 1][covariate].value_counts(normalize=True)
        prop_control = data[data[treatment_col] == 0][covariate].value_counts(normalize=True)
        
        smds = calculate_smd_categorical(prop_treated, prop_control, covariate)
        
        for category in prop_treated.index:
            rows.append({
                'Covariate': f"{covariate}_{category}",
                'Mean_Treated': prop_treated[category] * 100,
                'Std_Treated': np.nan,  # Standard deviation is not applicable for proportions
                'Mean_Control': prop_control[category] * 100,
                'Std_Control': np.nan,  # Standard deviation is not applicable for proportions
                'SMD': smds[category]
            })
    
    return pd.DataFrame(rows)


def calculate_smd(group1, group2, var_list):
    smds = {}
    for var in var_list:
        mean1 = group1[var].mean()
        mean2 = group2[var].mean()
        std1 = group1[var].std()
        std2 = group2[var].std()

        smd = abs(mean1 - mean2) / np.sqrt((std1**2 + std2**2) / 2)
        smds[var] = smd
    return smds


def calculate_smd_categorical(prop_treated, prop_control, covariate):
    smds = {}
    for category in prop_treated.index:
        prop_treated_val = prop_treated.get(category, 0)
        prop_control_val = prop_control.get(category, 0)
        
        smd = abs(prop_treated_val - prop_control_val) / np.sqrt((prop_treated_val * (1 - prop_treated_val) + prop_control_val * (1 - prop_control_val)) / 2)
        smds[category] = smd
    return smds


def plot_covariate_balance(data, covariates, treatment_col):
    for covariate in covariates:
        plt.figure(figsize=(8, 4))
        sns.kdeplot(data[data[treatment_col] == 1][covariate], label='Treated')
        sns.kdeplot(data[data[treatment_col] == 0][covariate], label='Control')
        plt.title(f'Distribution of {covariate} by Treatment Status')
        plt.xlabel(covariate)
        plt.ylabel('Density')
        plt.legend()
        plt.show()

In [None]:
covariates = ['follow_up', 'n_notes', 'age_at_first_visit']
categorical = ['gender_concept_id', 'race_concept_id']   # numerics that should be treated as categoricals
summary_stats = summary_stats_table(matched_data, covariates, categorical, 'is_ms')
np.round(summary_stats, 3)

In [None]:
# KDE
plot_covariate_balance(matched_data, covariates, 'is_ms')

In [None]:
# boxplots
for covariate in covariates:
    plt.figure(figsize=(8, 4))
    sns.boxplot(x='is_ms', y=covariate, data=matched_data)
    plt.title(f'Box Plot of {covariate} by Treatment Status')
    plt.show()

In [None]:
# Investigate specific covariates
col_to_plot = 'age_at_first_visit'


plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
sns.histplot(data=psm_data, x=col_to_plot, hue='is_ms', element='step', stat='density', common_norm=False, log_scale=10)
plt.title('Before Matching')

plt.subplot(1, 2, 2)
sns.histplot(data=matched_data, x=col_to_plot, hue='is_ms', element='step', stat='density', common_norm=False, log_scale=10)
plt.title('After Matching')

plt.show()


### Save the output

In [None]:
matched_data.to_csv("matched25_cohort.csv", index=False)

rtime()