In [None]:
import torch

from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_channels:int) -> None:
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1),
            self.lrelu
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            self.lrelu
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=2, stride=2, padding=0),
            self.lrelu
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=2, stride=2, padding=0),
            self.lrelu
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        print(f"{x.size()}")
        x = self.block2(x)
        print(f"{x.size()}")
        x = self.block3(x)
        print(f"{x.size()}")
        x = self.block4(x)
        print(f"{x.size()}")
        x = self.block5(x)
        print(f"{x.size()}", end="\n\n")
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=hidden_channels, out_channels=256, kernel_size=1, stride=1, padding=0),
            self.lrelu,
        )
        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0),
            self.lrelu,
        )
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0),
            self.lrelu,
        )
        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1),
            self.lrelu,
        )
        self.block5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, x):
        x = self.block1(x)
        print(f"{x.size()}")
        x = self.block2(x)
        print(f"{x.size()}")
        x = self.block3(x)
        print(f"{x.size()}")
        x = self.block4(x)
        print(f"{x.size()}")
        x = self.block5(x)
        print(f"{x.size()}")
        return x

In [None]:
class VAE(nn.Module):
    def __init__(self, hidden_channels: int) -> None:
        super().__init__()
        self.encode = Encoder(hidden_channels)
        self.decode = Decoder(hidden_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encode(x)
        x = self.decode(x)
        return x

In [None]:
import torch
import torch.nn.functional as function


def resize_to_multiple_of_16(tensor: torch.Tensor, mode="bicubic") -> torch.Tensor:
    """
    Resizes tensor (N, C, H, W) or (C, H, W) such that H and W are multiples of 16.
    :param tensor: tensor to resize
    :param mode: rescale mode
    :return: resized tensor
    """
    orig_h, orig_w = tensor.shape[-2], tensor.shape[-1]

    new_h = round(orig_h / 16) * 16
    new_w = round(orig_w / 16) * 16

    if tensor.ndim == 3:  # (C, H, W)
        tensor = tensor.unsqueeze(0)  # Add batch dim -> (1, C, H, W)
        resized = function.interpolate(tensor, size=(new_h, new_w), mode=mode, align_corners=False)
        resized = resized.squeeze(0)  # Remove batch dim -> (C, H, W)
    else:  # (N, C, H, W)
        resized = function.interpolate(tensor, size=(new_h, new_w), mode=mode, align_corners=False)

    return resized

In [127]:
test_model = VAE(hidden_channels=16)

test_tensor = torch.randn(1, 3, 1920, 1080)
resized_tensor = resize_to_multiple_of_16(test_tensor)
print(resized_tensor.shape, end="\n\n")
out_tensor = test_model(resized_tensor)
# print(out_tensor.shape)

torch.Size([1, 3, 1920, 1088])

torch.Size([1, 32, 960, 544])
torch.Size([1, 64, 480, 272])
torch.Size([1, 128, 240, 136])
torch.Size([1, 256, 120, 68])
torch.Size([1, 16, 120, 68])

torch.Size([1, 256, 120, 68])
torch.Size([1, 128, 240, 136])
torch.Size([1, 64, 480, 272])
torch.Size([1, 32, 960, 544])
torch.Size([1, 3, 1920, 1088])
