<a href="https://colab.research.google.com/github/DhargaveAC/All4SamplesIntegrated/blob/master/PCN520.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<center>    
PCN520 Research Project
<hr>
<h1> Utilising Deep Learning for Individualised Quality Assurance in Radiotherapy: DOSE PREDICTION AND OPTIMISATION </h1>
<hr>
<h3> 3D U-Net Voxel-Wise Dose Prediction Model </h3>
</center>

In [None]:
# Get the repo
repo_dir = 'OpenKBP'
!git clone https://github.com/ababier/open-kbp.git {repo_dir}

# Add repo to path
import sys
sys.path.append(repo_dir)

Cloning into 'OpenKBP'...
remote: Enumerating objects: 4260, done.[K
remote: Counting objects: 100% (52/52), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 4260 (delta 38), reused 30 (delta 29), pack-reused 4208 (from 1)[K
Receiving objects: 100% (4260/4260), 523.57 MiB | 12.35 MiB/s, done.
Resolving deltas: 100% (1603/1603), done.
Updating files: 100% (4016/4016), done.


# Data Loading

**Preamble**: Import all the libraries here.

In [None]:
from __future__ import annotations
from pathlib import Path
from typing import Optional, Union, List, Dict, Iterator
from collections import OrderedDict

import os
import glob
import numpy as np
import pandas as pd
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from numpy.typing import NDArray
from more_itertools import windowed
from tqdm import tqdm
import matplotlib.pyplot as plt

Before we run anything, first define the paths where the provided data is stored and where the results (e.g., models, predictions) should be saved.

In [None]:
# Define project directories
primary_directory = Path(repo_dir).resolve()  # directory where everything is stored
provided_data_dir = primary_directory / "provided-data"
training_data_dir = provided_data_dir / "train-pats"
validation_data_dir = provided_data_dir / "validation-pats"
testing_data_dir = provided_data_dir / "test-pats"
results_dir = primary_directory / "results"  # where any data generated by this code (e.g., predictions, models) are stored

Define classes for data loading

In [None]:
class DataBatch:
    def __init__(
        self,
        dose: Optional[torch.Tensor] = None,
        predicted_dose: Optional[torch.Tensor] = None,
        ct: Optional[torch.Tensor] = None,
        structure_masks: Optional[torch.Tensor] = None,
        structure_mask_names: Optional[List[str]] = None,
        possible_dose_mask: Optional[torch.Tensor] = None,
        voxel_dimensions: Optional[torch.Tensor] = None,
        patient_list: Optional[List[str]] = None,
        patient_path_list: Optional[List[Path]] = None,
    ):
        self.dose = dose
        self.predicted_dose = predicted_dose
        self.ct = ct
        self.structure_masks = structure_masks
        self.structure_mask_names = structure_mask_names
        self.possible_dose_mask = possible_dose_mask
        self.voxel_dimensions = voxel_dimensions
        self.patient_list = patient_list
        self.patient_path = patient_path_list

    @classmethod
    def initialize_from_required_data(cls, data_dimensions: dict[str, torch.Tensor], batch_size: int) -> DataBatch:
        attribute_values = {}
        for data, dimensions in data_dimensions.items():
            batch_data_dimensions = (batch_size, *dimensions)
            attribute_values[data] = torch.zeros(batch_data_dimensions, dtype=torch.float32)
        return cls(**attribute_values)

    def set_values(self, data_name: str, batch_index: int, values: torch.Tensor):
        getattr(self, data_name)[batch_index] = values

    def get_index_structure_from_structure(self, structure_name: str):
        return self.structure_mask_names.index(structure_name)


class DataShapes:
    def __init__(self, num_rois):
        self.num_rois = num_rois
        self.patient_shape = (128, 128, 128)

    @property
    def dose(self) -> torch.Size:
        return torch.Size(self.patient_shape + (1,))

    @property
    def predicted_dose(self) -> torch.Size:
        return self.dose

    @property
    def ct(self) -> torch.Size:
        return torch.Size(self.patient_shape + (1,))

    @property
    def structure_masks(self) -> torch.Size:
        return torch.Size(self.patient_shape + (self.num_rois,))

    @property
    def possible_dose_mask(self) -> torch.Size:
        return torch.Size(self.patient_shape + (1,))

    @property
    def voxel_dimensions(self) -> torch.Size:
        return torch.Size([3])

    def from_data_names(self, data_names: list[str]) -> dict[str, torch.Size]:
        data_shapes = {}
        for name in data_names:
            data_shapes[name] = getattr(self, name)
        return data_shapes


def load_file(file_path: Path) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
    if file_path.stem == "voxel_dimensions":
        return torch.tensor(np.loadtxt(file_path), dtype=torch.float32)

    loaded_file_df = pd.read_csv(file_path, index_col=0)
    if loaded_file_df.shape[1] == 1:  # Check if the loaded file has only one column
        loaded_file = {
            "indices": torch.tensor(loaded_file_df.index.values, dtype=torch.int64),
            "data": torch.tensor(loaded_file_df.iloc[:, 0].values, dtype=torch.float32)  # Extract data from the single column
        }
    elif loaded_file_df.isnull().values.any():
        loaded_file = torch.tensor(np.array(loaded_file_df.index).squeeze(), dtype=torch.float32)
    else:
        loaded_file = {
            "indices": torch.tensor(loaded_file_df.index.values, dtype=torch.int64),
            "data": torch.tensor(loaded_file_df.data.values, dtype=torch.float32)
        }

    return loaded_file


class DataLoaderCustom:
    def __init__(self, patient_paths: List[Path], batch_size: int = 2):
        self.patient_paths = patient_paths
        self.batch_size = batch_size
        self.paths_by_patient_id = {patient_path.stem: patient_path for patient_path in self.patient_paths}
        self.required_files: Optional[Dict] = None
        self.mode_name: Optional[str] = None
        self.rois = dict(
            oars=["Brainstem", "SpinalCord", "RightParotid", "LeftParotid", "Esophagus", "Larynx", "Mandible"],
            targets=["PTV56", "PTV63", "PTV70"],
        )
        self.full_roi_list = sum(map(list, self.rois.values()), [])
        self.num_rois = len(self.full_roi_list)
        self.data_shapes = DataShapes(self.num_rois)

    @property
    def patient_id_list(self) -> List[str]:
        return list(self.paths_by_patient_id.keys())

    def get_batches(self) -> Iterator[DataBatch]:
        batches = windowed(self.patient_paths, n=self.batch_size, step=self.batch_size)
        complete_batches = (batch for batch in batches if None not in batch)
        for batch_paths in tqdm(complete_batches):
            yield self.prepare_data(batch_paths)

    def get_patients(self, patient_list: List[str]) -> DataBatch:
        file_paths_to_load = [self.paths_by_patient_id[patient] for patient in patient_list]
        return self.prepare_data(file_paths_to_load)

    def set_mode(self, mode: str) -> None:
        self.mode_name = mode
        if mode == "training_model":
            required_data = ["dose", "ct", "structure_masks", "possible_dose_mask", "voxel_dimensions"]
        elif mode == "predicted_dose":
            required_data = [mode]
            self._force_batch_size_one()
        elif mode == "evaluation":
            required_data = ["dose", "structure_masks", "possible_dose_mask", "voxel_dimensions"]
            self._force_batch_size_one()
        elif mode == "dose_prediction":
            required_data = ["ct", "structure_masks", "possible_dose_mask", "voxel_dimensions"]
            self._force_batch_size_one()
        else:
            raise ValueError(f"Mode `{mode}` does not exist. Mode must be either training_model, prediction, predicted_dose, or evaluation")
        self.required_files = self.data_shapes.from_data_names(required_data)

    def _force_batch_size_one(self) -> None:
        if self.batch_size != 1:
            self.batch_size = 1
            Warning("Batch size has been changed to 1 for dose prediction mode")

    def shuffle_data(self) -> None:
        np.random.shuffle(self.patient_paths)

    def prepare_data(self, file_paths_to_load: List[Path]) -> DataBatch:
        batch_data = DataBatch.initialize_from_required_data(self.required_files, self.batch_size)
        for batch_idx, path in enumerate(file_paths_to_load):
            for data_name, data in self.required_files.items():
                if data_name == "structure_masks":
                    for roi_idx, roi in enumerate(self.full_roi_list):
                        roi_file_path = Path(path / "{}.csv".format(roi))
                        if not roi_file_path.exists():
                            continue
                        file_data = load_file(roi_file_path)
                        file_data_dense = self.dense_scatter(file_data)
                        file_data_dense = file_data_dense.unsqueeze(-1)  # Ensure shape (128, 128, 128, 1)
                        batch_data.structure_masks[..., roi_idx] = file_data_dense.squeeze()
                    batch_data.structure_mask_names = self.full_roi_list
                else:
                    file_path = Path(path / "{}.csv".format(data_name))
                    if not file_path.exists():
                        continue
                    file_data = load_file(file_path)
                    file_data_dense = self.dense_scatter(file_data)
                    file_data_dense = file_data_dense.unsqueeze(-1)  # Ensure shape (128, 128, 128, 1)
                    batch_data.set_values(data_name, batch_idx, file_data_dense)
            batch_data.patient_list = [os.path.basename(path) for path in file_paths_to_load]
            batch_data.patient_path_list = [path for path in file_paths_to_load]
        return batch_data


    def dense_scatter(self, sparse_vector: dict[str, torch.Tensor]) -> torch.Tensor:
        indices = sparse_vector["indices"].long()
        print(f"indices shape: {indices.shape}")  # Debugging statement
        print(f"indices: {indices}")  # Debugging statement
        flat_size = torch.prod(torch.tensor(self.data_shapes.patient_shape))
        if torch.any(indices >= flat_size) or torch.any(indices < 0):
           raise IndexError("Indices are out of bounds for the dense array shape.")
        dense_data = torch.zeros(self.data_shapes.patient_shape, dtype=torch.float32)
        dense_data = dense_data.view(-1)
        dense_data[indices] = sparse_vector["data"]
        return dense_data.view(self.data_shapes.patient_shape)

    def sparse_scatter(self, data: torch.Tensor, sparse: bool = True) -> dict[str, torch.Tensor]:
        if sparse:
            data = sparse_vector_function(data)
        else:
            data = data.view(-1)
        return data


In [None]:
# Data loading and preprocessing
def load_data(data_dir: str):
    data_dir = Path(data_dir)
    patient_dirs = [x for x in data_dir.iterdir() if x.is_dir()]

    data_loader = DataLoaderCustom(patient_dirs)
    data_loader.set_mode("training_model")

    return data_loader

# 3D U-Net Model (Voxel-Wise Dose Prediction)

In [None]:
class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels, init_features=32):
        super(UNet3D, self).__init__()

        features = init_features
        self.encoder1 = UNet3D._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder2 = UNet3D._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder3 = UNet3D._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder4 = UNet3D._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)

        self.bottleneck = UNet3D._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose3d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet3D._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose3d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet3D._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose3d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet3D._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose3d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet3D._block(features * 2, features, name="dec1")

        self.conv = nn.Conv3d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv3d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm3d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv3d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm3d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )


