In [None]:
import torch
import torch.nn as nn
import torchio as tio

from torch.utils.data import TensorDataset, DataLoader

from layer.component import EncoderBlock, DecoderBlock

from PIL import Image
import numpy as np
import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt

import nibabel as nib
import os
from pathlib import Path

from torchio.transforms import (
    RandomFlip,
    RandomAffine,
    RandomElasticDeformation, 
    RandomNoise,
    RandomMotion,
    RandomBiasField,
    RescaleIntensity,
    Resample,
    ToCanonical,
    ZNormalization,
    CropOrPad,
    HistogramStandardization,
    OneOf,
    Compose,
)

In [None]:
# train_data = nib.load(r'data\0160_20180601_190903_T2FLAIR_to_MAG.nii.gz')
# train_mask = nib.load(r'data\0160_20180601_190903_T2FLAIR_to_MAG_ROI.nii.gz')

# # header = proxy.header
# # print(header)
# # print(header['dim'])
# image_dataset = train_data.get_fdata()[:,:,10:42]
# mask_dataset = train_mask.get_fdata()[:,:,10:42]

# train_data.shape
print(type(tio.datasets.ICBM2009CNonlinearSymmetric()))

In [None]:
def Dataset_load(base_path, mask_path):

    dataset_dir = Path(base_path)
    maskset_dir = Path(mask_path)
    
    image_paths = sorted(dataset_dir.glob('*.nii.gz'))
    label_paths = sorted(maskset_dir.glob('*.nii.gz'))
    print(len(image_paths), len(label_paths))

    assert len(image_paths) == len(label_paths)

    T2F = 't2f'
    LABEL = 'label'
    subjects = []

    training_transform = tio.Compose([
    tio.ToCanonical(),
    # tio.CropOrPad((256,128,32)),
    # tio.Resample(2),
    # tio.RandomMotion(p=0.2),
    # tio.RandomBiasField(p=0.3),
    # tio.RandomNoise(p=0.5),
    # tio.RandomFlip(axes=(0,)),
    # tio.RandomAffine(),
    # ZNormalization(),
    ])

    for (image_path, label_path) in zip(image_paths, label_paths):
        subject = tio.Subject(
            T2F = tio.ScalarImage(image_path),
            LABEL = tio.LabelMap(label_path),
        )
        subjects.append(subject)
    dataset = tio.SubjectsDataset(subjects, transform=training_transform)
    return dataset

In [None]:
training_set = Dataset_load(
    base_path=r'data/image',
    mask_path=r'data/mask'
)

training_set._subjects

In [None]:
training_batch_size = 1
training_loader = torch.utils.data.DataLoader(dataset = training_set, batch_size = training_batch_size, shuffle = True,
                                              num_workers=0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
CHANNELS_DIMENSION = 1
SPATIAL_DIMENSIONS = 2,3,4

class UnetModel(nn.Module):

    def __init__(self, in_channels, out_channels, model_depth=4, final_activation="sigmoid"):
        super(UnetModel, self).__init__()
        self.encoder = EncoderBlock(in_channels=in_channels, model_depth=model_depth)
        self.decoder = DecoderBlock(out_channels=out_channels, model_depth=model_depth)
        if final_activation == "sigmoid":
            self.sigmoid = nn.Sigmoid()
        else:
            self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x, downsampling_features = self.encoder(x)
        x = self.decoder(x, downsampling_features)
        x = self.sigmoid(x)
        # print("Final output shape: ", x.shape)
        return x
        
def get_model_and_optimizer(device):
    model = UnetModel(in_channels=1, out_channels=1).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3, momentum = 0.9)
    return model, optimizer

In [None]:
class DiceLoss(Function):
    def __init__(self, *args, **kwargs):
        pass

    def forward(self, input, target, save=True):
        if save:
            self.save_for_backward(input, target)
        eps = 0.000001
        _, result_ = input.max(1)
        result_ = torch.squeeze(result_)
        if input.is_cuda:
            result = torch.cuda.FloatTensor(result_.size())
            self.target_ = torch.cuda.FloatTensor(target.size())
        else:
            result = torch.FloatTensor(result_.size())
            self.target_ = torch.FloatTensor(target.size())
        result.copy_(result_)
        self.target_.copy_(target)
        target = self.target_
#       print(input)
        intersect = torch.dot(result, target)
        # binary values so sum the same as sum of squares
        result_sum = torch.sum(result)
        target_sum = torch.sum(target)
        union = result_sum + target_sum + (2*eps)

        # the target volume can be empty - so we still want to
        # end up with a score of 1 if the result is 0/0
        IoU = intersect / union
        print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format(
            union, intersect, target_sum, result_sum, 2*IoU))
        out = torch.FloatTensor(1).fill_(2*IoU)
        self.intersect, self.union = intersect, union
        return out

    def backward(self, grad_output):
        input, _ = self.saved_tensors
        intersect, union = self.intersect, self.union
        target = self.target_
        gt = torch.div(target, union)
        IoU2 = intersect/(union*union)
        pred = torch.mul(input[:, 1], IoU2)
        dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4))
        grad_input = torch.cat((torch.mul(dDice, -grad_output[0]),
                                torch.mul(dDice, grad_output[0])), 0)
        return grad_input , None

def dice_loss(input, target):
    return DiceLoss()(input, target)

In [None]:
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, verbose=True)

model, optimizer = get_model_and_optimizer(device)

for epoch in range(10):
    for subjects_batch in training_loader:
        inputs = subjects_batch['T2F'][tio.DATA].to(device)
        target = subjects_batch['LABEL'][tio.DATA].to(device)

        outputs = model(inputs)
        loss = dice_loss(inputs,target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"epoch : {epoch} | loss : {loss.item()} iou : {iou}")

In [None]:
# for epoch in range(1000):
#     for batch_idx, samples in enumerate(train):
#         # 순전파 단계 : 모델에서 x에 대한 예측 값 y_pred를 계산합니다.

#         x_batch, y_batch = samples

#         y_pred = model(x_batch)
#         y_pred = y_pred.cpu()

#         loss = loss_function(y_pred,y_batch)
#         scheduler.step(loss)


#         optimizer.zero_grad()

#         # 역전파 단계
#         loss.backward()

#         # optimizer의 step 함수를 호출 하면 매개 변수가 갱신됨
#         optimizer.step()

#         print(f"epoch {epoch} | batch {batch_idx+1}/{len(dataloader)} | lose : {loss.item()}")