Skip to content

Commit

Permalink
Merge pull request #12 from ain-soph/master
Browse files Browse the repository at this point in the history
code optimization for insert_reflection.py
  • Loading branch information
DreamtaleCore committed Mar 6, 2022
2 parents f4237d6 + df03096 commit f280a99
Showing 1 changed file with 17 additions and 40 deletions.
57 changes: 17 additions & 40 deletions scripts/insert_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
generate different types of reflection images.
"""
import os
from functools import partial

import cv2
import random
import numpy as np
import scipy.stats as st
from skimage.measure import compare_ssim
from skimage.metrics import structural_similarity
import xml.etree.ElementTree as ET
import tqdm

Expand Down Expand Up @@ -64,7 +63,7 @@ def blend_images(img_t, img_r, max_image_size=560, ghost_rate=0.49, alpha_t=-1.,
if alpha_t < 0:
alpha_t = 1. - random.uniform(0.05, 0.45)

if random.randint(0, 100) < ghost_rate * 100:
if random.random() < ghost_rate:
t = np.power(t, 2.2)
r = np.power(r, 2.2)

Expand All @@ -76,24 +75,19 @@ def blend_images(img_t, img_r, max_image_size=560, ghost_rate=0.49, alpha_t=-1.,
r_2 = np.lib.pad(r, ((offset[0], 0), (offset[1], 0), (0, 0)),
'constant', constant_values=(0, 0))
if ghost_alpha < 0:
ghost_alpha_switch = 1 if random.random() > 0.5 else 0
ghost_alpha = abs(ghost_alpha_switch - random.uniform(0.15, 0.5))
ghost_alpha = abs(round(random.random()) - random.uniform(0.15, 0.5))

ghost_r = r_1 * ghost_alpha + r_2 * (1 - ghost_alpha)
ghost_r = cv2.resize(ghost_r[offset[0]: -offset[0], offset[1]: -offset[1], :], (w, h))
ghost_r = cv2.resize(ghost_r[offset[0]: -offset[0], offset[1]: -offset[1], :],
(w, h), cv2.INTER_CUBIC)
reflection_mask = ghost_r * (1 - alpha_t)

blended = reflection_mask + t * alpha_t

transmission_layer = np.power(t * alpha_t, 1 / 2.2)

ghost_r = np.power(reflection_mask, 1 / 2.2)
ghost_r[ghost_r > 1.] = 1.
ghost_r[ghost_r < 0.] = 0.

blended = np.power(blended, 1 / 2.2)
blended[blended > 1.] = 1.
blended[blended < 0.] = 0.
ghost_r = np.clip(np.power(reflection_mask, 1 / 2.2), 0, 1)
blended = np.clip(np.power(blended, 1 / 2.2), 0, 1)

reflection_layer = np.uint8(ghost_r * 255)
blended = np.uint8(blended * 255)
Expand Down Expand Up @@ -130,7 +124,7 @@ def gen_kernel(kern_len=100, nsig=1):
kernel = kernel / kernel.max()
return kernel

h, w = r_blur.shape[0: 2]
h, w = r_blur.shape[:2]
new_w = np.random.randint(0, max_image_size - w - 10) if w < max_image_size - 10 else 0
new_h = np.random.randint(0, max_image_size - h - 10) if h < max_image_size - 10 else 0

Expand Down Expand Up @@ -179,43 +173,26 @@ def gen_main_func():
print('Gather reflections with class name: ', REFLECT_SEM)
dir_rf = gather_reflection_images()

ssim_func = partial(compare_ssim, multichannel=True)
t_bar = tqdm.tqdm(range(NUM_ATTACK))
bg_pwds = os.listdir(dir_bg)
bg_pwds = [os.path.join(dir_bg, x) for x in bg_pwds]
t_bar = tqdm.tqdm(bg_pwds[:NUM_ATTACK])
t_bar.set_description('Generating: ')

for i in t_bar:
bg_pwd = bg_pwds[i]
rf_id = 0
for bg_pwd in t_bar:
img_bg = cv2.imread(bg_pwd)
while True:
if rf_id >= len(dir_rf):
break
rf_pwd = dir_rf[rf_id]
rf_id = rf_id + 1
for rf_pwd in dir_rf:
img_rf = cv2.imread(rf_pwd)
img_in, img_tr, img_rf = blend_images(img_bg, img_rf, ghost_rate=0.39)
# find a image with reflections with transmission as the primary layer
if np.mean(img_rf) > np.mean(img_in - img_rf) * 0.8:
continue
elif img_in.max() < 0.1 * 255:
continue
else:
if np.mean(img_rf) <= np.mean(img_in - img_rf) * 0.8 and img_in.max() >= 0.1 * 255:
# remove the image-pair which share too similar or distinct outlooks
ssim_diff = np.mean(ssim_func(img_in, img_tr))
if ssim_diff < 0.70 or ssim_diff > 0.85:
continue
else:
if 0.7 < np.mean(structural_similarity(img_in, img_tr, channel_axis=2)) < 0.85):
image_name = '%s+%s' % (os.path.basename(bg_pwd).split('.')[0], os.path.basename(rf_pwd).split('.')[0])
cv2.imwrite(os.path.join(dir_out, '%s-input.jpg' % image_name), img_in)
cv2.imwrite(os.path.join(dir_out, '%s-background.jpg' % image_name), img_tr)
cv2.imwrite(os.path.join(dir_out, '%s-reflection.jpg' % image_name), img_rf)
break

if rf_id >= len(dir_rf):
continue
image_name = '%s+%s' % (os.path.basename(bg_pwd).split('.')[0], os.path.basename(rf_pwd).split('.')[0])
cv2.imwrite(os.path.join(dir_out, '%s-input.jpg' % image_name), img_in)
cv2.imwrite(os.path.join(dir_out, '%s-background.jpg' % image_name), img_tr)
cv2.imwrite(os.path.join(dir_out, '%s-reflection.jpg' % image_name), img_rf)


if __name__ == '__main__':
gen_main_func()
Expand Down

0 comments on commit f280a99

Please sign in to comment.