In [1]:
### Building blocks

import torch.nn as nn
import torch


class Downscale(nn.Module):

    def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs):
        super().__init__(*kwargs)
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.conv1 = nn.Conv2d(self.in_ch, self.out_ch, self.kernel_size,
                               stride=2, padding="same")
        self.activation = nn.LeakyReLU(negative_slope=0.1)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.activation(x)
        return x

    def get_out_ch(self):
        return self.out_ch

class DownscaleBlock(nn.Module):

    def __init__(self, in_ch, ch, n_downscales, kernel_size):
        super().__init__()
        self.downs = []
        last_ch = in_ch
        for i in range(n_downscales):
            cur_ch = ch*( min(2**i, 8)  )
            self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size))
            last_ch = self.downs[-1].get_out_ch()
    
    def forward(self, inp):
        x = inp
        for down in self.downs:
            x = down(x)
        return x


def depth_to_space(x,size):
    x = torch.permute(x,(0,2,3,1))
    b,h,w,c = x.shape
    oh, ow = h * size, w * size
    oc = c // (size * size)
    x = x.reshape((-1,h,w,size,size,oc,))
    x = torch.permute(x,(0, 1, 3, 2, 4, 5))
    x = x.reshape((-1, oc, oh, ow))
    return x

class Upscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, *kwargs):
        super().__init__(*kwargs)
        self.conv1 = nn.Conv2d(in_ch, out_ch*4, kernel_size,
                               padding="same")
        self.activation = nn.LeakyReLU(negative_slope=0.1)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.activation(x)
        x = depth_to_space(x,2)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, ch, kernel_size=3, conv_dtype=torch.float32):
        super().__init__()
        self.conv1 = nn.Conv2d( ch, ch, kernel_size=kernel_size, padding='same', dtype=conv_dtype)
        self.conv2 = nn.Conv2d( ch, ch, kernel_size=kernel_size, padding='same', dtype=conv_dtype)
        self.activation = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, inp):
        x = self.conv1(inp)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(inp + x)
        return x

In [14]:
ae_dims = 128
e_dims = 64
d_dims = 64
d_mask_dims = 16
input_ch = 3
resolution = 96
opts = "ud"
use_fp16=False

In [32]:
class Encoder(nn.Module):
    def __init__(self, in_ch, e_ch,opts="ud", **kwargs ):
        self.in_ch = in_ch
        self.e_ch = e_ch
        self.opts = opts
        if 't' in self.opts:
            self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
            self.res1 = ResidualBlock(self.e_ch)
            self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
            self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
            self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
            self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
            self.res5 = ResidualBlock(self.e_ch*8)
        else:
            self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in self.opts else 5, kernel_size=5)
        super().__init__(**kwargs)

        

    def forward(self, x):

        if 't' in self.opts:
            x = self.down1(x)
            x = self.res1(x)
            x = self.down2(x)
            x = self.down3(x)
            x = self.down4(x)
            x = self.down5(x)
            x = self.res5(x)
        else:
            x = self.down1(x)
        x = torch.flatten(x)
        if 'u' in self.opts:
            x = torch.norm(x,dim = -1)
        return x

    def get_out_res(self, res):
        return res // ( (2**4) if 't' not in self.opts else (2**5) )

    def get_out_ch(self):
        return self.e_ch * 8

In [33]:
import torch

