In [1]:
import torch
import torch.nn as nn
import torchio as tio

from torch.utils.data import TensorDataset, DataLoader

from layer.component import EncoderBlock, DecoderBlock

from PIL import Image
import numpy as np
import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt

import nibabel as nib
import os
from pathlib import Path

from torchio.transforms import (
    RandomFlip,
    RandomAffine,
    RandomElasticDeformation, 
    RandomNoise,
    RandomMotion,
    RandomBiasField,
    RescaleIntensity,
    Resample,
    ToCanonical,
    ZNormalization,
    CropOrPad,
    HistogramStandardization,
    OneOf,
    Compose,
)

In [2]:
# train_data = nib.load(r'data\0160_20180601_190903_T2FLAIR_to_MAG.nii.gz')
# train_mask = nib.load(r'data\0160_20180601_190903_T2FLAIR_to_MAG_ROI.nii.gz')

# # header = proxy.header
# # print(header)
# # print(header['dim'])
# image_dataset = train_data.get_fdata()[:,:,10:42]
# mask_dataset = train_mask.get_fdata()[:,:,10:42]

# train_data.shape
print(type(tio.datasets.ICBM2009CNonlinearSymmetric()))

<class 'torchio.datasets.mni.icbm.ICBM2009CNonlinearSymmetric'>


In [3]:
def Dataset_load(base_path, mask_path):

    dataset_dir = Path(base_path)
    maskset_dir = Path(mask_path)
    
    image_paths = sorted(dataset_dir.glob('*.nii.gz'))
    label_paths = sorted(maskset_dir.glob('*.nii.gz'))
    print(len(image_paths), len(label_paths))

    assert len(image_paths) == len(label_paths)

    T2F = 't2f'
    LABEL = 'label'
    subjects = []

    training_transform = tio.Compose([
    tio.ToCanonical(),
    tio.CropOrPad((128,128,32)),
    # tio.Resample(2),
    # tio.RandomMotion(p=0.2),
    # tio.RandomBiasField(p=0.3),
    # tio.RandomNoise(p=0.5),
    # tio.RandomFlip(axes=(0,)),
    # tio.RandomAffine(),
    # ZNormalization(),
    ])

    for (image_path, label_path) in zip(image_paths, label_paths):
        subject = tio.Subject(
            T2F = tio.ScalarImage(image_path),
            LABEL = tio.LabelMap(label_path),
        )
        subjects.append(subject)
    dataset = tio.SubjectsDataset(subjects, transform=training_transform)
    return dataset

In [14]:
training_set = Dataset_load(
    base_path=r'data\image',
    mask_path=r'data\mask'
)

training_set._subjects

3 3


[Subject(Keys: ('T2F', 'LABEL'); images: 2),
 Subject(Keys: ('T2F', 'LABEL'); images: 2),
 Subject(Keys: ('T2F', 'LABEL'); images: 2)]

In [5]:
training_batch_size = 1
training_loader = torch.utils.data.DataLoader(dataset = training_set, batch_size = training_batch_size, shuffle = True,
                                              num_workers=0)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
CHANNELS_DIMENSION = 1
SPATIAL_DIMENSIONS = 2,3,4

class UnetModel(nn.Module):

    def __init__(self, in_channels, out_channels, model_depth=4, final_activation="sigmoid"):
        super(UnetModel, self).__init__()
        self.encoder = EncoderBlock(in_channels=in_channels, model_depth=model_depth)
        self.decoder = DecoderBlock(out_channels=out_channels, model_depth=model_depth)
        if final_activation == "sigmoid":
            self.sigmoid = nn.Sigmoid()
        else:
            self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x, downsampling_features = self.encoder(x)
        x = self.decoder(x, downsampling_features)
        x = self.sigmoid(x)
        # print("Final output shape: ", x.shape)
        return x

def prepare_batch(batch, device):
  inputs = batch['T2F'][tio.DATA].to(device)
  foreground = batch['LABEL'][tio.DATA].type(torch.float32).to(device)
  background = 1 - foreground
  targets = torch.cat((background, foreground), dim = CHANNELS_DIMENSION)
  return inputs, targets

def get_dice_score(output, target, epsilon = 1e-9):
  p0 = output
  g0 = target
  p1 = 1 - p0
  g1 = 1 - g0
  tp = (p0 * g0).sum(dim = SPATIAL_DIMENSIONS)
  fp = (p0 * g1).sum(dim = SPATIAL_DIMENSIONS)
  fn = (p1 * g0).sum(dim = SPATIAL_DIMENSIONS)
  num = 2 * tp
  denom = 2 * tp + fp + fn + epsilon
  dice_score = num / denom

  return dice_score

