In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../../")

In [None]:
from fishjaw.util import util

config = util.userconf()

In [None]:
"""Read in a small number of CT scans and the corresponding labels"""

import pathlib


dicom_dir = pathlib.Path("../../dicoms/Training set 2/")
dicom_paths = sorted(list(dicom_dir.glob("*.dcm")))
print(len(dicom_paths))

In [None]:
train_paths, val_paths, test_paths = (
    dicom_paths[:15],
    dicom_paths[15:18],
    dicom_paths[18:],
)

In [None]:
from tqdm import tqdm
from fishjaw.images import io

train_imgs, train_labels = zip(*[io.read_dicom(path) for path in tqdm(train_paths)])
test_imgs, test_labels = zip(*[io.read_dicom(path) for path in tqdm(test_paths)])
val_imgs, val_labels = zip(*[io.read_dicom(path) for path in tqdm(val_paths)])

In [None]:
from scipy.ndimage import center_of_mass

train_centroids = [
    tuple(map(int, center_of_mass(label))) for label in tqdm(train_labels)
]
val_centroids = [tuple(map(int, center_of_mass(label))) for label in tqdm(val_labels)]
test_centroids = [tuple(map(int, center_of_mass(label))) for label in tqdm(test_labels)]

In [None]:
"""Plot the centroids"""

import matplotlib.pyplot as plt

fig, axes = plt.subplots(figsize=(6, 6))

img_slice = img[centroid[0]]
label_slice = label[centroid[0]]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)

axis.scatter(centroid[2], centroid[1], color="red", s=20, label="Centroid")
axis.axis("off")

In [None]:
# Since we don't really need all the information of the high-resolution images,
# we can downsample them to speed up processing.

# ndimage.zoom is faster but i think skimage.resize is more accurate
from scipy.ndimage import zoom

output_size = (128, 128, 128)


def resize_image_and_label(imgs, labels, centroids, target_shape):
    scale_factors = []
    resized_imgs = []
    resized_labels = []
    scaled_centroids = []

    for img, label, centroid in tqdm(
        zip(imgs, labels, centroids, strict=True), total=len(imgs)
    ):
        assert img.shape == label.shape, "Image and label must have the same shape"

        # Calculate scale factors for each dimension
        scale_factor = tuple(
            target_dim / orig_dim
            for orig_dim, target_dim in zip(img.shape, output_size)
        )
        scale_factors.append(scale_factor)

        # Resize image and label using zoom
        resized_img = zoom(
            img, scale_factor, order=3
        )  # Use cubic interpolation for images
        resized_label = zoom(
            label, scale_factor, order=0
        )  # Use nearest-neighbor for labels

        resized_imgs.append(resized_img)
        resized_labels.append(resized_label)

        # Rescale centroid to match the resized image
        scaled_centroid = tuple(int(c * sf) for c, sf in zip(centroid, scale_factor))
        scaled_centroids.append(scaled_centroid)

    return resized_imgs, resized_labels, scaled_centroids, scale_factors


train_imgs, train_labels, train_centroids, _ = resize_image_and_label(
    train_imgs, train_labels, train_centroids, output_size
)
val_imgs, val_labels, val_centroids, _ = resize_image_and_label(
    val_imgs, val_labels, val_centroids, output_size
)

# Keep the test set at original resolution for evaluation
resized_test_imgs, resized_test_labels, resized_test_centroids, test_scale_factors = (
    resize_image_and_label(test_imgs, test_labels, test_centroids, output_size)
)

In [None]:
fig, axes = plt.subplots(figsize=(6, 6))

img_slice = resized_test_imgs[0][resized_test_centroids[0][0]]
label_slice = resized_test_imgs[0][resized_test_centroids[0][0]]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)

axis.scatter(
    resized_test_centroid[0][2],
    resized_test_centroid[0][1],
    color="red",
    s=20,
    label="Centroid",
)
axis.axis("off")

In [None]:
# Define a model arch
import torch
import torch.nn as nn
import torch.nn.functional as F


