In [1]:
from monai import transforms
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.inferers import sliding_window_inference
import matplotlib.pyplot as plt
import monai
import os
import glob
import re
import torch
from torch import nn

In [2]:
root_dir = '../Task_Dataset/Task03_Liver'
val_percent = 0.1
data_select = 20

train_images = sorted(glob.glob(os.path.join(root_dir, 'imagesTr', '*.nii.gz')))
train_labels = sorted(glob.glob(os.path.join(root_dir, 'labelsTr', '*.nii.gz')))
test_images = sorted(glob.glob(os.path.join(root_dir, 'imagesTs', '*.nii.gz')))

data_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_labels)]
data_dicts = data_dicts[:data_select]

num_percent = int(len(data_dicts) * val_percent)
train_files, valid_files = data_dicts[:-num_percent], data_dicts[-num_percent:]
print(f"train length: {len(train_files)}, valid length: {len(valid_files)}")

train length: 18, valid length: 2


In [3]:
monai.utils.set_determinism(seed=0)

train_transform = transforms.Compose([
    transforms.LoadImaged(keys=['image', 'label']),
    transforms.EnsureChannelFirstd(keys=['image', 'label']),
    transforms.ScaleIntensityRanged(
        keys=['image'],
        a_min=-100,
        a_max=200,
        b_min=0.0,
        b_max=1.0,
        clip=True
    ),
    transforms.CropForegroundd(keys=['image', 'label'], source_key='image', allow_smaller=True),
    transforms.Orientationd(keys=['image', 'label'], axcodes='RAS'),
    transforms.Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.0), mode=('bilinear', 'nearest')),
    transforms.RandCropByPosNegLabeld(
        keys=['image', 'label'],
        image_key='image',
        label_key='label',
        image_threshold=0,
        spatial_size=(36, 36, 36),
        pos=1,
        neg=1,
        num_samples=4
    )
])

valid_transform = transforms.Compose([
    transforms.LoadImaged(keys=['image', 'label']),
    transforms.EnsureChannelFirstd(keys=['image', 'label']),
    transforms.ScaleIntensityRanged(
        keys=['image'],
        a_min=-100,
        a_max=200,
        b_min=0.0,
        b_max=1.0,
        clip=True
    ),
    transforms.CropForegroundd(keys=['image', 'label'], source_key='image', allow_smaller=True),
    transforms.Orientationd(keys=['image', 'label'], axcodes='RAS'),
    transforms.Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.0), mode=('bilinear', 'nearest'))
])

In [4]:
class UNet3D(nn.Module):
    def __init__(self, in_channels: int, out_channels: int,
                 n_channels: list=None, batch_norm: bool=False) -> None:
        super(UNet3D, self).__init__()
        if n_channels is None:
            n_channels = [64, 128, 256, 512]

        self.in_conv = DoubleConv(in_channels, n_channels[0], batch_norm=batch_norm)
        self.encoder_1 = DownSample(n_channels[0], n_channels[1], batch_norm=batch_norm)
        self.encoder_2 = DownSample(n_channels[1], n_channels[2], batch_norm=batch_norm)
        self.encoder_3 = DownSample(n_channels[2], n_channels[3], batch_norm=batch_norm)

        self.decoder_1 = UpSample(n_channels[3], n_channels[2], n_channels[2], batch_norm=batch_norm)
        self.decoder_2 = UpSample(n_channels[2], n_channels[1], n_channels[1], batch_norm=batch_norm)
        self.decoder_3 = UpSample(n_channels[1], n_channels[0], n_channels[0], batch_norm=batch_norm)
        self.out_conv = OutConv(n_channels[0], out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.in_conv(x)
        x2 = self.encoder_1(x1)
        x3 = self.encoder_2(x2)
        x4 = self.encoder_3(x3)

        x = self.decoder_1(x4, x3)
        x = self.decoder_2(x, x2)
        x = self.decoder_3(x, x1)
        x = self.out_conv(x)
        return x


class DoubleConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, batch_norm: bool=False) -> None:
        super(DoubleConv, self).__init__()
        mid_channels = out_channels // 2
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        if batch_norm:
            self.conv1.append(nn.BatchNorm3d(mid_channels))
            self.conv2.append(nn.BatchNorm3d(out_channels))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class DownSample(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, batch_norm: bool=False) -> None:
        super(DownSample, self).__init__()
        self.down = nn.Sequential(
            nn.MaxPool3d(kernel_size=2, stride=2),
            DoubleConv(in_channels, out_channels, batch_norm=batch_norm)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.down(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, encoder_channels: int, batch_norm: bool=False) -> None:
        super(UpSample, self).__init__()
        self.up = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels + encoder_channels, out_channels, batch_norm=batch_norm)

    def forward(self, decoder: torch.Tensor, encoder: torch.Tensor) -> torch.Tensor:
        decoder = self.up(decoder)
        x = torch.cat([encoder, decoder], dim=1)
        x = self.conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super(OutConv, self).__init__()
        self.out = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.out(x)
        return x

