## Imports and setup.

In [2]:
# Enable autoreloading, when local modules are modified.
%load_ext autoreload
%autoreload 2

In [3]:
import os
import sys

# Allow absolute path imports.
project_root = os.path.abspath(os.path.join('../..'))
if project_root not in sys.path:
  sys.path.append(project_root)

In [4]:
import torch
import numpy as np
from src.models.svae import SpatialVAE

In [5]:
torch.set_printoptions(precision=3, sci_mode=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Torch device: {device}')

Torch device: cpu


## Make dummy data.

In [6]:
batch_size = 8
n_channels = 1  # e.g RGB (set to 1 for MNIST).
width = 28  # Image width.
height = 28  # Image height.

In [7]:
# A dummy input.
x = torch.rand(batch_size, n_channels, width, height).to(device)

## Instantiate model and call forward().

In [8]:
svae = SpatialVAE(
    width=width,
    height=height,
    n_channels=n_channels,
    n_hidden_units=500,
    n_layers=2,
    n_unconstrained=2).to(device)

In [9]:
# losgstd and mu shape contain the batch size, and parameters
# parameter order: [rotation, 2 x translation, unconstrained...]
reconstruction, mu, logstd = svae.forward(x)
print(reconstruction.shape)
print(mu.shape)
print(logstd.shape)

torch.Size([8, 1, 28, 28])
torch.Size([8, 5])
torch.Size([8, 5])


In [12]:
# Loss function
# last parameter is the stddev of prior of rotation distribution
# A good value is a larger value, can be set to pi
pi = torch.tensor(np.pi).float().unsqueeze(0).to(device)
result = svae.loss(x, reconstruction, mu, logstd, pi)
print(result.shape)

torch.Size([])
torch.Size([])


RuntimeError: output with shape [] doesn't match the broadcast shape [1]