In [86]:
import math
from typing import List, Optional, Sequence, Tuple

import torch
from torch import nn
from torch.nn import functional as F

%load_ext autoreload
%autoreload 2
import fold
import gpnn
import image as image_utils
import new_utils as utils
import resize_right

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [93]:
image = utils.imread('balloons.png')
small_image = resize_right.resize(image, (1, 0.75, 0.75))
image_utils.imshow(image)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xf8\x00\x00\x00\xba\x08\x02\x00\x00\x00\x0c#\x17…

In [5]:
def image_generation(image: torch.Tensor, noise_std: float = 0.75, alpha: float = float('inf'),
                     patch_size: int = 7, reduce: str = 'weighted_mean', 
                     downscale_ratio: float = 0.75, num_levels: int = 9) -> torch.Tensor:
    pyramid = gpnn.make_pyramid(image, num_levels, downscale_ratio)
    initial_guess = pyramid[-1] + noise_std * torch.randn_like(pyramid[-1])
    return gpnn.gpnn(pyramid, initial_guess, alpha=alpha, downscale_ratio=downscale_ratio,
                     patch_size=patch_size, reduce=reduce)

In [143]:
def image_editing(source_image: torch.Tensor, edited_image: torch.Tensor,
                  alpha: float = float('inf'), patch_size: int = 7,
                  reduce: str = 'weighted_mean', 
                  downscale_ratio: float = 0.75, num_levels: int = 5) -> torch.Tensor:
    source_pyramid = gpnn.make_pyramid(source_image, num_levels, downscale_ratio)
    edited_pyramid = gpnn.make_pyramid(edited_image, num_levels, downscale_ratio)
    initial_guess = edited_pyramid[-1]
    return gpnn.gpnn(source_pyramid, initial_guess, alpha=alpha, downscale_ratio=downscale_ratio,
                     patch_size=patch_size, reduce=reduce)

In [145]:
source_image = utils.imread('stone.png')
edited_image = utils.imread('stone_edit.png')

image_utils.imshow(source_image)
image_utils.imshow(edited_image)
output = image_editing(source_image, edited_image, num_levels=5)

# analogy_ab = structural_analogy(birds_a, birds_b, num_levels=5, alpha=5e-3)
# analogy_ba = structural_analogy(birds_b, birds_a, num_levels=5, alpha=5e-5)

image_utils.imshow(output)
# image_utils.imshow(analogy_ba)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xfa\x00\x00\x00\xa8\x08\x02\x00\x00\x00F\xc8d\x9…

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xfa\x00\x00\x00\xa8\x08\x02\x00\x00\x00F\xc8d\x9…

initial guess shape: torch.Size([3, 40, 60])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00<\x00\x00\x00(\x08\x02\x00\x00\x00-\xd9\x0e\xa8\x…

level: 5
corasest iteration: 0
output shape: torch.Size([3, 40, 60])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00<\x00\x00\x00(\x08\x02\x00\x00\x00-\xd9\x0e\xa8\x…

level: 4
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 54, 80])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00P\x00\x00\x006\x08\x02\x00\x00\x00\xdf\xde\x89\xb…

