In [None]:
%pip install nibabel matplotlib numpy torch

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
import numpy as np
import torch


class ToTensor:
    """Convert images in sample to Tensors"""

    def __call__(
        self, sample: tuple[np.ndarray, np.ndarray]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        src, tgt = sample
        return torch.from_numpy(src).float(), torch.from_numpy(tgt).float()


class RandomCrop3D:
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (tuple))
        assert len(output_size) == 3
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample[0], sample[1]

        h, w, d = image.shape[:3]
        new_h, new_w, new_d = self.output_size

        if new_h > h or new_w > w or new_d > d:
            raise ValueError("Output size is larger than input dimensions.")

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)
        depth = np.random.randint(0, d - new_d + 1)

        image = image[top : top + new_h, left : left + new_w, depth : depth + new_d]
        label = label[top : top + new_h, left : left + new_w, depth : depth + new_d]

        return image, label

In [None]:
import os

is_kaggle = True if os.getenv("KAGGLE_KERNEL_RUN_TYPE") else False
route_to_small = "/kaggle/input/t1t2-mri-data/small" if is_kaggle else "../small"

In [None]:
from torch.utils.data.dataset import Dataset
import os
import nibabel as nib


class NiftiDataset(Dataset):
    def __init__(self, source_dir: str, target_dir: str, transforms=None):
        self.source_images = [
            self._get_image_data("t1", image)
            for image in sorted(os.listdir(source_dir))
        ]
        self.target_images = [
            self._get_image_data("t2", image)
            for image in sorted(os.listdir(target_dir))
        ]
        self.transforms = transforms

    def __len__(self):
        return len(self.source_images)

    def _get_image_data(self, image_type: str, study_id: str):
        image = nib.load(f"{route_to_small}/{image_type}/{study_id}")
        return image.get_fdata()

    def __getitem__(self, idx: int):
        image_data = (self.source_images[idx], self.target_images[idx])

        if self.transforms:
            image_data = self.transforms(image_data)

        return image_data

In [None]:
from torchvision.transforms import Compose

output_size = (64, 64, 32)
dataset = NiftiDataset(
    source_dir=f"{route_to_small}/t1",
    target_dir=f"{route_to_small}/t2",
    transforms=Compose([RandomCrop3D(output_size=output_size), ToTensor()]),
)

In [None]:
valid_split = 0.1
batch_size = 16
num_jobs = 12
num_epochs = 50

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_train = len(dataset)
indices = list(range(num_train))
split = int(valid_split * num_train)
valid_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(
    dataset,
    sampler=train_sampler,
    batch_size=batch_size,
    num_workers=num_jobs,
    pin_memory=True,
)
valid_loader = DataLoader(
    dataset,
    sampler=valid_sampler,
    batch_size=batch_size,
    num_workers=num_jobs,
    pin_memory=True,
)

In [None]:
next(iter(train_loader))

In [None]:
from torch import Tensor, nn
import torch.nn.functional as F


def conv(in_channels: int, out_channels: int):
    return (
        nn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1,
            bias=False,
        ),
        nn.BatchNorm3d(num_features=out_channels),
        nn.ReLU(inplace=True),
    )


def double_conv(in_channels: int, intermediate_channels: int, out_channels: int):
    return nn.Sequential(
        *conv(in_channels=in_channels, out_channels=intermediate_channels),
        *conv(in_channels=intermediate_channels, out_channels=out_channels)
    )


class UNet(nn.Module):
    def __init__(self, image_size: tuple[int, int, int]):
        super().__init__()

        # input shape = (1, image_size[0], image_size[1], image_size[2])

        self.start = double_conv(
            in_channels=1,
            intermediate_channels=image_size[0],
            out_channels=image_size[0],
        )
        self.down1 = double_conv(
            in_channels=image_size[0],
            intermediate_channels=image_size[0] * 2,
            out_channels=image_size[0] * 2,
        )
        self.down2 = double_conv(
            in_channels=image_size[0] * 2,
            intermediate_channels=image_size[0] * 4,
            out_channels=image_size[0] * 4,
        )
        self.bridge = double_conv(
            in_channels=image_size[0] * 4,
            intermediate_channels=image_size[0] * 8,
            out_channels=image_size[0] * 4,
        )
        self.up2 = double_conv(
            in_channels=image_size[0] * 8,
            intermediate_channels=image_size[0] * 4,
            out_channels=image_size[0] * 2,
        )
        self.up1 = double_conv(
            in_channels=image_size[0] * 4,
            intermediate_channels=image_size[0] * 2,
            out_channels=image_size[0],
        )
        self.final = nn.Sequential(
            *conv(in_channels=image_size[0] * 2, out_channels=image_size[0]),
            nn.Conv3d(in_channels=image_size[0], out_channels=1, kernel_size=1)
        )

    def forward(self, x: Tensor) -> Tensor:
        if x.dim() == 4:
            x = x.unsqueeze(1)

        results: list[Tensor] = [self.start(x)]
        results.append(self.down1(F.max_pool3d(results[-1], 2)))
        results.append(self.down2(F.max_pool3d(results[-1], 2)))

        x = F.interpolate(
            self.bridge(F.max_pool3d(results[-1], 2)), size=results[-1].shape[2:]
        )
        x = F.interpolate(
            self.up2(torch.cat((x, results[-1]), dim=1)), size=results[-2].shape[2:]
        )
        x = F.interpolate(
            self.up1(torch.cat((x, results[-2]), dim=1)), size=results[-3].shape[2:]
        )
        x = self.final(torch.cat((x, results[-3]), dim=1))

        return x

In [None]:
model = UNet(image_size=output_size).to(device=device)
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=1e-6)
criterion = nn.SmoothL1Loss()  # nn.MSELoss()

In [None]:
# model.load_state_dict(torch.load('trained.pth'));

In [None]:
epoch_train_losses, epoch_val_losses = [], []
num_batches = len(train_loader)

for t in range(1, num_epochs + 1):
    # training
    train_losses = []
    model.train(True)
    for i, (source, target) in enumerate(train_loader):
        source, target = source.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(source)
        loss = criterion(out, target)
        train_losses.append(loss.item())
        loss.backward()
        optimizer.step()
    epoch_train_losses.append(train_losses)

    # validation
    val_losses = []
    model.train(False)
    with torch.set_grad_enabled(False):
        for source, target in valid_loader:
            source, target = source.to(device), target.to(device)
            out = model(source)
            loss = criterion(out, target)
            val_losses.append(loss.item())
        epoch_val_losses.append(val_losses)

    if not np.all(np.isfinite(train_losses)):
        raise RuntimeError("NaN or Inf in training loss, cannot recover. Exiting.")
    log = f"Epoch: {t} - Training Loss: {np.mean(train_losses):.2e}, Validation Loss: {np.mean(val_losses):.2e}"
    print(log)

In [None]:
torch.save(model.state_dict(), "trained.pth")