In [None]:
!pip install -q git+https://github.com/matjesg/deepflash2.git
!pip install segmentation_models_pytorch
!pip install pytorch-lightning
!pip install neptune-client

In [None]:
import zarr, cv2
from fastai.vision.all import *
from fastai.data.transforms import Normalize as N
from deepflash2.all import *
import albumentations as alb

import torch
from torch import nn
import torchvision
import os
import gc
import numpy as np
import pandas as pd
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
from scipy.ndimage.interpolation import zoom
from albumentations import *
from torch.nn import functional as F
import matplotlib.pyplot as plt
from PIL import Image
import tifffile as tiff
import cv2
import zipfile
import time
import random
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.unet import Unet
from segmentation_models_pytorch import Linknet
from tqdm.notebook import tqdm
from torchmetrics import Metric

import pytorch_lightning as pl

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers.neptune import NeptuneLogger

import warnings
warnings.filterwarnings("ignore")

In [None]:
@patch
def read_img(self:BaseDataset, file, *args, **kwargs):
    return zarr.open(str(file), mode='r')

@patch
def _name_fn(self:BaseDataset, g):
    "Name of preprocessed and compressed data."
    return f'{g}'

@patch
def apply(self:DeformationField, data, offset=(0, 0), pad=(0, 0), order=1):
    "Apply deformation field to image using interpolation"
    outshape = tuple(int(s - p) for (s, p) in zip(self.shape, pad))
    coords = [np.squeeze(d).astype('float32').reshape(*outshape) for d in self.get(offset, pad)]
    # Get slices to avoid loading all data (.zarr files)
    sl = []
    for i in range(len(coords)):
        cmin, cmax = int(coords[i].min()), int(coords[i].max())
        dmax = data.shape[i]
        if cmin<0: 
            cmax = max(-cmin, cmax)
            cmin = 0 
        elif cmax>dmax:
            cmin = min(cmin, 2*dmax-cmax)
            cmax = dmax
            coords[i] -= cmin
        else: coords[i] -= cmin
        sl.append(slice(cmin, cmax))    
    if len(data.shape) == len(self.shape) + 1:
        tile = np.empty((*outshape, data.shape[-1]))
        for c in range(data.shape[-1]):
            # Adding divide
            tile[..., c] = cv2.remap(data[sl[0],sl[1], c]/255, coords[1],coords[0], interpolation=order, borderMode=cv2.BORDER_REFLECT)
    else:
        tile = cv2.remap(data[sl[0], sl[1]], coords[1], coords[0], interpolation=order, borderMode=cv2.BORDER_REFLECT)
    return tile

In [None]:
class CONFIG():
    
    # data paths
    data_path = Path('../input/hubmap-kidney-segmentation')
    data_path_zarr = Path('../input/hubmap-zarr-dataset/train_scale2')
    mask_preproc_dir = '../input/hubmap-labels/masks_scale2'
    
    # deepflash2 dataset
    scale = 1 # data is already downscaled to 2, so absulute downscale is 3
    tile_shape = (512, 512)
    padding = (0,0) # Border overlap for prediction
    n_jobs = 1
    sample_mult = 100 # Sample 100 tiles from each image, per epoch
    val_length = 500 # Randomly sample 500 validation tiles
    stats = np.array([0.61561477, 0.5179343 , 0.64067212]), np.array([0.2915353 , 0.31549066, 0.28647661])
    
    # deepflash2 augmentation options
    zoom_sigma = 0.1
    flip = True
    max_rotation = 360
    deformation_grid_size = (150,150)
    deformation_magnitude = (10,10)
    

    # pytorch model (segmentation_models_pytorch)
    encoder_name = "efficientnet-b4"
    encoder_weights = 'imagenet'
    in_channels = 3
    classes = 2
    
    # Lightning Learner
    batch_size = 8
    epochs = 50
    learning_rate = 1e-3
    decay = 0.01
    virtual_batch_size = 32
    
    # fastai Learner 
    mixed_precision_training = True
    weight_decay = 0.01
    loss_func = CrossEntropyLossFlat(axis=1)
    metrics = [Iou(), Dice_f1()]
    optimizer = ranger
    max_learning_rate = 1e-4
    epochs = 12
    
cfg = CONFIG()

In [None]:
tfms = alb.OneOf([
    alb.HueSaturationValue(10,15,10),
    alb.CLAHE(clip_limit=2),
    alb.RandomBrightnessContrast(),            
    ], p=0.3)

In [None]:
df_train = pd.read_csv(cfg.data_path/'train.csv')
df_info = pd.read_csv(cfg.data_path/'HuBMAP-20-dataset_information.csv')

files = [x for x in cfg.data_path_zarr.iterdir() if x.is_dir() if not x.name.startswith('.')]
label_fn = lambda o: o

