In [None]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import mahalanobis
from sklearn.preprocessing import StandardScaler
import warnings
import matplotlib.pyplot as plt

%matplotlib inline


In [None]:
from typing import List, Tuple, Optional

class RiskSetMatcher:
    """
    A class to perform Risk Set Matching for observational studies.
    """

    def __init__(self, 
                 time_column: str = 'time',
                 treatment_column: str = 'treatment',
                 patient_id_column: str = 'patient_id',
                 covariate_columns: Optional[List[str]] = None):
        """
        Initialize the RiskSetMatcher.
        """
        self.time_column = time_column
        self.treatment_column = treatment_column
        self.patient_id_column = patient_id_column
        self.covariate_columns = covariate_columns
        self.scaler = StandardScaler()
        self.cov_inv = None

    def fit(self, data: pd.DataFrame) -> None:
        """
        Fit the matcher to the data by computing the covariance matrix and scaling parameters.
        """
        if self.covariate_columns is None:
            # Automatically detect numeric columns if not specified
            self.covariate_columns = data.select_dtypes(include=[np.number]).columns.tolist()
            self.covariate_columns = [col for col in self.covariate_columns 
                                      if col not in [self.time_column, self.treatment_column, self.patient_id_column]]
                
        # Standardize covariates
        self.scaler.fit(data[self.covariate_columns])

        # Calculate inverse covariance matrix for Mahalanobis distance
        scaled_data = self.scaler.transform(data[self.covariate_columns])
        cov_matrix = np.cov(scaled_data, rowvar=False)
        try:
            self.cov_inv = np.linalg.inv(cov_matrix)
        except np.linalg.LinAlgError:
            warnings.warn("Covariance matrix is singular. Using pseudoinverse instead.")
            self.cov_inv = np.linalg.pinv(cov_matrix)

    def _calculate_distances(self, treated_patient: pd.Series, potential_controls: pd.DataFrame) -> np.ndarray:
        """
        Calculate Mahalanobis distances between treated patient and potential controls.
        """
        # Convert treated_patient to DataFrame with correct feature names
        treated_covs = self.scaler.transform(
            treated_patient[self.covariate_columns].to_frame().T
        )

        control_covs = self.scaler.transform(
            potential_controls[self.covariate_columns]
        )

        diff = control_covs - treated_covs
        distances = np.sqrt(np.sum(diff.dot(self.cov_inv) * diff, axis=1))
        return distances

    def match(self, data: pd.DataFrame, caliper: Optional[float] = None) -> List[Tuple]:
        """
        Perform risk set matching on the data.
        """
        if self.cov_inv is None:
            self.fit(data)

        matched_pairs = []
        # Commented out used_controls to allow matching with replacement
        # used_controls = set()

        # Sort by time to ensure proper risk set construction
        data = data.sort_values(self.time_column)

        treated_patients = data[data[self.treatment_column] == 1]

        for _, treated in treated_patients.iterrows():
            # Find eligible controls (not yet treated)
            potential_controls = data[
                (data[self.treatment_column] == 0) &
                (data[self.time_column] <= treated[self.time_column])
                # & (~data[self.patient_id_column].isin(used_controls))  # Allow controls to be reused
            ]

            if len(potential_controls) == 0:
                warnings.warn(f"No eligible controls found for treated patient {treated[self.patient_id_column]}")
                continue

            # Calculate distances
            distances = self._calculate_distances(treated, potential_controls)

            # Apply caliper if specified
            if caliper is not None:
                valid_matches = distances <= caliper
                if not any(valid_matches):
                    warnings.warn(f"No controls within caliper for treated patient {treated[self.patient_id_column]}")
                    continue
                potential_controls = potential_controls.iloc[valid_matches]
                distances = distances[valid_matches]

            # Find best match
            best_match_idx = np.argmin(distances)
            matched_control = potential_controls.iloc[best_match_idx]

            matched_pairs.append((
                treated[self.patient_id_column],
                matched_control[self.patient_id_column],
                distances[best_match_idx]
            ))

            # Allowing controls to be matched multiple times
            # used_controls.add(matched_control[self.patient_id_column])

        return matched_pairs

    def assess_balance(self, data: pd.DataFrame, matched_pairs: List[Tuple]) -> pd.DataFrame:
        """
        Assess covariate balance between matched treated and control groups.
        """
        treated_ids = [pair[0] for pair in matched_pairs]
        control_ids = [pair[1] for pair in matched_pairs]

        treated_data = data[data[self.patient_id_column].isin(treated_ids)]
        control_data = data[data[self.patient_id_column].isin(control_ids)]

        balance_stats = []

        for col in self.covariate_columns:
            treated_mean = treated_data[col].mean()
            treated_std = treated_data[col].std()
            control_mean = control_data[col].mean()
            control_std = control_data[col].std()

            # Compute standardized mean difference
            pooled_std = np.sqrt((treated_std ** 2 + control_std ** 2) / 2)
            std_diff = (treated_mean - control_mean) / pooled_std

            balance_stats.append({
                'covariate': col,
                'treated_mean': treated_mean,
                'control_mean': control_mean,
                'std_diff': std_diff,
                'treated_std': treated_std,
                'control_std': control_std
            })

        return pd.DataFrame(balance_stats)
