In [None]:
import torch
from jupyter_core.migrate import security_file
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1),
            self.lrelu,
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            self.lrelu,
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=2, stride=2, padding=0),
            self.lrelu,
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=2, stride=2, padding=0),
            self.lrelu,
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0)
        )

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        print(f"{x.size()}")
        x = self.block2(x)
        print(f"{x.size()}")
        x = self.block3(x)
        print(f"{x.size()}")
        x = self.block4(x)
        print(f"{x.size()}")
        x = self.block5(x)
        print(f"{x.size()}", end="\n\n")
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=hidden_channels, out_channels=256, kernel_size=1, stride=1, padding=0),
            self.lrelu,
        )
        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0),
            self.lrelu,
        )
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0),
            self.lrelu,
        )
        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1),
            self.lrelu,
        )
        self.block5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, x):
        x = self.block1(x)
        print(f"{x.size()}")
        x = self.block2(x)
        print(f"{x.size()}")
        x = self.block3(x)
        print(f"{x.size()}")
        x = self.block4(x)
        print(f"{x.size()}")
        x = self.block5(x)
        print(f"{x.size()}")
        return x

In [None]:
class VAE(nn.Module):
    def __init__(self, hidden_channels:int) -> None:
        super().__init__()
        self.encode = Encoder(hidden_channels)
        self.decode = Decoder(hidden_channels)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.encode(x)
        x = self.decode(x)
        return x

In [None]:
c, h, w = 3, 512, 512
test_tensor = torch.randn(c, h, w)
hidden_num_channels = 16
encoder = Encoder(hidden_num_channels)
lat_space = encoder(test_tensor)
before, after = c*h*w, hidden_num_channels*(h/16)*(w/16)
print(f"Number of pixels before: {before:,} | Number of pixels after: {after:,} | reduction amount: {before/after}")

In [None]:
model = VAE(hidden_num_channels)
encdec = model(test_tensor)