In [48]:
%pip install nibabel matplotlib numpy torch

Looking in indexes: https://pypi.org/simple, https://aws:****@annaliseai-274616382064.d.codeartifact.ap-southeast-2.amazonaws.com/pypi/pypi/simple/
Note: you may need to restart the kernel to use updated packages.


In [49]:
import numpy as np
import torch


class ToTensor:
    """Convert images in sample to Tensors"""

    def __call__(
        self, sample: tuple[np.ndarray, np.ndarray]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        src, tgt = sample
        return torch.from_numpy(src).float(), torch.from_numpy(tgt).float()


class RandomCrop3D:
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size, output_size)
        else:
            assert len(output_size) == 3
            self.output_size = output_size

    def __call__(self, sample):
        image, label = sample[0], sample[1]

        h, w, d = image.shape[:3]
        new_h, new_w, new_d = self.output_size

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)
        depth = np.random.randint(0, d - new_d + 1)

        image = image[top : top + new_h, left : left + new_w, depth : depth + new_d]
        label = label[top : top + new_h, left : left + new_w, depth : depth + new_d]

        return image, label

In [50]:
from torch.utils.data.dataset import Dataset
import os
import nibabel as nib


class NiftiDataset(Dataset):
    def __init__(self, source_dir: str, target_dir: str, transforms=None):
        self.source_images = [
            self._get_image_data("t1", image)
            for image in sorted(os.listdir(source_dir))
        ]
        self.target_images = [
            self._get_image_data("t2", image)
            for image in sorted(os.listdir(target_dir))
        ]
        self.transforms = transforms

    def __len__(self):
        return len(self.source_images)

    def _get_image_data(self, image_type: str, study_id: str):
        image = nib.load(f"../small/{image_type}/{study_id}")
        return image.get_fdata()

    def __getitem__(self, idx: int):
        image_data = (self.source_images[idx], self.target_images[idx])

        if self.transforms:
            image_data = self.transforms(image_data)

        return image_data

In [51]:
from torchvision.transforms import Compose

output_size = (32, 32, 32)
train_data = NiftiDataset(
    source_dir="../small/t1",
    target_dir="../small/t2",
    transforms=Compose([RandomCrop3D(output_size=output_size), ToTensor()]),
)

In [52]:
print(train_data[0][0].shape)
print(train_data[0][1].shape)

torch.Size([32, 32, 32])
torch.Size([32, 32, 32])


In [129]:
from torch import Tensor, nn
import torch.nn.functional as F


def conv(in_channels: int, out_channels: int):
    return (
        nn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1,
            bias=False,
        ),
        nn.BatchNorm3d(num_features=out_channels),
        nn.ReLU(inplace=True),
    )


def double_conv(in_channels: int, intermediate_channels: int, out_channels: int):
    return nn.Sequential(
        *conv(in_channels=in_channels, out_channels=intermediate_channels),
        *conv(in_channels=intermediate_channels, out_channels=out_channels)
    )


class UNet(nn.Module):
    def __init__(self, image_size: tuple[int, int, int]):
        super().__init__()

        # input shape = (1, image_size[0], image_size[1], image_size[2])

        self.start = double_conv(
            in_channels=1,
            intermediate_channels=image_size[0],
            out_channels=image_size[0],
        )
        self.down1 = double_conv(
            in_channels=image_size[0],
            intermediate_channels=image_size[0] * 2,
            out_channels=image_size[0] * 2,
        )
        self.down2 = double_conv(
            in_channels=image_size[0] * 2,
            intermediate_channels=image_size[0] * 4,
            out_channels=image_size[0] * 4,
        )
        self.bridge = double_conv(
            in_channels=image_size[0] * 4,
            intermediate_channels=image_size[0] * 8,
            out_channels=image_size[0] * 4,
        )
        self.up2 = double_conv(
            in_channels=image_size[0] * 8,
            intermediate_channels=image_size[0] * 4,
            out_channels=image_size[0] * 2,
        )
        self.up1 = double_conv(
            in_channels=image_size[0] * 4,
            intermediate_channels=image_size[0] * 2,
            out_channels=image_size[0],
        )
        self.final = nn.Sequential(
            *conv(in_channels=image_size[0] * 2, out_channels=image_size[0]),
            nn.Conv3d(in_channels=image_size[0], out_channels=1, kernel_size=1)
        )

    def forward(self, x: Tensor) -> Tensor:
        results: list[Tensor] = [self.start(x)]
        results.append(self.down1(F.max_pool3d(results[-1], 2)))
        results.append(self.down2(F.max_pool3d(results[-1], 2)))

        x = F.interpolate(
            self.bridge(F.max_pool3d(results[-1], 2)), size=results[-1].shape[2:]
        )
        x = F.interpolate(
            self.up2(torch.cat((x, results[-1]), dim=1)), size=results[-2].shape[2:]
        )
        x = F.interpolate(
            self.up1(torch.cat((x, results[-2]), dim=1)), size=results[-3].shape[2:]
        )
        x = self.final(torch.cat((x, results[-3]), dim=1))

        return x

In [130]:
model = UNet(image_size=output_size)
first_image = train_data[0][0].unsqueeze(0)
batched_images = torch.stack([first_image, first_image])

predictions = model(batched_images)
print(predictions.shape)

torch.Size([2, 1, 32, 32, 32])
