<a href="https://colab.research.google.com/github/ShimonBezalel/style_gan_prior/blob/master/style_gan_prior.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [25]:
!cp "drive/My Drive/colab/style_gan_prior/perceptual_model.py" .
!cp "drive/My Drive/colab/style_gan_prior/inversion.py" .
!cp "drive/My Drive/colab/style_gan_prior/inpainting.py" .
!cp "drive/My Drive/colab/style_gan_prior/super_resolution.py" .
!cp "drive/My Drive/colab/style_gan_prior/karras2019stylegan-ffhq-1024x1024.pkl" .
!cp -r "drive/My Drive/colab/style_gan_prior/imgs_dir" .

%load_ext autoreload
%autoreload 2


KeyboardInterrupt: ignored

In [None]:
!pip install tensorflow==1.15

In [None]:
# add stylegan to repository and add path to pythonpath
!git clone https://github.com/NVlabs/stylegan.git

import sys
sys.path.insert(1, "stylegan")

In [None]:
import os 
os.path.exists('stylegan')


In [None]:
import perceptual_model
import inversion
import inpainting
import super_resolution
import numpy as np
import pickle
import os
import imageio
from tqdm import tqdm

In [None]:
import numpy as np
import scipy.stats as st
def gen_gaussian_kernel(size=21, nsig=3):
    """Returns a 2D Gaussian kernel."""
    x = np.linspace(-nsig, nsig, size+1)
    kern1d = np.diff(st.norm.cdf(x))
    kern2d = np.outer(kern1d, kern1d)
    return (kern2d/kern2d.sum()).astype(np.float32)



In [None]:
gen_gaussian_kernel(21, 1)

In [None]:
np.sum(gen_gaussian_kernel(21, 1))


### Code

In [None]:
"Style Image Prior for Inpainting"
"""inpainting.py --imgs-dir <input-imgs-dir> --masks-dir <output-masks-dir>
    --corruptions-dir <output-corruptions-dir> --restorations-dir <output-restorations-dir>
    --latents-dir <output-latents-dir>
    [--input-img-size INPUT_IMG_HEIGHT INPUT_IMG_WIDTH]
    [--perceptual-img-size EFFECTIVE_IMG_HEIGHT EFFECTIVE_IMG_WIDTH]
    [--mask-size MASK_HEIGHT MASK_WIDTH]
    [--learning-rate LEARNING_RATE]
    [--total-iterations TOTAL_ITERATIONS]"""

import numpy as np
import tensorflow as tf
import cv2

import dnnlib
import dnnlib.tflib as tflib
import config

from perceptual_model import PerceptualModel

STYLEGAN_MODEL_URL = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ'



def generate_random_mask(img_shape, mask_size):
    mask_2d = np.ones(img_shape, dtype=np.uint8)

    vq = img_shape[0] // 4
    top = np.random.randint(low=vq, high=3 * vq - mask_size[0])

    hq = img_shape[1] // 4
    left = np.random.randint(low=hq, high=3 * hq - mask_size[1])

    mask_2d[top:top + mask_size[0], left:left + mask_size[1]] = 0

    return mask_2d[..., np.newaxis]


