In [132]:
import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_default_device(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [133]:
import torch
from torch import nn


class Encoder(nn.Module):
    def __init__(self, hidden_channels: int) -> None:
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, stride=2, padding=0),
            self.lrelu
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, stride=2, padding=0),
            self.lrelu
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0),
            self.lrelu
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        print(f"enc block 1: {x.size()}")
        x = self.block2(x)
        print(f"enc block 2: {x.size()}",end="\n\n")
        x = self.block3(x)
        print(f"latent space: {x.size()}", end="\n\n")
        return x

In [134]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, hidden_channels: int) -> None:
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=hidden_channels, out_channels=64, kernel_size=1, stride=1, padding=0),
            self.lrelu
        )
        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2, padding=0),
            self.lrelu
        )
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=4, stride=2, padding=0),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        print(f"dec block 1: {x.size()}")
        x = self.block2(x)
        print(f"dec block 2: {x.size()}")
        x = self.block3(x)
        print(f"post dec: {x.size()}")
        return x

In [135]:
import torch
from torch import nn


class VAE(nn.Module):
    def __init__(self, hidden_channels) -> None:
        super().__init__()
        self.encoder = Encoder(hidden_channels)
        self.decoder = Decoder(hidden_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.decoder(x)
        return x


model = VAE(hidden_channels=2)

In [136]:
foo = torch.randn(28, 28)
foo = foo.unsqueeze(0)
print(f"pre enc:{foo.size()}")
foo = model(foo)

pre enc:torch.Size([1, 28, 28])
enc block 1: torch.Size([32, 13, 13])
enc block 2: torch.Size([64, 6, 6])

latent space: torch.Size([2, 6, 6])

dec block 1: torch.Size([64, 6, 6])
dec block 2: torch.Size([32, 12, 12])
post dec: torch.Size([1, 26, 26])
