In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

In [3]:
torch.cuda.get_device_name(0)

'GeForce RTX 2070'

In [4]:
img = torch.randn(1, 1, 28, 28)

In [5]:
img.shape

torch.Size([1, 1, 28, 28])

In [6]:
import torch.nn as nn

In [24]:
# 28 -> 10
# 10 -> 8
# 8 -> 16
# 16 -> 28

In [25]:
def conv_out_shape(w, k, s, p=0):
    return (w - k + 2 * p) / s + 1

In [26]:
def tconv_out_shape(w, k, s):
    return w * s + max(k - s, 0)


In [27]:
tconv_out_shape(4, 3, 2)

9

In [28]:
tconv_out_shape(9, 3, 2)

19

In [29]:
tconv_out_shape(19, 2, 2)

38

In [30]:
tconv_out_shape(27, 2, 1)

28

In [31]:
[4, 3, 2]

[4, 3, 2]

In [50]:
conv_out_shape(28, 4, 2)

13.0

In [51]:
conv_out_shape(13, 3, 2)

6.0

In [52]:
conv_out_shape(6, 3, 1)

4.0

In [35]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

In [36]:
flatten = Flatten()

In [37]:
f = flatten(x2)

In [38]:
f.shape

torch.Size([1, 1152])

In [72]:
class UnFlatten(nn.Module):
    def __init__(self, channels, height, width):
        super(UnFlatten, self).__init__()
        self.channels = channels
        self.height = height
        self.width = width

    def forward(self, x):
        return x.view(x.size(0), self.channels, self.height, self.width)

In [54]:
import sys
sys.path.append("..")
from utils.gelu import GELU

In [55]:
conv_layers = [
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=4,
                      stride=2),
            nn.BatchNorm2d(64),
            GELU(),
            nn.Conv2d(in_channels=64,
                      out_channels=128,
                      kernel_size=3,
                      stride=2),
            nn.BatchNorm2d(128),
            GELU(),
            nn.Conv2d(in_channels=128,
                      out_channels=256,
                      kernel_size=3,
                      stride=1),
            nn.BatchNorm2d(256),
            GELU(),
            Flatten(),
        ]

In [56]:
conv_encoding = nn.Sequential(*conv_layers)

In [57]:
i = conv_encoding(img)

In [65]:
i.shape

torch.Size([1, 4096])

In [73]:
tconv_layers = [
            UnFlatten(256, 4, 4),
            nn.ConvTranspose2d(in_channels=256,
                               out_channels=128,
                               kernel_size=3,
                               stride=2),
            nn.BatchNorm2d(128),
            GELU(),
            nn.ConvTranspose2d(in_channels=128,
                               out_channels=64,
                               kernel_size=3,
                               stride=3),
            nn.BatchNorm2d(64),
            GELU(),
            nn.ConvTranspose2d(in_channels=64,
                               out_channels=1,
                               kernel_size=2,
                               stride=1),
            nn.Sigmoid(),
        ]

decoding = nn.Sequential(*tconv_layers)

In [74]:
decoding(i).shape

torch.Size([1, 1, 28, 28])

In [83]:
from torch.distributions.normal import Normal

In [84]:
m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))

In [85]:
i.shape

torch.Size([1, 4096])

In [89]:
m.sample(torch.Size([1, 4096]))

tensor([[[ 4.4686e-05],
         [-1.8843e-01],
         [-2.6285e-01],
         ...,
         [ 2.3783e+00],
         [-1.1375e+00],
         [ 6.0490e-01]]])

In [88]:
torch.Size([1, 4096])

torch.Size([1, 4096])