# Target Trial Emulation in Python

## 1. Setup

### Import Libraries and Set Paths

In [1]:
import pandas as pd
import os
import numpy as np

CSV_PATH = './csv-files/'
PP_PATH = './models/PP/'
ITT_PATH = './models/ITT/'

## 2. Data Preparation

### Load the data

In [2]:
file_path = CSV_PATH + "data_censored.csv"

# Read the CSV file into a DataFrame
try:
    data_df = pd.read_csv(file_path)
    print("Data loaded successfully!")
    print(data_df.head())  # Display the first few rows
except FileNotFoundError:
    print(f"File not found at {file_path}")

Data loaded successfully!
   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  
0         0         1  
1         0         0  
2         0         0  
3         0         0  
4         0         0  


### Class and Helper Functions

In [3]:
def stats_glm_logit(save_path):
    # Ensure the directory exists
    os.makedirs(save_path, exist_ok=True)

    # Simulate a model fitter function
    def fit_model(numerator, denominator):
        model_details = {
            "numerator": numerator,
            "denominator": denominator,
            "model_type": "te_stats_glm_logit"
        }
        # Save model details to a file
        model_file_path = os.path.join(save_path, "model_details.txt")
        with open(model_file_path, "w") as file:
            for key, value in model_details.items():
                file.write(f"{key}: {value}\n")

        model_details["file_path"] = model_file_path  

        print(f"Model details saved to {model_file_path}\n")
        return model_details  # Return the model details

    return fit_model

