In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [1]:
import requests
import tempfile
import tarfile
import os

def download_file_from_google_drive(id, destination):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None

def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)

temp_dir = tempfile.mkdtemp()
data_dir = os.path.join(temp_dir, 'Task_04_Hippocampus')

if not os.path.exists(data_dir):
    tar_path = os.path.join(temp_dir, 'data.tar')
    if not os.path.exists(tar_path):
        print('Downloading Data')
        download_file_from_google_drive('1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C', 
                                        os.path.join(temp_dir, 'test.tar'))
    print('Extracting Data')
    tarfile.TarFile(os.path.join(temp_dir, 'test.tar')).extractall(temp_dir)
    print('Success!')
data_dir = os.path.join(temp_dir, 'Task04_Hippocampus')

Downloading Data
Extracting Data
Success!


In [21]:
import SimpleITK as sitk
import json
from rising import loading
from rising.loading import Dataset
import torch
class NiiDataset(Dataset):
    def __init__(self, train: bool, data_dir: str):
        with open(os.path.join(data_dir, 'dataset.json')) as f:
            content = json.load(f)['training']
            num_train_samples = int(len(content) * 0.9)
            if train:
                data = content[:num_train_samples]
            else:
                data = content[num_train_samples:]
            
            self.data = data
            self.data_dir = data_dir

    def __getitem__(self, item: int):
        sample = self.data[item]
        img = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(self.data_dir, sample['image'])))

        # add channel dim if necesary
        if img.ndim == 3:
            img = img[None]

        label = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(self.data_dir, sample['label'])))
        
        # convert multiclass to binary task by combining all positives
        label = label > 0
        
        # remove channel dim if necessary
        if label.ndim == 3:
            label = label[None]
        return {'data': torch.from_numpy(img).float(), 
                'label': torch.from_numpy(label).float()}

    def __len__(self):
        return len(self.data)


In [22]:
import pytorch_lightning as pl
class Unet(pl.LightningModule):
    def __init__(self, hparams: dict):
        super().__init__()
        # 4 downsample layers
        out_filts = hparams.get('start_filts', 16)
        depth = hparams.get('depth', 3)
        in_filts = hparams.get('in_channels', 1)
        num_classes = hparams.get('num_classes', 2)

        for idx in range(depth):
            down_block = torch.nn.Sequential(torch.nn.Conv3d(in_filts, out_filts, kernel_size=3, padding=1), torch.nn.ReLU(inplace=True),
                                             torch.nn.Conv3d(out_filts, out_filts, kernel_size=3, padding=1), torch.nn.ReLU(inplace=True))
            in_filts = out_filts
            out_filts *= 2

            setattr(self, 'down_block_%d' % idx, down_block)

        out_filts = out_filts // 2
        in_filts = in_filts // 2
        out_filts, in_filts = in_filts, out_filts

        for idx in range(depth-1):
            up_block = torch.nn.Sequential(torch.nn.Conv3d(in_filts + out_filts, out_filts, kernel_size=3, padding=1), torch.nn.ReLU(inplace=True),
                                            torch.nn.Conv3d(out_filts, out_filts, kernel_size=3, padding=1), torch.nn.ReLU(inplace=True))

            in_filts = out_filts
            out_filts = out_filts // 2

            setattr(self, 'up_block_%d' % idx, up_block)

        self.final_conv = torch.nn.Conv3d(in_filts, num_classes, kernel_size=1)
        self.max_pool = torch.nn.MaxPool3d(2, stride=2)
        self.up_sample = torch.nn.Upsample(scale_factor=2)
        self.hparams = hparams
    
    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        depth = self.hparams.get('depth', 3)

        intermediate_outputs = []

        for idx in range(depth):
            intermed = getattr(self, 'down_block_%d' % idx)(input_tensor)
            if idx < depth - 1:
                intermediate_outputs.append(intermed)
                input_tensor = getattr(self, 'max_pool')(intermed)
            else:
                input_tensor = intermed

        for idx in range(depth-1):
            input_tensor = getattr(self, 'up_sample')(input_tensor)
            from_down = intermediate_outputs.pop(-1)
            intermed = torch.cat([input_tensor, from_down], dim=1)
            input_tensor = getattr(self, 'up_block_%d' % idx)(intermed)

        return getattr(self, 'final_conv')(input_tensor)

        

