In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Create sample data
np.random.seed(42)
n_patients = 100

# Generate synthetic data
patient_ids = np.arange(n_patients)
times = np.random.uniform(0, 100, n_patients)
treatments = np.random.binomial(1, 0.3, n_patients)
age = np.random.normal(65, 10, n_patients)
blood_pressure = np.random.normal(120, 15, n_patients)
glucose = np.random.normal(100, 20, n_patients)

# Combine covariates into a single array
covariates = np.column_stack([age, blood_pressure, glucose])

# Standardize covariates
covariates_std = (covariates - np.mean(covariates, axis=0)) / np.std(covariates, axis=0)

# Calculate covariance matrix and its inverse
cov_matrix = np.cov(covariates_std.T)
cov_inv = np.linalg.pinv(cov_matrix)

def mahalanobis_distance(x, y, cov_inv):
    """Calculate Mahalanobis distance between two points"""
    diff = x - y
    return np.sqrt(diff.dot(cov_inv).dot(diff))

# Perform matching
matched_pairs = []
used_controls = set()

# Sort by time
sort_idx = np.argsort(times)
times = times[sort_idx]
treatments = treatments[sort_idx]
patient_ids = patient_ids[sort_idx]
covariates_std = covariates_std[sort_idx]
covariates = covariates[sort_idx]

# Find treated patients
treated_indices = np.where(treatments == 1)[0]

for treated_idx in treated_indices:
    treated_time = times[treated_idx]
    treated_covs = covariates_std[treated_idx]
    
    # Find eligible controls
    control_mask = (
        (treatments == 0) & 
        (times <= treated_time) & 
        ~np.isin(patient_ids, list(used_controls))
    )
    
    if not np.any(control_mask):
        continue
    
    # Calculate distances to all eligible controls
    control_indices = np.where(control_mask)[0]
    distances = np.array([
        mahalanobis_distance(treated_covs, covariates_std[idx], cov_inv)
        for idx in control_indices
    ])
    
    # Find best match
    best_control_idx = control_indices[np.argmin(distances)]
    min_distance = distances[np.argmin(distances)]
    
    # Add match if within caliper
    if min_distance <= 2.0:
        matched_pairs.append((
            patient_ids[treated_idx],
            patient_ids[best_control_idx],
            min_distance
        ))
        used_controls.add(patient_ids[best_control_idx])

# Get indices for matched patients
matched_treated_ids = [pair[0] for pair in matched_pairs]
matched_control_ids = [pair[1] for pair in matched_pairs]

treated_mask = np.isin(patient_ids, matched_treated_ids)
control_mask = np.isin(patient_ids, matched_control_ids)

# Create visualization of covariate distributions
plt.figure(figsize=(15, 5))
variables = ['Age', 'Blood Pressure', 'Glucose']

for i, (var_name, var_data) in enumerate(zip(variables, covariates.T)):
    plt.subplot(1, 3, i+1)
    
    # Plot treated group
    plt.hist(var_data[treated_mask], alpha=0.5, bins=15, 
             label='Treated', color='blue')
    
    # Plot control group
    plt.hist(var_data[control_mask], alpha=0.5, bins=15,
             label='Control', color='orange')
    
    plt.title(f'{var_name} Distribution')
    plt.xlabel(var_name)
    plt.ylabel('Count')
    plt.legend()

plt.tight_layout()
plt.show()

# Plot matching distances
plt.figure(figsize=(10, 5))
distances = [pair[2] for pair in matched_pairs]
plt.hist(distances, bins=20, color='green', alpha=0.6)
plt.title('Distribution of Matching Distances')
plt.xlabel('Mahalanobis Distance')
plt.ylabel('Count')
plt.axvline(np.mean(distances), color='red', linestyle='--', 
            label=f'Mean Distance: {np.mean(distances):.2f}')
plt.legend()
plt.show()

# Plot time distribution
plt.figure(figsize=(10, 5))
plt.scatter(times[treated_mask], patient_ids[treated_mask], 
           label='Treated', alpha=0.6, color='blue')
plt.scatter(times[control_mask], patient_ids[control_mask], 
           label='Control', alpha=0.6, color='orange')

# Draw lines connecting matched pairs
for treated_id, control_id, _ in matched_pairs:
    treated_idx = np.where(patient_ids == treated_id)[0][0]
    control_idx = np.where(patient_ids == control_id)[0][0]
    plt.plot([times[treated_idx], times[control_idx]], 
             [patient_ids[treated_idx], patient_ids[control_idx]], 
             'gray', alpha=0.2)

plt.title('Time Distribution of Matched Pairs')
plt.xlabel('Time')
plt.ylabel('Patient ID')
plt.legend()
plt.show()

# Print summary statistics
print("\nMatching Summary:")
print(f"Total number of treated patients: {np.sum(treatments == 1)}")
print(f"Number of successful matches: {len(matched_pairs)}")
print(f"Average matching distance: {np.mean(distances):.3f}")

print("\nCovariate Balance:")
print("Variable     | Treated Mean | Control Mean | Std. Diff.")
print("-" * 55)
for i, name in enumerate(variables):
    t_mean = np.mean(covariates[treated_mask, i])
    c_mean = np.mean(covariates[control_mask, i])
    pooled_std = np.sqrt((np.var(covariates[treated_mask, i]) + 
                         np.var(covariates[control_mask, i])) / 2)
    std_diff = (t_mean - c_mean) / pooled_std
    print(f"{name:12s} | {t_mean:11.2f} | {c_mean:11.2f} | {std_diff:10.2f}")