In [15]:
import time
from typing import Tuple
import random

import numpy as np
import matplotlib.pyplot as plt
import torch as th
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from mmidas.nn_model import mk_vae

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def load_mnist(B: int) -> Tuple[DataLoader, DataLoader]:
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,)),
                                    transforms.Lambda(lambda x: x.squeeze(0))])
    train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform), batch_size=B, shuffle=True)
    test_loader  = DataLoader(datasets.MNIST('./data', train=False, download=True, transform=transform), batch_size=B, shuffle=False)
    return train_loader, test_loader

def load_cifar10(B: int) -> tuple[DataLoader, DataLoader]:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False,
                                                download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=B, shuffle=False)
    return train_loader, test_loader


def load_fashion_mnist(B: int) -> Tuple[DataLoader, DataLoader]:
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,)),
                                    transforms.Lambda(lambda x: th.flatten(x).squeeze(0)),
                                    ])
    train_dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True)
    test_loader  = DataLoader(test_dataset, batch_size=B, shuffle=False)
    return train_loader, test_loader


B = 500
train_loader, test_loader = load_fashion_mnist(B)


  self.train_data, self.train_labels = torch.load(
  self.test_data, self.test_labels = torch.load(


In [4]:
itos = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'
}

# TODO: fix this
def visualize(samples: int, xs, ys, itos):
    classes = itos.values()
    K = len(classes)
    for (y, cls) in enumerate(classes):
        idxs = [i for i in range(samples) if ys[i] == y]
        idxs = random.sample(idxs, min(samples, len(idxs)))
        for (i, idx) in enumerate(idxs):
            plt_idx = i * K + y + 1
            plt.subplot(samples, K, plt_idx)
            plt.imshow(np.array(xs[idx]), cmap='gray')
            plt.axis('off')
            if i == 0:
                plt.title(cls)
    plt.show()

In [5]:
device = 'cuda'
A = 5
model = mk_vae(10, 2, 784, A=A, device=device).to(device)

In [6]:
(x, y) = next(iter(train_loader))
x = x.to(device)
y = y.to(device)
x_recs, _, _, x_lows, cs, s_smps, c_smps, s_means, s_logvars, c_probs = model([x for _ in range(A)], 1)
cs = th.stack(cs)
cs.shape

  img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))


torch.Size([500, 784])
torch.Size([500, 784])
torch.Size([500, 784])
torch.Size([500, 784])
torch.Size([500, 784])


torch.Size([5, 500, 10])

In [29]:
tic = time.time()
loss_naive = model.loss_naive(cs)
t1 = time.time() - tic

tic = time.time()
loss_vec = model.loss_vectorize(cs)
t2 = time.time() - tic

print(f"Naive loss: {loss_naive.item()}")
print(f"Vectorized loss: {loss_vec.item()}")
print(f"Relative error: {th.norm(loss_naive - loss_vec) / th.norm(loss_naive)}\n")


print(f"Naive loss computation took: {t1}s")
print(f"Vectorized loss computation took: {t2}s")
print(f"Speedup: {100 * (t1 - t2) / t1:.2f}%")

Naive loss: 340394.8125
Vectorized loss: 340394.8125
Relative error: 0.0

Naive loss computation took: 0.0037140846252441406s
Vectorized loss computation took: 0.0005176067352294922s
Speedup: 86.06%
