# Assignment 5: Deep Image Prior

**Due Date:** Feb 10, 2019.

**Submission:** In pairs, [here]().

**Student 1**
</br>**Name:** Jane Doe
</br>**Email:** jane@doe.com


**Student 2**
</br>**Name:** John Doe
</br>**Email:** john@doe.com

## Introduction

In a recent CVPR paper [Deep Image Prior](http://openaccess.thecvf.com/content_cvpr_2018/papers/Ulyanov_Deep_Image_Prior_CVPR_2018_paper.pdf) it was shown by Ulyanov et al. that deep neural nets have an inductive bias toward _natural images_: that is, if using a deep net architecture to generate an image it is _easier_ for the net to produce naturally looking images rather than noisy or blurry ones. Based on this observation, they proposed to use deep neural nets as a prior for image enhancement tasks.

In this exercise we will follow their lead and try different deep architectures as _prior_ for the task of image denoising. In order to denoise an image using the _Deep Image Prior_ framework one needs a deep net that **takes random noise** and **generates the noisy image** from it (we assume we do not have the clean image, we will use it only for measuring performance). Because of the inductive bias of deep nets toward natural images we expect a clean version of the image to emerge earlier in the process before the net overfits and produces the noise as well.

The exercise has three parts, with three main basic architectures:
* 1D-Generator (generating from 1D noise)
* 2D-Generator (generating from 2D noise)
* 2D-Generator with skip-connections

Throughout the exercise we encourage you to try various design choices so you can get the "feel" of working and experimenting with deep learning. We don't aim here for a specific solution, but rather are interested in the process (but we do hope you get appealing results in the end!). You are given a set of qestions for each part, please answer them but also feel free to discuss other insights you have from your experiments. 

## Setup
FIrst, we'd like to use the free GPU provided by Google Colab. This will accelerate the training by an order of magnitude.
1. In the menu, select: **Runtime -> Change runtime type**.
2. Choose "GPU" under **Hardware accelerator**.

Next, we'll need to install some python dependencies, and to download the dataset. You may need to repeat this process when the runtime is started.
1. Run the cell **Install requirements**.
2. Run the second cell in **Download dataset**.
3. Restart the runtime, either by typing **"`Ctrl+M .`"**, or by using the menus: **Runtime -> Restrart runtime...** .


In [0]:
#@title Install requiremnts
#@markdown Please run this cell to install python dependencies.
#@markdown When finished, type **`Ctrl+M .`** to restart the runtime. Alternatively, use the menus **Runtime -> Restart Runtime...**
#@markdown This should fix the following error:<br/>`AttributeError: module 'PIL.Image' has no attribute 'register_extensions'`

# Install pytorch=1.0.0
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
if accelerator.startswith('cu9'):
  accelerator = 'cu90'  
!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision


# Install pillow=5.4.1
!pip install -U -q "pillow~=5.4.1"


# Install livelossplot
!pip install -U -q "livelossplot~=0.3.0"


# should fix a problem with pillow
%reload_ext autoreload
%autoreload

[?25l[K    0% |▏                               | 10kB 15.4MB/s eta 0:00:01[K    1% |▎                               | 20kB 1.9MB/s eta 0:00:02[K    1% |▌                               | 30kB 2.7MB/s eta 0:00:01[K    2% |▋                               | 40kB 1.8MB/s eta 0:00:02[K    2% |▉                               | 51kB 2.2MB/s eta 0:00:01[K    3% |█                               | 61kB 2.6MB/s eta 0:00:01[K    3% |█▏                              | 71kB 3.0MB/s eta 0:00:01[K    4% |█▎                              | 81kB 3.4MB/s eta 0:00:01[K    4% |█▌                              | 92kB 3.8MB/s eta 0:00:01[K    5% |█▋                              | 102kB 2.9MB/s eta 0:00:01[K    5% |█▉                              | 112kB 2.9MB/s eta 0:00:01[K    6% |██                              | 122kB 4.1MB/s eta 0:00:01[K    6% |██▏                             | 133kB 4.1MB/s eta 0:00:01[K    7% |██▎                             | 143kB 7.8MB/s eta 0:00:01[K    

In [0]:
#@title Download dataset
#@markdown Please run this cell to download the datasets.

# download images
!wget -q https://wis-intro-vision-2019.wikidot.com/local--files/assignments/ex5-data.tar.gz
!tar -zxf ex5-data.tar.gz && rm -f ex5-data.tar.gz

In [0]:
#@title Download dataset-v2
#@markdown Please run this cell to download the datasets.

# download images
!wget -q https://wis-intro-vision-2019.wikidot.com/local--files/assignments/ex5-data-v2.tar.gz
!tar -zxf ex5-data-v2.tar.gz && rm -f ex5-data-v2.tar.gz

## General imports

If you encounter an error in this step, please follow the setup instructions above.

In [0]:
import os

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

## Utility Methods

You may want to use the provided utility functions, instead of implementing them yourselves.

### I/O

In [0]:
def gdrive_mount():
  if os.path.isdir('/gdrive'):
    return True

  import google.colab
  google.colab.drive.mount('/gdrive')  


def gdrive_path(path):
  return os.path.join('/gdrive/My Drive', path)

In [0]:
import tempfile

import PIL

import IPython.display

# Low-Level Utility Methods

def _img_to_float32(image, bounds=(0, 1)):
  """Receives an `image` and range `bounds`, 
     normalizes it and converts it to `float32`. 

  Arguments:
    image (np.ndarray): the image
    bounds (Tuple[float, float], optional): expected minimum and maximum values 
                                            of image pixels

  Returns:
    np.ndarray: the converted image, with `dtype=np.float32`.

  """
  minval, maxval = bounds
  image = np.asarray(image, dtype=np.float32) / 255.0
  image = np.clip((maxval - minval) * image + minval, minval, maxval)
  return image


def _img_to_uint8(image, bounds=(0, 1)):
  """Receives an `image` and range `bounds`, 
     noramlizes it and converts it to `uint8`.

  Arguments:
    image (np.ndarray): the image
    bounds (Tuple[float, float], optional): expected minimum and maximum values 
                                            of image pixels

  Returns:
    np.ndarray: the converted image, with `dtype=np.uint8`.
    
  """
  if image.dtype != np.uint8:
    minval, maxval = bounds
    image = (image.astype(np.float32) - minval) / (maxval - minval)
    image = (image * 255.0).round().clip(0, 255).astype(np.uint8)
  return image


def _check_path(path):
  """Checks whether a path exists. 
  If not, the required missing directories are created"""
  path = os.path.abspath(path)
  if not os.path.exists(os.path.dirname(path)):
    os.makedirs(os.path.dirname(path))
  return path


# High-Level Utility Methods

def imread(path, size=(256, 256), bounds=(0, 1)):
  """Reads an image, changes its size to match `size` (may crop and lose data),
     and converts it to float.
  
  Arguments:
    path (str): path to image
    size (Tuple[int, int], optional): desired image size
    bounds (Tuple[float, float], optional): expected minimum and maximum values 
                                            of image pixels
  
  Returns:
    image (np.ndarray): the image, where it values are in the given `bounds`.
  
  """
  image = PIL.Image.open(path).convert(mode='RGB')
  if size is not None:
    scale_factor = max([float(size[dim]) / float(image.size[dim]) for dim in range(2)])
    new_size = [int(scale_factor * sz) for sz in image.size]
    image = image.resize(new_size, PIL.Image.LANCZOS)
    left, top = [(image.size[dim] - size[dim]) // 2 for dim in range(2)] 
    image = image.crop((left, top, left + size[0], top + size[1]))
  image = _img_to_float32(image, bounds)
  return image


def imwrite(path, image, bounds=(0, 1), **kwargs):  
  """Normalize `image` and save it to `path`.

  Arguments:
    path (str): saving location
    image (np.ndarray): the image
    bounds (Tuple[float, float], optional): expected minimum and maximum values 
                                            of image pixels
  """  
  image = _img_to_uint8(image, bounds)
  path = _check_path(path)
  image = PIL.Image.fromarray(image)
  image.save(path, **kwargs)


def imshow(image, path=None, **kwargs):
  """Normalize `image`, save it, and show it.

  Arguments:
    path (str): saving location (if None, image is save to temporary location)
    image (np.ndarray): the image
  
  """
  fd = None
  if path is None:
    fd, path = tempfile.mkstemp(suffix='.png')
  
  imwrite(path, image, **kwargs)  
  output = IPython.display.Image(path)
  
  if fd is not None:
    os.close(fd)
  
  display(output)


# colab version
try:
  import google.colab.widgets
  
  def imshow_tabs(noisy, result, clean=None, paths=None, **kwargs):
    """Normalizes input images (`noisy`, `result`, and possible `clean`)
    and shows them in tabs.

    Arguments:
      noisy (np.ndarray): noisy image to show.
      result (np.ndarray): cleaned image to show.
      clean (np.ndarray, optional): ground truth clean image to show.
      paths (List[str]], optional): list of locations to save the images to.
    
    """
    images = [noisy, result]
    titles = ['noisy', 'result']
    if clean is not None:
      images.append(clean)
      titles.append('clean')
    if paths is None:
      paths = [None] * len(titles)
    assert len(paths) == len(titles)
    
    tab = google.colab.widgets.TabBar(titles)
    
    for title, path, image in zip(titles, paths, images):
      with tab.output_to(title):
        imshow(image, path=path, **kwargs)


# jupyter version
except ImportError:
  import ipywidgets
  
  def imshow_tabs(noisy, result, clean=None, paths=None, **kwargs):
    """Normalizes input images (`noisy`, `result`, and possible `clean`)
    and shows them in tabs.

    Arguments:
      noisy (np.ndarray): noisy image to show.
      result (np.ndarray): cleaned image to show.
      clean (np.ndarray, optional): ground truth clean image to show.
      paths (List[str]], optional): list of locations to save the images to.
    
    """
    images = [noisy, result]
    titles = ['noisy', 'result']
    if clean is not None:
      images.append(clean)
      titles.append('clean')
    if paths is None:
      paths = [None] * len(titles)
    assert len(paths) == len(titles)

    tab = ipywidgets.Tab([ipywidgets.Output() for _ in titles])
    
    for i, (title, path, image) in enumerate(zip(titles, paths, images)):
      tab.set_title(i, title)
      with tab.children[i]:
        imshow(image, path=path, **kwargs)
    
    display(tab)

### Evaluation

In [0]:
def root_mean_square_error(a, b):
  """Computes the RMSE between two images, or between two batches of images.
  
  Arguments:
    a (np.ndarray): the first image (or batch of images, stacked along axis 0)
    b (np.ndarray): the second image (or batch of images, stacked along axis 0)
  
  Returns:
    rmse (float / np.ndarray): if `a` has 2 or 3 dimensions, returns the
                               RMSE between `a` and `b`.
                               if `a` has 4 dimensions, returns a list of
                               RMSE between corresponding images in `a`
                               and `b`. 
  """
  assert a.ndim in {2, 3, 4}, '`a` should have 2/3/4 dimensions'
  assert a.shape == b.shape, '`a` and `b` should have the same shape'
  
  # batch of images
  if a.ndim == 4:  
    mse = np.mean((a - b)**2, axis=(1, 2, 3))
  
  # single image
  else:            
    mse = np.mean((a - b)**2)
  
  return np.sqrt(mse)
    

def peak_signal_noise_ratio(a, b, bounds=(0, 1)):
  """Computes the PSNR between two images, or between two batches of images.
  
  Arguments:
    a (np.ndarray): the first image (or batch of images, stacked along axis 0)
    b (np.ndarray): the second image (or batch of images, stacked along axis 0)
    bounds (Tuple[float, float]): the valid bounds of `a` and `b`.
 
  Returns:
    psnr (float / np.ndarray): if `a` has 2 or 3 dimensions, returns the
                               PSNR between `a` and `b`.
                               if `a` has 4 dimensions, returns list of
                               PSNR between corresponding images in `a`
                               and `b`.
  """

  assert a.ndim in {2, 3, 4}, '`a` should have 2/3/4 dimensions'
  assert a.shape == b.shape, '`a` and `b` should have the same shape'
  
  minval, maxval = bounds
  rmse = root_mean_square_error(a, b)
  return 20 * np.log10(maxval - minval) - 20 * np.log10(rmse)

### Datasets

In [0]:
def list_dataset(dataset_dir='dataset'):
  """Lists the images in a given dataset.
  
  Arguments:
    dataset_dir (str): where the dataset is.
  
  Returns:
    dataset (List[str]): list of paths to images in the dataset.
  """
  return sorted([os.path.join(dataset_dir, fname) for fname in os.listdir(dataset_dir)])

### Training

In [0]:
def add_noise(images, scale, bounds=(0, 1)):
  minval, maxval = bounds
  noise = np.random.uniform(-scale, scale, size=images.shape)
  return np.clip(images + noise, minval, maxval)


def add_noise_old(images, sigma, bounds=(0, 1)):
  """Adds i.i.d. Gaussian noise to all the pixels of an image
  
  Arguments:
    images (np.ndarray): an image (or batch of images)
    sigme (float): scale for noise
    bounds (Tuple[float, float], optional): desired minimum and maximum values 
                                            of image pixels
                                            
  Returns: 
    np.ndarray: noisy images
    
  """
  minval, maxval = bounds
  noise = np.random.normal(loc=0.0, scale=sigma, size=images.shape)
  return np.clip(images + noise, minval, maxval)

In [0]:
import livelossplot
import matplotlib.pyplot as plt
plt.style.use(['seaborn-whitegrid', 'seaborn-notebook'])

class Tracker(object):
  def __init__(self, fig_path=None):
    """Creates a tracker, that keeps track of training.
    
    You should probably create a tracker when creating the network, because they
    are usually coupled.

    Arguments:
      fig_path (str, optional): where to save an image of the tracker's plots.

    """
    self._plot = livelossplot.PlotLosses(
      plot_extrema=False,
      fig_path=fig_path,
      metric2title={'loss': 'Loss', 'rmse': 'RMSE', 'psnr': 'PSNR'}
    )
  
  def update(self, logs):
    """Update the tracker's data.
    
    Arguments:
      logs (dict): a dictionary with data for the tracker. Valid keys are
                   'loss', 'rmse', 'pnsr' and 'val_*' (for the mentioned keys).
                   'val_*' key means value on validation data.
                   the value corresponding to the key should be a single value,
                   not an array.
    
    """
    self._plot.update(logs)
  
  def draw(self):
    """Refresh the tracker's plot."""
    self._plot.draw()

In [0]:
class DeepImagePriorTrainer(object):
  def __init__(self, noise_input, noisy_image, clean_image=None, device=None):
    """Trainer for Deep Image Prior, on a given pre-set example.
    
    Arguments:
      noise_input (np.ndarray): vector of the noise to use as the network's input.
      noisy_image (np.ndarray): the noisy image the network tries to learn.
      clean_image (np.ndarray, optional): the clean image (ground truth).
                                          SHOULD BE USED FOR VALIDATION ONLY!!!
      device (str, optional): where to run. If None, chooses device automatically.
                              Usually better to keep this unchanged.
    
    """
    if device is None:
      device = "cuda:0" if torch.cuda.is_available() else "cpu"
    self.device = device
    
    if noise_input.ndim == 3:
      noise_input = np.transpose(noise_input, (2, 0, 1))
    self.noise = noise_input[None, ...]
    self.noise_t = torch.as_tensor(self.noise, dtype=torch.float32, device=self.device)

    self.noisy = np.transpose(noisy_image, (2, 0, 1))[None, ...]
    self.noisy_t = torch.as_tensor(self.noisy, dtype=torch.float32, device=self.device)
    
    if clean_image is not None:
      self.clean = np.transpose(clean_image, (2, 0, 1))[None, ...]
      self.clean_t = torch.as_tensor(self.clean, dtype=torch.float32, device=self.device)
    else:
      self.clean = None
      self.clean_t = None

  def train(self, net, optimizer, criterion, epochs, output_rate=50, tracker=None):
    """Trains a Deep Image Prior network on the pre-set example (input noise and noisy image).
    
    Arguments:
      net (torch.nn.Module): your network. receives the noise, and should return the image.
      optimizer (torch.optim.Optimizer): optimizer (make sure it's optimizing net's parameters).
      criterion: loss function.
      epochs (int): number of training iterations.
      output_rate (int): how frequently (number of epochs) should the tracker's plot be refreshed.
      tracker (Tracker, optional): tracker of the training.
      
    """
    net = net.to(self.device)
    net.train()
    
    for step in range(epochs):
      optimizer.zero_grad()
      cleaned_t = net(self.noise_t)
      loss_t = criterion(cleaned_t, self.noisy_t)
      loss_t.backward()
      optimizer.step()
      
      logs = {}
      cleaned = cleaned_t.cpu().detach().numpy()

      # report learning on provided (noisy) image
      loss = loss_t.item()
      logs['rmse'] = root_mean_square_error(cleaned, self.noisy)[0]
      logs['psnr'] = peak_signal_noise_ratio(cleaned, self.noisy)[0]
      logs['loss'] = loss

      # report learning on ground-truth (clean) image
      if self.clean is not None:
        with torch.no_grad():
          val_loss_t = criterion(cleaned_t, self.clean_t)
        val_loss = val_loss_t.item()
        logs['val_rmse'] = root_mean_square_error(cleaned, self.clean)[0]
        logs['val_psnr'] = peak_signal_noise_ratio(cleaned, self.clean)[0]
        logs['val_loss'] = val_loss
      
      if tracker is not None:
        tracker.update(logs)
      
      if step % output_rate == 0:
        if tracker is not None:
          tracker.draw()
        else:
          print(logs)  # print logs
        
    if tracker is not None:
      tracker.draw()
  
  def eval(self, net):
    """Predicts the image by feeding `net` with the pre-set noise.
    
    Arguments:
      net (torch.nn.Module): network to feed with input noise, and take its prediction as image.
      
    Returns:
      image (np.ndarray): the predicted imaged.
    """
    net = net.to(self.device)
    net.eval()
    
    cleaned_t = net(self.noise_t)
    cleaned = cleaned_t.cpu().detach().numpy()
    return np.transpose(cleaned[0, ...], (1, 2, 0))

In [0]:
def example_main_loop(net, opt_class, opt_kwargs, criterion, noise_size, image_path, sigma, epochs, tracker=None):
  optimizer = opt_class(net.parameters(), **opt_kwargs)
  noise_input = np.random.normal(size=noise_size).astype(np.float32)
  clean_image = imread(image_path, size=(256, 256))
  noisy_image = add_noise(clean_image, sigma)
  dip = DeepImagePriorTrainer(noise_input, noisy_image, clean_image)
  dip.train(net, optimizer, criterion, epochs=epochs, tracker=tracker)
  result_image = dip.eval(net)
  imshow_tabs(noisy_image, result_image, clean_image)
  return result_image

## Deep Image Prior

In this exercise, you'll test several architectures for _Deep Image Prior_. For **each** architecture, answer the following questions (in the report):
1. Does the "inductive bias" assumption hold? Do we see a clean image emerge before noise is being reconstructed? Is the net "rich" enough to overfit the noise?
2. What is the effect of each design choice?
3. Report average PSNR on all images.
4. How many trainable parameters are in the model? How many operations (Mult/Add)?

Support your claims by examples (of cleaned images), plots, etc.

### 1. 1D-Generator

Construct a "generating" net that expects as input a 1D random noise vector, passes it through a fully connected layer to arranges it in a 2D feature space. Add "upsampling" convolutional layers to produce the target image.

#### Design Choices
1. "width" of the different layers (number of filters)
2. Number of different layers: how many "upsamples", by what scale factor each time
3. "Upsample" method: transposed convolution vs. interp + conv (see [here](https://distill.pub/2016/deconv-checkerboard/))
4. How to init weights: uniform/random/zero?
5. What loss to use? L1/L2/CE?
6. Regularization?
7. Solver: SGD (with and without momentum) / ADAM
8. Learning rates: try different values `1e-5`, `1e-3`, `1e-1`, `1e1`. \\
See effect of learning rate on optimization progress: slow progress vs divergence.

In [0]:
class Generator1D(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    
    # ENTER YOUR SOLUTION HERE
    raise NotImplementedError()
   
  def forward(self, x):
    # ENTER YOUR SOLUTION HERE
    raise NotImplementedError()

### 2. 2D-Generator

Construct a "U-net": start with 2D noise the same size as the input image (but maybe more/less channels). Add some conv + stride/pooling layers to reduce spatial dimensions and creating an "information bottleneck", then use "upsampling" blocks to recover the original spatial resolution.

#### Design Choices
1. "width" of the different layers
2. Number of different layers: how many "upsamples", by what scale factor each time
3. "Upsample" method: transposed convolution vs. interp + conv (see [here](https://distill.pub/2016/deconv-checkerboard/))
4. How to init weights: uniform/random/zero?
5.Use the loss function that worked best for you in the previous part
6. Regularization?


In [0]:
class Generator2D(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    
    # ENTER YOUR SOLUTION HERE
    raise NotImplementedError()
   
  def forward(self, x):
    # ENTER YOUR SOLUTION HERE
    raise NotImplementedError()

### 3. 2D-Generator with skip-connections

Use the same "U-net" architecture from previous part, but now add "skip connections" connecting feature maps of same spatial resolution from the "downscale" part to the "upscale" part.

#### Design Choices
1. "width" of the different layers, and specifically, the "width" of the skip connections.
2. How to propagate the information passed through "skip connections": do we add it (like residual links), or concat it (like "densenet")?

In [0]:
class Generator2DS(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    
    # ENTER YOUR SOLUTION HERE
    raise NotImplementedError()
   
  def forward(self, x):
    # ENTER YOUR SOLUTION HERE
    raise NotImplementedError()

## Experiments

Please run your experiments below this point.

In [0]:
dataset = list_dataset('dataset-v2')

In [0]:
for fp in dataset_v2:
  im = imread(fp)
  assert im.shape == (256, 256, 3)