In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "overview",
   "metadata": {},
   "source": [
    "# Trial Analysis Notebook\n",
    "\n",
    "This notebook demonstrates a simplified analysis pipeline for trial data. The workflow includes:\n",
    "\n",
    "- Setting up directories\n",
    "- Creating dummy (simulated) trial data\n",
    "- Defining trial objects and configuring weight models\n",
    "- Calculating weights for the trial data\n",
    "- Fitting an outcome model (using a Cox proportional hazards model)\n",
    "- Expanding the trial data\n",
    "- Fitting a marginal structural model (MSM)\n",
    "- Predicting survival and plotting survival differences\n",
    "\n",
    "Each section is broken into separate cells with detailed documentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "import-libraries",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import statsmodels.api as sm\n",
    "from lifelines import CoxPHFitter\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Enable inline plotting for notebooks\n",
    "%matplotlib inline\n",
    "\n",
    "print('Libraries imported successfully!')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "setup-directories",
   "metadata": {},
   "source": [
    "## 1. Set-up Directories\n",
    "\n",
    "We set up two directories for the per-protocol trial (`trial_pp`) and the intention-to-treat trial (`trial_itt`). These directories are created in a temporary directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-setup-directories",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 1. Set-up Directories\n",
    "# ---------------------------\n",
    "trial_pp_dir = os.path.join(os.getenv(\"TMPDIR\", \"/tmp\"), \"trial_pp\")\n",
    "os.makedirs(trial_pp_dir, exist_ok=True)\n",
    "trial_itt_dir = os.path.join(os.getenv(\"TMPDIR\", \"/tmp\"), \"trial_itt\")\n",
    "os.makedirs(trial_itt_dir, exist_ok=True)\n",
    "\n",
    "print('Directories created:')\n",
    "print('Per-protocol trial directory:', trial_pp_dir)\n",
    "print('ITT trial directory:', trial_itt_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "create-dummy-data",
   "metadata": {},
   "source": [
    "## 2. Create Dummy Data\n",
    "\n",
    "Here we create a dummy dataset named `data_censored` to simulate trial data. This dataset contains various columns such as `id`, `period`, `treatment`, `outcome`, `eligible`, and several covariates. The data is generated using NumPy's random functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-create-dummy-data",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 2. Create Dummy Data (data_censored)\n",
    "# ---------------------------\n",
    "np.random.seed(42)\n",
    "n_total = 1000\n",
    "n_patients = n_total // 10  # assume 10 periods per patient\n",
    "\n",
    "data_censored = pd.DataFrame({\n",
    "    \"id\": np.repeat(np.arange(1, n_patients + 1), 10),\n",
    "    \"period\": np.tile(np.arange(1, 11), n_patients),\n",
    "    \"treatment\": np.random.binomial(1, 0.5, n_total),\n",
    "    \"outcome\": np.random.binomial(1, 0.1, n_total),\n",
    "    \"eligible\": np.random.binomial(1, 0.8, n_total),\n",
    "    \"age\": np.random.normal(50, 10, n_total),\n",
    "    \"x1\": np.random.normal(0, 1, n_total),\n",
    "    \"x3\": np.random.normal(0, 1, n_total),\n",
    "    \"x2\": np.random.normal(0, 1, n_total),\n",
    "    \"censored\": np.random.binomial(1, 0.05, n_total)\n",
    "})\n",
    "print('Dummy data head:')\n",
    "print(data_censored.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "define-trial-objects",
   "metadata": {},
   "source": [
    "## 3. Define Trial Objects\n",
    "\n",
    "We define two trial objects, one for a per-protocol analysis (`trial_pp`) and one for an intention-to-treat analysis (`trial_itt`). In this example, we store the data and estimand information in dictionaries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-define-trial-objects",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 3. Define “Trial” Objects\n",
    "# ---------------------------\n",
    "trial_pp = {\"estimand\": \"PP\", \"data\": data_censored.copy()}\n",
    "trial_itt = {\"estimand\": \"ITT\", \"data\": data_censored.copy()}\n",
    "\n",
    "print('Trial objects created with estimands:')\n",
    "print('trial_pp:', trial_pp['estimand'])\n",
    "print('trial_itt:', trial_itt['estimand'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "set-weight-models",
   "metadata": {},
   "source": [
    "## 4. Set Weight Models\n",
    "\n",
    "We define placeholder functions for setting up switch weight models and censor weight models. These functions store formulas and the save path in the trial object. Later on, you might replace these placeholders with the actual modeling steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-set-weight-models",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 4. Set Weight Models (Placeholders)\n",
    "# ---------------------------\n",
    "def set_switch_weight_model(trial, numerator_formula, denominator_formula, save_path):\n",
    "    # Placeholder: store formulas and save path\n",
    "    trial['switch_weight_model'] = {\n",
    "        \"numerator_formula\": numerator_formula,\n",
    "        \"denominator_formula\": denominator_formula,\n",
    "        \"save_path\": save_path\n",
    "    }\n",
    "    return trial\n",
    "\n",
    "def set_censor_weight_model(trial, censor_event, numerator_formula, denominator_formula, pool_models, save_path):\n",
    "    trial['censor_weight_model'] = {\n",
    "        \"censor_event\": censor_event,\n",
    "        \"numerator_formula\": numerator_formula,\n",
    "        \"denominator_formula\": denominator_formula,\n",
    "        \"pool_models\": pool_models,\n",
    "        \"save_path\": save_path\n",
    "    }\n",
    "    return trial\n",
    "\n",
    "# For Per-protocol trial (trial_pp)\n",
    "trial_pp = set_switch_weight_model(\n",
    "    trial_pp,\n",
    "    numerator_formula=\"age\",\n",
    "    denominator_formula=\"age + x1 + x3\",\n",
    "    save_path=os.path.join(trial_pp_dir, \"switch_models\")\n",
    ")\n",
    "print(\"trial_pp switch weight model:\", trial_pp.get(\"switch_weight_model\"))\n",
    "\n",
    "trial_pp = set_censor_weight_model(\n",
    "    trial_pp,\n",
    "    censor_event=\"censored\",\n",
    "    numerator_formula=\"x2\",\n",
    "    denominator_formula=\"x2 + x1\",\n",
    "    pool_models=\"none\",\n",
    "    save_path=os.path.join(trial_pp_dir, \"switch_models\")\n",
    ")\n",
    "print(\"trial_pp censor weight model:\", trial_pp.get(\"censor_weight_model\"))\n",
    "\n",
    "# For ITT trial (trial_itt)\n",
    "trial_itt = set_censor_weight_model(\n",
    "    trial_itt,\n",
    "    censor_event=\"censored\",\n",
    "    numerator_formula=\"x2\",\n",
    "    denominator_formula=\"x2 + x1\",\n",
    "    pool_models=\"numerator\",\n",
    "    save_path=os.path.join(trial_itt_dir, \"switch_models\")\n",
    ")\n",
    "print(\"trial_itt censor weight model:\", trial_itt.get(\"censor_weight_model\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "calculate-weights",
   "metadata": {},
   "source": [
    "## 5. Calculate Weights (Simulation)\n",
    "\n",
    "In real applications, weights would be estimated by fitting the specified models. Here, we simply simulate weight values and add them as new columns to our trial data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-calculate-weights",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 5. Calculate Weights (Simulation)\n",
    "# ---------------------------\n",
    "def calculate_weights(trial):\n",
    "    # Dummy weight estimation\n",
    "    trial['data']['weight'] = np.random.uniform(0.8, 1.2, len(trial['data']))\n",
    "    trial['data']['sample_weight'] = np.random.uniform(0.9, 1.1, len(trial['data']))\n",
    "    return trial\n",
    "\n",
    "trial_pp = calculate_weights(trial_pp)\n",
    "trial_itt = calculate_weights(trial_itt)\n",
    "\n",
    "print(\"trial_pp weights added:\", trial_pp['data'][['weight', 'sample_weight']].head())\n",
    "print(\"trial_itt weights added:\", trial_itt['data'][['weight', 'sample_weight']].head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "set-outcome-model",
   "metadata": {},
   "source": [
    "## 6. Set Outcome Model\n",
    "\n",
    "We now define a function to fit an outcome model using a Cox proportional hazards model from the `lifelines` package. This model uses `period` as the duration variable and `outcome` as the event indicator. Optionally, adjustment terms can be added."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-set-outcome-model",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 6. Set Outcome Model\n",
    "# ---------------------------\n",
    "def set_outcome_model(trial, adjustment_terms=None):\n",
    "    # Fit a Cox proportional hazards model\n",
    "    cph = CoxPHFitter()\n",
    "    df = trial['data']\n",
    "    \n",
    "    # Use 'treatment' as the main predictor; add adjustments if provided\n",
    "    predictors = ['treatment']\n",
    "    if adjustment_terms:\n",
    "        # Simplified parsing of adjustment formula (e.g. \"~ x2\")\n",
    "        adjustments = adjustment_terms.replace(\"~\", \"\").strip().split(\" + \")\n",
    "        predictors += adjustments\n",
    "    \n",
    "    # Prepare the model dataframe\n",
    "    cols = ['period', 'outcome'] + predictors\n",
    "    df_model = df[cols].dropna()\n",
    "    \n",
    "    try:\n",
    "        cph.fit(df_model, duration_col='period', event_col='outcome')\n",
    "    except Exception as e:\n",
    "        print(\"Outcome model fitting error:\", e)\n",
    "    \n",
    "    trial['outcome_model'] = cph\n",
    "    return trial\n",
    "\n",
    "trial_pp = set_outcome_model(trial_pp)\n",
    "trial_itt = set_outcome_model(trial_itt, adjustment_terms=\"~ x2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "expansion-options",
   "metadata": {},
   "source": [
    "## 7. Set Expansion Options & Expand Trials\n",
    "\n",
    "In this section, we define functions to set expansion options and then expand the trial data. This simulates restructuring or replicating the data. We also include a dummy function to load the expanded data with an added control flag."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-expansion-options",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 7. Set Expansion Options & Expand Trials\n",
    "# ---------------------------\n",
    "def set_expansion_options(trial, output, chunk_size):\n",
    "    trial['expansion_options'] = {\"output\": output, \"chunk_size\": chunk_size}\n",
    "    return trial\n",
    "\n",
    "def save_to_datatable():\n",
    "    # Dummy function: in practice, this might store data in a specific format\n",
    "    return lambda df: df  # identity function\n",
    "\n",
    "trial_pp = set_expansion_options(trial_pp, output=save_to_datatable(), chunk_size=500)\n",
    "trial_itt = set_expansion_options(trial_itt, output=save_to_datatable(), chunk_size=500)\n",
    "\n",
    "def expand_trials(trial):\n",
    "    # For this example, simply store the original data as \"expanded.\"\n",
    "    trial['expansion'] = trial['data']\n",
    "    return trial\n",
    "\n",
    "trial_pp = expand_trials(trial_pp)\n",
    "trial_itt = expand_trials(trial_itt)\n",
    "print(\"Expanded trial_pp data shape:\", trial_pp['expansion'].shape)\n",
    "\n",
    "def load_expanded_data(trial, seed=1234, p_control=0.5):\n",
    "    np.random.seed(seed)\n",
    "    df = trial['expansion'].copy()\n",
    "    df['control_flag'] = np.random.binomial(1, p_control, len(df))\n",
    "    trial['expansion'] = df\n",
    "    return trial\n",
    "\n",
    "trial_itt = load_expanded_data(trial_itt, seed=1234, p_control=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fit-msm",
   "metadata": {},
   "source": [
    "## 8. Fit Marginal Structural Model (MSM)\n",
    "\n",
    "We now fit a marginal structural model. In this example, we modify the weights (using a winsorization function) and then fit a weighted Cox model as a placeholder for MSM fitting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-fit-msm",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 8. Fit Marginal Structural Model (MSM)\n",
    "# ---------------------------\n",
    "def fit_msm(trial, weight_cols, modify_weights):\n",
    "    # Apply the provided modification to weights\n",
    "    weights = trial['data']['weight']\n",
    "    trial['data']['modified_weight'] = modify_weights(weights)\n",
    "    \n",
    "    # Fit a weighted Cox model as a placeholder for MSM fitting\n",
    "    cph = CoxPHFitter()\n",
    "    try:\n",
    "        cph.fit(trial['data'], duration_col='period', event_col='outcome', weights_col='modified_weight')\n",
    "    except Exception as e:\n",
    "        print(\"MSM fitting error:\", e)\n",
    "    trial['msm'] = cph\n",
    "    return trial\n",
    "\n",
    "def winsorize_weights(w):\n",
    "    q99 = np.quantile(w, 0.99)\n",
    "    return np.minimum(w, q99)\n",
    "\n",
    "trial_itt = fit_msm(trial_itt, weight_cols=[\"weight\", \"sample_weight\"], modify_weights=winsorize_weights)\n",
    "\n",
    "print(\"MSM outcome model summary for trial_itt:\")\n",
    "print(trial_itt['msm'].summary)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "predict-survival",
   "metadata": {},
   "source": [
    "## 9. Predict Survival & Plot Differences\n",
    "\n",
    "Finally, we define a function to predict survival functions using the fitted MSM. We then compute the difference between survival curves for treatment groups and plot these differences along with dummy confidence bounds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code-predict-survival",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------------------------\n",
    "# 9. Predict Survival & Plot Differences\n",
    "# ---------------------------\n",
    "def predict_survival(trial, newdata, predict_times):\n",
    "    # Use the fitted Cox model to predict survival functions\n",
    "    cph = trial['msm']\n",
    "    surv_funcs = cph.predict_survival_function(newdata)\n",
    "    \n",
    "    # Compare survival curves by treatment group\n",
    "    treatment_1_idx = newdata[newdata['treatment'] == 1].index\n",
    "    treatment_0_idx = newdata[newdata['treatment'] == 0].index\n",
    "    \n",
    "    surv_t1 = surv_funcs[treatment_1_idx].mean(axis=1)\n",
    "    surv_t0 = surv_funcs[treatment_0_idx].mean(axis=1)\n",
    "    \n",
    "    survival_diff = surv_t1 - surv_t0\n",
    "    # Dummy confidence intervals for illustration\n",
    "    lower_ci = survival_diff * 0.95\n",
    "    upper_ci = survival_diff * 1.05\n",
    "    \n",
    "    return pd.DataFrame({\n",
    "        \"followup_time\": surv_funcs.index,\n",
    "        \"survival_diff\": survival_diff,\n",
    "        \"lower\": lower_ci,\n",
    "        \"upper\": upper_ci\n",
    "    })\n",
    "\n",
    "# Select newdata from the expanded ITT data where period == 1\n",
    "newdata = trial_itt['expansion'][trial_itt['expansion']['period'] == 1]\n",
    "\n",
    "preds = predict_survival(trial_itt, newdata, predict_times=np.arange(0, 11))\n",
    "print('Predicted survival differences head:')\n",
    "print(preds.head())\n",
    "\n",
    "# Plot survival difference and confidence bounds\n",
    "plt.plot(preds[\"followup_time\"], preds[\"survival_diff\"], label=\"Survival Difference\")\n",
    "plt.plot(preds[\"followup_time\"], preds[\"lower\"], linestyle=\"--\", color=\"red\", label=\"2.5% CI\")\n",
    "plt.plot(preds[\"followup_time\"], preds[\"upper\"], linestyle=\"--\", color=\"red\", label=\"97.5% CI\")\n",
    "plt.xlabel(\"Follow up\")\n",
    "plt.ylabel(\"Survival difference\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.x"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
