### DataLoader 상의 이미지 확인하기
#### Augmentation 결과 확인

In [86]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image, ImageOps
from shapely.geometry import Polygon
import albumentations as A
from IPython.display import Image
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt

import os
import sys
import cv2
import json
import torch
import numpy as np
import os.path as osp
import matplotlib.patches as patches
# 상위 디렉토리를 import하기 위해, code 폴더를 시스템경로에 추가
# sys.path.append('/opt/ml/code/')

from dataset import *
from east_dataset import *
from dataset import SceneTextDataset
from east_dataset import EASTDataset

In [206]:
class TestAug(SceneTextDataset):
    def __init__(self):
        super().__init__('/opt/ml/input/data/ICDAR17_All/')

    def __getitem__(self, idx):
        image_fname = self.image_fnames[idx]
        image_fpath = osp.join(self.image_dir, image_fname)

        vertices, labels = [], []
        for word_info in self.anno['images'][image_fname]['words'].values():
            if len(word_info['points']) == 4: 
                vertices.append(np.array(np.round(word_info['points'])).flatten())
                labels.append(int(not word_info['illegibility']))

        vertices, labels = np.array(vertices, dtype=np.float32), np.array(labels, dtype=np.int64)
        vertices, labels = filter_vertices(vertices, labels, ignore_under=10, drop_under=1)

        image = Image.open(image_fpath)
        image = ImageOps.exif_transpose(image)
        
        ### 수정가능영역 시작 ###
        image, vertices = resize_img(image, vertices, self.image_size)
        image, vertices = adjust_height(image, vertices)
        image, vertices = rotate_img(image, vertices)
        image, vertices = crop_img(image, vertices, labels, self.crop_size)

        if image.mode != 'RGB':
            image = image.convert('RGB')
        image = np.array(image)

        funcs = []
        funcs.append(A.OneOf([
            A.RandomFog(fog_coef_lower=0.2, fog_coef_upper=0.3,
                        alpha_coef=0.08, always_apply=False, p=0.001),
            A.RandomRain(p=0.5),
            A.RandomShadow(p=0.001),
            A.RandomSunFlare(p=0.001)
        ]))
        
        if self.color_jitter:
            funcs.append(A.ColorJitter(0.5, 0.5, 0.5, 0.25))
        if self.normalize:
            funcs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
        transform = A.Compose(funcs)
        
        ### 수정가능영역 종료###

        image = transform(image=image)['image']
        word_bboxes = np.reshape(vertices, (-1, 4, 2))
        roi_mask = generate_roi_mask(image, vertices, labels)

        return image, word_bboxes, roi_mask, image_fname
        

In [207]:
class EAST_4Viz(EASTDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def __getitem__(self, idx):
            image, word_bboxes, roi_mask, file_name = self.dataset[idx]
            score_map, geo_map = generate_score_geo_maps(image, word_bboxes, map_scale=self.map_scale)

            mask_size = int(image.shape[0] * self.map_scale), int(image.shape[1] * self.map_scale)
            roi_mask = cv2.resize(roi_mask, dsize=mask_size)
            if roi_mask.ndim == 2:
                roi_mask = np.expand_dims(roi_mask, axis=2)

            if self.to_tensor:
                image = torch.Tensor(image).permute(2, 0, 1)
                score_map = torch.Tensor(score_map).permute(2, 0, 1)
                geo_map = torch.Tensor(geo_map).permute(2, 0, 1)
                roi_mask = torch.Tensor(roi_mask).permute(2, 0, 1)

            return image, score_map, geo_map, roi_mask, file_name, word_bboxes


In [208]:
sample_num = 1

dataset = TestAug()
visual_dataset = EAST_4Viz(dataset)
visual_dataloader = DataLoader(visual_dataset, batch_size=sample_num, shuffle=True)

In [None]:
# 실행할 때마다 랜덤한 이미지를 보여줌
# 이미지마다 bbox의 개수가 달라서, 한 번에 한 장씩만 시각화 가능

fig = plt.figure(figsize=(20, 20))
    
for img, gt_score_map, gt_geo_map, roi_mask, file_name, word_bboxes in visual_dataloader:
    for i in range(sample_num):
        ax = fig.add_subplot(1, 1, i+1)
        imgs = to_pil_image(img[i])
        ax.imshow(imgs)
        
        for bbox in word_bboxes[i]:
            poly = patches.Polygon(bbox, closed=True, edgecolor='black', fill=False, lw=5)
            ax.add_patch(poly)
    break
plt.show()