In [None]:
# Datasets
ds_kwargs = {
    'tile_shape':cfg.tile_shape,
    'padding':cfg.padding,
    'scale': cfg.scale,
    'n_jobs': cfg.n_jobs, 
    'preproc_dir': cfg.mask_preproc_dir, 
    'val_length':cfg.val_length, 
    'sample_mult':cfg.sample_mult,
    'loss_weights':False,
    'zoom_sigma': cfg.zoom_sigma,
    'flip' : cfg.flip,
    'max_rotation': cfg.max_rotation,
    'deformation_grid_size' : cfg.deformation_grid_size,
    'deformation_magnitude' : cfg.deformation_magnitude,
    'albumentations_tfms': tfms
}

train_ds = RandomTileDataset(files, label_fn=label_fn, **ds_kwargs)
valid_ds = TileDataset(files, label_fn=label_fn, **ds_kwargs, is_zarr=True)

In [None]:
dls = DataLoaders.from_dsets(train_ds,
                            valid_ds,
                            bs=cfg.batch_size, 
                            after_batch = N.from_stats(*cfg.stats))

dls = dls.cuda()

In [None]:
class HuBMAP(nn.Module):
    def __init__(self):
        super(HuBMAP, self).__init__()
        self.cnn_model = Unet(cfg.encoder_name, encoder_weights=cfg.encoder_weights, in_channels=cfg.in_channels,classes=cfg.classes, activation=None)
        #self.cnn_model.decoder.blocks.append(self.cnn_model.decoder.blocks[-1])
        #self.cnn_model.decoder.blocks[-2] = self.cnn_model.decoder.blocks[-3]
    
    def forward(self, imgs):
        img_segs = self.cnn_model(imgs)
        return img_segs

In [None]:
#losses

