In [5]:
import sys
sys.path.append("./../")

import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

from torch.utils.data import DataLoader

from datasets.triple_mnist import TripleMnistDataset
from config_reader import ConfigReader

from modules.dvae.model import DVAE
from train_utils.dvae_utils import KLD_uniform_loss

In [2]:
def show(img, figsize=(8, 4)):
    plt.figure(figsize=figsize)
    npimg = img.numpy()
    fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.show()

In [3]:
CONFIG = ConfigReader(config_path="/home/andrey/Aalto/thesis/TA-VQVAE/configs/dvae_triplemnist_local.yaml")
CONFIG.print_config_info()

BATCH_SIZE                              128                                     
DEVICE                                  cpu                                     
KLD_lambda_end                          6                                       
KLD_lambda_start                        0                                       
KLD_lambda_steps                        2500                                    
LR                                      0.0001                                  
LR_gamma                                0.35                                    
NUM_EPOCHS                              800                                     
config_path                             /home/andrey/Aalto/thesis/TA-VQVAE/configs/dvae_triplemnist_local.yaml
in_channels                             1                                       
num_resids_bottleneck                   2                                       
num_resids_downsample                   2                                      

In [6]:
dataset = TripleMnistDataset(
    root_img_path=CONFIG.root_img_path)

train_loader = DataLoader(
    dataset=dataset,
    batch_size=CONFIG.BATCH_SIZE,
    shuffle=True)

In [7]:
x, _ = next(iter(train_loader))

In [8]:
model = DVAE(in_channels=CONFIG.in_channels,
             vocab_size=CONFIG.vocab_size,
             num_x2downsamples=CONFIG.num_x2downsamples,
             num_resids_downsample=CONFIG.num_resids_downsample,
             num_resids_bottleneck=CONFIG.num_resids_bottleneck)

model.load_model(CONFIG.save_model_path, CONFIG.save_model_name)

In [12]:
x_recon, z_dist = model(x, tau=1)

In [13]:
z_dist

tensor([[[[0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          ...,
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156]],

         [[0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          ...,
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156]],

         [[0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0.0156, 0.0156, 0.0156],
          [0.0156, 0.0156, 0.0156,  ..., 0

In [14]:
x.shape

torch.Size([128, 1, 84, 84])