In [2]:
import torch
from torch.utils.data import DataLoader
from ae.utils import set_seed
from ae.dataset import HiddenManifold, load_mnist, load_cifar
from ae.sam import SAM
from ae.model import AutoEncoder
from ae.trainer import Trainer

In [27]:
# architecture
N = 3072
B = 30
encoder_hidden = [N, B]
decoder_hidden = [N]

# data
train_size = 8192
test_size = 2048
get_dataset = load_cifar
root = "../data"

# training
batch_size = 32
lr = 0.0005
weight_decay = 0.1
max_epochs = 150
target_loss = 0.0
base_optimizer = torch.optim.AdamW
optimizer_class = "adamw"
rho = 0.1

# pytorch
device = "cpu" if not torch.cuda.is_available() else "cuda"
compile_model = True
seed = 50
print("Using device: ", device)

# logging
log_to_wandb = True
log_images = True
log_interval = 10 # batches
checkpoint_interval = max_epochs # epochs
flatness_interval = max_epochs+1 # epochs
flatness_iters = 10
denoising_iters = 3
wandb_project = "ae-prove"

Using device:  cpu


In [28]:
# data
dataset, dataset_metadata = get_dataset(log_to_wandb=False, project=wandb_project, root=root)
set_seed(0)
dataset = dataset[torch.randperm(len(dataset))]
train_loader = DataLoader(dataset[:train_size], batch_size=batch_size)
test_loader = DataLoader(dataset[-test_size:], batch_size=test_size)  # be mindful of the size

# model
model = AutoEncoder(input_dim=N, encoder_hidden=encoder_hidden, activation="ReLU", seed=seed, decoder_hidden=decoder_hidden)
if optimizer_class.lower() == "sam":
    optimizer = SAM(model.parameters(), base_optimizer, lr=lr, weight_decay=weight_decay, rho=rho)
else:
    optimizer = base_optimizer(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = torch.nn.MSELoss()
scheduler = None

# trainer config 
train_config = {
    "model": model,
    "optimizer": optimizer,
    "criterion": criterion,    
    "train_loader": train_loader,
    "test_loader": test_loader,
    "dataset_metadata": dataset_metadata,
    "max_epochs": max_epochs,
    "device": device,
    "scheduler": scheduler,
    "log_to_wandb": log_to_wandb,
    "log_interval": log_interval,
    "log_images": log_images,
    "checkpoint_interval": checkpoint_interval,
    "checkpoint_root_dir": "../checkpoints",
    "flatness_interval": flatness_interval,
    "train_set_percentage_for_flatness": 'auto',
    "flatness_iters": flatness_iters,
    "denoising_iters": denoising_iters,
    "target_loss": target_loss,
    "seed": seed,
    "compile_model": compile_model,
    "wandb_project": wandb_project,
}

# from pprint import pprint
# pprint(train_config)

Files already downloaded and verified


In [29]:
model

AutoEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=3072, out_features=3072, bias=True)
    (1): ReLU()
    (2): Linear(in_features=3072, out_features=30, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=30, out_features=3072, bias=True)
    (1): ReLU()
    (2): Linear(in_features=3072, out_features=3072, bias=True)
  )
)

In [30]:
trainer = Trainer(**train_config)
trainer.train()

Epoch 1/150, train_loss: 0.2730, val_loss: 0.1375
Epoch 2/150, train_loss: 0.1301, val_loss: 0.1247
Epoch 3/150, train_loss: 0.1121, val_loss: 0.1006
Epoch 4/150, train_loss: 0.0958, val_loss: 0.0912
Epoch 5/150, train_loss: 0.0854, val_loss: 0.0820
Epoch 6/150, train_loss: 0.0772, val_loss: 0.0753
Epoch 7/150, train_loss: 0.0719, val_loss: 0.0727
Epoch 8/150, train_loss: 0.0678, val_loss: 0.0705
Epoch 9/150, train_loss: 0.0648, val_loss: 0.0683
Epoch 10/150, train_loss: 0.0629, val_loss: 0.0662
Epoch 11/150, train_loss: 0.0606, val_loss: 0.0642
Epoch 12/150, train_loss: 0.0581, val_loss: 0.0624
Epoch 13/150, train_loss: 0.0567, val_loss: 0.0619
Epoch 14/150, train_loss: 0.0555, val_loss: 0.0615
Epoch 15/150, train_loss: 0.0547, val_loss: 0.0618
Epoch 16/150, train_loss: 0.0540, val_loss: 0.0631
Epoch 17/150, train_loss: 0.0529, val_loss: 0.0617
Epoch 18/150, train_loss: 0.0515, val_loss: 0.0612
Epoch 19/150, train_loss: 0.0507, val_loss: 0.0615
Epoch 20/150, train_loss: 0.0502, val_lo

[E thread_pool.cpp:130] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:130] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:130] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:130] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:130] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:130] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

train_set = dataset[:train_size].numpy()
test_set = dataset[-test_size:].numpy()
pca = PCA(n_components=B)
pca.fit(train_set)

In [None]:
reconstructed = pca.inverse_transform(pca.transform(train_set))
reconstructed = torch.tensor(reconstructed)
mse = torch.nn.functional.mse_loss(reconstructed, dataset[:train_size])
print("MSE on train set: ", mse.item())

MSE on train set:  0.1132894903421402


In [None]:
reconstructed = pca.inverse_transform(pca.transform(test_set))
reconstructed = torch.tensor(reconstructed)
mse = torch.nn.functional.mse_loss(reconstructed, dataset[-test_size:])
print("MSE on test set: ", mse.item())

MSE on test set:  0.11342918872833252
