In [5]:
import math

import torch as t

from data import PokemonIMG
from distributions import DiscretizedMixtureLogitsDistribution
from train import train
from vae_small import VAE_SMALL

In [2]:
n_mixtures = 1

# input 'state' should have as shape (batch_size, z_space, height, width)
def state_to_dist(state):
    return DiscretizedMixtureLogitsDistribution(n_mixtures, state[:, :n_mixtures * 10, :, :])

In [3]:
z_size = 256
vae_hid = 128
n_mixtures = 1
batch_size = 32
dmg_size = 16
p_update = 1.0
min_steps, max_steps = 64, 128

encoder_hid = 32
h = w = 32
n_channels = 3

In [4]:
dset = PokemonIMG()

num_samples = len(dset)
train_split = 0.7
val_split = 0.2
test_split = 0.1

num_train = math.floor(num_samples*train_split)
num_val = math.floor(num_samples*val_split)
num_test = math.floor(num_samples*test_split)
num_test = num_test + (num_samples - num_train - num_val - num_test)

train_set, val_set, test_set = t.utils.data.random_split(dset, [num_train, num_val, num_test])

In [6]:
vae = VAE_SMALL(h, w, n_channels, z_size, train_set, val_set, test_set, state_to_dist, batch_size, p_update, min_steps, max_steps, encoder_hid)
vae.eval_batch()

VAE_SMALL(
  (conv2d1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2d2): Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv2d3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv2d4): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (elu): ELU(alpha=1.0)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear): Linear(in_features=4096, out_features=512, bias=True)
  (dec_lin): Linear(in_features=256, out_features=2048, bias=True)
  (conv_t2d1): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
  (conv_t2d2): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
  (conv_t2d3): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
  (conv_t2d4): ConvTranspose2d(64, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
  (conv_t2d



12459.662109375

In [None]:
train(vae, n_updates=100_000, eval_interval=100, suffix="VAE_SMALL")

In [None]:
vae.test(128)