In [None]:
import math
import os
import random
from random import shuffle

In [None]:
import cv2
import numpy as np
from keras.utils import Sequence

In [None]:
import import_ipynb

In [None]:
from config import batch_size
from config import fg_path, bg_path, a_path
from config import img_cols, img_rows
from config import unknown_code
from utils import safe_crop

In [None]:
kernel = cv2.getStructuringElement(cv.MORPH_ELLIPSE, (3, 3))
with open('data/Combined_Dataset/Training_set/training_fg_names.txt') as f:
    fg_files = f.read().splitlines()
with open('data/Combined_Dataset/Test_set/test_fg_names.txt') as f:
    fg_test_files = f.read().splitlines()
with open('data/Combined_Dataset/Training_set/training_bg_names.txt') as f:
    bg_files = f.read().splitlines()
with open('data/Combined_Dataset/Test_set/test_bg_names.txt') as f:
    bg_test_files = f.read().splitlines()

In [None]:
def get_alpha(name):
    filename = '/content/gdrive/My Drive/DIM/alpha/{}.jpg'.format(name)
    alpha = cv2.imread(filename, 0)
    return alpha

In [None]:
def composite4(fg, bg, a, w, h):
    fg = np.array(fg, np.float32)
    bg_h, bg_w = bg.shape[:2]
    x = 0
    if bg_w > w:
        x = np.random.randint(0, bg_w - w)
    y = 0
    if bg_h > h:
        y = np.random.randint(0, bg_h - h)
    bg = np.array(bg[y:y + h, x:x + w], np.float32)
    alpha = np.zeros((h, w, 1), np.float32)
    alpha[:, :, 0] = a / 255.
    im = alpha * fg + (1 - alpha) * bg
    im = im.astype(np.uint8)
    return im, a, fg, bg

In [None]:
def process(im_name, bg_name):
    im = cv2.imread('{}/{}.jpg'.format(fg_path,im_name))
    im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
    a = cv2.imread('{}/{}.jpg'.format(a_path,im_name), 0)
    h, w = im.shape[:2]
    bg = cv2.imread('{}/{}.jpg'.format(bg_path,bg_name))
    bh, bw = bg.shape[:2]
    wratio = w / bw
    hratio = h / bh
    ratio = wratio if wratio > hratio else hratio
    if ratio > 1:
        bg = cv2.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv2.INTER_CUBIC)

    return composite4(im, bg, a, w, h)

In [None]:
def generate_trimap(alpha):
    fg = np.array(np.equal(alpha, 255).astype(np.float32))
    # fg = cv.erode(fg, kernel, iterations=np.random.randint(1, 3))
    unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
    unknown = cv2.dilate(unknown, kernel, iterations=np.random.randint(1, 20))
    trimap = fg * 255 + (unknown - fg) * 128
    return trimap.astype(np.uint8)

In [None]:
def random_choice(trimap, crop_size=(320, 320)):
    crop_height, crop_width = crop_size
    y_indices, x_indices = np.where(trimap == unknown_code)
    num_unknowns = len(y_indices)
    x, y = 0, 0
    if num_unknowns > 0:
        ix = np.random.choice(range(num_unknowns))
        center_x = x_indices[ix]
        center_y = y_indices[ix]
        x = max(0, center_x - int(crop_width / 2))
        y = max(0, center_y - int(crop_height / 2))
    return x, y

In [None]:
class DataGenSequence(Sequence):
    def __init__(self, usage):
        self.usage = usage

        filename = '/content/gdrive/My Drive/DIM/{}.txt'.format(usage)
        with open(filename, 'r') as f:
            self.names = f.read().splitlines()

        np.random.shuffle(self.names)

    def __len__(self):
        return int(np.ceil(len(self.names) / float(batch_size)))

    def __getitem__(self, idx):
        i = idx * batch_size

        length = min(batch_size, (len(self.names) - i))
        batch_x = np.empty((length, img_rows, img_cols, 4), dtype=np.float32)
        batch_y = np.empty((length, img_rows, img_cols, 11), dtype=np.float32)

        for i_batch in range(length):
            name = self.names[i]
            im_name = name.split('_')[0]
            bg_name = name.split('_')[1]
            image, alpha, fg, bg = process(im_name, bg_name)

            # crop size 320:640:480 = 1:1:1
            different_sizes = [(320, 320), (480, 480), (640, 640)]
            crop_size = random.choice(different_sizes)

            trimap = generate_trimap(alpha)
            x, y = random_choice(trimap, crop_size)
            image = safe_crop(image, x, y, crop_size)
            alpha = safe_crop(alpha, x, y, crop_size)
            fg = safe_crop(fg, x, y, crop_size)
            bg = safe_crop(bg, x, y, crop_size)

            trimap = generate_trimap(alpha)

            # Flip array left to right randomly (prob=1:1)
            if np.random.random_sample() > 0.5:
                image = np.fliplr(image)
                trimap = np.fliplr(trimap)
                alpha = np.fliplr(alpha)
                fg = np.fliplr(fg)
                bg = np.fliplr(bg)

            batch_x[i_batch, :, :, 0:3] = image / 255.
            batch_x[i_batch, :, :, 3] = trimap / 255.

            mask = np.equal(trimap, 128).astype(np.float32)
            batch_y[i_batch, :, :, 0] = alpha / 255.
            batch_y[i_batch, :, :, 1] = mask
            batch_y[i_batch, :, :, 2:5] = image / 255.
            batch_y[i_batch, :, :, 5:8] = fg / 255.
            batch_y[i_batch, :, :, 8:11] = bg / 255.

            i += 1

        return batch_x, batch_y

    def on_epoch_end(self):
        np.random.shuffle(self.names)

In [None]:
def train_gen(name='train'):
    return DataGenSequence(name)

def valid_gen(name='valid'):
    return DataGenSequence(name)