In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
from model_mnist import LeNet5
from visualize import *
import dset_mnist as dset
import foolbox
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
import warnings
warnings.filterwarnings("ignore")
# disentangled vae
sys.path.append('../disentangling-vae')
from collections import defaultdict
import vae_trim, vae_trim_viz
from disvae.utils.modelIO import save_model, load_model, load_metadata
from disvae.models.losses import get_loss_f

### Train model

In [2]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 1
args.attr_lamb = 20
args.epochs = 60
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb) + "_bijection_3"
args.name = name

In [3]:
# train and evaluate model
vae_trim.main(args)

04:26:56 INFO - main: Root directory for saving and loading experiments: results/btcvae_B_1_attr_20_bijection_3
04:26:56 INFO - main: Train mnist with 60000 samples
04:26:56 INFO - main: Num parameters in model: 469173
04:26:57 INFO - __init__: Training Device: cuda
04:29:22 INFO - __call__: Epoch: 1 Average loss per image: 203.28   
04:31:48 INFO - __call__: Epoch: 2 Average loss per image: 166.29   
04:34:14 INFO - __call__: Epoch: 3 Average loss per image: 162.09   
04:36:38 INFO - __call__: Epoch: 4 Average loss per image: 159.89   
04:39:02 INFO - __call__: Epoch: 5 Average loss per image: 158.37   
04:41:26 INFO - __call__: Epoch: 6 Average loss per image: 157.18   
04:43:52 INFO - __call__: Epoch: 7 Average loss per image: 156.27   
04:46:15 INFO - __call__: Epoch: 8 Average loss per image: 155.58   
04:48:40 INFO - __call__: Epoch: 9 Average loss per image: 154.95   
04:51:02 INFO - __call__: Epoch: 10 Average loss per image: 154.48   
04:53:25 INFO - __call__: Epoch: 11 Averag

In [4]:
# generate visualization
args = vae_trim_viz.parse_arguments()
args.name = name
args.plots = "all"
args.n_rows = 10
args.n_cols = 10
vae_trim_viz.main(args)

Selected idcs: [50997, 28883, 7657, 490, 5940, 59701, 52887, 38156, 2288, 44011, 45422, 5500, 6450, 50232, 23241, 15519, 1142, 2019, 51693, 1038, 22681, 42488, 40847, 31735, 40358, 30379, 9735, 5982, 11999, 46754, 7498, 55400, 958, 32970, 31899, 57702, 16372, 4231, 43729, 35460, 59789, 30533, 4493, 39417, 44245, 5828, 33753, 37945, 3012, 17667, 54026, 36466, 4133, 42246, 19819, 31525, 55717, 23280, 17462, 16328, 42957, 13109, 29713, 57445, 34744, 43608, 1264, 33298, 4626, 378, 21856, 9422, 46531, 30987, 43080, 24729, 10951, 3550, 4996, 38504, 32543, 10748, 4936, 36525, 14096, 9453, 1777, 22690, 50526, 7267, 34713, 9255, 43942, 20014, 29762, 2594, 27279, 18139, 43410, 52855]


### Test model

In [5]:
args = vae_trim.parse_arguments()
args.loss = "VAE"
args.name = name
# results dir
exp_dir = os.path.join(vae_trim.RES_DIR, args.name)

# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)
metadata = load_metadata(exp_dir)

# load model
model = load_model(exp_dir, is_gpu=not args.no_cuda)

In [6]:
# loss
loss_f = get_loss_f(args.loss,
                    n_data=len(test_loader.dataset),
                    device=device,
                    **vars(args))

# evaluate on testset
storer = defaultdict(list)
for data, _ in tqdm(test_loader, leave=False, disable=args.no_progress_bar):
    data = data.to(device)
    recon_batch, latent_dist, latent_sample = model(data)
    _ = loss_f(data, recon_batch, latent_dist, model.training,
                    storer, latent_sample=latent_sample)    
    losses = {k: sum(v) / len(test_loader) for k, v in storer.items()}
    break
print(losses)

                                      

{'recon_loss': 12.027501678466797, 'kl_loss': 1.930692481994629, 'kl_loss_0': 0.19817252159118653, 'kl_loss_1': 0.1984402060508728, 'kl_loss_2': 0.1867664098739624, 'kl_loss_3': 0.18077932596206664, 'kl_loss_4': 0.24604420661926268, 'kl_loss_5': 0.1641427159309387, 'kl_loss_6': 0.18898794651031495, 'kl_loss_7': 0.25300734043121337, 'kl_loss_8': 0.1653110384941101, 'kl_loss_9': 0.14904073476791382, 'loss': 13.958193969726562}


