In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
from model_mnist import LeNet5
from visualize import *
import dset_mnist as dset
import foolbox
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
import warnings
warnings.filterwarnings("ignore")
# disentangled vae
sys.path.append('../disentangling-vae')
from collections import defaultdict
import vae_trim, vae_trim_viz
from disvae.utils.modelIO import save_model, load_model, load_metadata
from disvae.models.losses import get_loss_f

### Train model

In [2]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 0
args.attr_lamb = 10
args.epochs = 50
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [3]:
# train and evaluate model
vae_trim.main(args)

11:42:24 INFO - main: Root directory for saving and loading experiments: results/btcvae_B_0_attr_10
11:42:24 INFO - main: Train mnist with 60000 samples
11:42:24 INFO - main: Num parameters in model: 469173
11:42:26 INFO - __init__: Training Device: cuda
11:43:51 INFO - __call__: Epoch: 1 Average loss per image: 224.77
11:45:17 INFO - __call__: Epoch: 2 Average loss per image: 183.25
11:46:43 INFO - __call__: Epoch: 3 Average loss per image: 178.54
11:48:09 INFO - __call__: Epoch: 4 Average loss per image: 175.97
11:49:34 INFO - __call__: Epoch: 5 Average loss per image: 174.12
11:50:58 INFO - __call__: Epoch: 6 Average loss per image: 172.73
11:52:23 INFO - __call__: Epoch: 7 Average loss per image: 171.61
11:53:45 INFO - __call__: Epoch: 8 Average loss per image: 170.72
11:55:12 INFO - __call__: Epoch: 9 Average loss per image: 169.96
11:56:36 INFO - __call__: Epoch: 10 Average loss per image: 169.33
11:58:01 INFO - __call__: Epoch: 11 Average loss per image: 168.79
11:59:28 INFO - _

In [4]:
# generate visualization
args = vae_trim_viz.parse_arguments()
args.name = name
args.plots = "all"
args.n_rows = 10
args.n_cols = 10
vae_trim_viz.main(args)

Selected idcs: [50997, 28883, 7657, 490, 5940, 59701, 52887, 38156, 2288, 44011, 45422, 5500, 6450, 50232, 23241, 15519, 1142, 2019, 51693, 1038, 22681, 42488, 40847, 31735, 40358, 30379, 9735, 5982, 11999, 46754, 7498, 55400, 958, 32970, 31899, 57702, 16372, 4231, 43729, 35460, 59789, 30533, 4493, 39417, 44245, 5828, 33753, 37945, 3012, 17667, 54026, 36466, 4133, 42246, 19819, 31525, 55717, 23280, 17462, 16328, 42957, 13109, 29713, 57445, 34744, 43608, 1264, 33298, 4626, 378, 21856, 9422, 46531, 30987, 43080, 24729, 10951, 3550, 4996, 38504, 32543, 10748, 4936, 36525, 14096, 9453, 1777, 22690, 50526, 7267, 34713, 9255, 43942, 20014, 29762, 2594, 27279, 18139, 43410, 52855]


### Test model

In [5]:
args = vae_trim.parse_arguments()
args.loss = "VAE"
args.name = name
# results dir
exp_dir = os.path.join(vae_trim.RES_DIR, args.name)

# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)
metadata = load_metadata(exp_dir)

# load model
model = load_model(exp_dir, is_gpu=not args.no_cuda)

In [6]:
# inputs, labels = iter(test_loader).next()
# inputs = inputs.to(device)
# labels = labels.to(device)
# recon_batch, latent_dist, latent_sample = model(inputs)
# s, _ = latent_dist
# s = s.detach()

# loss_f = get_loss_f(args.loss,
#                     n_data=len(test_loader.dataset),
#                     device=device,
#                     **vars(args))
# storer = defaultdict(list)
# loss = loss_f(inputs, recon_batch, latent_dist, model.training,
#                    storer, latent_sample=latent_sample)

In [7]:
# loss
loss_f = get_loss_f(args.loss,
                    n_data=len(test_loader.dataset),
                    device=device,
                    **vars(args))

# evaluate on testset
storer = defaultdict(list)
for data, _ in tqdm(test_loader, leave=False, disable=args.no_progress_bar):
    data = data.to(device)
    recon_batch, latent_dist, latent_sample = model(data)
    _ = loss_f(data, recon_batch, latent_dist, model.training,
                    storer, latent_sample=latent_sample)    
    losses = {k: sum(v) / len(test_loader) for k, v in storer.items()}
    break
print(losses)

                                      

{'recon_loss': 11.618922424316406, 'kl_loss': 2.9720155715942385, 'kl_loss_0': 0.30269935131073, 'kl_loss_1': 0.28816895484924315, 'kl_loss_2': 0.28763034343719485, 'kl_loss_3': 0.2977740287780762, 'kl_loss_4': 0.29319636821746825, 'kl_loss_5': 0.2898214101791382, 'kl_loss_6': 0.30013489723205566, 'kl_loss_7': 0.3069141149520874, 'kl_loss_8': 0.30695381164550783, 'kl_loss_9': 0.2987224102020264, 'loss': 14.59093780517578}


