In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, Normalize
import torch.nn as nn
from tqdm.notebook import tqdm, trange

In [2]:
_default_mnist_avalanche_transform = Compose(
    [ToTensor()]
)
batch_size = 256

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=_default_mnist_avalanche_transform,
                               download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=_default_mnist_avalanche_transform,
                              download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [3]:
import os

os.chdir('..')

In [4]:
% load_ext autoreload
% autoreload 2

UsageError: Line magic function `%` not found.


In [5]:
from src.model.rnd.vae_generator import MNISTVaeCNNGenerator, MNISTVaeLinearGenerator
import torch.nn.functional as F


class MNISTVaeEncoder(nn.Module):
    """Gan generated MNIST images"""

    def __init__(self, output_dim: int) -> None:
        super().__init__()

        self.input_dim = 28 * 28
        self.output_dim = output_dim

        # self.module = nn.Sequential(
        #     nn.Conv2d(1, 3, 3),
        #     nn.ReLU(),
        #     nn.Conv2d(3, 12, 3),
        #     nn.ReLU(),
        #     nn.Conv2d(12, 6, 3),
        #     nn.ReLU(),
        #     nn.Conv2d(6, 1, 3),
        #     nn.ReLU(),
        #     nn.Flatten(),
        #     nn.Linear(400, 254),
        #     nn.ReLU(),
        # )

        self.module = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.Linear(512, 256)
        )

        self.mu_head = nn.Linear(256, self.output_dim)
        self.sigma_head = nn.Linear(256, self.output_dim)

    def forward(self, x):
        x = self.module(x)

        return self.mu_head(x), self.sigma_head(x)


class VAE(nn.Module):
    def __init__(self, z_dim: int):
        super().__init__()

        self.encoder = MNISTVaeEncoder(output_dim=z_dim)
        # self.decoder = MNISTVaeCNNGenerator(input_dim=z_dim)
        self.decoder = MNISTVaeLinearGenerator(28 * 28, 512, 256, z_dim)
        self.reconstruction_criterion = nn.MSELoss()

    def forward(self, x):
        mu, log_sigma = self.encoder(x)

        # sampling
        std = torch.exp(0.5 * log_sigma)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mu)

        # compute losses
        x_pred = self.decoder(z)

        kl_div = 1 / 2 * (log_sigma.exp() + mu.pow(2) - log_sigma - 1).sum(dim=1).mean()
        reconstruction_loss = 1 / 2 * self.reconstruction_criterion(x_pred, x)

        return kl_div, reconstruction_loss

In [6]:
device = torch.device('mps')
z_dim = 2
vae = VAE(z_dim=z_dim).to(device)

In [7]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go


def visualize_losses(train_losses, test_losses):
    fig = make_subplots(
        rows=3, cols=2,
        specs=[[{}, {}],
               [{"colspan": 2}, None],
               [{"colspan": 2}, None]],
        subplot_titles=("Train epoch kl", "Test epoch reconstruction", "Train epoch loss", "Test epoch loss")
    )

    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch_kl"]))),
            y=train_losses["epoch_kl"],
        ),
        col=1,
        row=1,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch_rec"]))),
            y=train_losses["epoch_rec"],
        ),
        col=2,
        row=1,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch"]))),
            y=train_losses["epoch"],
        ),
        col=1,
        row=2,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(test_losses["epoch"]))),
            y=test_losses["epoch"],
        ),
        col=1,
        row=3,
    )

    fig.update_layout(height=1000)
    return fig

In [9]:
from IPython.display import clear_output

num_epochs = 100
validate_every = 5

train_losses = {
    "epoch": [],
    "epoch_kl": [],
    "epoch_rec": [],
    "batch": [],
    "batch_kl": [],
    "batch_rec": []
}
test_losses = {
    "epoch": [],
    "batch": []
}

optimizer = torch.optim.Adam(vae.parameters())

for epoch_num in trange(num_epochs, desc="Epoch: "):

    # train loop
    train_losses["batch"] = []
    vae.train()

    for batch in tqdm(train_loader, desc="Train batch: ", leave=False):
        x, _ = batch
        x = x.to(device)

        kl, rec = vae(x)
        loss = kl + rec
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        train_losses["batch"].append(loss.cpu().item())
        train_losses["batch_kl"].append(kl.cpu().item())
        train_losses["batch_rec"].append(rec.cpu().item())

    train_losses["epoch"].append(torch.as_tensor(train_losses["batch"]).mean().cpu().item())
    train_losses["epoch_kl"].append(torch.as_tensor(train_losses["batch_kl"]).mean().cpu().item())
    train_losses["epoch_rec"].append(torch.as_tensor(train_losses["batch_rec"]).mean().cpu().item())

    # Test loop
    if epoch_num % validate_every == 0:
        test_losses["batch"] = []
        vae.eval()

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Test batch: ", leave=False):
                x, _ = batch
                x = x.to(device)

                kl, rec = vae(x)
                loss = kl + rec
                test_losses["batch"].append(loss.cpu().item())

            test_epoch_loss = torch.as_tensor(test_losses["batch"]).mean()
            test_losses["epoch"].append(test_epoch_loss.cpu().item())

    clear_output(wait=True)
    visualize_losses(train_losses, test_losses).show()

Train batch:   0%|          | 0/235 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [10]:
from torchvision.utils import save_image
from torch.autograd import Variable

with torch.no_grad():
    test_z = Variable(torch.randn(batch_size, z_dim).to(device))
    generated = vae.decoder(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), 'sample_' + '.png')

In [5]:
sigmas = []
mus = []
min_v = float('+inf')
max_v = float('-inf')

for el, _ in train_dataset:
    mus.append(el.mean())
    sigmas.append(el.std())

    min_v = min(el.min(), min_v)
    max_v = max(el.max(), max_v)

In [6]:
torch.as_tensor(sigmas).mean(), torch.as_tensor(mus).mean(), min_v, max_v

(tensor(0.3015), tensor(0.1307), tensor(0.), tensor(1.))