# Load and augmentation

In [None]:
import os

path_image = '/media/riccelli/Disco 2/datasets_covid/segmentation/all_images_250'
path_mask = '/media/riccelli/Disco 2/datasets_covid/segmentation/all_lung_masks'

paths_image = sorted(os.listdir(path_image))
paths_mask = sorted(os.listdir(path_mask))

paths_image = [os.path.join(path_image,ls) for ls in paths_image]
paths_mask = [os.path.join(path_mask,ls) for ls in paths_mask]

In [None]:
# import wandb
# wandb.login()

In [None]:
from torch.utils.data import Dataset, DataLoader
import cv2
import torch
import random
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

data_transforms = {
    'train': A.Compose([
                A.Resize(224, 224),
                A.HorizontalFlip(p=0.5),
                A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5),
                A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
                A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2(),
    ]),
    'val': A.Compose([
        A.Resize(224, 224), 
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 
        ToTensorV2()
    ]),
}

class LungSegmentationDataset(Dataset):
    def __init__(self, paths_image, paths_mask, transform=None, target_transform=None):
        self.paths_image = paths_image
        self.paths_mask = paths_mask
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.paths_image)

    def __getitem__(self, idx):
        path_image = self.paths_image[idx]
        image = cv2.imread(path_image)
        
        path_mask = self.paths_mask[idx]
        mask = cv2.imread(path_mask,0) / 255
        mask = np.expand_dims(mask, 2)
        
        if self.transform:  
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
            mask = mask.permute(2,0,1)
            
        if self.target_transform:
            label = self.target_transform(label)
        return {'image':image, 'mask':mask}
    
lung_dataset = LungSegmentationDataset(paths_image, paths_mask, data_transforms['train'])
lung_dataloader = DataLoader(lung_dataset, batch_size=16, shuffle=True)

In [None]:
import torchvision
from matplotlib import pyplot as plt
import numpy as np

plt.rcParams['figure.figsize'] = [30, 20]

def imshow_img(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated
    
def imshow_mask(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
batch = next(iter(lung_dataloader))
inputs = batch['image']
masks = batch['mask']

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
out2 = torchvision.utils.make_grid(masks)
print(out.shape)
print(out2.shape)

imshow_img(out)
imshow_mask(out2)

In [None]:
from torch.nn.functional import pairwise_distance

def hausdorff_distance(x, y):
    dist1 = pairwise_distance(x.unsqueeze(1), y.unsqueeze(0))
    dist2 = pairwise_distance(y.unsqueeze(1), x.unsqueeze(0))
    return torch.max(torch.min(dist1, dim=1)[0], torch.min(dist2, dim=1)[0]).max()

# Train U-Net

In [None]:
import segmentation_models_pytorch as smp
import pytorch_lightning as pl

class LungModel(pl.LightningModule):
    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # preprocessing parameteres for image
        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))

        # for image segmentation dice loss could be the best first choice
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
        
        self.save_hyperparameters()

    def forward(self, image):
        # normalize image here
        image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage):
        
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")
        hd = hausdorff_distance(pred_mask.long(), mask.long())

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
            "hd":hd,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])        
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])
        hd = [x["hd"] for x in outputs]
        hd = torch.mean(torch.stack(hd))

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        
        accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
        f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
            f"{stage}_accuracy": accuracy,
            f"{stage}_f1_score": f1_score,
            f"{stage}_hausdorff": hd,
            
        }
        
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0001)

In [None]:
from sklearn.model_selection import train_test_split, KFold
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import time
from pprint import pprint
from random import randint

random_num = randint(0,10000)

kf = KFold(n_splits=10, shuffle=True, random_state=42)

paths_image = np.array(paths_image)
paths_mask = np.array(paths_mask)

fold_metrics = []
training_time = []
testing_time = []