level: 3
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 71, 106])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00j\x00\x00\x00G\x08\x02\x00\x00\x00`qE(\x00\x008\x…

level: 2
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 95, 141])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x8d\x00\x00\x00_\x08\x02\x00\x00\x00Z\xca\xd3\xd…

level: 1
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 126, 188])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xbc\x00\x00\x00~\x08\x02\x00\x00\x00\x11\x0b\xec…

level: 0
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 168, 250])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xfa\x00\x00\x00\xa8\x08\x02\x00\x00\x00F\xc8d\x9…

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xfa\x00\x00\x00\xa8\x08\x02\x00\x00\x00F\xc8d\x9…

In [128]:
def conditional_inpainting(masked_image: torch.Tensor, mask: torch.Tensor,
                           alpha: float = float('inf'), patch_size: int = 7,
                           reduce: str = 'weighted_mean', 
                           downscale_ratio: float = 0.75, num_levels: int = 5) -> torch.Tensor:
    pyramid = gpnn.make_pyramid(masked_image, num_levels, downscale_ratio)
    mask_pyramid = gpnn.make_pyramid(mask.to(masked_image), num_levels, downscale_ratio)
    initial_guess = pyramid[-1]
    return gpnn.gpnn(pyramid, initial_guess, mask_pyramid=mask_pyramid, alpha=alpha,
                     downscale_ratio=downscale_ratio, patch_size=patch_size, reduce=reduce)


In [131]:
def structural_analogy(source_image: torch.Tensor, structure_image: torch.Tensor,
                       alpha: float = 5e-3, patch_size: int = 7,
                       reduce: str = 'weighted_mean', 
                       downscale_ratio: float = 0.75, num_levels: int = 5) -> torch.Tensor:
    source_pyramid = gpnn.make_pyramid(source_image, num_levels, downscale_ratio)
    structure_pyramid = gpnn.make_pyramid(structure_image, num_levels, downscale_ratio)
    output_pyramid_shape = [x.shape for x in structure_pyramid]
    initial_guess = structure_pyramid[-1]
    return gpnn.gpnn(source_pyramid, initial_guess, output_pyramid_shape=output_pyramid_shape, alpha=alpha,
                     downscale_ratio=downscale_ratio, patch_size=patch_size, reduce=reduce)

In [141]:
birds_a = utils.imread('snow_real_a.jpeg')
birds_b = utils.imread('snow_real_b.jpeg')

image_utils.imshow(birds_a)
image_utils.imshow(birds_b)

analogy_ab = structural_analogy(birds_a, birds_b, num_levels=5, alpha=5e-3)
# analogy_ba = structural_analogy(birds_b, birds_a, num_levels=5, alpha=5e-5)

image_utils.imshow(analogy_ab)
# image_utils.imshow(analogy_ba)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xdc\x00\x00\x00\xdc\x08\x02\x00\x00\x00\x948X\xd…

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xdc\x00\x00\x00\xdc\x08\x02\x00\x00\x00\x948X\xd…

initial guess shape: torch.Size([3, 53, 53])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x005\x00\x00\x005\x08\x02\x00\x00\x00n\x844\'\x00\x0…

level: 5
corasest iteration: 0
output shape: torch.Size([3, 53, 53])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x005\x00\x00\x005\x08\x02\x00\x00\x00n\x844\'\x00\x0…

level: 4
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 70, 70])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00F\x00\x00\x00F\x08\x02\x00\x00\x00\xfeLu\xd3\x00\…

level: 3
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 93, 93])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00]\x00\x00\x00]\x08\x02\x00\x00\x00H\x13\xfd\x94\x…

level: 2
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 124, 124])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00|\x00\x00\x00|\x08\x02\x00\x00\x00$|C\xe9\x00\x00…

level: 1
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5


KeyboardInterrupt: 

In [140]:
balloons_in_blue = utils.imread('balloons_in_blue.png')
color = torch.tensor([0, 1.0, 1.0]).reshape(3, 1, 1)
mask = ((balloons_in_blue == color).to(torch.int).sum(dim=0, keepdim=True) == 3).to(balloons_in_blue).repeat((3, 1, 1))
image_utils.imshow(balloons_in_blue)
image_utils.imshow(mask)

inpainted = conditional_inpainting(
    balloons_in_blue,
    1 - mask[:1],
    num_levels=5,
)
image_utils.imshow(inpainted)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xdc\x00\x00\x00\xdc\x08\x02\x00\x00\x00\x948X\xd…

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xdc\x00\x00\x00\xdc\x08\x02\x00\x00\x00\x948X\xd…

initial guess shape: torch.Size([3, 53, 53])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x005\x00\x00\x005\x08\x02\x00\x00\x00n\x844\'\x00\x0…

level: 5
corasest iteration: 0
output shape: torch.Size([3, 53, 53])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x005\x00\x00\x005\x08\x02\x00\x00\x00n\x844\'\x00\x0…

level: 4
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 70, 70])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00F\x00\x00\x00F\x08\x02\x00\x00\x00\xfeLu\xd3\x00\…

level: 3
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
output shape: torch.Size([3, 93, 93])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00]\x00\x00\x00]\x08\x02\x00\x00\x00H\x13\xfd\x94\x…

level: 2
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6


KeyboardInterrupt: 

In [114]:
mask[:1].unsqueeze(0).shape

torch.Size([1, 1, 220, 220])

In [94]:
def _get_retargeting_step_size(retargeting_ratio: Tuple[float, float],
                               gradual: bool) -> Tuple[int, Tuple[float, float]]:
    if gradual:
        raise ValueError('gradual=True is not supported.')
        # step_size = [0.9, 0.9]
        # if retargeting_ratio[0] >= 1:
        #     step_size[0] = 1.1
        # if retargeting_ratio[1] >= 1:
        #     step_size[0] = 1.1
        # num_steps = math.floor(max(math.log(retargeting_ratio[0])/math.log(step_size[0]), 
        #                         math.log(retargeting_ratio[1])/math.log(step_size[1])))
        # step_size[0] = 10**(math.log10(retargeting_ratio[0])/num_steps)
        # step_size[1] = 10**(math.log10(retargeting_ratio[1])/num_steps)
    else:
        num_steps = 1
        step_size = retargeting_ratio
    return num_steps, tuple(step_size)

def _get_num_levels_1d(current_size, min_axis_size_coarsest, downscale_ratio):
    return math.floor(math.log(min_axis_size_coarsest / current_size) / math.log(downscale_ratio))

def _get_num_levels(current_size, min_axis_size_coarsest, downscale_ratio):
    return min(
        _get_num_levels_1d(current_size_1d, min_axis_size_coarsest, downscale_ratio)
        for current_size_1d in current_size
    )

def retargeting(image: torch.Tensor, retargeting_ratio: Tuple[float, float],
                alpha: float = 1e-3,  patch_size: int = 7, reduce: str = 'weighted_mean',
                downscale_ratio: float = 0.8, max_num_levels: int = 9, gradual: bool = False,
                min_axis_size_coarsest: int = 21):
    pyramid = gpnn.make_pyramid(image, max_num_levels, downscale_ratio)
    num_steps, step_size = _get_retargeting_step_size(retargeting_ratio, gradual)
    generated = image
    for _ in range(num_steps):
        retargeted_generated = resize_right.resize(generated, step_size)
        # current_num_levels = max_num_levels
        # retargeted_generated = generated
        image_utils.imshow(retargeted_generated)
        current_num_levels = _get_num_levels(retargeted_generated.shape[-2:], min_axis_size_coarsest, downscale_ratio)
        # print(f'Num levels: {current_num_levels}')
        if current_num_levels > max_num_levels:
            raise RuntimeError('max_num_levels is smaller than one of the requested num_levels. '
                               'Please increase max_num_levels or min_axis_size_coarsest.')
        retargeted_pyramid = gpnn.make_pyramid(retargeted_generated, current_num_levels, downscale_ratio)
        retargeted_pyramid_shape = [level.shape for level in retargeted_pyramid]
        # print(step_size)
        print(retargeted_pyramid_shape)
        print([x.shape for x in pyramid[:current_num_levels + 1]])
        initial_quess = retargeted_pyramid[-1]
        generated = gpnn.gpnn(pyramid[:current_num_levels + 1], initial_quess, output_pyramid_shape=retargeted_pyramid_shape,
                              alpha=alpha, downscale_ratio=downscale_ratio, patch_size=patch_size,
                              reduce=reduce, num_iters_in_coarsest_level=10)
    return generated

In [97]:
out = retargeting(small_image.unsqueeze(0),
                  alpha=1e-2,
                  retargeting_ratio=(0.7, 1.0),
                  downscale_ratio=0.75,
                  max_num_levels=8,
                  min_axis_size_coarsest=28)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xba\x00\x00\x00b\x08\x02\x00\x00\x00h\x01^\xe4\x…

[torch.Size([1, 3, 98, 186]), torch.Size([1, 3, 74, 140]), torch.Size([1, 3, 56, 105]), torch.Size([1, 3, 42, 79]), torch.Size([1, 3, 32, 59])]
[torch.Size([1, 3, 140, 186]), torch.Size([1, 3, 105, 140]), torch.Size([1, 3, 79, 105]), torch.Size([1, 3, 60, 79]), torch.Size([1, 3, 45, 59])]
torch.Size([1, 3, 32, 59])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00;\x00\x00\x00 \x08\x02\x00\x00\x00#V\x97\xbc\x00\…

level: 4
corasest iteration: 0
corasest iteration: 1
corasest iteration: 2
corasest iteration: 3
corasest iteration: 4
corasest iteration: 5
corasest iteration: 6
corasest iteration: 7
corasest iteration: 8
corasest iteration: 9
torch.Size([1, 3, 32, 59])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00;\x00\x00\x00 \x08\x02\x00\x00\x00#V\x97\xbc\x00\…

level: 3
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 42, 79])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00O\x00\x00\x00*\x08\x02\x00\x00\x00}o\x918\x00\x00…

level: 2
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 56, 105])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00i\x00\x00\x008\x08\x02\x00\x00\x00q2A?\x00\x00%VI…

level: 1
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 74, 140])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x8c\x00\x00\x00J\x08\x02\x00\x00\x00\xe6\x13+\xc…

level: 0
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 98, 186])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xba\x00\x00\x00b\x08\x02\x00\x00\x00h\x01^\xe4\x…

In [None]:
out = gpnn.gpnn(pyramid, torch.rand_like(pyramid[-1]), patch_size=3)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xf8\x00\x00\x00\xba\x08\x02\x00\x00\x00\x0c#\x17…

In [92]:
out = image_generation(
    small_image.unsqueeze(0),
    num_levels=8, 
    alpha=10,
    noise_std=0.75,
)

torch.Size([1, 3, 15, 19])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x13\x00\x00\x00\x0f\x08\x02\x00\x00\x00\x89&c{\x…

level: 8
corasest iteration: 0
torch.Size([1, 3, 15, 19])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x13\x00\x00\x00\x0f\x08\x02\x00\x00\x00\x89&c{\x…

level: 7
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 19, 25])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x19\x00\x00\x00\x13\x08\x02\x00\x00\x00\xea\x101…

level: 6
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 25, 34])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00"\x00\x00\x00\x19\x08\x02\x00\x00\x00\xdb\x98k\xf…

level: 5
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 34, 45])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00-\x00\x00\x00"\x08\x02\x00\x00\x00D.\xc7\x18\x00\…

level: 4
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 45, 59])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00;\x00\x00\x00-\x08\x02\x00\x00\x00\x9f\xc8\x84b\x…

level: 3
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 60, 79])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00O\x00\x00\x00<\x08\x02\x00\x00\x00\xa8\xe0p\xbe\x…

level: 2
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 79, 105])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00i\x00\x00\x00O\x08\x02\x00\x00\x00g\x15|F\x00\x00…

level: 1
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6


KeyboardInterrupt: 

In [24]:
imshow(small_image)
imshow(out)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xba\x00\x00\x00\x8c\x08\x02\x00\x00\x00DO$\x16\x…

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xba\x00\x00\x00\x8c\x08\x02\x00\x00\x00DO$\x16\x…

In [33]:
out = retargeting(small_image, retargeting_ratio=(0.75, 0.75))

TypeError: 'int' object is not iterable