In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import Image, clear_output, display

import glob
import os
import time

from PIL import Image
from skimage.color import lab2rgb, rgb2hsv, rgb2lab
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import Places365
from tqdm.notebook import tqdm
import cv2

# From Baseline

In [10]:
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)
        
        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)
        

class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
    
    def forward(self, x):
        return self.model(x)

In [12]:
u = Unet(input_c=1, output_c=2, n_down=8, num_filters=64)
print(u)

Unet(
  (model): UnetBlock(
    (model): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): UnetBlock(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): UnetBlock(
            (model): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace=True)
              (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): UnetBlock(
                (model): Sequential(
                  (0): LeakyReLU(negative_slope=0.2, inplace=True)
                  (1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
 

In [13]:
dummy_input = torch.randn(16, 1, 256, 256) # batch_size, channels, size, size
out = u(dummy_input)
out.shape

torch.Size([16, 2, 256, 256])

# New implementation

In [14]:
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False, down=True):
        super().__init__()
        self.outermost = outermost
        self.innermost = innermost
        self.down_sample = down
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        uprelu = nn.ReLU(True)
        downnorm = nn.BatchNorm2d(ni)
        upnorm = nn.BatchNorm2d(nf)
        
        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            if down:
                down = [downconv]
                model = down + [submodule]
            else:
                up = [uprelu, upconv, nn.Tanh()]
                model = [submodule] + up
        
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
  
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            if down:
                down = [downrelu, downconv, downnorm]
                model = down + [submodule]
            else:
                up = [uprelu, upconv, upnorm]
                if dropout: up += [nn.Dropout(0.5)]
                if submodule is None:
                    model = up
                else:
                    model = [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            print("outermost")
            print("x dim:", x.shape)
            y = self.model(x)
            print("y dim:", y.shape)
            return y
        else:
            print("another layer")
            print("x dim:",x.shape)
            y = torch.cat([x, self.model(x)],1)
#             y = self.model(x)
            print("y dim:",y.shape)
            return y
        
class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        
        # down sampling + bottleneck
        unet_block = UnetBlock(num_filters * 8, num_filters*8, innermost=True) # bottleneck
        for _ in range(n_down - 5): # build 3 first downsampling layers
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True, down=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block, down=True)
            out_filters //= 2
        unet_block = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True, down=True)
        self.model = unet_block
    
        # upsampling
#         unet_block = UnetBlock(num_filters * 8, num_filters * 8, dropout=True, down=False) # first upsampling layer
#         for _ in range(n_down - 6): # NOTE: we loop one time less here since we have already defined the first upsampling layer
#             unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule = unet_block, dropout=True, down=False)
#         out_filters = num_filters * 8
#         for _ in range(3):
#             unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block, down=False)
#             out_filters //= 2
#         unet_block = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True, down=False)
        
        # FFNN for label
        self.embed = nn.Embedding(365,100)
#         self.linear = nn.Sequential()
    
    def forward(self, x):
        x = self.model(x)
#         embedding = self.embed(self.label)
#         encoded_label = self.linear(embedding)
#         x = torch.cat([x, encoded_label],dim=1) # concatenate label and data in dim 1
#         x = self.up(x)
        return x

In [15]:
u = Unet(input_c=1, output_c=2, n_down=8, num_filters=64)
print(u)

Unet(
  (model): UnetBlock(
    (model): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): UnetBlock(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): UnetBlock(
            (model): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace=True)
              (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): UnetBlock(
                (model): Sequential(
                  (0): LeakyReLU(negative_slope=0.2, inplace=True)
                  (1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
 

In [None]:
dummy_input = torch.randn(16, 1, 256, 256) # batch_size, channels, size, size
out = u(dummy_input)
out.shape