for train_index, test_index in kf.split(paths_image):
    X_train = paths_image[train_index]
    y_train = paths_mask[train_index]
    X_test = paths_image[test_index]
    y_test = paths_mask[test_index]

    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

    train_dataset = LungSegmentationDataset(X_train, y_train, data_transforms['train'])
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=20)

    val_dataset = LungSegmentationDataset(X_val, y_val, data_transforms['val'])
    val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=20)
    
    test_dataset = LungSegmentationDataset(X_test, y_test, data_transforms['val'])
    test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=20)
    
    seg_encoder = 'mobilenet_v2'
    seg_model = 'FPN'

    checkpoint_callback = ModelCheckpoint(
         dirpath='/media/riccelli/Disco 1/datasets_covid/weights/',
         filename=f'lung-paper-{seg_encoder}-{seg_model}-{k}-{random_num}-250',
         monitor='valid_per_image_iou',
         mode="max",
        )
    
    early_stop_callback = EarlyStopping(
        monitor="valid_per_image_iou", 
        patience=5, 
        verbose=False, 
        mode="max"
        )
    
    run_name = f'{seg_encoder}-{seg_model}-fold-{k}-250'
    #wandb_logger = WandbLogger(project="LungSegmentation", name=run_name)

    model = LungModel(seg_model, seg_encoder, in_channels=3, out_classes=1)

    trainer = pl.Trainer(
        gpus=1, 
        max_epochs=30,
        callbacks = [checkpoint_callback, early_stop_callback],
#         logger=wandb_logger,
        accumulate_grad_batches=8
    )
    
    start_time = time.time()
    trainer.fit(
        model, 
        train_dataloaders=train_dataloader, 
        val_dataloaders=val_dataloader,
    )
    training_time.append(time.time()-start_time)
    
     #Load best
    ckpt_path = '/media/riccelli/Disco 1/datasets_covid/weights/' + f'lung-paper-{seg_encoder}-{seg_model}-{k}-{random_num}-250' + '.ckpt'
    print(ckpt_path)
    model = LungModel.load_from_checkpoint(ckpt_path)  
    
    # run validation dataset
    valid_metrics = trainer.validate(model, dataloaders=val_dataloader, verbose=False)
    pprint(valid_metrics)
    
    # run test dataset
    start_time = time.time()
    test_metrics = trainer.test(model, dataloaders=test_dataloader, verbose=False)
    testing_time.append(time.time()-start_time)
    pprint(test_metrics)
    
#     wandb.finish()
    
    fold_metrics.append(test_metrics)
    os.remove(ckpt_path)

In [None]:
[os.path.join(path_image,ls) for ls in paths_image]

dataset_iou = [fold_metrics[i][0]['test_dataset_iou'] for i in range(0,10)]
image_iou = [fold_metrics[i][0]['test_per_image_iou'] for i in range(0,10)]
accuracy = [fold_metrics[i][0]['test_accuracy'] for i in range(0,10)]
f1_score = [fold_metrics[i][0]['test_f1_score'] for i in range(0,10)]
hausdorff = [fold_metrics[i][0]['test_hausdorff'] for i in range(0,10)]

print('======= Mobilenet + FPN Results =======\n')
print(f'Accuracy: {round(np.mean(accuracy)*100,2)} \u00B1 {round(np.std(accuracy)*100,2)}')
print(f'IoU: {round(np.mean(dataset_iou)*100,2)} \u00B1 {round(np.std(dataset_iou)*100,2)}')
print(f'per image IoU: {round(np.mean(image_iou)*100,2)} \u00B1 {round(np.std(image_iou)*100,2)}')
print(f'F1-Score: {round(np.mean(f1_score)*100,2)} \u00B1 {round(np.std(f1_score)*100,2)}')
print(f'Hausdorff: {round(np.mean(hausdorff),2)} \u00B1 {round(np.std(hausdorff),2)}')

In [None]:
import json

# write metrics for generating plots
metrics = {'Accuracy': accuracy, 'F1-Score':f1_score, 'Dataset IoU': dataset_iou, 'Per Image IoU': image_iou,
           'Hausdorff': hausdorff,'Training Time': training_time, 'Testing Time': testing_time}

json.dump(metrics, open(f'{seg_encoder}-{seg_model}-metrics.txt','w'))
d2 = json.load(open(f'{seg_encoder}-{seg_model}-metrics.txt'))

In [None]:
d2['Testing Time']