In [23]:
net = Unet({'num_classes': 2, 'in_channels': 1, 'depth': 3})
print(net(torch.rand(1, 1, 128, 128, 128)).shape)

torch.Size([1, 2, 128, 128, 128])


In [43]:
import rising

# Taken from https://github.com/justusschock/dl-utils/blob/master/dlutils/losses/soft_dice.py
class SoftDiceLoss(torch.nn.Module):
    def __init__(self, square_nom=False, square_denom=False, weight=None,
                 smooth=1., reduction="elementwise_mean", non_lin=None):
        """
        SoftDice Loss
        Parameters
        ----------
        square_nom : bool
            square nominator
        square_denom : bool
            square denominator
        weight : iterable
            additional weighting of individual classes
        smooth : float
            smoothing for nominator and denominator
        """
        super().__init__()
        self.square_nom = square_nom
        self.square_denom = square_denom

        self.smooth = smooth

        if weight is not None:
            self.register_buffer("weight", torch.tensor(weight))
        else:
            self.weight = None

        self.reduction = reduction
        self.non_lin = non_lin

    def forward(self, predictions, targets):
        """
        Compute SoftDice Loss
        Parameters
        ----------
        inp : torch.Tensor
            prediction
        targets : torch.Tensor
            ground truth tensor
        Returns
        -------
        torch.Tensor
            loss
        """
        # number of classes for onehot
        n_classes = predictions.shape[1]
        with torch.no_grad():
            targets_onehot = rising.transforms.functional.channel.one_hot_batch(
                targets.unsqueeze(1), num_classes=n_classes)
        # sum over spatial dimensions
        dims = tuple(range(2, predictions.dim()))

        # apply nonlinearity
        if self.non_lin is not None:
            predictions = self.non_lin(predictions)

        # compute nominator
        if self.square_nom:
            nom = torch.sum((predictions * targets_onehot.float()) ** 2, dim=dims)
        else:
            nom = torch.sum(predictions * targets_onehot.float(), dim=dims)
        nom = 2 * nom + self.smooth

        # compute denominator
        if self.square_denom:
            i_sum = torch.sum(predictions ** 2, dim=dims)
            t_sum = torch.sum(targets_onehot ** 2, dim=dims)
        else:
            i_sum = torch.sum(predictions, dim=dims)
            t_sum = torch.sum(targets_onehot, dim=dims)

        denom = i_sum + t_sum.float() + self.smooth

        # compute loss
        frac = nom / denom

        # apply weight for individual classesproperly
        if self.weight is not None:
            frac = self.weight * frac

        # average over classes
        frac = - torch.mean(frac, dim=1)

        return frac


        



In [44]:
# Taken from https://github.com/justusschock/dl-utils/blob/master/dlutils/metrics/dice.py
def binary_dice_coefficient(pred: torch.Tensor, gt: torch.Tensor,
                            thresh: float = 0.5, smooth: float = 1e-7):
    """
    A binary dice coefficient
    Parameters
    ----------
    pred : torch.Tensor
        predicted segmentation (of shape NxCx(Dx)HxW)
    gt : torch.Tensor
        target segmentation (of shape NxCx(Dx)HxW)
    thresh : float
        segmentation threshold
    smooth : float
        smoothing value to avoid division by zero
    Returns
    -------
    torch.Tensor
        dice score
    """
    pred_bool = pred > thresh

    intersec = (pred_bool * gt).float()
    return 2 * intersec.sum() / (pred_bool.float().sum()
                                 + gt.float().sum() + smooth)



In [45]:
from rising.transforms import Compose, ResizeNative

