In [3]:
# !pip install -U segmentation-models-pytorch albumentations --user
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt
import linzhutils as lu
from segdataset import Dataset
from torch.utils.data import DataLoader
import albumentations as albu
import segmentation_models_pytorch as smp
import torch

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

[H[2J

In [None]:
DATA_DIR = '../data/hel2019/'
val_img_dir = os.path.join(DATA_DIR, 'val_images')
val_anno_dir = os.path.join(DATA_DIR, 'val_bin_masks')


def get_training_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5,
                              rotate_limit=0,
                              shift_limit=0.1,
                              p=1,
                              border_mode=0),
        albu.PadIfNeeded(min_height=512,
                         min_width=512,
                         always_apply=True,
                         border_mode=0),
        albu.RandomCrop(height=512, width=512, always_apply=True),
        albu.GaussNoise(p=0.2),
        albu.Perspective(p=0.5),
        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightnessContrast(p=1),
                albu.RandomGamma(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.Sharpen(p=1),
                # albu.Blur(blur_limit=2, p=1),
                # albu.MotionBlur(blur_limit=2, p=1),
            ],
            p=0.5,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [albu.PadIfNeeded(512, 512)]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
from segmentation_models_pytorch import utils as smpu

loss = smp.losses.DiceLoss('binary')
loss.__name__ = 'focal_loss'
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.0001),
])

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
# load best saved checkpoint
best_model = torch.load('./best_model.pth')

In [None]:
# create test dataset
test_dataset = Dataset(
    val_img_dir,
    val_anno_dir,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

In [None]:
# test dataset without transformations for image visualization
test_dataset_vis = Dataset(
    val_img_dir,
    val_anno_dir,
    classes=CLASSES,
)


# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(test_dataloader)

In [None]:
for i in range(20):
    n = np.random.choice(len(test_dataset))

    image_vis = test_dataset_vis[n][0].astype('uint8')
    image, gt_mask = test_dataset[n]

    gt_mask = gt_mask.squeeze()

    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())

    visualize(image=image_vis,
              ground_truth_mask=gt_mask,
              predicted_mask=pr_mask)
