In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, Normalize
import torch.nn as nn
from tqdm.notebook import tqdm, trange

In [2]:
import os

os.chdir('..')

In [9]:
%load_ext autoreload
%autoreload 2

In [10]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from torchvision.utils import make_grid
import numpy as np


def generate_samples(model, device):
    generated = model.decoder.generate(15, device)
    return make_grid(generated, nrow=5).permute(1, 2, 0) * 255


def visualize_losses(train_losses, test_losses, model):
    fig = make_subplots(
        rows=3, cols=2,
        subplot_titles=("Train epoch kl", "Train epoch reconstruction", "Train epoch loss", "Test epoch loss")
    )

    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch_kl"]))),
            y=train_losses["epoch_kl"],
        ),
        col=1,
        row=1,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch_rec"]))),
            y=train_losses["epoch_rec"],
        ),
        col=2,
        row=1,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch"]))),
            y=train_losses["epoch"],
        ),
        col=1,
        row=2,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(test_losses["epoch"]))),
            y=test_losses["epoch"],
        ),
        col=2,
        row=2,
    )
    fig.add_trace(
        go.Image(
            z=generate_samples(model, device).cpu()
        ),
        row=3,
        col=1,
    )

    # visualize projections
    with torch.no_grad():
        rand_indices = np.random.choice(len(train_dataset), 200)
        x_samples, colors = [], []

        for idx in rand_indices:
            x, y = train_dataset[idx]

            x_samples.append(x[None])
            colors.append(y)

        x_samples = torch.cat(x_samples).to(device)
        projections = model.encoder(x_samples)[0].cpu()

    color_palette = np.array([[158, 1, 6],
                              [213, 62, 79],
                              [244, 109, 67],
                              [253, 174, 97],
                              [254, 224, 139],
                              [230, 245, 152],
                              [172, 221, 164],
                              [102, 194, 165],
                              [50, 136, 189],
                              [94, 79, 162]])
    # colors = color_palette[colors]
    # print(colors.shape, projections.shape)
    fig.add_trace(
        go.Scatter(
            x=projections[:, 0],
            y=projections[:, 1],
            mode='markers',
            marker=dict(
                size=6,
                color=colors,
                # colorscale='Viridis', # one of plotly colorscales
            )

        ),
        row=3,
        col=2,
    )

    fig.update_layout(height=1000)
    return fig

In [18]:
_default_mnist_avalanche_transform = Compose(
    [ToTensor()]
)
batch_size = 32

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=_default_mnist_avalanche_transform,
                               download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=_default_mnist_avalanche_transform,
                              download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [19]:
from src.vae_ft.model.vae import MLPVae

device = torch.device('mps')
z_dim = 2
vae = MLPVae(input_dim=28*28, z_dim=z_dim).to(device)

In [21]:
from IPython.display import clear_output

num_epochs = 50
validate_every = 5

train_losses = {
    "epoch": [],
    "epoch_kl": [],
    "epoch_rec": [],
    "batch": [],
    "batch_kl": [],
    "batch_rec": []
}
test_losses = {
    "epoch": [],
    "batch": []
}

optimizer = torch.optim.Adam(vae.parameters(), lr=1e-4)

for epoch_num in trange(num_epochs, desc="Epoch: "):

    # train loop
    train_losses["batch"] = []
    vae.train()

    for batch in tqdm(train_loader, desc="Train batch: ", leave=False):
        x, _ = batch
        x = x.to(device)

        x_pred, x, log_sigma, mu =  vae(x)
        kl_div, reconstruction_loss = vae.criterion((x_pred, x, log_sigma, mu)) * x.shape[0]
        loss = kl_div + reconstruction_loss

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        train_losses["batch"].append(loss.cpu().item())
        train_losses["batch_kl"].append(kl.cpu().item())
        train_losses["batch_rec"].append(rec.cpu().item())

    train_losses["epoch"].append(torch.as_tensor(train_losses["batch"]).mean().cpu().item())
    train_losses["epoch_kl"].append(torch.as_tensor(train_losses["batch_kl"]).mean().cpu().item())
    train_losses["epoch_rec"].append(torch.as_tensor(train_losses["batch_rec"]).mean().cpu().item())

    # Test loop
    if epoch_num % validate_every == 0:
        test_losses["batch"] = []
        vae.eval()

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Test batch: ", leave=False):
                x, _ = batch
                x = x.to(device)

                kl, rec = vae(x)
                loss = kl + rec
                test_losses["batch"].append(loss.cpu().item())

            test_epoch_loss = torch.as_tensor(test_losses["batch"]).mean()
            test_losses["epoch"].append(test_epoch_loss.cpu().item())

    clear_output(wait=True)
    visualize_losses(train_losses, test_losses, vae).show()

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Train batch:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batch:   0%|          | 0/313 [00:00<?, ?it/s]

ValueError: too many values to unpack (expected 2)

In [25]:
visualize_losses(train_losses, test_losses, vae).show()