## Imports and setup.

In [1]:
# Enable autoreloading, when local modules are modified.
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

# Allow absolute path imports.
project_root = os.path.abspath(os.path.join('../..'))
if project_root not in sys.path:
  sys.path.append(project_root)

In [3]:
import torch

from src.models.svae import SpatialVAE

In [4]:
torch.set_printoptions(precision=3, sci_mode=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Torch device: {device}')

Torch device: cuda:0


## Make dummy data.

In [5]:
batch_size = 8
n_channels = 3  # e.g RGB (set to 1 for MNIST).
width = 28  # Image width.
height = 28  # Image height.

In [6]:
# A dummy input.
x = torch.rand(batch_size, n_channels, width, height).to(device)

## Instantiate model and call forward().

In [7]:
svae = SpatialVAE(
    width=width,
    height=height,
    n_channels=n_channels,
    n_hidden_units=500,
    n_layers=2,
    n_unconstrained=2).to(device)

In [8]:
reconstruction, mu, logstd = svae.forward(x)
reconstruction.shape

torch.Size([8, 3, 28, 28])