def get_dice_loss(output, target):
  return 1 - get_dice_score(output, target)

def forward(model, inputs):
  with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category = UserWarning)
    logits = model(inputs)
  return logits
def get_model_and_optimizer(device):
    model = UnetModel(in_channels=1, out_channels=1).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3, momentum = 0.9)
    return model, optimizer

In [8]:
class DiceLoss(nn.Module):
    r"""Criterion that computes Sørensen-Dice Coefficient loss.

    According to [1], we compute the Sørensen-Dice Coefficient as follows:

    .. math::

        \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}

    where:
       - :math:`X` expects to be the scores of each class.
       - :math:`Y` expects to be the one-hot tensor with the class labels.

    the loss, is finally computed as:

    .. math::

        \text{loss}(x, class) = 1 - \text{Dice}(x, class)

    [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

    Shape:
        - Input: :math:`(N, C, H, W)` where C = number of classes.
        - Target: :math:`(N, H, W)` where each value is
          :math:`0 ≤ targets[i] ≤ C−1`.

    Examples:
        >>> N = 5  # num_classes
        >>> loss = tgm.losses.DiceLoss()
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = loss(input, target)
        >>> output.backward()
    """

    def __init__(self) -> None:
        super(DiceLoss, self).__init__()
        self.eps: float = 1e-6

    def forward(
            self,
            input: torch.Tensor,
            target: torch.Tensor) -> torch.Tensor:
        if not torch.is_tensor(input):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(input)))
        if not len(input.shape) == 5:
            raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
                             .format(input.shape))
        if not input.shape[-2:] == target.shape[-2:]:
            raise ValueError("input and target shapes must be the same. Got: {}"
                             .format(input.shape, input.shape))
        if not input.device == target.device:
            raise ValueError(
                "input and target must be in the same device. Got: {}" .format(
                    input.device, target.device))
        # compute softmax over the classes axis
        input_soft = F.softmax(input, dim=1)

        # create the labels one hot tensor
        target_one_hot = one_hot(target, num_classes=input.shape[1],
                                 device=input.device, dtype=input.dtype)

        # compute the actual dice score
        dims = (1, 2, 3)
        intersection = torch.sum(input_soft * target_one_hot, dims)
        cardinality = torch.sum(input_soft + target_one_hot, dims)

        dice_score = 2. * intersection / (cardinality + self.eps)
        return torch.mean(1. - dice_score)

def dice_loss(
        input: torch.Tensor,
        target: torch.Tensor) -> torch.Tensor:
    r"""Function that computes Sørensen-Dice Coefficient loss.

    See :class:`~torchgeometry.losses.DiceLoss` for details.
    """
    return DiceLoss()(input, target)
    
def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    
    intersection = torch.bitwise_and(outputs, labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = torch.bitwise_or(outputs,labels).float().sum((1, 2))         # Will be zzero if both are 0
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded  # Or thresholded.mean() if you are interested in average across the batch

In [9]:
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, verbose=True)

model, optimizer = get_model_and_optimizer(device)

for epoch in range(10):
    for subjects_batch in training_loader:
        inputs = subjects_batch['T2F'][tio.DATA].to(device)
        target = subjects_batch['LABEL'][tio.DATA].to(device)

        outputs = model(inputs)
        loss = dice_loss(inputs,target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"epoch : {epoch} | loss : {loss.item()} iou : {iou}")

ValueError: Invalid input shape, we expect BxNxHxW. Got: torch.Size([3, 1, 128, 128, 32])

In [None]:
# for epoch in range(1000):
#     for batch_idx, samples in enumerate(train):
#         # 순전파 단계 : 모델에서 x에 대한 예측 값 y_pred를 계산합니다.

#         x_batch, y_batch = samples

#         y_pred = model(x_batch)
#         y_pred = y_pred.cpu()

#         loss = loss_function(y_pred,y_batch)
#         scheduler.step(loss)


#         optimizer.zero_grad()

#         # 역전파 단계
#         loss.backward()

#         # optimizer의 step 함수를 호출 하면 매개 변수가 갱신됨
#         optimizer.step()

#         print(f"epoch {epoch} | batch {batch_idx+1}/{len(dataloader)} | lose : {loss.item()}")