<img src="../assets/tumor_twin.png" alt="Tumor Twin" width="500"/>

# Triple negative breast cancer (TNBC) Demo
This demo showcases an end-to-end image-guided digital twin workflow with `TumorTwin`.

---
## 📚 Table of Contents
- [Step 1: Data Loading](#step-1-data-loading-and-preprocessing)
- [Step 2: Create Tumor Growth Model](#step-2-create-tumor-growth-model)
- [Step 3: Create a Solver object](#step-3-create-a-solver-object)
- [Step 4: Make a prediction](#step-4-make-a-prediction)
- [Step 5 (Optional): Compute a quantity of interest and its gradient](#step-5-optional-compute-a-quantity-of-interest-and-its-gradient)
- [Step 6: Compare the model prediction to patient data](#step-6-compare-the-model-prediction-to-patient-data)
- [Step 7: Calibrate the model to patient data via numerical optimization](#step-7-calibrate-the-model-to-patient-data-via-numerical-optimization)
- [Step 8: Predict patient response under alternative treatment plan](#step-8-predict-patient-response-under-alternative-treatment-plan)
- [Conclusion & Discussion questions](#conclusion--discussion-questions)


---

In [1]:
%load_ext autoreload
%autoreload 2
## Imports...
import os
from datetime import timedelta
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from dotenv import load_dotenv
from pydantic import FilePath
from rich import print

from tumortwin.models import ReactionDiffusion3D
from tumortwin.optimizers import LMoptimizer, LMoptions
from tumortwin.postprocessing import (
    compute_total_cell_count,
    plot_calibration,
    plot_calibration_iter,
    plot_cellularity_map,
    plot_imaging_summary,
    plot_loss,
    plot_maps_final,
    plot_measured_TCC,
    plot_patient_timeline,
    plot_predicted_TCC,
)
from tumortwin.preprocessing import ADC_to_cellularity, compute_carrying_capacity
from tumortwin.solvers import TorchDiffEqSolver, TorchDiffEqSolverOptions
from tumortwin.types import (
    ChemotherapySpecification,
    CropSettings,
    CropTarget,
    TNBCPatientData,
)
from tumortwin.utils import daterange, days_since_first

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%matplotlib inline
font = {
    "weight": "normal",
    "size": 10,
}
matplotlib.rc("font", **font)
matplotlib.rc("figure", dpi=300)
matplotlib.rc("savefig", dpi=300)

---
## Step 1: Data Loading
### Create PatientData object

Here we load in the input dataset. You will need to ensure that the relevant input filepaths are set correctly.

In [2]:
PATIENT_INFO_PATH = FilePath("../input_files/TNBC_demo_001/TNBC_demo_001.json")
IMAGE_PATH = FilePath("../input_files/TNBC_demo_001")
crop_settings = CropSettings(crop_to=CropTarget.ROI_ENHANCE, padding=10, visit_index=-1)
patient_data = TNBCPatientData.from_file(
    PATIENT_INFO_PATH, image_dir=IMAGE_PATH, crop_settings=crop_settings
)
measured_cellularity_maps = [
    ADC_to_cellularity(
        visit.adc_image, visit.roi_enhance_image
    )
    for visit in patient_data.visits
]
k = torch.tensor(0.025, requires_grad=True, device=device)
d = torch.tensor(0.05, requires_grad=True, device=device)
theta = torch.tensor(1.0, requires_grad=False, device=device)
ct = ChemotherapySpecification(
    sensitivity=0.2,
    decay_rate=0.7,
    times=[c.time for c in patient_data.chemotherapy],
    doses=[c.dose for c in patient_data.chemotherapy],
)
# print(ct)
model = ReactionDiffusion3D(
    k=k,
    d=d,
    theta=theta,
    patient_data=patient_data,
    initial_time=patient_data.visits[0].time,
    chemotherapy_specifications=[ct],
    radiotherapy_specification=None,
)

---
## Step 3: Create a Solver object
Under the hood, our `ReactionDiffusion3D` model is actually a spatially discretized version of the reaction-diffusion PDE. Solving this semi-discrete model requires solving a large system of ordinary differential equations (ODEs). We do this via the `torchdiffeq` library, using the `TorchDiffEqSolver` object.

The `TorchDiffEqSolverOptions` object contains standard solver options such as the method to use (e.g. fourth-order Runge-Kutta or "rk4"), the timestep to use for solving, and whether to use the adjoint method for gradient computations (`False` would resort to automatic differentiation).

In [3]:
solver_options = TorchDiffEqSolverOptions(
    step_size=timedelta(days=0.5),
    use_adjoint=True,
    device=device,
    method="rk4",
)

solver = TorchDiffEqSolver(model, solver_options)

---
## Step 4: Make a prediction
We are now ready to leverage our model and solver to make a prediction of tumor growth and response to treatment.

We will use the measured cellularity map from the first patient visit (`measured_cellularity_maps[0]`) as an initial condition for the model, and make a prediction from the first visit date until the final visit date: a total of `119 days`. We will output the solution every `0.5 days`.

In [4]:
timepoints = daterange(
    patient_data.visits[0].time, patient_data.visits[-1].time, timedelta(days=0.5)
)
u0 = torch.from_numpy(measured_cellularity_maps[0].array)

# times, predicted_cellularity_maps = solver.solve(timepoints=timepoints, u_initial=u0)

---
## Step 7: Calibrate the model to patient data via numerical optimization
Rather than a trial-and-error approach, we can instead leverage numerical optimization to calibrate our model parameters to better match the observed data.

We will use the Levenberg-Marquardt (LM) algorithm ([Wikipedia Link](https://en.wikipedia.org/wiki/Levenberg–Marquardt_algorithm)). This algorithm will seek to minimize the sum-of-squares difference between our predicted cellularity fields and the measured cellularity fields., i.e.,

$$
\sum_{v=0}^{\texttt{N\_visits}} \sum_{j=0}^{\texttt{N\_voxels}} (N_j(t_v) - \hat{N}_j(t_v))^2
$$
where $t_v$ is the time step corresponding to visit $v$, $N_j$ and $\hat{N}_j$ are the discretized predicted and measured (respectively) cellularity values at voxel $j$.

First we choose how many imaging visits to calibrate to by setting `n_calibration_targets`, and picking out the corresponding solutions and timepoints:

In [5]:
# How many imaging dates do we want to try and match
n_visits_calibration = 2 # *Including* the initial visit

target_timepoints = [visit.time for visit in patient_data.visits[:n_visits_calibration]]
target_solution = torch.stack(
    tuple(
        [
            torch.from_numpy(m.array)
            for m in measured_cellularity_maps[: n_visits_calibration]
        ]
    )
)

Next we create a helper function for the optimizer. This function simply takes a set of parameter values, updates the model with these parameter values, runs a forward solve, and outputs the solution at the target timepoints.

In [14]:
# from functorch import make_functional
# fmodel, params = make_functional(solver.model)

# def model(model_parameters):
#     solve.model = fmodel.bind(model_pa)
#     d, k, ct_sens = model_parameters
#     solver.model.d = torch.nn.Parameter(d)
#     solver.model.k = torch.nn.Parameter(k)
#     solver.model.chemotherapy_specifications[0].sensitivity = torch.nn.Parameter(ct_sens)
#     _, predicted_cellularity_maps = solver.solve(timepoints=timepoints, u_initial=u0)
#     return predicted_cellularity_maps

# def predict(timepoints = target_timepoints):

# initial guess for the optimizer - here we use the values currently stored in the model
# initial_parameters = torch.tensor((solver.model.d, solver.model.k, solver.model.radiotherapy_specification.alpha, solver.model.chemotherapy_specifications[0].sensitivity))
initial_parameters = torch.tensor((0.025, 0.05, 0.5))
options = LMoptions()
optim = LMoptimizer(
    solver=solver,
    timepoints = target_timepoints,
    u0 = u0,
    initial_guess=initial_parameters,
    bounds=torch.tensor(((0.0, 2), (0.0, 0.5), (0.0, 1.0))),
    y_data=target_solution,
    options=options,
)
# Run optimization for n_iter steps
n_iter = 1

for i in range(n_iter):
    print("Optimization Step: " + str(i+1) +"/"+str(n_iter))
    optim.step()

best_parameters = optim.parameters[-1]
print("Best parameters: ", best_parameters)

Initial step


Forward Simulation: [2001-09-14 00:00:00 to 2001-11-10 00:00:00 with timestep 0.50 days]:   0%|          | 0.0…

Computing jacobians at initial point
self.x requires grad?
True
True


Forward Simulation: [2001-09-14 00:00:00 to 2001-11-10 00:00:00 with timestep 0.50 days]:   0%|          | 0.0…

x.grad is None? False
Parameter containing:
tensor(0.0500, requires_grad=True)


AttributeError: 'TorchDiffEqSolver' object has no attribute 'k'

#### Visualize the optimization iterations
Let's visualize the progress of the optimizer by visualizing solutions at every `n_opt_viz` optimization steps.

In [None]:
n_opt_viz = 2
sols = []
for params in optim.parameters[::n_opt_viz]:
    predicted_cellularity_maps = update_model_and_predict(params, timepoints)
    sols.append(predicted_cellularity_maps)

Now we plot the solutions visited by the optimizer, along with the corresponding loss (sum-of-squares error) values:

In [None]:
fig, ax =plt.subplots(1,1, figsize=(5,2))
plot_calibration_iter(
    sols, carrying_capacity, timepoints, measured_cellularity_maps, patient_data, t_calibration_end=target_timepoints[-1], ax=ax
)

fig, ax =plt.subplots(1,1, figsize=(5,2))
plot_loss(torch.tensor(optim.error), ax=ax)

For this particular dataset, we know what the ground-truth parameter values are since we used them to generate the data! Of course, in a real scenario you wouldn't know what the "true" parameters are. In fact, the data might not exactly match the model for _any_ value of the parameters.

Ideally, our calibrated parameter values will be close to these, so let's compute the relative error in each parameter:

In [None]:
true_parameters = torch.tensor((0.1,0.05, 0.2))
final_parameters = optim.parameters[-1]


def relative_error_2dp(estimate, truth):
    return (100*abs(truth-estimate)/truth).round(decimals=2)


print(f"Error in d: {relative_error_2dp(final_parameters[0], true_parameters[0])}%")
print(f"Error in k: {relative_error_2dp(final_parameters[1], true_parameters[1])}%")
print(f"Error in ct_sens: {relative_error_2dp(final_parameters[2], true_parameters[2])}%")

## Step 8: Predict patient response under alternative treatment plan
Now that we have a calibrated digital twin model, we can use it to predict how this particular patient might respond to different treatment plans.

Recall that we calibrated the digital twin model to imaging visits acquired after _four weeks of neoadjuvant chemotherapy_. A remaining treatment decision might be the neoadjuvant chemotherapy dosages and schedule. We would expect that increasing the dosage will lead to greater tumor control, but note that higher dosages are also likely to lead to greater toxicity. Let's explore the tradeoff using our calibrated digital twin model!

First, we'll define a function that updates the remaining chemotherapy doses based on a given *total* chemotherapy dosage.

In [None]:
def update_ct_total_dose(ct : ChemotherapySpecification, total_dose : float):
    current_total_dose : float = np.sum(np.array(ct.doses))
    additional_dose = total_dose - current_total_dose

    adjuvant_total_dose = np.sum(ct.doses[4:])
    dose_multiplier = (additional_dose+adjuvant_total_dose) / adjuvant_total_dose
    ct.doses[4:] = [d*dose_multiplier for d in ct.doses[4:]]
    return ct

Now let's predict the tumor response for a range of total dosages, and plot the results:

In [None]:
sols = []
candidate_doses = [20,30, 40, 50, 60]

fig, ax = plt.subplots(1, 1, figsize=(5, 2))
for ct_total_dose in candidate_doses:
    update_ct_total_dose(ct, ct_total_dose)
    solver.model.chemotherapy_specifications = [ct]
    print(f"Running forward solve with total dose = {ct_total_dose}")
    times, predicted_cellularity_maps = solver.solve(timepoints=timepoints, u_initial=u0)
    plot_predicted_TCC(predicted_cellularity_maps, timepoints, ax=ax, alpha= 0.25 + 0.75*(ct_total_dose-min(candidate_doses))/(max(candidate_doses)-min(candidate_doses)))

ax.legend(["Total dose: "+str(d) for d in candidate_doses]);
plt.show()

## Conclusion 
Here we have demonstrated the core workflows of `TumorTwin`. We have shown how to load in a patient dataset, create a tumor growth model, create a solver for the model, make predictions with the model under various parameters and treatments, and calibrate the model to patient data.

## Discussion Questions
__Modeling__
- What effects could we add to the reaction-diffusion model?

__Calibration__
- How much data is needed for calibration?
- How does the timing of the imaging visits influence the calibration performance?
- Under what conditions might the calibration be unable to uniquely identify all the parameters?