def common_per_sample_trafos():
        return Compose(ResizeNative(size=(32, 64, 32), keys=('data',), mode='trilinear'),
                        ResizeNative(size=(32, 64, 32), keys=('label',), mode='nearest'))

In [53]:
from rising.transforms.affine import BaseAffine
import random
from typing import Optional, Sequence

class RandomAffine(BaseAffine):
    def __init__(self, scale_range: Optional[tuple] = None, 
                 rotation_range: Optional[tuple] = None, 
                 translation_range: Optional[tuple] = None,
                 degree: bool = True,
                 image_transform: bool = True,
                 keys: Sequence = ('data',),
                 grad: bool = False,
                 output_size: Optional[tuple] = None,
                 adjust_size: bool = False,
                 interpolation_mode: str = 'nearest',
                 padding_mode: str = 'zeros',
                 align_corners: bool = False,
                 reverse_order: bool = False,
                 **kwargs,):
        super().__init__(scale=None, rotation=None, translation=None, 
                         degree=degree,
                         image_transform=image_transform, 
                         keys=keys, 
                         grad=grad, 
                         output_size=output_size, 
                         adjust_size=adjust_size, 
                         interpolation_mode=interpolation_mode, 
                         padding_mode=padding_mode, 
                         align_corners=align_corners, 
                         reverse_order=reverse_order, 
                         **kwargs)
        
        self.scale_range = scale_range
        self.rotation_range = rotation_range
        self.translation_range = translation_range
        
    def assemble_matrix(self, **data):
        ndim = data[self.keys[0]].ndim - 2
        
        if self.scale_range is not None:
            self.scale = [random.uniform(*self.scale_range) for _ in range(ndim)]
            
        if self.translation_range is not None:
            self.translation = [random.uniform(*self.translation_range) for _ in range(ndim)]
            
        if self.rotation_range is not None:
            if ndim == 3:
                self.rotation = [random.uniform(*self.rotation_range) for _ in range(ndim)]
            elif ndim == 1:
                self.rotation = random.uniform(*self.rotation_range)
            
        return super().assemble_matrix(**data)
        

