In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, Normalize
import torch.nn as nn
from tqdm.notebook import tqdm, trange

In [2]:
import os

os.chdir('..')

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from torch.autograd import Variable
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
import numpy as np

def generate_samples(model, device):
    generated = model.decoder.generate(15, device)
    return make_grid(generated, nrow=5).permute(1,2,0) * 255

def visualize_losses(train_losses, test_losses, model):
    fig = make_subplots(
        rows=3, cols=2,
        subplot_titles=("Train epoch kl", "Train epoch reconstruction", "Train epoch loss", "Test epoch loss")
    )

    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch_kl"]))),
            y=train_losses["epoch_kl"],
        ),
        col=1,
        row=1,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch_rec"]))),
            y=train_losses["epoch_rec"],
        ),
        col=2,
        row=1,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch"]))),
            y=train_losses["epoch"],
        ),
        col=1,
        row=2,
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(test_losses["epoch"]))),
            y=test_losses["epoch"],
        ),
        col=2,
        row=2,
    )
    fig.add_trace(
        go.Image(
            z=generate_samples(model, device).cpu()
        ),
        row=3,
        col=1,
    )

    # visualize projections
    with torch.no_grad():
        rand_indices = np.random.choice(len(train_dataset), 200)
        x_samples, colors = [], []

        for idx in rand_indices:
            x, y = train_dataset[idx]

            x_samples.append(x[None])
            colors.append(y)

        x_samples = torch.cat(x_samples).to(device)
        projections = model.encoder(x_samples)[0].cpu()

    fig.add_trace(
        go.Scatter(
            x=projections[:,0],
            y=projections[:,1],
            mode='markers',
            marker=dict(
                size=6,
                color=colors,
                colorscale='Viridis', # one of plotly colorscales
            )
        ),
        row=3,
        col=2,
    )

    fig.update_layout(height=1000)
    return fig

In [5]:
_default_mnist_avalanche_transform = Compose(
    [ToTensor()]
)
batch_size = 32

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=_default_mnist_avalanche_transform,
                               download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=_default_mnist_avalanche_transform,
                              download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [6]:
from src.model.rnd.vae_generator import MNISTVaeCNNGenerator, MNISTVaeLinearGenerator
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn.functional as F


class MNISTVaeEncoder(nn.Module):
    """Gan generated MNIST images"""

    def __init__(self, output_dim: int) -> None:
        super().__init__()

        self.input_dim = 28 * 28
        self.output_dim = output_dim

        self.module = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(3, 32, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 16, kernel_size=2),
            nn.AvgPool2d(5),
            nn.ReLU(),
            nn.Flatten(),
        )

        # self.module = nn.Sequential(
        #     nn.Flatten(),
        #     nn.Linear(28 * 28, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 256),
        #     nn.ReLU(),
        # )

        self.mu_head = nn.Linear(16, self.output_dim)
        self.sigma_head = nn.Linear(16, self.output_dim)

    def forward(self, x):
        x = self.module(x)

        return self.mu_head(x), self.sigma_head(x)


class VAE(nn.Module):
    def __init__(self, z_dim: int):
        super().__init__()

        self.encoder = MNISTVaeEncoder(output_dim=z_dim)
        self.decoder = MNISTVaeCNNGenerator(input_dim=z_dim, apply_sigmoid=True)
        # self.decoder = MNISTVaeLinearGenerator(28 * 28, 512, 256, z_dim)
        self.reconstruction_criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        mu, log_sigma = self.encoder(x)

        # sampling
        std = torch.exp(0.5 * log_sigma)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mu)

        # compute losses
        x_pred = self.decoder(z)

        kl_div = -0.5 * torch.sum(1 + log_sigma - mu.pow(2) - log_sigma.exp())
        reconstruction_loss = F.binary_cross_entropy(x_pred.flatten(1), x.flatten(1), reduction='sum')

        return kl_div, reconstruction_loss

In [7]:
device = torch.device('mps')
z_dim = 2
vae = VAE(z_dim=z_dim).to(device)

In [8]:
from IPython.display import clear_output

num_epochs = 50
validate_every = 5

train_losses = {
    "epoch": [],
    "epoch_kl": [],
    "epoch_rec": [],
    "batch": [],
    "batch_kl": [],
    "batch_rec": []
}
test_losses = {
    "epoch": [],
    "batch": []
}

optimizer = torch.optim.Adam(vae.parameters(), lr=1e-4)

for epoch_num in trange(num_epochs, desc="Epoch: "):

    # train loop
    train_losses["batch"] = []
    vae.train()

    for batch in tqdm(train_loader, desc="Train batch: ", leave=False):
        x, _ = batch
        x = x.to(device)

        kl, rec = vae(x)
        loss = kl + rec
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        train_losses["batch"].append(loss.cpu().item())
        train_losses["batch_kl"].append(kl.cpu().item())
        train_losses["batch_rec"].append(rec.cpu().item())

    train_losses["epoch"].append(torch.as_tensor(train_losses["batch"]).mean().cpu().item())
    train_losses["epoch_kl"].append(torch.as_tensor(train_losses["batch_kl"]).mean().cpu().item())
    train_losses["epoch_rec"].append(torch.as_tensor(train_losses["batch_rec"]).mean().cpu().item())

    # Test loop
    if epoch_num % validate_every == 0:
        test_losses["batch"] = []
        vae.eval()

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Test batch: ", leave=False):
                x, _ = batch
                x = x.to(device)

                kl, rec = vae(x)
                loss = kl + rec
                test_losses["batch"].append(loss.cpu().item())

            test_epoch_loss = torch.as_tensor(test_losses["batch"]).mean()
            test_losses["epoch"].append(test_epoch_loss.cpu().item())

    clear_output(wait=True)
    visualize_losses(train_losses, test_losses, vae).show()

Train batch:   0%|          | 0/1875 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
from matplotlib import pyplot as plt

plt.imshow(generate_samples(vae).cpu())