In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from pydose_rt.data import Patient, MachineConfig, BeamSequence
from pydose_rt import DoseEngine
from IPython.display import clear_output
from pydose_rt.objectives.losses import dvh_percentile_objective
from pydose_rt.utils.plotting import overlay_mask_outline

In [None]:
raw_data = np.load("../example_data/water_patient.npz")["structures"]
["CTV", "PTV", "FemoralHead_L", "FemoralHead_R", "Bladder", "Rectum", "External"]
ct_image = torch.from_numpy(np.where(raw_data != 0.0, 0.0, -1000.0))
external = torch.from_numpy(raw_data > 0.0)
rectum = torch.from_numpy(raw_data == 2.0)
bladder = torch.from_numpy(raw_data == 3.0)
femoralhead_r = torch.from_numpy(raw_data == 4.0)
femoralhead_l = torch.from_numpy(raw_data == 5.0)
ctv = torch.from_numpy(raw_data == 7.0)
ptv = torch.from_numpy(raw_data == 6.0) + ctv

patient = Patient(ct_tensor = ct_image, resolution=(2.0, 2.0, 2.0))
patient.add_mask("External", external)
patient.add_mask("Rectum", rectum)
patient.add_mask("Bladder", bladder)
patient.add_mask("FemoralHead_L", femoralhead_l)
patient.add_mask("FemoralHead_R", femoralhead_r)
patient.add_mask("PTV", ptv)
patient.add_mask("CTV", ctv)

colors = {"External": "orange",
          "Rectum": "purple",
          "Bladder": "yellow",
          "FemoralHead_L": "green",
          "FemoralHead_R": "blue",
          "PTV": "black",
          "CTV": "red"}

In [None]:
slice_idx = 63
plt.imshow(patient.density_image.cpu().detach().numpy()[slice_idx, :, :], cmap='gray')
for struct_name in patient.structures.keys():
    overlay_mask_outline(patient.structures[struct_name].cpu().detach().numpy()[slice_idx, :, :], color=colors[struct_name])

In [None]:

device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
number_of_leaf_pairs = 60

machine_config = MachineConfig(
    tpr_20_10=0.739,
    mean_photon_energy_MeV=0.00057099,
    number_of_leaf_pairs=number_of_leaf_pairs,
    )
number_of_beams = 30
gantry_angles = torch.from_numpy(np.linspace(-170, 170, number_of_beams, endpoint=False))
field_size = (400, 400)
iso_center = (100.0, 200.0, 200.0)
collimator_angles = torch.from_numpy(np.array([0.0 for _ in range(number_of_beams)]))
sid = 1000.0
open_field_size = 0.0
kernel_size = 5

beam_sequence = BeamSequence.create(gantry_angles,
                                    number_of_leaf_pairs,
                                    field_size,
                                    iso_center,
                                    collimator_angles,
                                    sid,
                                    open_field_size,
                                    device,
                                    dtype,
                                    True)
patient = patient.to(device).to(dtype)

engine = DoseEngine(
    machine_config=machine_config,
    dose_grid_spacing=patient.resolution,
    dose_grid_shape=patient.density_image.shape,
    beam_template=beam_sequence,
    kernel_size=kernel_size, 
    adjust_values=False,
    dtype=dtype, 
    device=device
)
# engine.fluence_map_layer.training_sharpness = 1.0
engine.train()

ct_volume = patient.density_image.unsqueeze(0)


In [None]:
slice_idx = 63
dose_max = 47
def plot_progress(epoch, raw_losses, dose_pred):
    plt.figure(figsize=(20, 6))  # Make figure span screen
    plt.suptitle(f'Epoch {epoch}: {[str(np.round(loss.item(), 3)) for loss in raw_losses]}', fontsize=16)
    
    plt.subplot(1,2,1)
    plt.axis('off')
    plt.imshow(patient.density_image.cpu().detach().numpy()[slice_idx, :, :], cmap='gray')
    plt.imshow(dose_pred.cpu().detach().numpy()[slice_idx, :, :], vmin=0.0, vmax=dose_max, cmap='jet', alpha=0.8)
    for struct_name in patient.structures.keys():
        overlay_mask_outline(patient.structures[struct_name].cpu().detach().numpy()[slice_idx, :, :], color=colors[struct_name])
    
    plt.subplot(1,2,2)
    for struct_name in patient.structures.keys():
        roi = patient.structures[struct_name]
        dose_values = dose_pred[roi > 0.0].cpu().detach().numpy()
        if dose_values.size == 0:
            continue
        bins = np.linspace(0, dose_max, 1000)
        hist, bin_edges = np.histogram(dose_values, bins=bins, density=False)
        cumulative_hist = np.cumsum(hist[::-1])[::-1]
        cumulative_hist_normalized = np.divide(cumulative_hist, cumulative_hist.max())
        plt.plot(bin_edges[:-1], cumulative_hist_normalized, linestyle="solid", label=struct_name, color=colors[struct_name])
    
    plt.show()

In [None]:
patience = 0
epoch = 0
lr = 5.0
num_epochs = 2000
lr_decay = 1e-4
optimizer = torch.optim.AdamW(
    beam_sequence.parameters(), 
    lr=lr,
    weight_decay=1e-4,
    )

last_raw_losses = None
last_dose_pred = None

def closure():
    global last_raw_losses, last_dose_pred

    optimizer.zero_grad(set_to_none=True)
    
    # Forward
    dose_pred = engine.compute_dose(
        beam_sequence,
        density_image=patient.density_image.unsqueeze(0)
    )
    dose_pred = torch.where(patient.structures["External"], dose_pred[0], 0.0) * 7
    raw_losses = []

    raw_losses.append(100.0*torch.mean(torch.abs(dose_pred[patient.structures["PTV"]] - 42.7)**2))
    raw_losses.append(dvh_percentile_objective(dose_pred, patient.structures["FemoralHead_L"], 20))
    raw_losses.append(dvh_percentile_objective(dose_pred, patient.structures["FemoralHead_R"], 20))
    raw_losses.append(10.0 * torch.mean(torch.abs(dose_pred[patient.structures["External"]])**2))

    loss = torch.stack(raw_losses).sum()
    
    # Backprop
    loss.backward()

    # Store for inspection outside
    last_raw_losses = [l.detach().clone() for l in raw_losses]
    last_dose_pred = dose_pred.detach().clone()

    return loss



In [None]:

for i in range(num_epochs):
    x = optimizer.step(closure) # loss, raw_losses, dose_pred

    plot_progress(i, last_raw_losses, last_dose_pred)
    clear_output(wait=True)