class Simple3DCNNRegressor(nn.Module):
    def __init__(self, in_channels=1):
        super(Simple3DCNNRegressor, self).__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv3d(in_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv3d(out_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_c),
                nn.ReLU(inplace=True),
                nn.MaxPool3d(2),
            )

        self.features = nn.Sequential(
            conv_block(in_channels, 32),  # -> [B, 32, D/2, H/2, W/2]
            conv_block(32, 64),  # -> [B, 64, D/4, H/4, W/4]
            conv_block(64, 128),  # -> [B, 128, D/8, H/8, W/8]
            conv_block(128, 256),  # -> [B, 256, D/16, H/16, W/16]
            nn.AdaptiveAvgPool3d(1),  # -> [B, 256, 1, 1, 1]
        )

        self.regressor = nn.Sequential(
            nn.Flatten(),  # [B, 256]
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 3),  # Output: [X, Y, Z]
        )

    def forward(self, x):
        x = self.features(x)
        return self.regressor(x)

In [None]:
# Training loop
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np


class ImgDataset(Dataset):
    def __init__(self, images, centroids):
        self.data = (
            torch.tensor(np.array(images, dtype=np.float32), dtype=torch.float32)
            .unsqueeze(1)
            .to("cuda")
        )
        self.centroids = torch.tensor(np.array(centroids), dtype=torch.float32).to(
            "cuda"
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.centroids[idx]


# Initialize
model = Simple3DCNNRegressor().to("cuda")
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

train_dataset = ImgDataset(train_imgs, train_centroids)
val_dataset = ImgDataset(val_imgs, val_centroids)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

In [None]:
# Training loop
def train(model, train_data, val_data, n_epochs):
    model.train()
    train_losses, val_losses = [], []

    pbar = tqdm(range(n_epochs), total=n_epochs, desc="Training Epochs")
    for epoch in pbar:
        train_loss, val_loss = [], []

        for i, (volumes, coords) in enumerate(train_data):
            optimizer.zero_grad()
            outputs = model(volumes)
            loss = criterion(outputs, coords)
            loss.backward()

            optimizer.step()

            train_loss.append(loss.item())

        for i, (volumes, coords) in enumerate(val_data):
            with torch.no_grad():
                outputs = model(volumes)
                loss = criterion(outputs, coords)
                val_loss.append(loss.item())

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        pbar.set_postfix(train_loss=np.mean(train_loss), val_loss=np.mean(val_loss))

    return train_losses, val_losses


train_losses, val_losses = train(model, train_loader, val_loader, 100)

Training Epochs:  13%|█▎        | 13/100 [00:11<01:20,  1.08it/s, train_loss=6.33e+3, val_loss=5.73e+3]

In [None]:
from fishjaw.visualisation import training

fig = training.plot_losses(train_losses, val_losses)
plt.ylim(0, 1000)

In [None]:
# Test the model - perform some cropping
model.eval()
with torch.no_grad():
    test_vol = (
        torch.tensor(
            np.array(resized_test_imgs[0], dtype=np.float32), dtype=torch.float32
        )
        .unsqueeze(0)
        .unsqueeze(0)
        .to("cuda")
    )

    predicted_coords = model(test_vol).squeeze(0).detach().cpu().numpy()

print(f"train data: Predicted {predicted_coords}, True: {resized_test_centroids[0]}")

In [None]:
# Plot slices of the cropped jaw
fig, axis = plt.subplots(1, 1)

img_slice = resized_test_imgs[0][int(predicted_coords[0])]
label_slice = resized_test_labels[0][int(predicted_coords[0])]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)
axis.scatter(
    predicted_coords[2],
    predicted_coords[1],
    color="red",
    s=20,
    label="Predicted Centroid",
)
axis.axis("off")

In [None]:
# Scale the prediction back up
scaled_prediction = tuple(
    int(coord / scale_factor)
    for coord, scale_factor in zip(predicted_coords, test_scale_factors[0])
)
print(f"Scaled prediction: {scaled_prediction}")

fig, axis = plt.subplots(1, 1)

img_slice = test_imgs[0][int(scaled_prediction[0])]
label_slice = test_labels[0][int(scaled_prediction[0])]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)
axis.scatter(
    scaled_prediction[2],
    scaled_prediction[1],
    color="red",
    s=20,
    label="Predicted Centroid",
)
axis.axis("off")