In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, Normalize
import torch.nn as nn
from tqdm.notebook import tqdm, trange

In [2]:
import os

os.chdir('..')

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from torch.autograd import Variable
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
import numpy as np

def generate_samples(model):
    with torch.no_grad():
        generated = model.decoder.generate(15)

        return make_grid(generated, nrow=5).permute(1, 2, 0) * 255


def visualize_losses(train_losses, model):
    fig = make_subplots(
        rows=2, cols=2,
        specs=[[{"colspan": 2}, None],
               [{}, {}]],
        subplot_titles=("Train epoch loss", "", "Generated images", "Latent space")
    )
    fig.add_trace(
        go.Scatter(
            x=list(range(len(train_losses["epoch"]))),
            y=train_losses["epoch"],
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Image(
            z=generate_samples(model).cpu()
        ),
        row=2,
        col=1,
    )

    # visualize projections
    with torch.no_grad():
        rand_indices = np.random.choice(len(train_dataset), 200)
        indices, colors = [], []

        for idx in rand_indices:
            _, _, y = train_dataset[idx]

            indices.append(idx)
            colors.append(y)

        indices = torch.as_tensor(indices).to(device)
        emb = model.latent_embeddings(indices)

    fig.add_trace(
        go.Scatter(
            x=emb[:,0],
            y=emb[:,1],
            mode='markers',
            marker=dict(
                size=6,
                color=colors,
                colorscale='Viridis', # one of plotly colorscales
            )
        ),
        row=2,
        col=2,
    )

    fig.update_layout(height=1000)
    return fig

In [5]:
_default_mnist_avalanche_transform = Compose(
    [ToTensor()]
)
batch_size = 32


class MnistWithIndices(datasets.MNIST):

    def __getitem__(self, index):
        data, target = super().__getitem__(index)
        return index, data, target


train_dataset = MnistWithIndices(root='./mnist_data/', train=True, transform=_default_mnist_avalanche_transform,
                                 download=True)
test_dataset = MnistWithIndices(root='./mnist_data/', train=False, transform=_default_mnist_avalanche_transform,
                                download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [6]:
from src.model.rnd.ad_generator import MNISTAutoDecoderLinearGenerator
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn.functional as F


class AutoDecoder(nn.Module):
    def __init__(self, z_dim: int, num_embeddings: int):
        super().__init__()

        self.latent_embeddings = nn.Embedding(
            num_embeddings=num_embeddings,
            embedding_dim=z_dim,
            max_norm=1.0,
        )
        self.decoder = MNISTAutoDecoderLinearGenerator(z_dim)

    def forward(self, emb):
        x = self.decoder(emb)

        return x

In [7]:
device = torch.device('cpu')
z_dim = 2
auto_decoder = AutoDecoder(z_dim=z_dim, num_embeddings=len(train_dataset)).to(device)

In [8]:
emb = auto_decoder.latent_embeddings(torch.as_tensor([1, 2, 4]))
auto_decoder(emb).shape

torch.Size([3, 784])

In [12]:
from IPython.display import clear_output

num_epochs = 300
validate_every = 5

train_losses = {
    "epoch": [],
    "epoch_kl": [],
    "epoch_rec": [],
    "batch": [],
    "batch_kl": [],
    "batch_rec": []
}
test_losses = {
    "epoch": [],
    "batch": []
}

lambda_code_regularization = 0.1
optimizer = torch.optim.Adam(auto_decoder.parameters(), lr=1e-4)

for epoch_num in trange(num_epochs, desc="Epoch: "):

    # train loop
    train_losses["batch"] = []
    auto_decoder.train()

    for batch in tqdm(train_loader, desc="Train batch: ", leave=False):
        indices, x, _ = batch
        x = x.to(device)
        indices = indices.int().to(device)

        latent_emb = auto_decoder.latent_embeddings(indices)
        x_pred = auto_decoder(latent_emb)

        reconstruction_loss = F.l1_loss(x_pred.reshape(-1, 784), x.reshape(-1, 784), reduction='mean')
        emb_regularization = (
                torch.mean(torch.norm(latent_emb, dim=1))
                * lambda_code_regularization
        )

        loss = reconstruction_loss + emb_regularization
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        train_losses["batch"].append(loss.cpu().item())

    train_losses["epoch"].append(torch.as_tensor(train_losses["batch"]).mean().cpu().item())

    clear_output(wait=True)
    visualize_losses(train_losses, auto_decoder).show()