In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from typing import List

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

import albumentations as albu

import torch
import numpy as np
import segmentation_models_pytorch as smp
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

DATA_DIR = "/media/akhan/NVME 1TB/RZD/data/test"

In [None]:
class Dataset(BaseDataset):
        
    def __init__(
            self, 
            images: List[str],
            preprocessing=None,
    ):
        
        self.basewidth = WIDTH
        self.images_list = images
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        image = Image.open(self.images_list[i])
        wpercent = (self.basewidth/float(image.size[0]))
        hsize = int((float(image.size[1])*float(wpercent)))
        image = image.resize((self.basewidth, hsize), resample=Image.ANTIALIAS, reducing_gap=3.)
        image = np.array(image)
        image, remainder = pad(image)
        
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']
            
        return {"image": image, "remainder": remainder}
        
    def __len__(self):
        return len(self.images_list)

In [None]:
ALL_IMG = glob.glob(os.path.join(DATA_DIR, "*.png"))
x_val = np.array(ALL_IMG)

In [None]:
WIDTH = 1024

In [None]:
def pad(img: np.array) -> np.array:
    """
    Pads an array so that it's height is divisible by 32,
    given it's width is divisible by 32 already.

    Parameters
    ----------
    x : array_like
        image (H, W, C)

    Returns
    -------
    tuple
        (padded image (H+p, W, C), int)

    """

    h, w = img.shape[:2]
    assert w % 32 == 0, "Image width must be divisible by 32"
    if h % 32 == 0:
        return (img, 0)

    remainder = 32 - (h + 32) % 32
    img = np.pad(img, ((remainder, 0), (0, 0), (0, 0)))

    return (img, remainder)

def unpad(padded: np.array, remainder: int) -> np.array:

    return padded[remainder:, ...]

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
class rzdModel(pl.LightningModule):

    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # preprocessing parameteres for image
        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))

        # for image segmentation dice loss could be the best first choice
        self.loss_fn = smp.losses.DiceLoss(smp.losses.MULTILABEL_MODE, smooth=0.05)
#         self.loss_fn = smp.losses.TverskyLoss(smp.losses.BINARY_MODE, alpha=0.6, beta=0.4, gamma=1.4)
    
        self.freeze_encoder()
        
    def forward(self, image):
        # normalize image here
#         image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask

    def freeze_encoder(self):
        for child in self.model.encoder.children():
            for param in child.parameters():
                param.requires_grad = False
        return
    
    def shared_step(self, batch, stage):
        
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]
        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
#         assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="multilabel")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }
        
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

In [None]:
ARCH = "UnetPlusPlus"
ENCODER = "resnext50_32x4d"
ENCODER_WEIGHTS = "imagenet"
NUM_CLASS = 3

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
model = rzdModel(ARCH, ENCODER, in_channels=3, out_classes=NUM_CLASS).to(device)

In [None]:
state_dict = torch.load("../dzr/ckpts/multilabel_1024/resnext50_32x4d/epoch=59-step=110760.ckpt")

In [None]:
model.load_state_dict(state_dict["state_dict"])

In [None]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor
    

encoder_params = smp.encoders.get_preprocessing_params(ENCODER)
std = encoder_params['std']
mean = encoder_params['mean']

postprocess_fn = UnNormalize(mean, std)

In [None]:
val_dataset = Dataset(x_val,
                      preprocessing=get_preprocessing(preprocessing_fn))

In [None]:
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
model.to(device)
for batch in val_loader:
    with torch.no_grad():
        model.eval()
        logits = model(batch["image"].to(device))
        remainder  = int(batch["remainder"])
    pr_masks = logits.sigmoid()
    pr_masks = (pr_masks > 0.5).float()
    for pr_mask in pr_masks:
        pr_mask = pr_mask.squeeze().cpu().permute(1, 2, 0).numpy()
        pr_mask = unpad(pr_mask, remainder)
        plt.figure(figsize=(16, 16))
        
        # denormalize image
        image = batch['image']
        image = postprocess_fn(image) * 255
        image = image.squeeze().permute(1,2,0).numpy().astype(np.uint8)
        
        plt.subplot(1, 2, 1)
        plt.imshow(image)  # convert CHW -> HWC
        plt.title("Image")
        plt.axis("off")

        plt.subplot(1, 2, 2)
        plt.imshow(pr_mask[..., 2]) # just squeeze classes dim, because we have only one class
        plt.title("Prediction")
        plt.axis("off")
        plt.show()


In [None]:
!gedit .gitignore

In [None]:
!git status