UNet3D(
  (inputs): ConvDouble(
    (conv_double): Sequential(
      (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
  )
  (down_1): DownSampling(
    (downsample): Sequential(
      (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): ConvDouble(
        (conv_double): Sequential(
          (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (4): BatchNorm3d(128, eps=1e-0

In [5]:
train_cache = CacheDataset(data=train_files, transform=train_transform, cache_rate=1.0, num_workers=4)
valid_cache = CacheDataset(data=valid_files, transform=valid_transform, cache_rate=1.0, num_workers=4)

Loading dataset: 100%|██████████| 18/18 [01:30<00:00,  5.02s/it]
Loading dataset: 100%|██████████| 2/2 [00:22<00:00, 11.01s/it]


In [6]:
train_loader = DataLoader(dataset=train_cache, batch_size=2, shuffle=True)
valid_loader = DataLoader(dataset=valid_cache, batch_size=1, shuffle=False)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = UNet3D(in_channels=1, out_channels=1, batch_norm=True).to(device)
loss_fn = DiceLoss(to_onehot_y=True, softmax=True, squared_pred=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dice_metric = DiceMetric(include_background=False, reduction='mean')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5,
                                                       patience=2, threshold=1e-04, threshold_mode='rel',
                                                       cooldown=0, min_lr=1e-06, eps=1e-08)

In [8]:
MAX_EPOCHS = 100
ROI_SIZE = (24, 24, 24)

SW_BATCH_SIZE = 4
VAL_INTERVAL = 2

mean_dice = []
train_loss = []

best_dice = -1
best_epoch = -1

post_pred = transforms.Compose([transforms.AsDiscrete(argmax=True, to_onehot=2)])
post_label = transforms.Compose([transforms.AsDiscrete(to_onehot=2)])

In [11]:
for epoch in range(MAX_EPOCHS):
    print(f"{f'Epoch {epoch + 1}/{MAX_EPOCHS}':-^50}")

    train_step = 0
    epoch_loss = 0

    model.train()
    for batch in train_loader:
        train_step += 1
        images, labels = batch['image'].to(device), batch['label'].to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        print(f"{train_step}/{len(train_cache) // train_loader.batch_size}, train loss: {loss.item():.4f}")

    epoch_loss /= train_step
    train_loss.append(epoch_loss)
    print(f"epoch: {epoch + 1}/{MAX_EPOCHS}, average loss: {epoch_loss:.4f}")

    if epoch % VAL_INTERVAL:
        continue

    model.eval()
    with torch.no_grad():
        for batch in valid_loader:
            images, labels = batch['image'].to(device), batch['label'].to(device)
            outputs = sliding_window_inference(images, ROI_SIZE, SW_BATCH_SIZE, model)
            valid_outputs = [post_pred(i) for i in decollate_batch(outputs)]
            valid_labels = [post_label(i) for i in decollate_batch(labels)]
            dice_metric(y_pred=valid_outputs, y=valid_labels)

        dice = dice_metric.aggregate().item()
        scheduler.step(dice)

        mean_dice.append(dice)
        dice_metric.reset()

        if dice > best_dice:
            best_dice = dice
            best_epoch = epoch + 1
            # torch.save(model.state_dict(), f"./results/best_dice_model.pth")

        print(f"epoch: {epoch + 1}/{MAX_EPOCHS}, current mean dice: {dice:.4f}, "
              f"best mean dice: {best_dice:.4f} at epoch {best_epoch}")

-------------------Epoch 1/100--------------------


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 9 but got size 8 for tensor number 1 in the list.