In [54]:
from rising.transforms import NormZeroMeanUnitStd
from rising.loading import DataLoader
import torch
class TrainableUNet(Unet):
    def __init__(self, hparams: Optional[dict] = None):
        if hparams is None:
            hparams = {}
        super().__init__(hparams)
        
        self.dice_loss = SoftDiceLoss(weight=[0., 1.])
        self.ce_loss = torch.nn.CrossEntropyLoss()
        
    def train_dataloader(self):
        dataset = NiiDataset(train=True, data_dir=data_dir)
        
        batch_transforms = Compose([
            RandomAffine(scale_range=(self.hparams.get('min_scale', 0.9), self.hparams.get('max_scale', 1.1)),
                         rotation_range=(self.hparams.get('min_rotation', -10), self.hparams.get('max_rotation', 10)),
                        keys=('data',)),
            NormZeroMeanUnitStd(keys=('data',))
        ])
        dataloader = DataLoader(dataset,
                                batch_size=self.hparams.get('batch_size', 1),
                                batch_transforms=batch_transforms,
                                shuffle=True, 
                                sample_transforms=common_per_sample_trafos(),
                                pseudo_batch_dim=True,
                                num_workers=self.hparams.get('num_workers', 4))
        return dataloader
    
    def val_dataloader(self):
        dataset = NiiDataset(train=False, data_dir=data_dir)
        
        batch_transforms = NormZeroMeanUnitStd(keys=('data',))
        dataloader = DataLoader(dataset,
                                batch_size=self.hparams.get('batch_size', 1),
                                batch_transforms=batch_transforms,
                                shuffle=False, 
                                sample_transforms=common_per_sample_trafos(),
                                pseudo_batch_dim=True,
                                num_workers=self.hparams.get('num_workers', 4))
        
        return dataloader
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.get('learning_rate', 1e-3))
    
    def training_step(self, batch, batch_idx):
        x, y = batch['data'], batch['label']
        
        # remove channel dim from gt (was necessary for augmentation)
        y = y[:, 0].long()
        
        pred = self(x)
        softmaxed_pred = torch.nn.functional.softmax(pred, dim=1)
        
        ce_loss = self.ce_loss(pred, y)
        dice_loss = self.dice_loss(softmaxed_pred, y)
        total_loss = (ce_loss + dice_loss) / 2
        
        dice_coeff = binary_dice_coefficient(torch.argmax(softmaxed_pred, dim=1), y)
        
        self.logger.experiment.add_scalar('Train/DiceCoeff', dice_coeff)
        self.logger.experiment.add_scalar('Train/CE', ce_loss)
        self.logger.experiment.add_scalar('Train/SoftDiceLoss', dice_loss)
        self.logger.experiment.add_scalar('Train/TotalLoss', total_loss)
        
        return {'loss': total_loss}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch['data'], batch['label']
        
        # remove channel dim from gt (was necessary for augmentation)
        y = y[:, 0].long()
        
        pred = self(x)
        softmaxed_pred = torch.nn.functional.softmax(pred, dim=1)
        
        ce_loss = self.ce_loss(pred, y)
        dice_loss = self.dice_loss(softmaxed_pred, y)
        total_loss = (ce_loss + dice_loss) / 2
        
        dice_coeff = binary_dice_coefficient(torch.argmax(softmaxed_pred, dim=1), y)
        
        self.logger.experiment.add_scalar('Val/DiceCoeff', dice_coeff)
        self.logger.experiment.add_scalar('Val/CE', ce_loss)
        self.logger.experiment.add_scalar('Val/SoftDiceLoss', dice_loss)
        self.logger.experiment.add_scalar('Val/TotalLoss', total_loss)
        
        return {'val_loss': total_loss, 'dice': dice_coeff}
    
    def validation_epoch_end(self, outputs):
        mean_outputs = {}
        for k in outputs[0].keys():
            mean_outputs[k] = torch.stack([x[k] for x in outputs]).mean()
        return mean_outputs
    

In [61]:
# Start tensorboard.

%reload_ext tensorboard
%tensorboard --logdir {temp_dir}

In [None]:
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import Trainer

early_stop_callback = EarlyStopping(monitor='dice', min_delta=0.001, patience=3, verbose=False, mode='max')


if torch.cuda.is_available():
    gpus = 0
else:
    gpus = None

model = TrainableUNet({'num_workers': 0})

trainer = Trainer(gpus=gpus, default_save_path=temp_dir, early_stop_callback=early_stop_callback, max_nb_epochs=5)
trainer.fit(model)


INFO:lightning:GPU available: False, used: False
INFO:lightning:
   | Name           | Type             | Params
------------------------------------------------
0  | down_block_0   | Sequential       | 7 K   
1  | down_block_0.0 | Conv3d           | 448   
2  | down_block_0.1 | ReLU             | 0     
3  | down_block_0.2 | Conv3d           | 6 K   
4  | down_block_0.3 | ReLU             | 0     
5  | down_block_1   | Sequential       | 41 K  
6  | down_block_1.0 | Conv3d           | 13 K  
7  | down_block_1.1 | ReLU             | 0     
8  | down_block_1.2 | Conv3d           | 27 K  
9  | down_block_1.3 | ReLU             | 0     
10 | down_block_2   | Sequential       | 166 K 
11 | down_block_2.0 | Conv3d           | 55 K  
12 | down_block_2.1 | ReLU             | 0     
13 | down_block_2.2 | Conv3d           | 110 K 
14 | down_block_2.3 | ReLU             | 0     
15 | up_block_0     | Sequential       | 110 K 
16 | up_block_0.0   | Conv3d           | 82 K  
17 | up_block_0.1   | 

Validation sanity check: 0it [00:00, ?it/s]

  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1:   0%|          | 0/260 [00:00<?, ?it/s]                      



Epoch 1:   3%|▎         | 8/260 [00:12<06:34,  1.57s/it, loss=0.242, v_num=12]