In [1]:
import os
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

from utils import fix_seed
from train import epoch_loop
from callbacks import EarlyStopping
from models import VariationalAutoEncoder
from datasets import load_tfds, pre_train_preprocessing

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_dim = 28 * 28
z_dim = 2

model = VariationalAutoEncoder(x_dim, z_dim, device)
model.load_state_dict(torch.load("./models/checkpoint_z2.pth"))
model.eval()
for param in model.parameters():
    param.grad = None

In [3]:
model.qconfig = torch.quantization.default_qconfig

In [4]:
torch.quantization.convert(model, inplace=True)

VariationalAutoEncoder(
  (encoder): Encoder(
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (mean): Linear(in_features=256, out_features=2, bias=True)
    (log_var): Linear(in_features=256, out_features=2, bias=True)
  )
  (decoder): Decoder(
    (fc1): Linear(in_features=2, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=512, bias=True)
    (drop): Dropout(p=0.2, inplace=False)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )
)

In [5]:
log_dir = "./logs"
writer = SummaryWriter(log_dir)
seed = 42
fix_seed(seed)
x_dim = 28 * 28
z_dim = 3
batch_size = 1024
num_epochs = 1000
learning_rate = 0.001
loss_fn = lambda lower_bound: -sum(lower_bound)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs(log_dir) if not os.path.exists(log_dir) else None

dataset_train, dataset_valid, dataset_test = load_tfds("mnist", 
    batch_size=batch_size, preprocess_fn=pre_train_preprocessing, seed=seed)

In [8]:
for data in dataset_train:
    x = data['image']
    t = data['label']
    print(x.shape)
    print(t)
    break

(1024, 28, 28, 1)
tf.Tensor([8 4 8 ... 3 1 4], shape=(1024,), dtype=int64)


2022-06-22 17:06:48.886743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