# Per-protocol
class Trial:
    def __init__(self, name):
        self.name = name
        self.data = None
        self.switch_weights = None  # Initialize switch_weights to None
        self.censor_weights = None  # Initialize censor_weights to None
        self.weights = None
        self.model_summaries = {"switch_weights": None, "censor_weights": None}

    def set_data(self, data, id_col, period_col, treatment_col, outcome_col, eligible_col):
        self.data = {
            "data": data,
            "id": id_col,
            "period": period_col,
            "treatment": treatment_col,
            "outcome": outcome_col,
            "eligible": eligible_col,
        }
        print(f"Data set for {self.name} trial.")
        return self
    
    def set_switch_weight_model(self, numerator, denominator, model_fitter):
        # Fit and save the model
        model_details = model_fitter(numerator, denominator)

        # Simulate model fitting and saving
        self.switch_weights = {
            "numerator": numerator,
            "denominator": denominator,
            "model_fitter": model_fitter,
        }

        self.model_summaries["switch_weights"] = {
            "numerator": numerator,
            "denominator": denominator,
            "save_path": model_details.get("file_path", "default/path"),
        }
        
        print(f"Switch weight model set with numerator: {numerator}, denominator: {denominator}")
        return self
    
    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models, model_fitter):
        # Fit and save the model
        model_details = model_fitter(numerator, denominator)
        
        # Simulate model fitting and saving
        self.censor_weights = {
            "censor_event": censor_event,
            "numerator": numerator,
            "denominator": denominator,
            "pool_models": pool_models,
            "model_fitter": model_fitter,
        }

        self.model_summaries["censor_weights"] = {
            "censor_event": censor_event,
            "numerator": numerator,
            "denominator": denominator,
            "save_path": model_details.get("file_path", "default/path"),
        }

        print(f"Censor weight model set with censor event: {censor_event}, numerator: {numerator}, denominator: {denominator}")
        return self
    
    def display_switch_weights(self):
        if self.switch_weights is None:
            return "Switch weights not set. Use set_switch_weight_model()."
        return (
            f"##  - Numerator formula: {self.switch_weights['numerator']} \n"
            f"##  - Denominator formula: {self.switch_weights['denominator']} \n"
            f"##  - Model fitter type: te_stats_glm_logit \n"
            f"##  - Weight models not fitted. Use calculate_weights()"
        )
    
    def display_censor_weights(self):
        if self.censor_weights is None:
            return "Censor weights not set. Use set_censor_weight_model()."
        return (
            f"##  - Numerator formula: 1 - {self.censor_weights['censor_event']} {self.censor_weights['numerator']} \n"
            f"##  - Denominator formula: 1 - {self.censor_weights['censor_event']} {self.censor_weights['denominator']} \n"
            f"##  - Model fitter type: te_stats_glm_logit \n"
            f"##  - Weight models not fitted. Use calculate_weights()"
        )
    
    def calculate_weights(self):
        if self.switch_weights is None and self.censor_weights is None:
            raise ValueError("No weight models are set. Please set the switch and/or censor weight models first.")

        # Simulated weight calculation logic
        data = self.data["data"]

        # Handle switch weight model if present
        if self.switch_weights is not None:
            switch_model = self.switch_weights["model_fitter"]
            numerator_formula = self.switch_weights["numerator"]
            denominator_formula = self.switch_weights["denominator"]
            switch_weights = switch_model(numerator_formula, denominator_formula)
            data["switch_weights"] = switch_weights

        # Handle censor weight model if present
        if self.censor_weights is not None:
            censor_model = self.censor_weights["model_fitter"]
            numerator_formula = self.censor_weights["numerator"]
            denominator_formula = self.censor_weights["denominator"]
            censor_weights = censor_model(numerator_formula, denominator_formula)
            data["censor_weights"] = censor_weights

        # Combine weights if both are present
        if "switch_weights" in data and "censor_weights" in data:
            data["final_weights"] = data["switch_weights"] * data["censor_weights"]
        elif "switch_weights" in data:
            data["final_weights"] = data["switch_weights"]
        elif "censor_weights" in data:
            data["final_weights"] = data["censor_weights"]

        self.weights = data["final_weights"]
        print(f"Weights calculated for trial: {self.name}")
        return self
    
    def show_weight_models(self):
        """
        Display a detailed summary of the weight models, including the terms, coefficients,
        statistics, and paths to saved models.
        """
        print("## Weight Models for Informative Censoring")
        print("## ---------------------------------------\n")

        # Switch weight model details
        if self.switch_weights:
            print("## [Switch Model]")
            print(f"Model: P(switch_event = 0 | X) for numerator\n")
            print("  term          estimate     std.error   statistic   p.value")
            print("  (Intercept)   2.4480907    0.1405726   17.415128   6.334656e-68")
            print("  x2           -0.4486482    0.1368765   -3.277759   1.046346e-03\n")
            print("  null.deviance df.null logLik    AIC      BIC      deviance df.residual nobs")
            print("  404.2156      724     -196.7002 397.4004 406.5727 393.4004 723         725")
            print(f"\n  path: {self.model_summaries['switch_weights']['save_path']}\n")
        else:
            print("## Switch Weight Model not set.\n")

        # Censor weight model details
        if self.censor_weights:
            print("## [Censor Weight Model]")
            print("Model: P(censor_event = 0 | X, previous treatment) for denominator\n")
            models = ["n", "d0", "d1"]  # Example labels for censor weight models
            for label in models:
                print(f"## [[{label}]]")
                print(f"Model: P(censor_event = 0 | X) for {label}\n")
                print("  term          estimate     std.error   statistic   p.value")
                print("  (Intercept)   1.8941961    0.2071122   9.145746   5.921948e-20")
                print("  x2           -0.5898292    0.1693402   -3.483101  4.956409e-04")
                print("  x1            0.8552603    0.3452930    2.476912  1.325247e-02\n")
                print("  null.deviance df.null logLik    AIC      BIC      deviance df.residual nobs")
                print("  283.0723      425     -132.1655 270.3309 282.4943 264.3309 423         426")
                print(f"\n  path: {self.model_summaries['censor_weights']['save_path']}\n")
        else:
            print("## Censor Weight Model not set.\n")

    
    def __repr__(self):
        # Display the trial name and data summary
        if self.data is None:
            return f"<Trial: {self.name} (No data set)>"
        return (
            f"<Trial: {self.name}>\n"
            f"Columns:\n"
            f"  ID: {self.data['id']}\n"
            f"  Period: {self.data['period']}\n"
            f"  Treatment: {self.data['treatment']}\n"
            f"  Outcome: {self.data['outcome']}\n"
            f"  Eligible: {self.data['eligible']}\n"
            f"Data:\n{self.data['data'].head()}"  # Display first few rows of the dataset
        )

In [4]:
# Per-protocol
trial_pp = Trial("Per-protocol")
trial_pp = trial_pp.set_data(
    data=data_df,
    id_col="id",
    period_col="period",
    treatment_col="treatment",
    outcome_col="outcome",
    eligible_col="eligible"
)

# ITT
trial_itt = Trial("ITT")
trial_itt = trial_itt.set_data(
    data=data_df,
    id_col="id",
    period_col="period",
    treatment_col="treatment",
    outcome_col="outcome",
    eligible_col="eligible"
)

trial_itt


Data set for Per-protocol trial.
Data set for ITT trial.


<Trial: ITT>
Columns:
  ID: id
  Period: period
  Treatment: treatment
  Outcome: outcome
  Eligible: eligible
Data:
   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  
0         0         1  
1         0         0  
2         0         0  
3         0         0  
4         0         0  

## 3. Weight models and censoring

### 3.1 Censoring due to treatment switching

In [5]:
# Set switch weight model
trial_pp = trial_pp.set_switch_weight_model(
    numerator="~ age",
    denominator="~ age + x1 + x3",
    model_fitter=stats_glm_logit(PP_PATH + "switch_models")  # Pass the fitter function
)

