In [27]:
import math
from typing import List, Optional, Sequence, Tuple

import torch
from torch import nn
from torch.nn import functional as F

%load_ext autoreload
%autoreload 2
import fold
import gpnn
import image as image_utils
import new_utils as utils
import resize_right

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
query = torch.rand(1, 3, 10, 11)
key = torch.rand(1, 3, 12, 13)
value = torch.rand(1, 3, 12, 13)

In [3]:
out = gpnn.pnn(query, key, value, patch_size=3)
assert out.shape == query.shape

In [4]:
pyramid = gpnn.make_pyramid(query, 2, 0.75)
initial_guess = torch.rand_like(pyramid[-1])
out = gpnn.gpnn(pyramid, initial_guess, patch_size=3)
assert out.shape == pyramid[0].shape

level: 2
level: 1
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
level: 0
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9


In [5]:
def image_generation(image: torch.Tensor, noise_std: float = 0.75, alpha: float = 10.0,
                     patch_size: int = 7, reduce: str = 'weighted_mean', 
                     downscale_ratio: float = 0.75, num_levels: int = 9):
    pyramid = gpnn.make_pyramid(image, num_levels, downscale_ratio)
    initial_guess = pyramid[-1] + noise_std * torch.randn_like(pyramid[-1])
    return gpnn.gpnn(pyramid, initial_guess, alpha=alpha, downscale_ratio=downscale_ratio,
                     patch_size=patch_size, reduce=reduce)

In [45]:
def _get_retargeting_step_size(retargeting_ratio: Tuple[float, float],
                               gradual: bool) -> Tuple[int, Tuple[float, float]]:
    if gradual:
        raise ValueError('gradual=True is not supported.')
        # step_size = [0.9, 0.9]
        # if retargeting_ratio[0] >= 1:
        #     step_size[0] = 1.1
        # if retargeting_ratio[1] >= 1:
        #     step_size[0] = 1.1
        # num_steps = math.floor(max(math.log(retargeting_ratio[0])/math.log(step_size[0]), 
        #                         math.log(retargeting_ratio[1])/math.log(step_size[1])))
        # step_size[0] = 10**(math.log10(retargeting_ratio[0])/num_steps)
        # step_size[1] = 10**(math.log10(retargeting_ratio[1])/num_steps)
    else:
        num_steps = 1
        step_size = retargeting_ratio
    return num_steps, tuple(step_size)

def _get_num_levels_1d(current_size, min_axis_size_coarsest, downscale_ratio):
    return math.floor(math.log(min_axis_size_coarsest / current_size) / math.log(downscale_ratio))

def _get_num_levels(current_size, min_axis_size_coarsest, downscale_ratio):
    return min(
        _get_num_levels_1d(current_size_1d, min_axis_size_coarsest, downscale_ratio)
        for current_size_1d in current_size
    )

def retargeting(image: torch.Tensor, retargeting_ratio: Tuple[float, float],
                alpha: float = 1e-3,  patch_size: int = 7, reduce: str = 'weighted_mean',
                downscale_ratio: float = 0.8, max_num_levels: int = 9, gradual: bool = False,
                min_axis_size_coarsest: int = 21):
    pyramid = gpnn.make_pyramid(image, max_num_levels, downscale_ratio)
    num_steps, step_size = _get_retargeting_step_size(retargeting_ratio, gradual)
    generated = image
    for _ in range(num_steps):
        retargeted_generated = resize_right.resize(generated, step_size)
        current_num_levels = _get_num_levels(retargeted_generated.shape[-2:], min_axis_size_coarsest, downscale_ratio)
        if current_num_levels > max_num_levels:
            raise RuntimeError('max_num_levels is smaller than one of the requested num_levels. '
                               'Please increase max_num_levels or min_axis_size_coarsest.')
        retargeted_pyramid = gpnn.make_pyramid(retargeted_generated, current_num_levels, downscale_ratio)
        retargeted_pyramid_shape = [level.shape for level in retargeted_pyramid]
        initial_quess = retargeted_pyramid[-1]
        generated = gpnn.gpnn(pyramid[:current_num_levels], initial_quess, # output_pyramid_shape=retargeted_pyramid_shape,
                              alpha=alpha, downscale_ratio=downscale_ratio, patch_size=patch_size,
                              reduce=reduce, num_iters_in_coarsest_level=10)
    return generated