class DoseLoss(nn.Module):
    def __init__(self, max_dose):
        super(DoseLoss, self).__init__()
        self.max_dose = max_dose

    def forward(self, predicted_dose, true_dose, possible_dose_mask):
        # Apply possible dose mask
        predicted_dose = predicted_dose * possible_dose_mask
        true_dose = true_dose * possible_dose_mask

        # Compute mean squared error loss
        mse_loss = F.mse_loss(predicted_dose, true_dose)

        # Normalize by max dose to ensure that the dose values are in the same scale
        normalized_loss = mse_loss / self.max_dose

        return normalized_loss


def train_model(
    model, data_loader, num_epochs: int, learning_rate: float, max_dose: float
):
    criterion = DoseLoss(max_dose)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data_batch in enumerate(data_loader.get_batches()):
            ct_scans = torch.tensor(data_batch.ct, dtype=torch.float32)
            structure_masks = torch.tensor(data_batch.structure_masks, dtype=torch.float32)
            possible_dose_masks = torch.tensor(data_batch.possible_dose_mask, dtype=torch.float32)
            true_doses = torch.tensor(data_batch.dose, dtype=torch.float32)

            optimizer.zero_grad()

            predicted_doses = model(ct_scans)

            loss = criterion(predicted_doses, true_doses, possible_dose_masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 10 == 9:
                print(
                    f"[{epoch + 1}, {i + 1}] loss: {running_loss / 10:.3f}"
                )
                running_loss = 0.0

    print("Finished Training")


if __name__ == "__main__":
    data_dir = "/content/OpenKBP/provided-data/train-pats"  # Replace with your data directory path
    data_loader = load_data(data_dir)

    # Initialize model
    in_channels = 1
    out_channels = 1
    model = UNet3D(in_channels, out_channels)

    # Training parameters
    num_epochs = 20
    learning_rate = 0.001
    max_dose = 100.0  # Replace with the appropriate max dose for your dataset

    # Train model
    train_model(model, data_loader, num_epochs, learning_rate, max_dose)


0it [00:00, ?it/s]

indices shape: torch.Size([69403])
indices: tensor([ 581179,  581307,  581435,  ..., 1499689, 1499816, 1499944])
indices shape: torch.Size([86095])
indices: tensor([ 532027,  532155,  532283,  ..., 1532456, 1532457, 1532458])
indices shape: torch.Size([683])
indices: tensor([1040157, 1040285, 1040413, 1040541, 1040669, 1056413, 1056415, 1056416,
        1056540, 1056541, 1056542, 1056543, 1056544, 1056545, 1056546, 1056668,
        1056669, 1056670, 1056671, 1056672, 1056673, 1056674, 1056675, 1056796,
        1056797, 1056798, 1056799, 1056800, 1056801, 1056802, 1056803, 1056924,
        1056925, 1056926, 1056927, 1056928, 1056929, 1056930, 1056931, 1057052,
        1057053, 1057054, 1057055, 1057056, 1057057, 1057058, 1057059, 1057181,
        1057182, 1057183, 1057184, 1057185, 1057310, 1057311, 1072671, 1072674,
        1072795, 1072797, 1072798, 1072799, 1072800, 1072801, 1072802, 1072803,
        1072923, 1072924, 1072925, 1072926, 1072927, 1072928, 1072929, 1072930,
        1072

0it [00:00, ?it/s]

indices shape: torch.Size([4858])
indices: tensor([ 810803,  810931,  810932,  ..., 1255886, 1255887, 1255888])
indices shape: torch.Size([3884])
indices: tensor([ 843851,  843978,  843979,  ..., 1107909, 1107910, 1108037])
indices shape: torch.Size([69403])
indices: tensor([ 581179,  581307,  581435,  ..., 1499689, 1499816, 1499944])





IndexError: too many indices for tensor of dimension 1