In [13]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import mahalanobis
from pulp import LpProblem, LpVariable, LpMinimize, lpSum
import seaborn as sns
import matplotlib.pyplot as plt

In [14]:
# Define functions
def create_risk_sets(data, treatment_col, time_col):
    """Create risk sets for treated patients by finding controls who could still be matched."""
    risk_sets = {}
    treated = data[data[treatment_col] == 1]
    for index, treated_row in treated.iterrows():
        t = treated_row[time_col]
        untreated = data[(data[treatment_col] == 0) & (data[time_col] >= t)]
        risk_sets[index] = untreated.index.tolist()
    return risk_sets

In [15]:
def calculate_mahalanobis_matrix(data, cov_matrix):
    """Compute Mahalanobis distance for all patient pairs."""
    dist_matrix = pd.DataFrame(index=data.index, columns=data.index)
    for i in data.index:
        for j in data.index:
            dist = mahalanobis(data.loc[i], data.loc[j], np.linalg.inv(cov_matrix))
            dist_matrix.loc[i, j] = dist
    return dist_matrix

In [16]:
def optimal_balanced_matching(distance_matrix, risk_sets):
    """Perform optimal matching using linear programming."""
    prob = LpProblem("Optimal_Matching", LpMinimize)
    match_vars = {(t, c): LpVariable(f"match_{t}_{c}", 0, 1, cat="Binary") for t, controls in risk_sets.items() for c in controls}
    prob += lpSum(match_vars[t, c] * distance_matrix.loc[t, c] for t, controls in risk_sets.items() for c in controls)
    for t in risk_sets:
        prob += lpSum(match_vars[t, c] for c in risk_sets[t]) == 1
    all_controls = set(c for controls in risk_sets.values() for c in controls)
    for c in all_controls:
        prob += lpSum(match_vars[t, c] for t in risk_sets if c in risk_sets[t]) <= 1
    prob.solve()
    matches = {t: c for t, controls in risk_sets.items() for c in controls if match_vars[t, c].value() == 1}
    return matches


In [17]:

def sensitivity_analysis(matches, bias_factor=1.0):
    """Perform sensitivity analysis on matched pairs."""
    return {(t, c): 1 / (1 + bias_factor) for t, c in matches.items()}
    

In [18]:
def visualize_matching(data, matches):
    """Visualize the matching results using a scatter plot."""
    plt.figure(figsize=(8, 6))
    treated = data[data['Treatment'] == 1]
    control = data[data['Treatment'] == 0]
    sns.scatterplot(x=treated['Covariate1'], y=treated['Covariate2'], color='red', label='Treated')
    sns.scatterplot(x=control['Covariate1'], y=control['Covariate2'], color='blue', label='Control')
    for t, c in matches.items():
        plt.plot([data.loc[t, 'Covariate1'], data.loc[c, 'Covariate1']],
                 [data.loc[t, 'Covariate2'], data.loc[c, 'Covariate2']], 'k--', lw=0.5)
    plt.xlabel("Covariate 1")
    plt.ylabel("Covariate 2")
    plt.title("Matched Pairs Visualization")
    plt.legend()
    plt.show()

In [19]:
# Load and preprocess data
data = pd.DataFrame({
    "Patient_ID": [1, 2, 3, 4, 5, 6, 7, 8],  # Example patient IDs
    "Treatment": [1, 0, 1, 0, 1, 0, 1, 0],   # 1 = Treated, 0 = Control
    "Time": [5, 6, 8, 9, 4, 7, 10, 12],      # Treatment or observation time
    "Covariate1": [2.3, 2.1, 2.5, 2.4, 2.2, 2.0, 2.6, 2.7],
    "Covariate2": [3.5, 3.6, 3.4, 3.5, 3.3, 3.2, 3.7, 3.8]
})

In [20]:

# Compute risk sets and distance matrix
risk_sets = create_risk_sets(data, "Treatment", "Time")
cov_matrix = np.cov(data["Covariate1"], data["Covariate2"])
distance_matrix = calculate_mahalanobis_matrix(data[["Covariate1", "Covariate2"]], cov_matrix)

In [21]:

# Perform matching
matches = optimal_balanced_matching(distance_matrix, risk_sets)
    

In [22]:

# Perform sensitivity analysis
sensitivity_scores = sensitivity_analysis(matches)


# Visualize results
visualize_matching(data, matches)
