<center>    
PCN520 Research Project
<hr>
<h1> Utilising Deep Learning for Individualised Quality Assurance in Radiotherapy: DOSE PREDICTION AND OPTIMISATION </h1>
<hr>
<h3> 3D U-Net Voxel-Wise Dose Prediction Model </h3>
</center>

# OpenKBP Database Download

Before running this notebook, we need to get the repo which contains the data. The download should be quick as it's a sercer-to-server process.

In [2]:
# Get the repo
repo_dir = 'open-kbp'
!git clone https://github.com/ababier/open-kbp.git {repo_dir}

fatal: destination path 'open-kbp' already exists and is not an empty directory.


In [None]:
# Add repo to path
import sys
sys.path.append(repo_dir)

!pip install -r /content/open-kbp/requirements.txt

# Step 1: **Loading the Data**

In [4]:
import pandas as pd
import numpy as np
import os
from pathlib import Path
from provided_code.batch import DataBatch
from provided_code.data_shapes import DataShapes
from provided_code.data_loader import DataLoader

In [66]:
# Load the testing dataset
testdata_dir = '/content/open-kbp/provided-data/test-pats/'
patient_paths = [Path(os.path.join(testdata_dir, patient)) for patient in os.listdir(testdata_dir) if os.path.isdir(os.path.join(testdata_dir, patient))]

# Initialize the DataLoader
data_loader = DataLoader(patient_paths=patient_paths, batch_size=1)

# Set the mode for DataLoader to load training data
data_loader.set_mode("training_model")

# Get a batch of data
batch_data = next(data_loader.get_batches())

# Print all attributes of batch_data
print("Dose shape:", batch_data.dose.shape if batch_data.dose is not None else "Dose data is None")
print("CT shape:", batch_data.ct.shape if batch_data.ct is not None else "CT data is None")
print("Structure masks shape:", batch_data.structure_masks.shape if batch_data.structure_masks is not None else "Structure masks data is None")
print("Possible dose mask shape:", batch_data.possible_dose_mask.shape if batch_data.possible_dose_mask is not None else "Possible dose mask data is None")
print("Voxel dimensions shape:", batch_data.voxel_dimensions.shape if batch_data.voxel_dimensions is not None else "Voxel dimensions data is None")
print("Patient list:", batch_data.patient_list)
print("Patient path list:", batch_data.patient_path)


0it [00:00, ?it/s]

Dose shape: (1, 128, 128, 128, 1)
CT shape: (1, 128, 128, 128, 1)
Structure masks shape: (1, 128, 128, 128, 10)
Possible dose mask shape: (1, 128, 128, 128, 1)
Voxel dimensions shape: (1, 3)
Patient list: ['pt_293']
Patient path list: None





# Step 2: **Creating the Autoencoder**