def optimize_latent_codes(args):
    tflib.init_tf()

    with open('karras2019stylegan-ffhq-1024x1024.pkl', "rb") as f:
        _G, _D, Gs = pickle.load(f)
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        latent_code = tf.get_variable(
        name='latent_code', shape=(1, 18, 512), dtype='float32', initializer=tf.initializers.zeros())

    generated_img = Gs.components.synthesis.get_output_for(latent_code, randomize_noise=False)
    generated_img = tf.transpose(generated_img, [0, 2, 3, 1])
    generated_img = ((generated_img + 1) / 2) * 255

    original_img = tf.placeholder(tf.float32, [None, args.input_img_size[0], args.input_img_size[1], 3])
    degradation_mask = tf.placeholder(tf.float32, [None, args.input_img_size[0], args.input_img_size[1], 1])

    degraded_img_resized_for_perceptual = tf.image.resize_images(
        original_img * degradation_mask, tuple(args.perceptual_img_size), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )

    generated_img_resized_to_original = tf.image.resize_images(
        generated_img, tuple(args.input_img_size), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )

    generated_img_resized_for_perceptual = tf.image.resize_images(
        generated_img_resized_to_original * degradation_mask, tuple(args.perceptual_img_size), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )

    generated_img_for_display = tf.saturate_cast(generated_img_resized_to_original, tf.uint8)

    perceptual_model = PerceptualModel(img_size=args.perceptual_img_size)
    print("generated_img_resized_for_perceptual:", generated_img_resized_for_perceptual.shape)
    generated_img_features = perceptual_model(generated_img_resized_for_perceptual)
    print("degraded_img_resized_for_perceptual:",degraded_img_resized_for_perceptual)
    target_img_features = perceptual_model(degraded_img_resized_for_perceptual)

    loss_op = tf.reduce_mean(tf.abs(generated_img_features - target_img_features))

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
    train_op = optimizer.minimize(loss_op, var_list=[latent_code])

    sess = tf.get_default_session()

    img_names = sorted(os.listdir(args.imgs_dir))
    for img_name in img_names:
        img = imageio.imread(os.path.join(args.imgs_dir, img_name))
        img = cv2.resize(img, dsize=tuple(args.input_img_size))
        mask = generate_random_mask(img.shape[:2], mask_size=args.mask_size)

        corrupted_img = img * mask

        imageio.imwrite(os.path.join(args.corruptions_dir, img_name), corrupted_img)
        imageio.imwrite(os.path.join(args.masks_dir, img_name), mask * 255)

        sess.run(tf.variables_initializer([latent_code] + optimizer.variables()))

        progress_bar_iterator = tqdm(
            iterable=range(args.total_iterations),
            bar_format='{desc}: {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}{postfix}',
            desc=img_name
        )

        for i in progress_bar_iterator:
            loss, _ = sess.run(
                fetches=[loss_op, train_op],
                feed_dict={
                    original_img: img[np.newaxis, ...],
                    degradation_mask: mask[np.newaxis, ...]
                }
            )

            progress_bar_iterator.set_postfix_str('loss=%.2f' % loss)
        reconstructed_imgs, latent_codes = sess.run(
            fetches=[generated_img_for_display, latent_code],
            feed_dict={
                original_img: img[np.newaxis, ...],
                degradation_mask: mask[np.newaxis, ...]
            }
        )

        imageio.imwrite(os.path.join(args.restorations_dir, img_name), reconstructed_imgs[0])
        np.savez(file=os.path.join(args.latents_dir, img_name + '.npz'), latent_code=latent_codes[0])





In [None]:
def conv_image(image, kernel_2d):
    kernel_2d = tf.squeeze(kernel_2d)
    image = tf.squeeze(image)
    print("iamge shape: ", image.shape)
    print("kernal shape: ", kernel_2d.shape)
    gauss_kernel = tf.tile(kernel_2d[:, :, tf.newaxis, tf.newaxis], [1, 1, 3, 1]) # 5*5*3*1

    # Pointwise filter that does nothing
    pointwise_filter = tf.eye(3, batch_shape=[1, 1])
    image = tf.expand_dims(image, 0)
    image = tf.nn.separable_conv2d(image, gauss_kernel, pointwise_filter,
                                  strides=[1, 1, 1, 1], padding='SAME')
    image = tf.squeeze(image)
    return image 



