In [3]:
import os
import sys
import random
import warnings

import numpy as np
import pandas as pd
import glob

import matplotlib.pyplot as plt
from skimage.io import imread, imshow
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
# from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Nadam
import tensorflow as tf
import tensorflow.python.keras.backend as k

from skimage.transform import resize

IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 1


#데이터 경로 지정  
TRAIN_PATH = './page_data/train/'
TEST_PATH = './page_data/test/'

# UserWarning을 무시하는 설정(없음 OK)
warnings.filterwarnings('ignore', category=UserWarning, module='skimage')

#이미지 파일명을 리스트 형식으로 리턴  
train_imgs = glob.glob(TRAIN_PATH+'org/*.jpg')
train_masks = glob.glob(TRAIN_PATH+'seg/*.jpg')
test_imgs = glob.glob(TEST_PATH+'org/*.jpg')
test_masks = glob.glob(TEST_PATH+'seg/*.jpg')


#리스트 길이 리턴  
num_of_train_imgs = len(train_imgs)
num_of_train_masks = len(train_masks)
if num_of_train_imgs != num_of_train_masks:
    print('invalid datasets, please check train data')
    
#각이미지를 배열로 리턴
#image 
X_train = np.zeros((num_of_train_imgs, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
# mask
Y_train = np.zeros((num_of_train_masks, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
print('Getting and resizing train images and masks ... ')
sys.stdout.flush()

for n in range(num_of_train_imgs):
    # 흑백으로 불러와서
    img = imread(train_imgs[n],as_gray=True)
    # 정규화하고
    img = (img * 255).astype(np.uint8)
    # 모델에 들어가는 이미지 사이즈로 변경
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    
    X_train[n] = img.reshape(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
    
    mask = imread(train_masks[n],as_gray=True)
    mask_resized = resize(mask, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    Y_train[n] = mask_resized.reshape(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
    
# 테스트 이미지를 배열로 리턴
X_test = np.zeros((len(test_imgs), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
Y_test = np.zeros((len(test_masks), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
print('Getting test images ... ')
sys.stdout.flush()

for n in range(len(test_imgs)):
    img = imread(test_imgs[n],as_gray=True)
    img = (img * 255).astype(np.uint8)
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    X_test[n] = img.reshape(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
    
    mask = imread(test_masks[n],as_gray=True)
    mask_resized = resize(mask, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    Y_test[n] = mask_resized.reshape(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)

print('Preparing is Done!')



Getting and resizing train images and masks ... 


Getting test images ... 
Preparing is Done!


In [7]:
# file_names = [os.path.basename(path) for path in train_imgs]
# output_file_path = './page_data/pages.txt'
# with open(output_file_path, 'w') as file:
#     for name in file_names:
#         file.write(name + '\n')

In [5]:
# 데이터 증강 해볼것이다
import cv2
import numpy as np
import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import random

import kornia.augmentation as KA
import kornia.geometry.transform as KG

class DIW(Dataset):
    def __init__(self, root_dir, is_train=True, num=0):
        super(DIW, self).__init__()
        self.is_train = is_train
        self.num = num
        # load the list of diw images
        with open('./page_data/pages.txt', 'r') as fid:
            self.X = fid.read().splitlines()
        self.X = [root_dir + '/org/' + t + '.jpg' for t in self.X]

        with open('./data/bgtex.txt', 'r') as fid:
            self.bgtex = fid.read().splitlines()

    def __len__(self):
        if self.num:
            return self.num
        else:
            return len(self.X)

    def __getitem__(self, index):
        t = self.X[index]
        im = cv2.imread(t).astype(np.float32) / 255.0
        im = im[..., ::-1]

        t = t.replace('img', 'seg')
        ms = cv2.imread(t).astype(np.float32) / 255.0
        ms = np.mean(ms, axis=2, keepdims=True)

        # random sample a background image
        ind = random.randint(0, len(self.bgtex) - 1)
        bg = cv2.imread(self.bgtex[ind]).astype(np.float32) / 255.0
        bg = cv2.resize(bg, (200, 200))
        bg = np.tile(bg, (3, 3, 1))

        im = torch.from_numpy(im.transpose((2, 0, 1)).copy())
        ms = torch.from_numpy(ms.transpose((2, 0, 1)).copy())
        bg = torch.from_numpy(bg.transpose((2, 0, 1)).copy())

        return im, ms, bg

class DIWDataAug(nn.Module):
    def __init__(self):
        super(DIWDataAug, self).__init__()
        self.cj = KA.ColorJitter(0.1, 0.1, 0.1, 0.1)
    
    def forward(self, img, ms, bg):
        # tight crop
        mask = ms[:, 0] > 0.5
        
        B = img.size(0)
        c = torch.randint(20, (B, 5))
        img_list = []
        msk_list = []
        for ii in range(B):
            x_img = img[ii]
            x_msk = mask[ii]
            y, x = x_msk.nonzero(as_tuple=True)
            minx = x.min()
            maxx = x.max()
            miny = y.min()
            maxy = y.max()
            x_img = x_img[:, miny : maxy + 1, minx : maxx + 1]
            x_msk = x_msk[None, miny : maxy + 1, minx : maxx + 1]

            # padding
            x_img = F.pad(x_img, c[ii, : 4].tolist())
            x_msk = F.pad(x_msk, c[ii, : 4].tolist())

            # replace bg
            if c[ii][-1] > 2:
                x_bg = bg[ii][:, :x_img.size(1), :x_img.size(2)]
            else:
                x_bg = torch.ones_like(x_img) * torch.rand((3, 1, 1), device=x_img.device)
            x_msk = x_msk.float()
            x_img = x_img * x_msk + x_bg * (1. - x_msk)

            # resize
            x_img = KG.resize(x_img[None, :], (256, 256))
            x_msk = KG.resize(x_msk[None, :], (64, 64))
            img_list.append(x_img)
            msk_list.append(x_msk)
        img = torch.cat(img_list)
        msk = torch.cat(msk_list)
        # jitter color
        img = self.cj(img)
        return img, msk


In [None]:
# 데이터셋과 데이터 증강 모듈 초기화
root_dir = './page_data/train/'
dataset = DIW(root_dir, is_train=True)
data_aug = DIWDataAug()

# DataLoader를 통해 데이터셋 로드
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# 데이터 증강을 적용하여 배치 처리
for im, ms, bg in dataloader:
    augmented_im, augmented_ms = data_aug(im, ms, bg)
    # 이제 augmented_im과 augmented_ms를 사용하여 트레이닝 수행