class Decoder(nn.Module):

    def __init__(self, in_ch, d_ch, d_mask_ch, opts):
        "create Decoder as torch Module using previously defined building blocks"
        super().__init__()
        self.opts = opts
        conv_dtype = torch.float32

        if 't' not in self.opts:
            self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
            self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
            self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
            self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
            self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
            self.res2 = ResidualBlock(d_ch*2, kernel_size=3)

            self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
            self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
            self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)

            self.out_conv  = nn.Conv2d( d_ch*2, 3, kernel_size=1, padding='same', dtype=conv_dtype)

            if 'd' in self.opts:
                self.out_conv1 = nn.Conv2d( d_ch*2, 3, kernel_size=3, padding='same', dtype=conv_dtype)
                self.out_conv2 = nn.Conv2d( d_ch*2, 3, kernel_size=3, padding='same', dtype=conv_dtype)
                self.out_conv3 = nn.Conv2d( d_ch*2, 3, kernel_size=3, padding='same', dtype=conv_dtype)
                self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
                self.out_convm = nn.Conv2d( d_mask_ch*1, 1, kernel_size=1, padding='same', dtype=conv_dtype)
            else:
                self.out_convm = nn.Conv2d( d_mask_ch*2, 1, kernel_size=1, padding='same', dtype=conv_dtype)
        else:
            self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
            self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
            self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
            self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
            self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
            self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
            self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
            self.res3 = ResidualBlock(d_ch*2, kernel_size=3)

            self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
            self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
            self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
            self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
            self.out_conv  = nn.Conv2d( d_ch*2, 3, kernel_size=1, padding='same', dtype=conv_dtype)

            if 'd' in self.opts:
                self.out_conv1 = nn.Conv2d( d_ch*2, 3, kernel_size=3, padding='same', dtype=conv_dtype)
                self.out_conv2 = nn.Conv2d( d_ch*2, 3, kernel_size=3, padding='same', dtype=conv_dtype)
                self.out_conv3 = nn.Conv2d( d_ch*2, 3, kernel_size=3, padding='same', dtype=conv_dtype)
                self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
                self.out_convm = nn.Conv2d( d_mask_ch*1, 1, kernel_size=1, padding='same', dtype=conv_dtype)
            else:
                self.out_convm = nn.Conv2d( d_mask_ch*2, 1, kernel_size=1, padding='same', dtype=conv_dtype)

    
        
    def forward(self, z):
        x = self.upscale0(z)
        x = self.res0(x)
        x = self.upscale1(x)
        x = self.res1(x)
        x = self.upscale2(x)
        x = self.res2(x)

        if 't' in self.opts:
            x = self.upscale3(x)
            x = self.res3(x)

        if 'd' in self.opts:
            x = torch.sigmoid( depth_to_space(torch.concat( (self.out_conv(x),
                                                                self.out_conv1(x),
                                                                self.out_conv2(x),
                                                                self.out_conv3(x)), -1), 2) )
        else:
            x = torch.sigmoid(self.out_conv(x))


        m = self.upscalem0(z)
        m = self.upscalem1(m)
        m = self.upscalem2(m)

        if 't' in self.opts:
            m = self.upscalem3(m)
            if 'd' in self.opts:
                m = self.upscalem4(m)
        else:
            if 'd' in self.opts:
                m = self.upscalem3(m)

        m = torch.sigmoid(self.out_convm(m))

        # if use_fp16:
        #     x = tf.cast(x, tf.float32)
        #     m = tf.cast(m, tf.float32)

        return x, m


In [34]:
lowest_dense_res = resolution // (32 if 'd' in opts else 16)

class Inter(nn.Module):
    def __init__(self, in_ch, ae_ch, ae_out_ch,opts = "ud", **kwargs):
        self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
        self.opts = opts
        in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch

        self.dense1 = nn.Linear( in_ch, ae_ch )
        self.dense2 = nn.Linear( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
        if 't' not in self.opts:
            self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
        super().__init__(**kwargs)



    def forward(self, inp):
        x = inp
        x = self.dense1(x)
        x = self.dense2(x)
        x = x.reshape((-1,lowest_dense_res, lowest_dense_res, self.ae_out_ch))
        if 't' not in self.opts:
            x = self.upscale1(x)

        return x

    def get_out_res(self):
        return lowest_dense_res * 2 if 't' not in self.opts else lowest_dense_res

    def get_out_ch(self):
        return self.ae_out_ch

In [35]:
# out_ch is multiplied by 4 -> understand "image channels"

dec = Decoder(in_ch=ae_dims, d_ch=d_dims, d_mask_ch=d_mask_dims,opts="up")

In [36]:
img = torch.rand((1,128,32,32))
y = dec(img)

In [37]:
y[0].shape

torch.Size([1, 3, 256, 256])

In [38]:
import numpy as np
params = sum([np.prod(p.size()) for p in dec.parameters()])
print(dec)
print("num of params", params)

Decoder(
  (upscale0): Upscale(
    (conv1): Conv2d(128, 2048, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): LeakyReLU(negative_slope=0.1)
  )
  (upscale1): Upscale(
    (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): LeakyReLU(negative_slope=0.1)
  )
  (upscale2): Upscale(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): LeakyReLU(negative_slope=0.1)
  )
  (res0): ResidualBlock(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (res1): ResidualBlock(
    (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (res2): ResidualBlock(
    (conv1): Conv2d(128, 128, ke

In [39]:
encoder = Encoder(in_ch=input_ch, e_ch=e_dims,)
encoder_out_ch = encoder.get_out_ch()*encoder.get_out_res(resolution)**2

inter = Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims)
inter_out_ch = inter.get_out_ch()

decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims,opts="ud")
decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims,opts="ud")

ValueError: padding='same' is not supported for strided convolutions

In [24]:
import numpy as np
params = sum([np.prod(p.size()) for p in encoder.parameters()])
print(encoder)
print("num of params", params)

Encoder()
num of params 0


In [25]:
encoder

Encoder()