In [46]:
out = retargeting(small_image.unsqueeze(0), alpha=100, retargeting_ratio=(0.75, 0.75), min_axis_size_coarsest=28)

torch.Size([1, 3, 35, 46])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00.\x00\x00\x00#\x08\x02\x00\x00\x00dE\xaf\xbe\x00\…

level: 4
corasest iteration: 0
corasest iteration: 1
corasest iteration: 2
corasest iteration: 3
corasest iteration: 4
corasest iteration: 5
corasest iteration: 6
corasest iteration: 7
corasest iteration: 8
corasest iteration: 9
torch.Size([1, 3, 35, 46])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00.\x00\x00\x00#\x08\x02\x00\x00\x00dE\xaf\xbe\x00\…

level: 3
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 72, 96])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00`\x00\x00\x00H\x08\x02\x00\x00\x00\x86\x05g4\x00\…

level: 2
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 90, 120])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00x\x00\x00\x00Z\x08\x02\x00\x00\x00\xfcb\x05\xb8\x…

level: 1
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6


KeyboardInterrupt: 

In [None]:
out = gpnn.gpnn(pyramid, torch.rand_like(pyramid[-1]), patch_size=3)

In [6]:
image = utils.imread('balloons.png')

In [14]:
image_utils.imshow(image)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xf8\x00\x00\x00\xba\x08\x02\x00\x00\x00\x0c#\x17…

In [10]:
small_image = resize_right.resize(image, (1, 0.75, 0.75))

In [25]:
out = image_generation(small_image.unsqueeze(0), num_levels=8, alpha=10)

torch.Size([1, 3, 15, 19])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x13\x00\x00\x00\x0f\x08\x02\x00\x00\x00\x89&c{\x…

level: 8
corasest iteration: 0
torch.Size([1, 3, 15, 19])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x13\x00\x00\x00\x0f\x08\x02\x00\x00\x00\x89&c{\x…

level: 7
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 19, 25])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x19\x00\x00\x00\x13\x08\x02\x00\x00\x00\xea\x101…

level: 6
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 25, 34])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00"\x00\x00\x00\x19\x08\x02\x00\x00\x00\xdb\x98k\xf…

level: 5
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 34, 45])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00-\x00\x00\x00"\x08\x02\x00\x00\x00D.\xc7\x18\x00\…

level: 4
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 45, 59])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00;\x00\x00\x00-\x08\x02\x00\x00\x00\x9f\xc8\x84b\x…

level: 3
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 60, 79])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00O\x00\x00\x00<\x08\x02\x00\x00\x00\xa8\xe0p\xbe\x…

level: 2
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 79, 105])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00i\x00\x00\x00O\x08\x02\x00\x00\x00g\x15|F\x00\x00…

level: 1
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 105, 140])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x8c\x00\x00\x00i\x08\x02\x00\x00\x00g+\\X\x00\x0…

level: 0
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
torch.Size([1, 3, 140, 186])


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xba\x00\x00\x00\x8c\x08\x02\x00\x00\x00DO$\x16\x…

In [24]:
imshow(small_image)
imshow(out)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xba\x00\x00\x00\x8c\x08\x02\x00\x00\x00DO$\x16\x…

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xba\x00\x00\x00\x8c\x08\x02\x00\x00\x00DO$\x16\x…

In [33]:
out = retargeting(small_image, retargeting_ratio=(0.75, 0.75))

TypeError: 'int' object is not iterable