In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

import wandb
import torch
import torch.nn as nn
import torch.utils.data as data
from skorch import NeuralNet
from drcomp.reducers import AutoEncoder
from drcomp.autoencoder import AbstractAutoEncoder, Cifar10ConvAE
import numpy as np
from skorch.callbacks import Checkpoint, LRScheduler, ProgressBar, WandbLogger

import matplotlib.pyplot as plt

%env "WANDB_NOTEBOOK_NAME" "./cifar10_autoencoder.ipynb"
import pickle

In [None]:
DATASET_PATH = "/storage/data"

In [None]:
train_dataset = CIFAR10(
    root="../data", train=True, transform=transforms.ToTensor(), download=True
)

In [None]:
X = train_dataset.data.reshape(-1, 3, 32, 32).astype(np.float32)
X.shape

In [None]:
wandb.login()

In [None]:
lr_schedule = LRScheduler(policy="ReduceLROnPlateau")

config = {
    "epochs": 100,
    "batch_size": 100,
}

wandb_run = wandb.init(project="drcomp", group="CIFAR10_Autoencoder", reinit=True)
wandb = WandbLogger(wandb_run)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoEncoder(
    Cifar10ConvAE,
    batch_size=config["batch_size"],
    max_epochs=config["epochs"],
    device=device,
    callbacks=[lr_schedule, WandbLogger(wandb_run)],
)
model.fit(X)
wandb_run.finish()
with open("models/cifar10_autoencoder_contractive.pkl", "wb") as f:
    pickle.dump(model, f)

In [None]:
model = pickle.load(open("models/cifar10_autoencoder_contractive.pkl", "rb"))

In [None]:
Y = model.transform(X)

In [None]:
X_hat = model.inverse_transform(Y)