def optimize_latent_for_deblur(args):
    tflib.init_tf()

    with open('karras2019stylegan-ffhq-1024x1024.pkl', "rb") as f:
        _G, _D, Gs = pickle.load(f)
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        latent_code = tf.get_variable(
        name='latent_code', shape=(1, 18, 512), dtype='float32', initializer=tf.initializers.zeros())

    generated_img = Gs.components.synthesis.get_output_for(latent_code, randomize_noise=False)
    generated_img = tf.transpose(generated_img, [0, 2, 3, 1])
    generated_img = ((generated_img + 1) / 2) * 255

    original_img = tf.placeholder(tf.float32, [ args.input_img_size[0], args.input_img_size[1], 3])
    degradation_mask = tf.placeholder(tf.float32, [ args.mask_size[0], args.mask_size[1], 1])
    print(degradation_mask.shape)
    print(original_img.shape)
    degraded_img_resized_for_perceptual = tf.image.resize_images(
        conv_image(original_img , degradation_mask), tuple(args.perceptual_img_size), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    generated_img_resized_to_original = tf.image.resize_images(
        generated_img, tuple(args.input_img_size), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )


    
    generated_img_resized_for_perceptual = tf.image.resize_images(
         conv_image(generated_img_resized_to_original,degradation_mask), tuple(args.perceptual_img_size), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )

    generated_img_for_display = tf.saturate_cast(generated_img_resized_to_original, tf.uint8)

    perceptual_model = PerceptualModel(img_size=args.perceptual_img_size)
    generated_img_resized_for_perceptual = generated_img_resized_for_perceptual[tf.newaxis, ...]
    print("generated_img_resized_for_perceptual:", generated_img_resized_for_perceptual.shape)
    generated_img_features = perceptual_model(generated_img_resized_for_perceptual)
    print("degraded_img_resized_for_perceptual:",degraded_img_resized_for_perceptual)
    target_img_features = perceptual_model(degraded_img_resized_for_perceptual[tf.newaxis, ...])

    loss_op = tf.reduce_mean(tf.abs(generated_img_features - target_img_features))

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
    train_op = optimizer.minimize(loss_op, var_list=[latent_code])

    sess = tf.get_default_session()

    img_names = sorted(os.listdir(args.imgs_dir))
    for img_name in img_names:
        img = imageio.imread(os.path.join(args.imgs_dir, img_name))
        img = cv2.resize(img, dsize=tuple(args.input_img_size))
        mask = args.kernel
        img = img.astype(np.float32)
        corrupted_img = conv_image(img , mask).eval()

        imageio.imwrite(os.path.join(args.corruptions_dir, img_name), corrupted_img)
        imageio.imwrite(os.path.join(args.masks_dir, img_name), mask * 255)

        sess.run(tf.variables_initializer([latent_code] + optimizer.variables()))

        progress_bar_iterator = tqdm(
            iterable=range(args.total_iterations),
            bar_format='{desc}: {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}{postfix}',
            desc=img_name
        )

        for i in progress_bar_iterator:
            loss, _ = sess.run(
                fetches=[loss_op, train_op],
                feed_dict={
                    original_img: img,
                    degradation_mask: mask[..., np.newaxis]
                }
            )

            progress_bar_iterator.set_postfix_str('loss=%.2f' % loss)
        reconstructed_imgs, latent_codes = sess.run(
            fetches=[generated_img_for_display, latent_code],
            feed_dict={
                original_img: img,
                degradation_mask: mask[..., np.newaxis]
            }
        )

        imageio.imwrite(os.path.join(args.restorations_dir, img_name), reconstructed_imgs[0])
        np.savez(file=os.path.join(args.latents_dir, img_name + '.npz'), latent_code=latent_codes[0])

In [None]:
A = np.array([[[1,1,1],[1,1,1],[1,1,1]],[[1,1,1],[1,1,1],[1,1,1]],[[1,1,1],[1,1,1],[1,1,1]]])
B = np.array([[2,2,],[2,2]])

kernel_2d = tf.constant(B, dtype=tf.float32)
image = tf.constant(A, dtype=tf.float32)
# print(kernel_2d.shape)
# print(image.shape)
# gauss_kernel = tf.tile(kernel_2d[:, :, tf.newaxis, tf.newaxis], [1, 1, 3, 1]) # 5*5*3*1

# # Pointwise filter that does nothing
# pointwise_filter = tf.eye(3, batch_shape=[1, 1])
# image = tf.expand_dims(image, 0)
# print(image.shape)
# print(gauss_kernel.shape)
# image = tf.nn.separable_conv2d(image, gauss_kernel, pointwise_filter,
#                                strides=[1, 1, 1, 1], padding='SAME')
# image = tf.squeeze(image) 
# print(image.shape)


# image.eval()
print(image.shape)
print(tf.expand_dims(image, 0).shape)



In [None]:

class HackArgs():
  def __init__(self):
    os.makedirs(self.masks_dir, exist_ok=True)
    os.makedirs(self.corruptions_dir, exist_ok=True)
    os.makedirs(self.restorations_dir, exist_ok=True)
    os.makedirs(self.latents_dir, exist_ok=True)
  input_img_size = (256, 256)
  perceptual_img_size= (256, 256)
  mask_size = (64,64)
  learning_rate = 1e-2
  total_iterations = 1000
  mask_size = (5,5)

  imgs_dir= 'imgs_dir'
  masks_dir = 'masks_dir'
  corruptions_dir= 'corruptions_dir'
  restorations_dir= 'restorations_dir' 
  latents_dir= 'latents_dir' 
  kernel = gen_gaussian_kernel(21, 1)

args = HackArgs()
optimize_latent_for_deblur(args)

In [None]:
from skimage import color
from skimage.transform import resize
def im_to_kernel(path, size=21):
    kernel = plt.imread(path)
    kernel = color.rgb2gray(kernel)
    kernel = np.squeeze(kernel)
    assert len(kernel.shape) == 2

    resized = resize(kernel, (size, size), mode='constant')

    normalized = resized / np.max(resized)

    kernel = normalized / np.sum(normalized)

    assert np.abs(np.sum(kernel) - 1.0) < 0.01

    return kernel