In [None]:
!pip install torch==1.11.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

In [1]:
import os
from typing import Tuple

import tqdm
import math
from datetime import datetime

import torch as t

from data import PokemonIMG
from distributions import DiscretizedMixtureLogitsDistribution
from model import Model
from vae import VAE


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def train(model: Model, n_updates=int(1e6), eval_interval=1000):
    best = float("inf")
    for i in tqdm.tqdm(range(n_updates)):
        model.train_batch()
        if (i + 1) % eval_interval == 0:
            loss = model.eval_batch()
            model.save("latest")
            if loss < best:
                best = loss
                model.save("best")

In [8]:
n_mixtures = 1

# input 'state' should have as shape (batch_size, z_space, height, width)
def state_to_dist(state):
    return DiscretizedMixtureLogitsDistribution(n_mixtures, state[:, :n_mixtures * 10, :, :])

In [9]:
z_size = 256
vae_hid = 128
n_mixtures = 1
batch_size = 32
dmg_size = 16
p_update = 1.0
min_steps, max_steps = 64, 128

encoder_hid = 32
h = w = 32
n_channels = 3

In [13]:
dset = PokemonIMG()

num_samples = len(dset)
train_split = 0.7
val_split = 0.2
test_split = 0.1

num_train = math.floor(num_samples*train_split)
num_val = math.floor(num_samples*val_split)
num_test = math.floor(num_samples*test_split)
num_test = num_test + (num_samples - num_train - num_val - num_test)

train_set, val_set, test_set = t.utils.data.random_split(dset, [num_train, num_val, num_test])

In [None]:
vae = VAE(h, w, n_channels, z_size, train_set, val_set, test_set, state_to_dist, batch_size, p_update, min_steps, max_steps, encoder_hid)
vae.eval_batch()
train(vae, n_updates=100_000, eval_interval=100)