In [36]:
import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_default_device(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [37]:
import torch
from torch import nn


class Encoder(nn.Module):
    def __init__(self, hidden_channels: int) -> None:
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, stride=2, padding=0),
            self.lrelu
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0),
            self.lrelu
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0),
            self.lrelu
        )


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x) # makes 14, 14, 32
        x = self.block2(x) # makes 7, 7, 64
        x = self.block3(x) # makes 7, 7, 4
        return x

In [38]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, hidden_channels: int) -> None:
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=hidden_channels, out_channels=64, kernel_size=4, stride=1, padding=0),
            self.lrelu
        )
        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=0),
            self.lrelu
        )
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=4, stride=2, padding=0),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

In [39]:
import torch
from torch import nn


class VAE(nn.Module):
    def __init__(self, hidden_channels) -> None:
        super().__init__()
        self.encoder = Encoder(hidden_channels)
        self.decoder = Decoder(hidden_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = VAE(hidden_channels=4)

In [40]:
foo = torch.randn(28, 28)
foo = foo.unsqueeze(0)
skibidi = model(foo)

print(skibidi.size())

torch.Size([1, 38, 38])
