In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

import torch
import torch.nn as nn
import torch.utils.data as data
from skorch import NeuralNet

In [None]:
DATASET_PATH = "../data/"

In [None]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor()])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = CIFAR10(
    root=DATASET_PATH, train=True, transform=transform, download=True
)

In [None]:
class ConvAutoEncoder(nn.Module):
    def __init__(
        self, num_input_channels: int, latent_dim: int, act_fn: object = nn.ReLU
    ):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(
                num_input_channels, 8, kernel_size=3, stride=2, padding=1
            ),  # b, 8, 16, 16
            act_fn(),
            nn.Flatten(),
            nn.Linear(8 * 16 * 16, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 8 * 16 * 16),
            act_fn(),
            nn.Unflatten(1, (8, 16, 16)),
            nn.ConvTranspose2d(
                8,
                num_input_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            ),  # b, 8, 32, 32
            act_fn(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, encoded

In [None]:
class CustomNet(NeuralNet):
    def get_loss(self, y_pred, y_true, *args, **kwargs):
        decoded, _ = y_pred
        reconstr_loss = super().get_loss(decoded, y_true, *args, **kwargs)
        # reconstr_loss = reconstr_loss.sum(dim=(1,2,3),reduction=None).mean(dim=[0])
        return reconstr_loss

In [None]:
num_input_channels = 3
base_channel_size = 3
latent_dim = 64
model = CustomNet(
    ConvAutoEncoder(num_input_channels, latent_dim),
    criterion=nn.MSELoss,
    iterator_train__shuffle=True,
    batch_size=128,
)
print(model)

In [None]:
model.fit(train_dataset)