In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1),  # Adjust input channels as needed
            nn.ReLU(),
            nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.bottleneck = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(128, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Assuming output need to be in a normalized range [0, 1]
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.bottleneck(x)
        x = self.decoder(x)
        return x

# Instantiate the model
autoencoder = Autoencoder()

# Step 3: **Creating the 3D U-Net Model**

In [57]:
class UNet3D(nn.Module):
    def __init__(self):
        super(UNet3D, self).__init__()
        self.encoder1 = nn.Conv3d(1, 32, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool3d(2)

        self.bottleneck = nn.Conv3d(64, 128, kernel_size=3, padding=1)

        self.decoder1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)

        self.final = nn.Conv3d(32, 1, kernel_size=1)

    def forward(self, x):
        enc1 = F.relu(self.encoder1(x))
        enc2 = F.relu(self.encoder2(self.pool(enc1)))

        bott = F.relu(self.bottleneck(self.pool(enc2)))

        dec1 = F.relu(self.decoder1(bott))
        dec1 = dec1 + enc2  # Skip connection
        dec2 = F.relu(self.decoder2(dec1))
        dec2 = dec2 + enc1  # Skip connection

        return self.final(dec2)

# Instantiate the U-Net model
unet_model = UNet3D()

In [49]:
class DoseLoss(nn.Module):
    def __init__(self, max_dose):
        super(DoseLoss, self).__init__()
        self.max_dose = max_dose

    def forward(self, predicted_dose, true_dose, possible_dose_mask):
        # Apply possible dose mask
        predicted_dose = predicted_dose * possible_dose_mask
        true_dose = true_dose * possible_dose_mask

        # Compute mean squared error loss
        mse_loss = F.mse_loss(predicted_dose, true_dose)

        # Normalize by max dose to ensure that the dose values are in the same scale
        normalized_loss = mse_loss / self.max_dose

        return normalized_loss

# Step 4: **Training the Models**

In [50]:
# Define the loss function and optimizer for the Autoencoder
criterion_ae = DoseLoss(100.0)  # max_dose = 100.0; replace with the appropriate max dose
optimizer_ae = torch.optim.Adam(autoencoder.parameters(), lr=0.001)

# Define training function for Autoencoder
def train_autoencoder(data_loader, model, criterion, optimizer):
    model.train()
#    for batch_data in data_loader.get_batches():
#        # Let the Autoencoder uses Dose data as input
#        data_tensor = torch.tensor(batch_data.ct, dtype=torch.float32).squeeze(-1)  # Convert CT data to tensor
#        optimizer.zero_grad()
#        output = model(data_tensor)
#        loss = criterion(output, data_tensor)
#        loss.backward()
#        optimizer.step()
#        print(f"Training Loss: {loss.item()}")

    for epoch in range(20):
        running_loss = 0.0
        for i, batch_data in enumerate(data_loader.get_batches()):
            ct_scans = torch.tensor(batch_data.ct, dtype=torch.float32).squeeze(-1)
            structure_masks = torch.tensor(batch_data.structure_masks, dtype=torch.float32).squeeze(-1)
            possible_dose_masks = torch.tensor(batch_data.possible_dose_mask, dtype=torch.float32).squeeze(-1)
            true_doses = torch.tensor(batch_data.dose, dtype=torch.float32).squeeze(-1)

            optimizer.zero_grad()

            predicted_doses = model(ct_scans)

            loss = criterion(predicted_doses, true_doses, possible_dose_masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 10 == 9:
                print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 10:.3f}")
                running_loss = 0.0
    print("Finished Training")

# Example usage
train_autoencoder(data_loader, autoencoder, criterion_ae, optimizer_ae)

10it [00:13,  1.33s/it]

[1, 10] loss: 0.786


20it [00:26,  1.28s/it]

[1, 20] loss: 0.640


30it [00:39,  1.36s/it]

[1, 30] loss: 0.665


40it [00:51,  1.34s/it]

[1, 40] loss: 0.688


50it [01:02,  1.13s/it]

[1, 50] loss: 0.590


60it [01:14,  1.16s/it]

[1, 60] loss: 0.651


70it [01:27,  1.16s/it]

[1, 70] loss: 0.657


80it [01:40,  1.18s/it]

[1, 80] loss: 0.698


90it [01:52,  1.20s/it]

[1, 90] loss: 0.928


100it [02:05,  1.25s/it]


[1, 100] loss: 0.656


10it [00:13,  1.27s/it]

[2, 10] loss: 0.786


20it [00:26,  1.20s/it]

[2, 20] loss: 0.640


30it [00:38,  1.24s/it]

[2, 30] loss: 0.665


40it [00:50,  1.34s/it]

[2, 40] loss: 0.688


50it [01:02,  1.35s/it]

[2, 50] loss: 0.590


60it [01:14,  1.29s/it]

[2, 60] loss: 0.651


70it [01:26,  1.15s/it]

[2, 70] loss: 0.657


80it [01:39,  1.16s/it]

[2, 80] loss: 0.698


90it [01:52,  1.18s/it]

[2, 90] loss: 0.928


100it [02:04,  1.25s/it]


[2, 100] loss: 0.656


10it [00:12,  1.19s/it]

[3, 10] loss: 0.786


20it [00:25,  1.16s/it]

[3, 20] loss: 0.640


30it [00:37,  1.19s/it]

[3, 30] loss: 0.665


40it [00:49,  1.28s/it]

[3, 40] loss: 0.688


50it [01:01,  1.27s/it]

[3, 50] loss: 0.590


60it [01:15,  1.69s/it]

[3, 60] loss: 0.651


70it [01:27,  1.32s/it]

[3, 70] loss: 0.657


80it [01:40,  1.25s/it]

[3, 80] loss: 0.698


90it [01:53,  1.18s/it]

[3, 90] loss: 0.928


100it [02:06,  1.26s/it]


[3, 100] loss: 0.656


10it [00:12,  1.19s/it]

[4, 10] loss: 0.786


20it [00:25,  1.13s/it]

[4, 20] loss: 0.640


30it [00:37,  1.14s/it]

[4, 30] loss: 0.665


40it [00:50,  1.19s/it]

[4, 40] loss: 0.688


50it [01:02,  1.22s/it]

[4, 50] loss: 0.590


60it [01:15,  1.31s/it]

[4, 60] loss: 0.651


70it [01:27,  1.33s/it]

[4, 70] loss: 0.657


80it [01:40,  1.38s/it]

[4, 80] loss: 0.698


90it [01:53,  1.45s/it]

[4, 90] loss: 0.928


100it [02:05,  1.25s/it]


[4, 100] loss: 0.656


10it [00:13,  1.46s/it]

[5, 10] loss: 0.786


20it [00:25,  1.12s/it]

[5, 20] loss: 0.640


30it [00:38,  1.11s/it]

[5, 30] loss: 0.665


40it [00:50,  1.16s/it]

[5, 40] loss: 0.688


50it [01:02,  1.16s/it]

[5, 50] loss: 0.590


60it [01:15,  1.24s/it]

[5, 60] loss: 0.651


70it [01:27,  1.29s/it]

[5, 70] loss: 0.657


80it [01:40,  1.29s/it]

[5, 80] loss: 0.698


90it [01:52,  1.37s/it]

[5, 90] loss: 0.928


100it [02:05,  1.25s/it]


[5, 100] loss: 0.656


10it [00:12,  1.34s/it]

[6, 10] loss: 0.786


20it [00:23,  1.13s/it]

[6, 20] loss: 0.640


30it [00:35,  1.10s/it]

[6, 30] loss: 0.665


40it [00:48,  1.15s/it]

[6, 40] loss: 0.688


50it [01:00,  1.14s/it]

[6, 50] loss: 0.590


60it [01:13,  1.20s/it]

[6, 60] loss: 0.651


70it [01:26,  1.32s/it]

[6, 70] loss: 0.657


80it [01:39,  1.27s/it]

[6, 80] loss: 0.698


90it [01:52,  1.34s/it]

[6, 90] loss: 0.928


100it [02:04,  1.25s/it]


[6, 100] loss: 0.656


10it [00:12,  1.41s/it]

[7, 10] loss: 0.786


20it [00:24,  1.26s/it]

[7, 20] loss: 0.640


30it [00:37,  1.25s/it]

[7, 30] loss: 0.665


40it [00:49,  1.15s/it]

[7, 40] loss: 0.688


50it [01:01,  1.13s/it]

[7, 50] loss: 0.590


60it [01:14,  1.25s/it]

[7, 60] loss: 0.651


70it [01:27,  1.18s/it]

[7, 70] loss: 0.657


80it [01:40,  1.19s/it]

[7, 80] loss: 0.698


90it [01:53,  1.25s/it]

[7, 90] loss: 0.928


100it [02:05,  1.26s/it]


[7, 100] loss: 0.656


10it [00:12,  1.24s/it]

[8, 10] loss: 0.786


20it [00:25,  1.35s/it]

[8, 20] loss: 0.640


30it [00:38,  1.45s/it]

[8, 30] loss: 0.665


40it [00:50,  1.38s/it]

[8, 40] loss: 0.688


50it [01:02,  1.25s/it]

[8, 50] loss: 0.590


60it [01:14,  1.17s/it]

[8, 60] loss: 0.651


70it [01:28,  1.24s/it]

[8, 70] loss: 0.657


80it [01:42,  1.26s/it]

[8, 80] loss: 0.698


90it [01:55,  1.26s/it]

[8, 90] loss: 0.928


100it [02:07,  1.28s/it]


[8, 100] loss: 0.656


10it [00:12,  1.17s/it]

[9, 10] loss: 0.786


20it [00:24,  1.12s/it]

[9, 20] loss: 0.640


30it [00:37,  1.14s/it]

[9, 30] loss: 0.665


40it [00:49,  1.24s/it]

[9, 40] loss: 0.688


50it [01:02,  1.26s/it]

[9, 50] loss: 0.590


60it [01:14,  1.26s/it]

[9, 60] loss: 0.651


70it [01:28,  1.61s/it]

[9, 70] loss: 0.657


80it [01:41,  1.35s/it]

[9, 80] loss: 0.698


90it [01:54,  1.41s/it]

[9, 90] loss: 0.928


100it [02:06,  1.27s/it]


[9, 100] loss: 0.656


10it [00:12,  1.31s/it]

[10, 10] loss: 0.786


20it [00:24,  1.11s/it]

[10, 20] loss: 0.640


30it [00:36,  1.11s/it]

[10, 30] loss: 0.665


40it [00:48,  1.16s/it]

[10, 40] loss: 0.688


50it [01:00,  1.14s/it]

[10, 50] loss: 0.590


60it [01:13,  1.20s/it]

[10, 60] loss: 0.651


70it [01:26,  1.22s/it]

[10, 70] loss: 0.657


80it [01:38,  1.23s/it]

[10, 80] loss: 0.698


90it [01:51,  1.29s/it]

[10, 90] loss: 0.928


100it [02:03,  1.24s/it]


[10, 100] loss: 0.656


10it [00:12,  1.38s/it]

[11, 10] loss: 0.786


20it [00:24,  1.37s/it]

[11, 20] loss: 0.640


30it [00:38,  1.38s/it]

[11, 30] loss: 0.665


40it [00:50,  1.33s/it]

[11, 40] loss: 0.688


50it [01:02,  1.12s/it]

[11, 50] loss: 0.590


60it [01:14,  1.18s/it]

[11, 60] loss: 0.651


70it [01:27,  1.18s/it]

[11, 70] loss: 0.657


80it [01:40,  1.18s/it]

[11, 80] loss: 0.698


90it [01:52,  1.19s/it]

[11, 90] loss: 0.928


100it [02:05,  1.25s/it]


[11, 100] loss: 0.656


10it [00:12,  1.22s/it]

[12, 10] loss: 0.786


20it [00:25,  1.19s/it]

[12, 20] loss: 0.640


30it [00:37,  1.24s/it]

[12, 30] loss: 0.665


40it [00:49,  1.36s/it]

[12, 40] loss: 0.688


50it [01:02,  1.39s/it]

[12, 50] loss: 0.590


60it [01:14,  1.36s/it]

[12, 60] loss: 0.651


70it [01:26,  1.20s/it]

[12, 70] loss: 0.657


80it [01:40,  1.46s/it]

[12, 80] loss: 0.698


90it [01:52,  1.18s/it]

[12, 90] loss: 0.928


100it [02:05,  1.25s/it]


[12, 100] loss: 0.656


10it [00:12,  1.18s/it]

[13, 10] loss: 0.786


20it [00:25,  1.13s/it]

[13, 20] loss: 0.640


30it [00:37,  1.15s/it]

[13, 30] loss: 0.665


40it [00:50,  1.20s/it]

[13, 40] loss: 0.688


50it [01:02,  1.19s/it]

[13, 50] loss: 0.590


60it [01:15,  1.26s/it]

[13, 60] loss: 0.651


70it [01:27,  1.33s/it]

[13, 70] loss: 0.657


80it [01:40,  1.38s/it]

[13, 80] loss: 0.698


90it [01:52,  1.45s/it]

[13, 90] loss: 0.928


100it [02:05,  1.25s/it]


[13, 100] loss: 0.656


10it [00:12,  1.26s/it]

[14, 10] loss: 0.786


20it [00:24,  1.11s/it]

[14, 20] loss: 0.640


30it [00:37,  1.34s/it]

[14, 30] loss: 0.665


40it [00:50,  1.17s/it]

[14, 40] loss: 0.688


50it [01:02,  1.14s/it]

[14, 50] loss: 0.590


60it [01:15,  1.21s/it]

[14, 60] loss: 0.651


70it [01:27,  1.21s/it]

[14, 70] loss: 0.657


80it [01:40,  1.23s/it]

[14, 80] loss: 0.698


90it [01:52,  1.27s/it]

[14, 90] loss: 0.928


100it [02:05,  1.25s/it]


[14, 100] loss: 0.656


10it [00:12,  1.33s/it]

[15, 10] loss: 0.786


20it [00:25,  1.35s/it]

[15, 20] loss: 0.640


30it [00:37,  1.34s/it]

[15, 30] loss: 0.665


40it [00:49,  1.28s/it]

[15, 40] loss: 0.688


50it [01:01,  1.15s/it]

[15, 50] loss: 0.590


60it [01:14,  1.17s/it]

[15, 60] loss: 0.651


70it [01:26,  1.18s/it]

[15, 70] loss: 0.657


80it [01:39,  1.18s/it]

[15, 80] loss: 0.698


90it [01:53,  1.26s/it]

[15, 90] loss: 0.928


100it [02:05,  1.26s/it]


[15, 100] loss: 0.656


10it [00:12,  1.20s/it]

[16, 10] loss: 0.786


20it [00:25,  1.17s/it]

[16, 20] loss: 0.640


30it [00:37,  1.19s/it]

[16, 30] loss: 0.665


40it [00:49,  1.26s/it]

[16, 40] loss: 0.688


50it [01:01,  1.29s/it]

[16, 50] loss: 0.590


60it [01:14,  1.37s/it]

[16, 60] loss: 0.651


70it [01:26,  1.44s/it]

[16, 70] loss: 0.657


80it [01:39,  1.41s/it]

[16, 80] loss: 0.698


90it [01:51,  1.33s/it]

[16, 90] loss: 0.928


100it [02:03,  1.24s/it]


[16, 100] loss: 0.656


10it [00:12,  1.17s/it]

[17, 10] loss: 0.786


20it [00:24,  1.13s/it]

[17, 20] loss: 0.640


30it [00:37,  1.11s/it]

[17, 30] loss: 0.665


40it [00:50,  1.27s/it]

[17, 40] loss: 0.688


50it [01:03,  1.15s/it]

[17, 50] loss: 0.590


60it [01:15,  1.20s/it]

[17, 60] loss: 0.651


70it [01:28,  1.21s/it]

[17, 70] loss: 0.657


80it [01:40,  1.20s/it]

[17, 80] loss: 0.698


90it [01:53,  1.23s/it]

[17, 90] loss: 0.928


100it [02:06,  1.26s/it]


[17, 100] loss: 0.656


10it [00:12,  1.26s/it]

[18, 10] loss: 0.786


20it [00:25,  1.23s/it]

[18, 20] loss: 0.640


30it [00:37,  1.28s/it]

[18, 30] loss: 0.665


40it [00:49,  1.39s/it]

[18, 40] loss: 0.688


50it [01:01,  1.34s/it]

[18, 50] loss: 0.590


60it [01:14,  1.32s/it]

[18, 60] loss: 0.651


70it [01:26,  1.18s/it]

[18, 70] loss: 0.657


80it [01:38,  1.17s/it]

[18, 80] loss: 0.698


90it [01:51,  1.22s/it]

[18, 90] loss: 0.928


100it [02:05,  1.25s/it]


[18, 100] loss: 0.656


10it [00:12,  1.22s/it]

[19, 10] loss: 0.786


20it [00:25,  1.14s/it]

[19, 20] loss: 0.640


30it [00:37,  1.15s/it]

[19, 30] loss: 0.665


40it [00:49,  1.20s/it]

[19, 40] loss: 0.688


50it [01:02,  1.21s/it]

[19, 50] loss: 0.590


60it [01:14,  1.28s/it]

[19, 60] loss: 0.651


70it [01:27,  1.31s/it]

[19, 70] loss: 0.657


80it [01:40,  1.35s/it]

[19, 80] loss: 0.698


90it [01:52,  1.40s/it]

[19, 90] loss: 0.928


100it [02:05,  1.25s/it]


[19, 100] loss: 0.656


10it [00:12,  1.39s/it]

[20, 10] loss: 0.786


20it [00:24,  1.24s/it]

[20, 20] loss: 0.640


30it [00:36,  1.12s/it]

[20, 30] loss: 0.665


40it [00:48,  1.15s/it]

[20, 40] loss: 0.688


50it [01:02,  1.18s/it]

[20, 50] loss: 0.590


60it [01:14,  1.18s/it]

[20, 60] loss: 0.651


70it [01:26,  1.18s/it]

[20, 70] loss: 0.657


80it [01:39,  1.18s/it]

[20, 80] loss: 0.698


90it [01:52,  1.20s/it]

[20, 90] loss: 0.928


100it [02:04,  1.25s/it]

[20, 100] loss: 0.656
Finished Training





# Step 5: **Use Autoencoder Output as Input for U-Net**

In [None]:
# Define the training loop for U-Net
criterion_unet = DoseLoss(100.0)  # max_dose = 100.0; replace with the appropriate max dose
optimizer_unet = torch.optim.Adam(unet_model.parameters(), lr=0.001)

def train_unet(data_loader, unet_model, autoencoder, criterion, optimizer):
    unet_model.train()
    for epoch in range(20):
        running_loss = 0.0
        for i, batch_data in enumerate(data_loader.get_batches()):
            ct_scans = autoencoder(torch.tensor(batch_data.ct, dtype=torch.float32).squeeze(-1))
#            structure_masks = autoencoder(torch.tensor(batch_data.structure_masks, dtype=torch.float32).squeeze(-1))
            possible_dose_masks = autoencoder(torch.tensor(batch_data.possible_dose_mask, dtype=torch.float32).squeeze(-1))
            true_doses = autoencoder(torch.tensor(batch_data.dose, dtype=torch.float32).squeeze(-1))

            optimizer.zero_grad()
            predicted_doses = unet_model(ct_scans)

            loss = criterion(predicted_doses, true_doses, possible_dose_masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 10 == 9:
                print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 10:.3f}")
                running_loss = 0.0
    print("Finished Training")

# Example usage
train_unet(data_loader, unet_model, autoencoder, criterion_unet, optimizer_unet)

10it [02:14, 13.59s/it]

[1, 10] loss: 0.001


20it [04:26, 13.24s/it]

[1, 20] loss: 0.000


30it [06:41, 13.34s/it]

[1, 30] loss: 0.000


40it [08:56, 13.52s/it]

[1, 40] loss: 0.000


50it [11:09, 13.36s/it]

[1, 50] loss: 0.000


60it [13:24, 13.31s/it]

[1, 60] loss: 0.000


70it [15:37, 13.37s/it]

[1, 70] loss: 0.000


80it [17:52, 13.42s/it]

[1, 80] loss: 0.000


90it [20:06, 13.36s/it]

[1, 90] loss: 0.000


100it [22:22, 13.43s/it]


[1, 100] loss: 0.000


10it [02:17, 14.18s/it]

[2, 10] loss: 0.000


20it [04:30, 13.30s/it]

[2, 20] loss: 0.000


30it [06:42, 13.06s/it]

[2, 30] loss: 0.000


40it [08:56, 13.49s/it]

[2, 40] loss: 0.000


50it [11:09, 13.35s/it]

[2, 50] loss: 0.000


60it [13:25, 13.57s/it]

[2, 60] loss: 0.000


70it [15:40, 13.45s/it]

[2, 70] loss: 0.000


80it [17:53, 13.24s/it]

[2, 80] loss: 0.000


90it [20:08, 13.36s/it]

[2, 90] loss: 0.000


100it [22:22, 13.43s/it]


[2, 100] loss: 0.000


10it [02:12, 13.34s/it]

[3, 10] loss: 0.000


20it [04:27, 13.25s/it]

[3, 20] loss: 0.000


30it [06:40, 13.19s/it]

[3, 30] loss: 0.000


40it [08:56, 13.77s/it]

[3, 40] loss: 0.000


50it [11:06, 13.11s/it]

[3, 50] loss: 0.000


60it [13:21, 13.42s/it]

[3, 60] loss: 0.000


70it [15:36, 13.54s/it]

[3, 70] loss: 0.000


80it [17:50, 13.40s/it]

[3, 80] loss: 0.000


90it [20:04, 13.33s/it]

[3, 90] loss: 0.000


100it [22:19, 13.39s/it]


[3, 100] loss: 0.000


10it [02:14, 13.41s/it]

[4, 10] loss: 0.000


20it [04:29, 13.34s/it]

[4, 20] loss: 0.000


30it [06:43, 13.32s/it]

[4, 30] loss: 0.000


40it [08:58, 13.59s/it]

[4, 40] loss: 0.000


50it [11:11, 13.28s/it]

[4, 50] loss: 0.000


60it [13:26, 13.37s/it]

[4, 60] loss: 0.000


70it [15:38, 13.44s/it]

[4, 70] loss: 0.000


80it [17:52, 13.39s/it]

[4, 80] loss: 0.000


90it [20:08, 13.47s/it]

[4, 90] loss: 0.000


100it [22:23, 13.43s/it]


[4, 100] loss: 0.000


10it [02:15, 13.75s/it]

[5, 10] loss: 0.000


20it [04:28, 13.28s/it]

[5, 20] loss: 0.000


30it [06:42, 13.22s/it]

[5, 30] loss: 0.000


40it [08:57, 13.54s/it]

[5, 40] loss: 0.000


50it [11:11, 13.40s/it]

[5, 50] loss: 0.000


60it [13:28, 13.57s/it]

[5, 60] loss: 0.000


70it [15:42, 13.43s/it]

[5, 70] loss: 0.000


80it [17:57, 13.50s/it]

[5, 80] loss: 0.000


90it [20:11, 13.36s/it]

[5, 90] loss: 0.000


100it [22:27, 13.47s/it]


[5, 100] loss: 0.000


10it [02:14, 13.42s/it]

[6, 10] loss: 0.000


20it [04:29, 13.56s/it]

[6, 20] loss: 0.000


30it [06:42, 13.27s/it]

[6, 30] loss: 0.000


40it [08:56, 13.39s/it]

[6, 40] loss: 0.000


50it [11:11, 13.46s/it]

[6, 50] loss: 0.000


60it [13:27, 13.98s/it]

[6, 60] loss: 0.000


70it [15:41, 13.43s/it]

[6, 70] loss: 0.000


80it [17:57, 13.40s/it]

[6, 80] loss: 0.000


90it [20:12, 13.50s/it]

[6, 90] loss: 0.000


100it [22:27, 13.47s/it]


[6, 100] loss: 0.000


10it [02:13, 13.38s/it]

[7, 10] loss: 0.000


15it [03:22, 13.65s/it]

# Step 6: **Load the Validation Data**

In [48]:
# Load the validating dataset
validation_dir = Path('/content/open-kbp/provided-data/validation-pats/')
validation_patient_paths = [Path(os.path.join(validation_dir, patient)) for patient in os.listdir(validation_dir) if os.path.isdir(os.path.join(validation_dir, patient))]

# Initialize the DataLoader for validating dataset
validation_loader = DataLoader(validation_patient_paths, batch_size=1)

# Set mode to evaluation
validation_loader.set_mode('evaluation')

# Load validation data
validation_batch = next(validation_loader.get_batches())

# Print shapes of validation data
print("Dose shape:", validation_batch.dose.shape if validation_batch.dose is not None else "Dose data is None")
print("CT shape:", validation_batch.ct.shape if validation_batch.ct is not None else "CT data is None")
print("Structure masks shape:", validation_batch.structure_masks.shape)
print("Possible dose mask shape:", validation_batch.possible_dose_mask.shape)
print("Voxel dimensions shape:", validation_batch.voxel_dimensions.shape)
print("Patient list:", validation_batch.patient_list)
print("Patient path list:", validation_batch.patient_path)

0it [00:00, ?it/s]

Dose shape: (1, 128, 128, 128, 1)
CT shape: CT data is None
Structure masks shape: (1, 128, 128, 128, 10)
Possible dose mask shape: (1, 128, 128, 128, 1)
Voxel dimensions shape: (1, 3)
Patient list: ['pt_240']
Patient path list: None





# Step 7: **Define Evaluation Function**

In [25]:
def evaluate_unet(validation_loader, unet_model, autoencoder, criterion):
    unet_model.eval()
    total_loss = 0
    with torch.no_grad():
        for validation_batch in validation_loader.get_batches():
            ct_scans = torch.tensor(batch_data.ct, dtype=torch.float32).squeeze(-1)
            structure_masks = torch.tensor(batch_data.structure_masks, dtype=torch.float32).squeeze(-1)
            possible_dose_masks = torch.tensor(batch_data.possible_dose_mask, dtype=torch.float32).squeeze(-1)
            true_doses = torch.tensor(batch_data.dose, dtype=torch.float32).squeeze(-1)

            # Pass data through the autoencoder
            features = autoencoder(ct_data)

            # Run U-Net on features
            preds = unet_model(features)

            # Loss calculation
            loss = criterion(preds, dose_data)
            total_loss += loss.item()

    avg_loss = total_loss / len(data_loader.patient_paths)  # Adjust if needed
    return avg_loss


# Step 8: **Prepare the DataLoader for Validation Data**

In [26]:
from torch.utils.data import DataLoader, TensorDataset

# Example code for converting validation data
validation_inputs = []
validation_targets = []

# Assuming validation_patient_data is a dict with preprocessed tensors
for patient, data in validation_patient_data.items():
    input_tensor = torch.tensor(data['ct']).unsqueeze(0)  # Adjust key if needed
    target_tensor = torch.tensor(data['dose']).unsqueeze(0)  # Adjust key if needed

    validation_inputs.append(input_tensor)
    validation_targets.append(target_tensor)

validation_dataset = TensorDataset(torch.stack(validation_inputs), torch.stack(validation_targets))
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False)


KeyError: 'input'