In [None]:
import os
import urllib.request
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import CIFAR10
from tqdm.notebook import tqdm

%matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()
sns.set()

# Tensorboard extension (for visualization purposes later)
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data/raw"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = None

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
from drcomp import estimate_intrinsic_dimension, DimensionalityReducer
from sklearn.datasets import fetch_lfw_people

In [None]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor()])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = fetch_lfw_people(min_faces_per_person=70, data_home="../data/raw").data
pl.seed_everything(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [1100, 188])

# Loading the test set
# test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(
    train_set,
    batch_size=16,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=4,
)
val_loader = data.DataLoader(
    val_set, batch_size=16, shuffle=False, drop_last=False, num_workers=4
)
# test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)

In [None]:
class AutoEncoderBase(nn.Module):
    def __init__(self, input_size: int, intrinsic_dim: int):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, intrinsic_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(intrinsic_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1024),
            nn.ReLU(),
            nn.Linear(1024, input_size),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
class Autoencoder(pl.LightningModule):
    def __init__(self, input_size: int, intrinsic_dim: int):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Example input array needed for visualizing the graph of the network
        self.base = AutoEncoderBase(input_size, intrinsic_dim)
        # self.example_input_array = torch.zeros(2, 1, 62, 47)

    def forward(self, x):
        """The forward function takes in an image and returns the reconstructed image."""
        return self.base.forward(x)

    def _get_reconstruction_loss(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)"""
        x = batch  # We do not need the labels
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="mean")
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("test_loss", loss)

In [None]:
def train_lfw(latent_dim):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(
        default_root_dir="../data",
        # gpus=1 if str(device).startswith("cuda") else 0,
        accelerator="mps",
        max_epochs=500,
        callbacks=[
            ModelCheckpoint(save_weights_only=True),
            # GenerateCallback(get_train_images(8), every_n_epochs=10),
            LearningRateMonitor("epoch"),
        ],
    )
    trainer.logger._log_graph = (
        True  # If True, we plot the computation graph in tensorboard
    )
    trainer.logger._default_hp_metric = (
        None  # Optional logging argument that we don't need
    )

    model = Autoencoder(input_size=2914, intrinsic_dim=latent_dim)
    trainer.fit(model, train_loader, val_loader)

    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    # test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    # result = {"test": test_result, "val": val_result}
    result = {"val": val_result}
    return model, result

In [None]:
train_lfw(22)

In [None]:
model = Autoencoder.load_from_checkpoint(
    "../data/lightning_logs/version_1/checkpoints/epoch=17-step=1224.ckpt"
)

In [None]:
model.eval()

In [None]:
X = transform(train_dataset).reshape(-1, 2914)

In [None]:
X_hat = model.forward(X)
X_hat[0]

In [None]:
plt.imshow(X[200].reshape(62, 47), cmap="gray")
plt.show()

In [None]:
plt.imshow(X_hat[200].detach().numpy().reshape(62, 47), cmap="gray")
plt.show()