In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../../")

In [None]:
from fishjaw.util import util

config = util.userconf()

In [None]:
"""Read in a small number of CT scans and the corresponding labels"""

import pathlib


dicom_dir = pathlib.Path("../../dicoms/Training set 2/")
dicom_paths = sorted(list(dicom_dir.glob("*.dcm")))

In [None]:
dicom_paths = dicom_paths[:5]

In [None]:
from tqdm import tqdm
from fishjaw.images import io

imgs, labels = zip(*[io.read_dicom(path) for path in tqdm(dicom_paths)])

In [None]:
from scipy.ndimage import center_of_mass

centroids = [tuple(map(int, center_of_mass(label))) for label in tqdm(labels)]

In [None]:
"""Plot the centroids"""

import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

for img, label, centroid, axis in zip(imgs, labels, centroids, axes):
    img_slice = img[centroid[0]]
    label_slice = label[centroid[0]]

    axis.imshow(img_slice, cmap="gray", vmin=0, vmax=255**2)
    axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)

    axis.scatter(centroid[2], centroid[1], color="red", s=20, label="Centroid")
    axis.axis("off")

In [None]:
# Since we don't really need all the information of the high-resolution images,
# we can downsample them to speed up processing.

# ndimage.zoom is faster but i think skimage.resize is more accurate
from scipy.ndimage import zoom

scale_factors = []
resized_imgs = []
resized_labels = []
scaled_centroids = []

output_size = (128, 128, 128)
for img, label, centroid in tqdm(
    zip(imgs, labels, centroids, strict=True), total=len(imgs)
):
    assert img.shape == label.shape, "Image and label must have the same shape"

    # Calculate scale factors for each dimension
    scale_factor = tuple(
        target_dim / orig_dim for orig_dim, target_dim in zip(img.shape, output_size)
    )
    scale_factors.append(scale_factor)

    # Resize image and label using zoom
    resized_img = zoom(img, scale_factor, order=3)  # Use cubic interpolation for images
    resized_label = zoom(
        label, scale_factor, order=0
    )  # Use nearest-neighbor for labels

    resized_imgs.append(resized_img)
    resized_labels.append(resized_label)

    # Rescale centroid to match the resized image
    scaled_centroid = tuple(int(c * sf) for c, sf in zip(centroid, scale_factor))
    scaled_centroids.append(scaled_centroid)

print("Scale factors:", scale_factors)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

for img, label, centroid, axis in zip(
    resized_imgs, resized_labels, scaled_centroids, axes
):
    img_slice = img[centroid[0]]
    label_slice = label[centroid[0]]

    axis.imshow(img_slice, cmap="gray")
    axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)

    axis.scatter(centroid[2], centroid[1], color="red", s=20, label="Centroid")
    axis.axis("off")

In [None]:
# Define a model arch
import torch
import torch.nn as nn
import torch.nn.functional as F


class Simple3DCNNRegressor(nn.Module):
    def __init__(self, in_channels=1):
        super(Simple3DCNNRegressor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv3d(in_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d(1),
        )
        self.regressor = nn.Sequential(
            nn.Flatten(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 3)
        )

    def forward(self, x):
        x = self.features(x)
        return self.regressor(x)

In [None]:
# Training loop
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np


class ImgDataset(Dataset):
    def __init__(self, images, centroids):
        self.data = (
            torch.tensor(np.array(images, dtype=np.float32), dtype=torch.float32)
            .unsqueeze(1)
            .to("cuda")
        )
        self.centroids = torch.tensor(np.array(centroids), dtype=torch.float32).to(
            "cuda"
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.centroids[idx]


# Initialize
model = Simple3DCNNRegressor().to("cuda")
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

train_dataset = ImgDataset(resized_imgs, scaled_centroids)
val_dataset = ImgDataset(resized_imgs, scaled_centroids)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

In [None]:
# Training loop
def train(model, train_data, val_data):
    model.train()
    train_losses, val_losses = [], []

    pbar = tqdm(range(25), total=25, desc="Training Epochs")
    for epoch in pbar:
        train_loss, val_loss = [], []

        for i, (volumes, coords) in enumerate(train_data):
            optimizer.zero_grad()
            outputs = model(volumes)
            loss = criterion(outputs, coords)
            loss.backward()

            optimizer.step()

            train_loss.append(loss.item())

        for i, (volumes, coords) in enumerate(val_data):
            with torch.no_grad():
                outputs = model(volumes)
                loss = criterion(outputs, coords)
                val_loss.append(loss.item())

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        pbar.set_postfix(train_loss=np.mean(train_loss), val_loss=np.mean(val_loss))

    return train_losses, val_losses


train_losses, val_losses = train(model, train_loader, val_loader)

In [None]:
from fishjaw.visualisation import training

fig = training.plot_losses(train_losses, val_losses)

In [None]:
# Test the model - perform some cropping
model.eval()
with torch.no_grad():
    test_volume, test_coords = dataset[0]
    test_volume = test_volume.unsqueeze(0).to("cuda")
    predicted_coords = model(test_volume)
    predicted_coords = predicted_coords.squeeze(0).detach().cpu().numpy()

print(f"train data: Predicted {predicted_coords}, True: {test_coords.numpy()}")

In [None]:
# Plot slices of the cropped jaw
fig, axis = plt.subplots(1, 1)

img_slice = resized_imgs[0][int(predicted_coords[0])]
label_slice = resized_labels[0][int(predicted_coords[0])]

axis.imshow(img_slice, cmap="gray")
axis.imshow(label_slice, cmap="afmhot_r", alpha=0.3)
axis.scatter(
    predicted_coords[2],
    predicted_coords[1],
    color="red",
    s=20,
    label="Predicted Centroid",
)
axis.axis("off")