In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import utils
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from typing import Iterator

In [2]:
from datasets import load_dataset, iterable_dataset

ds = load_dataset("nielsr/CelebA-faces", streaming=True)['train']

In [170]:
class FacesIterDataset(utils.data.IterableDataset):
    def __init__(self, 
                 iterable: iterable_dataset.IterableDataset,
                 transforms: A.Compose
                ):
        self.iterable = iterable
        self.transforms = transforms

    def __iter__(self) -> Iterator[torch.Tensor]:
        for item in self.iterable:
            image = np.array(item['image'])
            image = self.transforms(image=image)['image']

            yield image

transforms = A.Compose([
    A.CenterCrop(160, 140, p=1),
    A.ToGray(1, p=1),
    A.Normalize(mean=0.4375, std=0.2708),
    ToTensorV2()
])

train_dataloader = utils.data.DataLoader(FacesIterDataset(ds, transforms),
                              batch_size = 6,
                              shuffle=False)

In [171]:
def image_denormalization(img: torch.Tensor, mean: float = 0.4375, std: float = 0.2708) -> np.ndarray:
    return np.clip((img.numpy() * std + mean)*255, 0, 255).astype(np.uint8)

In [221]:
from math import ceil

class Encoder(nn.Module):
    def __init__(self, 
                 img_size: tuple[int],
                 in_channels: int,
                 latent_dim: int,
                 hidden_layers: tuple[int]
                ):
        super().__init__()

        layers = []
        
        for layer in hidden_layers:
            layers.extend([
                nn.Conv2d(in_channels, layer, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(layer),
                nn.ELU()
            ])
            in_channels = layer

        self.encoder = nn.Sequential(*layers)
        
        enc_dim = hidden_layers[-1] * ceil(img_size[0] / 2**len(hidden_layers)) * ceil(img_size[1] / 2**len(hidden_layers)) 
        self.fc_mu = nn.Linear(enc_dim, latent_dim)
        self.fc_sigma = nn.Linear(enc_dim, latent_dim)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
        x = torch.flatten(self.encoder(x))
        
        mu = self.fc_mu(x)
        sigma = self.fc_sigma(x)

        return mu, sigma

class Decoder(nn.Module):
    pass

class VAE(nn.Module):
    def __init__(self, 
                 img_size: tuple[int],
                 in_channels: int,
                 latent_dim: int = 256,
                 hidden_layers: tuple[int] = (8, 32, 64, 128)
                ):
        super().__init__()
        
        self.encoder = Encoder(...)
        self.decoder = Decoder(...)

    def encode(x: torch.Tensor):
        pass

    def decode(mu: torch.Tensor, sigma: torch.Tensor):
        pass