In [1]:
import jupyter_black

from PIL import Image
import os, shutil
import sys
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt

jupyter_black.load()


if "/usr/src" not in sys.path:
    sys.path.append("/usr/src")

In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [3]:
# Custom modules

In [93]:
def size_after_conv(size_in: tuple, kernel_size, stride, padding):
    h_in, w_in = size_in
    h_out = int(
        np.floor((h_in + 2 * padding - (kernel_size[0] - 1) - 1) / stride[0] + 1)
    )
    w_out = int(
        np.floor((w_in + 2 * padding - (kernel_size[1] - 1) - 1) / stride[1] + 1)
    )
    return h_out, w_out


def size_after_conv_convtranspose(size_in, kernel_size, stride, padding):
    h_in, w_in = size_in
    h_out = (h_in - 1) * stride[0] - 2 * padding + (kernel_size[0] - 1) + 1
    w_out = (w_in - 1) * stride[1] - 2 * padding + (kernel_size[1] - 1) + 1

    return h_out, w_out

In [89]:
size = (64, 64)
kernel_sizes = [(3, 3) for i in range(5)]
strides = [(1, 1), (2, 2), (2, 2), (2, 2), (1, 1)]
paddings = [1 for i in range(5)]

for i in range(5):
    size = size_after_conv(size, kernel_sizes[i], strides[i], paddings[i])
    print(f"Conv {i+1}, size = {size}")

Conv 1, size = (64, 64)
Conv 2, size = (32, 32)
Conv 3, size = (16, 16)
Conv 4, size = (8, 8)
Conv 5, size = (8, 8)


In [38]:
img_channels = 1
in_channels = 32
channel_mults = (1, 2, 2, 2, 2)

conv_channels = [img_channels, *map(lambda m: in_channels * m, channel_mults)]
encoder_channels = list(zip(conv_channels[:-1], conv_channels[1:]))
conv_channels = list(
    reversed([img_channels, *map(lambda m: in_channels * m, channel_mults)])
)
decoder_channels = list(zip(conv_channels[:-1], conv_channels[1:]))

In [43]:
encoder_channels, decoder_channels

([(1, 32), (32, 64), (64, 64), (64, 64), (64, 64)],
 [(64, 64), (64, 64), (64, 64), (64, 32), (32, 1)])

In [101]:
class Encoder(nn.Module):
    def __init__(
        self,
        image_size: tuple | int,
        latent_size: int,
        in_out_channels: list[tuple],
        kernels: list[tuple],
        strides: list[tuple],
        pads: list[int] | list[tuple],
        flattened_size: int,
    ):
        super().__init__()
        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        self.image_size = image_size
        self.latent_size = latent_size

        self.convolution_seria = nn.Sequential()
        for in_out, kernel, stride, pad in zip(in_out_channels, kernels, strides, pads):
            conv = nn.Conv2d(
                in_out[0], in_out[1], stride=stride, kernel_size=kernel, padding=pad
            )
            image_size = size_after_conv(image_size, kernel, stride, pad)
            self.convolution_seria.append(conv)
            self.convolution_seria.append(nn.LeakyReLU(0.01))
        self.convolution_seria.pop(-1)  # delete last activation function
        self.convolution_seria.append(nn.Flatten())

        self.z_mean = torch.nn.Linear(flattened_size, self.latent_size)
        self.z_log_var = torch.nn.Linear(flattened_size, self.latent_size)


class Decoder(nn.Module):
    def __init__(
        self,
        image_size: tuple | int,
        latent_size: int,
        in_out_channels: list[tuple],
        kernels: list[tuple],
        strides: list[tuple],
        pads: list[int] | list[tuple],
        flattened_size: int,
        unflattened_size: tuple,
    ):
        super().__init__()
        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        self.image_size = image_size
        self.latent_size = latent_size

        self.convolution_transpose_seria = nn.Sequential()
        self.convolution_transpose_seria.append(
            torch.nn.Linear(self.latent_size, flattened_size)
        )
        self.convolution_transpose_seria.append(nn.Unflatten(1, unflattened_size))

        size = unflattened_size[1:]
        for in_out, kernel, stride, pad in zip(in_out_channels, kernels, strides, pads):
            conv_transpose = nn.ConvTranspose2d(
                in_out[0], in_out[1], stride=stride, kernel_size=kernel, padding=pad
            )
            size = size_after_conv_convtranspose(size, kernel, stride, pad)
            self.convolution_transpose_seria.append(conv_transpose)
            self.convolution_transpose_seria.append(nn.LeakyReLU(0.01))
        self.convolution_transpose_seria.pop(-1)
        self.convolution_transpose_seria.append(Trim(self.image_size))
        self.convolution_transpose_seria.append(nn.Sigmoid())

        print(size)

        # self.convolution_transpose_seria = nn.Sequential(
        #     torch.nn.Linear(self.latent_size, 3136),
        #     nn.Unflatten(1, (64, 7, 7)),
        #     nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
        #     nn.LeakyReLU(0.01),
        #     nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
        #     nn.LeakyReLU(0.01),
        #     nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=0),
        #     nn.LeakyReLU(0.01),
        #     nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0),
        #     nn.LeakyReLU(0.01),
        #     nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0),
        #     Trim(self.image_size),  # 1x57x57 -> 1x56x56
        #     nn.Sigmoid()
        # )


class Trim(nn.Module):
    def __init__(self, image_size):
        super().__init__()
        self.size = image_size

    def forward(self, x):
        return x[:, :, : self.size, : self.size]

In [None]:
import lightning.pytorch as pl

flatten_size = in_out_channels[-1][-1] * image_size[0] * image_size[1]
unflattened_size = (in_out_channels[-1][-1], image_size[0], image_size[1])


class VariationalAutoEncoder(pl.LightningModule):
    def __init__(self, img_channels, in_channels, channel_mults):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder

In [103]:
kernels = [(3, 3) for i in range(5)]
strides = [(1, 1), (2, 2), (2, 2), (2, 2), (1, 1)]
paddings = [1 for i in range(5)]

encoder = Encoder(64, 4, encoder_channels, kernels, strides, paddings, 4096)
decoder = Decoder(
    (64, 64), 4, decoder_channels, kernels, strides, [1, 1, 0, 0, 0], 4096, (64, 8, 8)
)
decoder

(65, 65)


Decoder(
  (convolution_transpose_seria): Sequential(
    (0): Linear(in_features=4, out_features=4096, bias=True)
    (1): Unflatten(dim=1, unflattened_size=(64, 8, 8))
    (2): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.01)
    (6): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (7): LeakyReLU(negative_slope=0.01)
    (8): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2))
    (9): LeakyReLU(negative_slope=0.01)
    (10): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(1, 1))
    (11): Trim()
    (12): Sigmoid()
  )
)