In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../../")

In [None]:
from fishjaw.util import util

config = util.userconf()

In [None]:
"""Read in a small number of CT scans and the corresponding labels"""

import pathlib


dicom_dir = pathlib.Path("../../dicoms/Training set 2/")
dicom_paths = sorted(list(dicom_dir.glob("*.dcm")))
print(len(dicom_paths))

In [None]:
train_paths, val_paths, test_paths = (
    dicom_paths[:15],
    dicom_paths[15:18],
    dicom_paths[18:],
)

In [None]:
from tqdm import tqdm
from fishjaw.images import io

train_imgs, train_labels = zip(*[io.read_dicom(path) for path in tqdm(train_paths)])
test_imgs, test_labels = zip(*[io.read_dicom(path) for path in tqdm(test_paths)])
val_imgs, val_labels = zip(*[io.read_dicom(path) for path in tqdm(val_paths)])

In [None]:
from scipy.ndimage import center_of_mass

train_centroids = [
    tuple(map(int, center_of_mass(label))) for label in tqdm(train_labels)
]
val_centroids = [tuple(map(int, center_of_mass(label))) for label in tqdm(val_labels)]
test_centroids = [tuple(map(int, center_of_mass(label))) for label in tqdm(test_labels)]

In [None]:
"""Plot the centroids"""

import matplotlib.pyplot as plt

fig, axis = plt.subplots(figsize=(6, 6))

img_slice = train_imgs[0][train_centroids[0][0]]
label_slice = train_labels[0][train_centroids[0][0]]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)

axis.scatter(
    train_centroids[0][2], train_centroids[0][1], color="red", s=20, label="Centroid"
)
axis.axis("off")

In [None]:
# Since we don't really need all the information of the high-resolution images,
# we can downsample them to speed up processing.
import numpy as np

# ndimage.zoom is faster but i think skimage.resize is more accurate
from scipy.ndimage import zoom

output_size = (512, 128, 128)


def resize_image_and_label(imgs, labels, centroids, target_shape):
    scale_factors = []
    resized_imgs = []
    resized_labels = []
    scaled_centroids = []

    for img, label, centroid in tqdm(
        zip(imgs, labels, centroids, strict=True), total=len(imgs)
    ):
        assert img.shape == label.shape, "Image and label must have the same shape"

        # Calculate scale factors for each dimension
        scale_factor = tuple(
            target_dim / orig_dim
            for orig_dim, target_dim in zip(img.shape, output_size)
        )
        scale_factors.append(scale_factor)

        # Resize image and label using zoom
        resized_img = zoom(
            img, scale_factor, order=3
        )  # Use cubic interpolation for images
        resized_label = zoom(
            label, scale_factor, order=0
        )  # Use nearest-neighbor for labels

        resized_imgs.append(resized_img)
        resized_labels.append(resized_label)

        # Rescale centroid to match the resized image
        scaled_centroid = tuple(int(c * sf) for c, sf in zip(centroid, scale_factor))
        scaled_centroids.append(scaled_centroid)

    print(np.mean(scale_factors, axis=0), np.std(scale_factors, axis=0))

    return resized_imgs, resized_labels, scaled_centroids, scale_factors


train_imgs, train_labels, train_centroids, _ = resize_image_and_label(
    train_imgs, train_labels, train_centroids, output_size
)
val_imgs, val_labels, val_centroids, _ = resize_image_and_label(
    val_imgs, val_labels, val_centroids, output_size
)

# Keep the test set at original resolution for evaluation
resized_test_imgs, resized_test_labels, resized_test_centroids, test_scale_factors = (
    resize_image_and_label(test_imgs, test_labels, test_centroids, output_size)
)

In [None]:
fig, axis = plt.subplots(figsize=(6, 6))

img_slice = resized_test_imgs[0][resized_test_centroids[0][0]]
label_slice = resized_test_labels[0][resized_test_centroids[0][0]]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)

axis.scatter(
    resized_test_centroids[0][2],
    resized_test_centroids[0][1],
    color="red",
    s=20,
    label="Centroid",
)
axis.axis("off")

In [None]:
# Define a model arch

from monai.networks.nets import AttentionUnet

model = AttentionUnet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    strides=(2, 2, 2),
    channels=(4, 8, 16, 32),
    dropout=0.05,
).to("cuda")

In [None]:
# Training loop
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

from scipy.ndimage import gaussian_filter


