## Imports and setup.

In [1]:
# Enable autoreloading, when local modules are modified.
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

# Allow absolute path imports.
project_root = os.path.abspath(os.path.join('../..'))
if project_root not in sys.path:
  sys.path.append(project_root)

In [12]:
import torch
import numpy as np
from src.models.svae import SpatialVAE

In [4]:
torch.set_printoptions(precision=3, sci_mode=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Torch device: {device}')

Torch device: cpu


## Make dummy data.

In [5]:
batch_size = 8
n_channels = 1  # e.g RGB (set to 1 for MNIST).
width = 28  # Image width.
height = 28  # Image height.

In [6]:
# A dummy input.
x = torch.rand(batch_size, n_channels, width, height).to(device)

## Instantiate model and call forward().

In [9]:
svae = SpatialVAE(
    width=width,
    height=height,
    n_channels=n_channels,
    n_hidden_units=500,
    n_layers=2,
    n_unconstrained=2).to(device)

In [10]:
# losgstd and mu shape contain the batch size, and parameters
# parameter order: [unconstrained..., rotation, 2 x translation]
reconstruction, mu, logstd = svae.forward(x)
print(reconstruction.shape)
print(mu.shape)
print(logstd.shape)

torch.Size([8, 1, 28, 28])
torch.Size([8, 5])
torch.Size([8, 5])


In [21]:
# Loss function
# last parameter is the stddev of prior of rotation distribution
# A good value is a larger value, can be set to pi
pi = torch.tensor(np.pi).float().unsqueeze(0).to(device)
print(pi.shape)
result = svae.loss(x, reconstruction, mu, logstd, pi)
print(result)

torch.Size([1])


RuntimeError: The size of tensor a (8) must match the size of tensor b (2) at non-singleton dimension 1