<a href="https://colab.research.google.com/github/MartonJToth/DIP_Labor/blob/master/DIP_Labor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#Based on https://github.com/DmitryUlyanov/deep-image-prior/blob/master/inpainting.ipynb

!pip3 install torch torchvision 
!pip install matplotlib
!pip3 install Pillow==4.2.1

In [0]:
!nvcc --version
import torch
print(torch.cuda.is_available())

In [0]:
!wget http://cg.iit.bme.hu/~tmarton/deeplearning/DIPLabData.zip
!unzip -qq DIPLabData.zip
!rm DIPLabData.zip

In [0]:
imsize = -1
dim_div_by = 64

img_path  = 'DIPLabData/kate.png'
mask_path = 'DIPLabData/kate_mask.png'

dtype = torch.cuda.FloatTensor

In [0]:
from PIL import Image
import PIL
import numpy as np
import torchvision
import matplotlib.pyplot as plt
from IPython import display


def get_noise(input_depth, spatial_size, noise_type='u', var=1./10):

    if isinstance(spatial_size, int):
        spatial_size = (spatial_size, spatial_size)
        
      
    shape = [1, input_depth, spatial_size[0], spatial_size[1]]
    net_input = torch.zeros(shape)

    net_input.uniform_()
    net_input *= var            
        
    return net_input

def pil_to_np(img_PIL):
  '''Converts image in PIL format to np.array.

  From W x H x C [0...255] to C x W x H [0..1]
  '''
  ar = np.array(img_PIL)

  if len(ar.shape) == 3:
    ar = ar.transpose(2,0,1)
  else:
    ar = ar[None, ...]

  return ar.astype(np.float32) / 255.

def get_image(path, imsize=-1):
  
    img = img = Image.open(path)

    if isinstance(imsize, int):
        imsize = (imsize, imsize)

    if imsize[0]!= -1 and img.size != imsize:
        if imsize[0] > img.size[0]:
            img = img.resize(imsize, Image.BICUBIC)
        else:
            img = img.resize(imsize, Image.ANTIALIAS)

    img_np = pil_to_np(img)

    return img, img_np
  
  
def crop_image(img, d=32):
    '''Make dimensions divisible by `d`'''

    new_size = (img.size[0] - img.size[0] % d, 
                img.size[1] - img.size[1] % d)

    bbox = [
            int((img.size[0] - new_size[0])/2), 
            int((img.size[1] - new_size[1])/2),
            int((img.size[0] + new_size[0])/2),
            int((img.size[1] + new_size[1])/2),
    ]

    img_cropped = img.crop(bbox)
    return img_cropped
  
def plot_image_grid(images_np, nrow =8, factor=1, interpolation='lanczos'):

  n_channels = max(x.shape[0] for x in images_np)
  assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"

  images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]
  
  images_torch = [torch.from_numpy(x) for x in images_np]
  torch_grid = torchvision.utils.make_grid(images_torch, nrow)

  grid = torch_grid.numpy()

  plt.figure(figsize=(len(images_np) + factor, 12 + factor))

  if images_np[0].shape[0] == 1:
      plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
  else:
      plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)
      
  plt.show()

  return grid

In [0]:
img_pil, img_np = get_image(img_path, imsize)
img_mask_pil, img_mask_np = get_image(mask_path, imsize)

In [0]:
img_mask_pil = crop_image(img_mask_pil, dim_div_by)
img_pil      = crop_image(img_pil,      dim_div_by)

img_np      = pil_to_np(img_pil)
img_mask_np = pil_to_np(img_mask_pil)

TODO: saját maszk készítése, meddig képes a részleteket összerakni, vagy zaj hozzákeverése

In [0]:
plot_image_grid([img_np, img_mask_np, img_mask_np*img_np], 3,11);

In [0]:
import torch
import torch.nn as nn

def add_module(self, module):
    self.add_module(str(len(self) + 1), module)
    
torch.nn.Module.add = add_module

def bn(num_features):
    return nn.BatchNorm2d(num_features)
  
def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'):
    downsampler = None
    if stride != 1 and downsample_mode != 'stride':

        if downsample_mode == 'avg':
            downsampler = nn.AvgPool2d(stride, stride)
        elif downsample_mode == 'max':
            downsampler = nn.MaxPool2d(stride, stride)
        elif downsample_mode  in ['lanczos2', 'lanczos3']:
            downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True)
        else:
            assert False

        stride = 1

    padder = None
    to_pad = int((kernel_size - 1) / 2)
    if pad == 'reflection':
        padder = nn.ReflectionPad2d(to_pad)
        to_pad = 0
  
    convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)


    layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
    return nn.Sequential(*layers)
  
  
def act(act_fun = 'LeakyReLU'):

  if isinstance(act_fun, str):
    if act_fun == 'LeakyReLU':
      return nn.LeakyReLU(0.2, inplace=True)
    elif act_fun == 'Swish':
      return Swish()
    elif act_fun == 'ELU':
      return nn.ELU()
    elif act_fun == 'none':
      return nn.Sequential()
    else:
      assert False
  else:
    return act_fun()

