In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

import wandb
import torch
import torch.nn as nn
import torch.utils.data as data
from skorch import NeuralNet
from drcomp.reducers import AutoEncoder
from drcomp.autoencoder.base import AbstractAutoEncoder
import numpy as np
from skorch.callbacks import Checkpoint, LRScheduler, ProgressBar, WandbLogger

import matplotlib.pyplot as plt

%env "WANDB_NOTEBOOK_NAME" "./cifar10_autoencoder.ipynb"
import pickle

In [None]:
DATASET_PATH = "/storage/data"

In [None]:
train_dataset = CIFAR10(
    root=DATASET_PATH, train=True, transform=transforms.ToTensor(), download=True
)

In [None]:
X = train_dataset.data.reshape(-1, 3, 32, 32).astype(np.float32)
X.shape

In [None]:
class CIFAR10_Autoencoder(AbstractAutoEncoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),  # b, 16, 16, 16
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # b, 32, 8, 8
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # b, 64, 4, 4
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # b, 128, 2, 2
            nn.Flatten(),
            nn.Linear(128 * 2 * 2, 64),
        )
        self.decoder = nn.Sequential(
            nn.Linear(64, 128 * 2 * 2),
            nn.ReLU(),
            nn.Unflatten(1, (128, 2, 2)),
            nn.ConvTranspose2d(
                128, 64, 3, stride=2, padding=1, output_padding=1
            ),  # b, 64, 4, 4
            nn.ReLU(),
            nn.ConvTranspose2d(
                64, 32, 3, stride=2, padding=1, output_padding=1
            ),  # b, 32, 8, 8
            nn.ConvTranspose2d(
                32, 16, 3, stride=2, padding=1, output_padding=1
            ),  # b, 16, 16, 16
            nn.ReLU(),
            nn.ConvTranspose2d(
                16, 3, 3, stride=2, padding=1, output_padding=1
            ),  # b, 3, 32, 32
        )

In [None]:
# wandb.login(key="API_KEY")

In [None]:
lr_schedule = LRScheduler(policy="ReduceLROnPlateau")

config = {
    "epochs": 100,
    "batch_size": 128,
}

wandb_run = wandb.init(project="drcomp", group="CIFAR10_Autoencoder")
wandb = WandbLogger(wandb_run)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoEncoder(
    CIFAR10_Autoencoder,
    batch_size=config["batch_size"],
    max_epochs=config["epochs"],
    device=device,
    callbacks=[lr_schedule, WandbLogger(wandb_run)],
)
model.fit(X)
wandb_run.finish()
with open("../models/cifar10_autoencoder.pkl", "wb") as f:
    pickle.dump(model, f)

In [None]:
model = pickle.load(open("../models/cifar10_autoencoder.pkl", "rb"))

In [None]:
Y = model.transform(X)

In [None]:
X_hat = model.inverse_transform(Y)