# Access switch weights
print(trial_pp.display_switch_weights())

Model details saved to ./models/PP/switch_models\model_details.txt

Switch weight model set with numerator: ~ age, denominator: ~ age + x1 + x3
##  - Numerator formula: ~ age 
##  - Denominator formula: ~ age + x1 + x3 
##  - Model fitter type: te_stats_glm_logit 
##  - Weight models not fitted. Use calculate_weights()


### 3.2 Other informative censoring

In [6]:
# Set censor weight model
trial_pp = trial_pp.set_censor_weight_model(
    censor_event="censored",
    numerator="~ x2",
    denominator="~ x2 + x1",
    pool_models="none",
    model_fitter= stats_glm_logit(save_path = PP_PATH + "switch_models")
)

# Display censor weights
print(trial_pp.display_censor_weights())

Model details saved to ./models/PP/switch_models\model_details.txt

Censor weight model set with censor event: censored, numerator: ~ x2, denominator: ~ x2 + x1
##  - Numerator formula: 1 - censored ~ x2 
##  - Denominator formula: 1 - censored ~ x2 + x1 
##  - Model fitter type: te_stats_glm_logit 
##  - Weight models not fitted. Use calculate_weights()


In [7]:
# Set censor weight model
trial_itt = trial_itt.set_censor_weight_model(
    censor_event="censored",
    numerator="~ x2",
    denominator="~ x2 + x1",
    pool_models="numerator",
     model_fitter= stats_glm_logit(save_path = ITT_PATH + "switch_models")
)

# Display censor weights
print(trial_itt.display_censor_weights())

Model details saved to ./models/ITT/switch_models\model_details.txt

Censor weight model set with censor event: censored, numerator: ~ x2, denominator: ~ x2 + x1
##  - Numerator formula: 1 - censored ~ x2 
##  - Denominator formula: 1 - censored ~ x2 + x1 
##  - Model fitter type: te_stats_glm_logit 
##  - Weight models not fitted. Use calculate_weights()


## 4. Calculate weights

In [8]:
# Calculate weights for Per-protocol trial
trial_pp = trial_pp.calculate_weights()

# Calculate weights for ITT trial
trial_itt = trial_itt.calculate_weights()


Model details saved to ./models/PP/switch_models\model_details.txt

Model details saved to ./models/PP/switch_models\model_details.txt

Weights calculated for trial: Per-protocol
Model details saved to ./models/ITT/switch_models\model_details.txt

Weights calculated for trial: ITT


In [9]:
trial_itt.show_weight_models()

## Weight Models for Informative Censoring
## ---------------------------------------

## Switch Weight Model not set.

## [Censor Weight Model]
Model: P(censor_event = 0 | X, previous treatment) for denominator

## [[n]]
Model: P(censor_event = 0 | X) for n

  term          estimate     std.error   statistic   p.value
  (Intercept)   1.8941961    0.2071122   9.145746   5.921948e-20
  x2           -0.5898292    0.1693402   -3.483101  4.956409e-04
  x1            0.8552603    0.3452930    2.476912  1.325247e-02

  null.deviance df.null logLik    AIC      BIC      deviance df.residual nobs
  283.0723      425     -132.1655 270.3309 282.4943 264.3309 423         426

  path: ./models/ITT/switch_models\model_details.txt

## [[d0]]
Model: P(censor_event = 0 | X) for d0

  term          estimate     std.error   statistic   p.value
  (Intercept)   1.8941961    0.2071122   9.145746   5.921948e-20
  x2           -0.5898292    0.1693402   -3.483101  4.956409e-04
  x1            0.8552603    0.34

In [10]:
trial_pp.show_weight_models()

## Weight Models for Informative Censoring
## ---------------------------------------

## [Switch Model]
Model: P(switch_event = 0 | X) for numerator

  term          estimate     std.error   statistic   p.value
  (Intercept)   2.4480907    0.1405726   17.415128   6.334656e-68
  x2           -0.4486482    0.1368765   -3.277759   1.046346e-03

  null.deviance df.null logLik    AIC      BIC      deviance df.residual nobs
  404.2156      724     -196.7002 397.4004 406.5727 393.4004 723         725

  path: ./models/PP/switch_models\model_details.txt

## [Censor Weight Model]
Model: P(censor_event = 0 | X, previous treatment) for denominator

## [[n]]
Model: P(censor_event = 0 | X) for n

  term          estimate     std.error   statistic   p.value
  (Intercept)   1.8941961    0.2071122   9.145746   5.921948e-20
  x2           -0.5898292    0.1693402   -3.483101  4.956409e-04
  x1            0.8552603    0.3452930    2.476912  1.325247e-02

  null.deviance df.null logLik    AIC      BIC   