class BCEMultiClass(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(BCEMultiClass, self).__init__()
        self.loss_func = nn.BCEWithLogitsLoss()

    def forward(self, inputs, targets, smooth=1):
        
        bs = inputs.size(0)
        
#         inputs = inputs.log_softmax(dim=1).exp()       
        
        #flatten label and prediction tensors
        inputs = inputs.view(bs, cfg.classes, -1)
        targets = targets.view(bs, -1)
        targets = F.one_hot(targets, 2).permute(0,2,1)
        
        loss = self.loss_func(torch.tensor(inputs).clone().detach().requires_grad_(True), torch.tensor(targets).type_as(inputs).clone().detach().requires_grad_(True))
        
        return loss

class LogCosHDice(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(LogCosHDice, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        bs = inputs.size(0)
        
        inputs = inputs.log_softmax(dim=1).exp()       
        
        #flatten label and prediction tensors
        inputs = inputs.view(bs, cfg.classes, -1)
        targets = targets.view(bs, -1)
        targets = F.one_hot(targets, 2).permute(0,2,1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return torch.log(torch.cosh(1 - dice))
    
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        bs = inputs.size(0)
        
        inputs = inputs.log_softmax(dim=1).exp()       
        
        #flatten label and prediction tensors
        inputs = inputs.view(bs, cfg.classes, -1)
        targets = targets.view(bs, -1)
        targets = F.one_hot(targets, 2).permute(0,2,1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return (1 - dice)
    
class CELogCosHDice(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CELogCosHDice, self).__init__()
        
        self.ce = CrossEntropyLossFlat(axis = 1)
        self.lcd = LogCosHDice()

    def forward(self, inputs, targets):
        
        l1 = self.ce(inputs, targets) 
        l2 = self.lcd(inputs, targets)
        
        return l1+l2
        
    
class CEDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CEDiceLoss, self).__init__()
        self.ce = CrossEntropyLossFlat(axis = 1)
        self.d = DiceLoss()

    def forward(self, inputs, targets):
        
        l1 = self.ce(inputs, targets)
        l2 = self.d(inputs, targets)
        
        return l1 + l2

In [None]:
class DiceCoeff(Metric):
    def __init__(self, dist_sync_on_step=False, eps = 1e-7):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        self.eps = eps

        self.add_state("inter", default=torch.tensor(0, dtype = torch.float32), dist_reduce_fx="sum")
        self.add_state("union", default=torch.tensor(0, dtype = torch.float32), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        bs = preds.size(0)
        preds, target = preds.argmax(dim=1).view(-1), target.view(-1)
        assert preds.shape == target.shape, f"Expected output and target to have the same number of elements but got {len(preds)} and {len(target)}."

        self.inter += (preds*target).float().sum().item()
        self.union += (preds+target).float().sum().item()

    def compute(self):
        return (2.0 * self.inter)/(self.union + 1e-7)

In [None]:
#losses single class
class BCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(BCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
        
        return BCE
    
#metric single class
class DiceCoeffS(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("inter", default=torch.tensor(0, dtype = torch.float32), dist_reduce_fx="sum")
        self.add_state("union", default=torch.tensor(0, dtype = torch.float32), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = torch.sigmoid(preds).contiguous().view(-1), target.contiguous().view(-1)
        assert preds.shape == target.shape, f"Expected output and target to have the same number of elements but got {len(preds)} and {len(target)}."

        self.inter += (preds*target).float().sum().item()
        self.union += (preds+target).float().sum().item()

    def compute(self):
        return 2.0 * self.inter/self.union if self.union > 0 else None

In [None]:
class HuBMAP_net(pl.LightningModule):
    def __init__(self, dls, ths=np.arange(0.1,0.9,0.05), learning_rate = None, batch_size = 32, validation_batch_size = 16):
        super().__init__()
        #model
        self.net = HuBMAP()
        
        self.dls = dls
        
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.ths = ths
        self.validation_batch_size = validation_batch_size
        self.save_hyperparameters()

        self.train_metric = DiceCoeff()
        self.valid_metric = DiceCoeff()
        
        self.loss_function = CEDiceLoss()



    def forward(self, x):
        x = self.net(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=0.01)
#         lr_scheduler = {
#         "scheduler":torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, 
#                                                                threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=1e-6, eps=1e-08, verbose = True),
#         "name":"ReduceLROnPlateau",
#         "monitor":"Validation_loss_epoch",
#         "interval":"epoch"
#         }
    
#         lr_scheduler = {
#         "scheduler":torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=cfg.decay, last_epoch=-1, verbose=True),
#         "name":"StepLR",
#         }
        
#         lr_scheduler = {
#         "scheduler":torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
#                 optimizer, T_0=5, T_mult=2, eta_min=1e-5, last_epoch=-1, verbose = True
#             ),
#         "name":"CosineAnnealingWarmRestarts",
#         }

        lr_scheduler = {
        "scheduler":torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=  [3,15,30], gamma=0.1, last_epoch=-1, verbose=True),
        "name":"MultiStepLR",
        }
        return [optimizer], [lr_scheduler]
#         return optimizer

    def training_step(self, batch, batch_idx):
        image, targets = batch
        y_pred = self.forward(image)
        loss = self.loss_function(y_pred, targets)
        train_dice_batch = self.train_metric(y_pred, targets)
        self.log('DiceCoeff', train_dice_batch, prog_bar = True)
        self.log('train_loss_batch', loss)
        return {
            'loss': loss,
            'train_dice_batch': train_dice_batch,
        }

    def training_epoch_end(self, outputs):
        current_train_loss = torch.stack([x['loss'] for x in outputs]).mean()
        train_dice_epoch = torch.stack([x['train_dice_batch'] for x in outputs]).mean()
        self.log('Training_loss_epoch', current_train_loss)
        self.log('Training_dice_epoch', train_dice_epoch)

    def validation_step(self, batch, batch_idx):
        image, targets = batch
        y_pred = self.forward(image)
        loss = self.loss_function(y_pred, targets)
        valid_dice_batch = self.valid_metric(y_pred, targets)
        self.log('DiceCoeff_Valid', valid_dice_batch)
        self.log('valid_loss_batch', loss)
        return {
            'valid_loss': loss,
            'valid_dice_batch': valid_dice_batch,
        }

    def validation_epoch_end(self, outputs):
        current_val_loss = torch.stack([x['valid_loss'] for x in outputs]).mean()
        valid_dice_epoch = torch.stack([x['valid_dice_batch'] for x in outputs]).mean()
        print(f"Epoch {self.current_epoch}: Dice:{valid_dice_epoch:4f}")
        self.log("Validation_loss_epoch", current_val_loss)
        self.log("Validation_dice_epoch", valid_dice_epoch)
        
    def train_dataloader(self):
        return dls.train
      
    def val_dataloader(self):
        return dls.valid

In [None]:
def run():

    neptune_logger = NeptuneLogger(
        api_key="",
        project_name="",
        experiment_name=f"{cfg.encoder_name}::lr={cfg.learning_rate}::lr_scheduler=CosineAnnealingWarmRestarts::CEDiceLoss",
        tags=[f"{cfg.encoder_name}"],
    )

    check_path = './lightning_checkpoints/'

    checkpointer = ModelCheckpoint(
        monitor = 'Validation_loss_epoch',
        dirpath = check_path,
        filename = f"{cfg.encoder_name}" + "-{epoch:02d}-{Validation_loss_epoch:2f}",
        mode = 'min',
        save_weights_only = True,
        save_top_k = 1,
        verbose = True
    )


    early_stopping = EarlyStopping(
        monitor = 'Validation_loss_epoch',
        patience = 5,
        mode = 'min',
        verbose = True
    )

    learning_rate_monitor = LearningRateMonitor(
        logging_interval = 'epoch'
    )

    callbacks = [
        checkpointer, 
        early_stopping, 
        learning_rate_monitor
        ]

    model = HuBMAP_net(dls,learning_rate=cfg.learning_rate, batch_size = cfg.batch_size, validation_batch_size = int(cfg.batch_size/2))
    trainer = pl.Trainer(
#         logger = neptune_logger,
#         callbacks = callbacks,
#         max_epochs = cfg.epochs,
      max_epochs = 5, #for test run
        progress_bar_refresh_rate = 20,
        gpus = 1,
        accumulate_grad_batches=int(cfg.virtual_batch_size/cfg.batch_size),
        precision = 16,
        move_metrics_to_cpu = True
    )

    trainer.fit(model)

In [None]:
run()