## 필요 모듈 설치

In [None]:
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch

In [2]:
import torch
import numpy as np

import os
import cv2
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

import segmentation_models_pytorch as smp
import albumentations as albu

## 경로 설정

In [5]:
segmentation_path = '../input/lv2-dataset/SIA_pytorch/segmentation_models'

best_model = torch.load('../input/lv2-dataset/best_model.pth')

x_valid_dir = '../input/lv2-dataset/LV2_validation_set/images'
y_valid_dir = '../input/lv2-dataset/LV2_validation_set/labels'

## 데이터로더 정의

In [6]:
class Dataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['building', 'road']
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [200, 255]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']  

        return image, mask
        
    def __len__(self):
        return len(self.ids)

In [7]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [8]:
ENCODER = 'efficientnet-b7'
ENCODER_WEIGHTS = 'imagenet'
DEVICE = 'cuda'
CLASSES = ['building', 'road']

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [11]:
valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

In [12]:
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

In [13]:
valid_epoch = smp.utils.train.ValidEpoch(
    model=best_model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

valid_dataset_vis = Dataset(
    x_valid_dir, y_valid_dir, 
    classes=CLASSES,
)

## 모델 검증

In [None]:
start = torch.cuda.Event(enable_timing=True) 
end = torch.cuda.Event(enable_timing=True)
for i in range(2):
    start.record()
    valid_logs = valid_epoch.run(valid_loader)
    end.record()
    torch.cuda.synchronize()

    print(f'FPS : {(start.elapsed_time(end)/len(valid_loader))/1000}')

## 이미지 시각화

In [15]:
def combine_masks(masks):
  # masks should be size (channels, w, h)
  output_mask = np.zeros(masks[0].shape, dtype=np.uint8)

  for i, mask in enumerate(masks):
    output_mask[mask==1] = i + 1

  return output_mask

In [16]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [17]:
for i in range(10):
    n = np.random.choice(len(valid_dataset))
    
    image_vis = valid_dataset_vis[n][0].astype('uint8')
    image, gt_mask = valid_dataset[n]
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
        
    visualize(
        image=image_vis, 
        gt_mask=combine_masks(gt_mask),
        pr_mask=combine_masks(pr_mask),
    )