class ImgDataset(Dataset):
    def __init__(self, images, centroids, sigma):
        self.data = torch.tensor(
            np.array(images, dtype=np.float32), dtype=torch.float32
        ).unsqueeze(1)

        self.centroids = []
        for img, centroid in zip(images, centroids):
            gaussian_mask = np.zeros_like(img, dtype=np.float32)
            gaussian_mask[centroid[0], centroid[1], centroid[2]] = 1  # Set the centroid
            gaussian_mask = gaussian_filter(
                gaussian_mask, sigma=sigma
            )  # Apply Gaussian filter

            # Normalise to sum to 1 to make it a valid probability distribution
            gaussian_mask /= np.sum(gaussian_mask) if np.sum(gaussian_mask) > 0 else 1

            self.centroids.append(gaussian_mask)

        self.centroids = torch.tensor(
            np.array(self.centroids), dtype=torch.float32
        ).unsqueeze(1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].to("cuda"), self.centroids[idx].to("cuda")


# Initialize
sigma = 5
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train_dataset = ImgDataset(train_imgs, train_centroids, sigma=sigma)
val_dataset = ImgDataset(val_imgs, val_centroids, sigma=sigma)

batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Plot the first thing in the dataloader


def plot_sample(img, centroid):
    # Find the index of the Z slice where the centroid has the highest sum
    centre = torch.argmax(centroid[0][0].sum(dim=(1, 2))).item()

    fig, axis = plt.subplots(figsize=(6, 6))

    img_slice = img[0][0][centre].cpu().numpy()
    centroid_slice = centroid[0][0][centre].cpu().numpy()
    axis.imshow(img_slice, cmap="gray")
    axis.imshow(centroid_slice, cmap="afmhot_r", alpha=0.3)

    axis.axis("off")


plot_sample(*next(iter(train_loader)))

In [None]:
# Training loop
import torch.nn.functional as F


def train(model, train_data, val_data, n_epochs):
    model.train()
    train_losses, val_losses = [], []

    pbar = tqdm(range(n_epochs), total=n_epochs, desc="Training Epochs")
    for epoch in pbar:
        train_loss, val_loss = [], []

        for i, (volumes, coords) in enumerate(train_data):
            optimizer.zero_grad()
            outputs = model(volumes)
            loss = F.kl_div(
                F.softmax(outputs.view(-1), dim=0).log(),
                coords.view(-1),
                reduction="batchmean",
            )
            loss.backward()

            optimizer.step()

            train_loss.append(loss.item())

        for i, (volumes, coords) in enumerate(val_data):
            with torch.no_grad():
                outputs = model(volumes)
                loss = criterion(outputs, coords)
                val_loss.append(loss.item())

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        pbar.set_postfix(train_loss=np.mean(train_loss), val_loss=np.mean(val_loss))

    return train_losses, val_losses


train_losses, val_losses = train(model, train_loader, val_loader, 10)

In [None]:
from fishjaw.visualisation import training

fig = training.plot_losses(train_losses, val_losses)

In [None]:
# Test the model - visualise the heatmap
model.eval()
with torch.no_grad():
    test_vol = (
        torch.tensor(
            np.array(resized_test_imgs[0], dtype=np.float32), dtype=torch.float32
        )
        .unsqueeze(0)
        .unsqueeze(0)
        .to("cuda")
    )

    predicted_heatmap = model(test_vol)

    plot_sample(test_vol, predicted_heatmap)

In [None]:
# Plot slices of the cropped jaw
fig, axis = plt.subplots(1, 1)

img_slice = resized_test_imgs[0][int(predicted_coords[0])]
label_slice = resized_test_labels[0][int(predicted_coords[0])]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)
axis.scatter(
    predicted_coords[2],
    predicted_coords[1],
    color="red",
    s=20,
    label="Predicted Centroid",
)
axis.axis("off")

In [None]:
# Scale the prediction back up
scaled_prediction = tuple(
    int(coord / scale_factor)
    for coord, scale_factor in zip(predicted_coords, test_scale_factors[0])
)
print(f"Scaled prediction: {scaled_prediction}")

fig, axis = plt.subplots(1, 1)

img_slice = test_imgs[0][int(scaled_prediction[0])]
label_slice = test_labels[0][int(scaled_prediction[0])]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)
axis.scatter(
    scaled_prediction[2],
    scaled_prediction[1],
    color="red",
    s=20,
    label="Predicted Centroid",
)
axis.axis("off")