In [None]:
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.pyplot as plt

from drcomp.reducers import AutoEncoder
from drcomp.autoencoder.base import AbstractAutoEncoder
import torch.nn as nn
import torch

In [None]:
celeba = datasets.CelebA(
    "/storage/data", split="all", download=True, transform=transforms.ToTensor()
)

In [None]:
channels, *image_size = celeba[0][0].shape

In [None]:
class CustomDataSet(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx][0].view(3, 218, 178)
        return x, x


celeba = CustomDataSet(celeba)

In [None]:
class CelebaAutoEncoder(AbstractAutoEncoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=3, padding=1),  # 3x218x178 -> 8x73x60
            nn.ReLU(),
            nn.Flatten(),  # 8x73x60 -> 344800
            nn.Linear(8 * 73 * 60, 10),
        )
        self.decoder = nn.Sequential(
            nn.Linear(10, 8 * 73 * 60),
            nn.Unflatten(1, (8, 73, 60)),
            nn.ConvTranspose2d(8, 3, 3, stride=3, padding=1, output_padding=1),
            nn.ReLU(),
        )


class SimpleAutoEncoder(AbstractAutoEncoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(), nn.Linear(3 * 218 * 178, 10), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(10, 3 * 218 * 178), nn.ReLU(), nn.Unflatten(1, (3, 218, 178))
        )

In [None]:
ae = AutoEncoder(SimpleAutoEncoder, batch_size=128, max_epochs=2, device="cuda")
ae.fit(celeba)

In [None]:
# saving
import pickle

with open("models/celeba_simple.pkl", "wb") as f:
    pickle.dump(ae, f)