class Concat(nn.Module):
  def __init__(self, dim, *args):
    super(Concat, self).__init__()
    self.dim = dim

    for idx, module in enumerate(args):
      self.add_module(str(idx), module)

  def forward(self, input):
    inputs = []
    for module in self._modules.values():
      inputs.append(module(input))

    inputs_shapes2 = [x.shape[2] for x in inputs]
    inputs_shapes3 = [x.shape[3] for x in inputs]        

    if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
      inputs_ = inputs
    else:
      target_shape2 = min(inputs_shapes2)
      target_shape3 = min(inputs_shapes3)

      inputs_ = []
      for inp in inputs: 
        diff2 = (inp.size(2) - target_shape2) // 2 
        diff3 = (inp.size(3) - target_shape3) // 2 
        inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])

    return torch.cat(inputs_, dim=self.dim)

  def __len__(self):
    return len(self._modules)

def skip(
        num_input_channels=2, num_output_channels=3, 
        num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 
        filter_size_down=3, filter_size_up=3, filter_skip_size=1,
        need_sigmoid=True, need_bias=True, 
        pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 
        need1x1_up=True):
    """Assembles encoder-decoder with skip connections.

    Arguments:
        act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU)
        pad (string): zero|reflection (default: 'zero')
        upsample_mode (string): 'nearest|bilinear' (default: 'nearest')
        downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride')

    """
    assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip)

    n_scales = len(num_channels_down) 

    if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) :
        upsample_mode   = [upsample_mode]*n_scales

    if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)):
        downsample_mode   = [downsample_mode]*n_scales
    
    if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) :
        filter_size_down   = [filter_size_down]*n_scales

    if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) :
        filter_size_up   = [filter_size_up]*n_scales

    last_scale = n_scales - 1 

    cur_depth = None

    model = nn.Sequential()
    model_tmp = model

    input_depth = num_input_channels
    for i in range(len(num_channels_down)):

        deeper = nn.Sequential()
        skip = nn.Sequential()

        if num_channels_skip[i] != 0:
            model_tmp.add(Concat(1, skip, deeper))
        else:
            model_tmp.add(deeper)
        
        model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])))

        if num_channels_skip[i] != 0:
            skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad))
            skip.add(bn(num_channels_skip[i]))
            skip.add(act(act_fun))
            
        # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part))

        deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]))
        deeper.add(bn(num_channels_down[i]))
        deeper.add(act(act_fun))

        deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad))
        deeper.add(bn(num_channels_down[i]))
        deeper.add(act(act_fun))

        deeper_main = nn.Sequential()

        if i == len(num_channels_down) - 1:
            # The deepest
            k = num_channels_down[i]
        else:
            deeper.add(deeper_main)
            k = num_channels_up[i + 1]

        deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i]))

        model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad))
        model_tmp.add(bn(num_channels_up[i]))
        model_tmp.add(act(act_fun))


        if need1x1_up:
            model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad))
            model_tmp.add(bn(num_channels_up[i]))
            model_tmp.add(act(act_fun))

        input_depth = num_channels_down[i]
        model_tmp = deeper_main

    model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad))
    if need_sigmoid:
        model.add(nn.Sigmoid())

    return model

In [0]:
input_depth = 32
LR = 0.01 
num_iter = 6001
param_noise = False
show_every = 50
figsize = 5
reg_noise_std = 0.03
pad = 'reflection' # 'zero'

net = skip(input_depth, img_np.shape[0], 
           num_channels_down = [128] * 5,
           num_channels_up =   [128] * 5,
           num_channels_skip =    [128] * 5,  
           filter_size_up = 3, filter_size_down = 3, 
           upsample_mode='nearest', filter_skip_size=1,
           need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)

net = net.type(dtype)
net_input = get_noise(input_depth, img_np.shape[1:]).type(dtype)

In [11]:
# Compute number of parameters
s  = sum(np.prod(list(p.size())) for p in net.parameters())
print ('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

img_var = torch.from_numpy(img_np)[None, :].type(dtype)
mask_var = torch.from_numpy(img_mask_np)[None, :].type(dtype)

Number of params: 3002627


In [0]:
def optimize(parameters, closure, LR, num_iter):

  
  print('Starting optimization with ADAM')
  optimizer = torch.optim.Adam(parameters, lr=LR)

  for j in range(num_iter):
    optimizer.zero_grad()
    closure()
    optimizer.step()

In [0]:

i = 0
def closure():
    
    global i
    
    if param_noise:
        for n in [x for x in net.parameters() if len(x.size()) == 4]:
            n = n + n.detach().clone().normal_() * n.std() / 50
    
    net_input = net_input_saved
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
        
        
    out = net(net_input)
   
    total_loss = mse(out * mask_var, img_var * mask_var)
    total_loss.backward()
        
    print ('Iteration %05d    Loss %f' % (i, total_loss.item()), '\r', end='')
    if  i % show_every == 0:
        out_np = out.detach().cpu().numpy()[0]
        plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)
        
    i += 1

    return total_loss

In [0]:
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()


optimize(net.parameters(), closure, LR, num_iter)