In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
from model_mnist import LeNet5
from visualize import *
import dset_mnist as dset
import foolbox
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
import warnings
warnings.filterwarnings("ignore")
# disentangled vae
sys.path.append('../disentangling-vae')
from collections import defaultdict
import vae_trim, vae_trim_viz
from disvae.utils.modelIO import save_model, load_model, load_metadata
from disvae.models.losses import get_loss_f

### Train model

In [2]:
args = vae_trim.parse_arguments()
args.btcvae_B = 3
args.reg_anneal = 0
args.trim_lamb = 10
name = "disvae_btcvae_B_" + str(args.btcvae_B) + "_lamb_" + str(args.trim_lamb)
args.name = name

# load classifier
m = LeNet5().eval()
m.load_state_dict(torch.load('weights/lenet_epoch=12_test_acc=0.991.pth'))
m = m.to(device)

In [None]:
# train and evaluate model
vae_trim.main(args, classifier=m)

08:03:53 INFO - main: Root directory for saving and loading experiments: results/disvae_btcvae_B_3_lamb_10
08:03:53 INFO - main: Train mnist with 60000 samples
08:03:53 INFO - main: Num parameters in model: 469173
08:03:53 INFO - __init__: Training Device: cuda
08:04:25 INFO - __call__: Epoch: 1 Average loss per image: 151.97
08:04:56 INFO - __call__: Epoch: 2 Average loss per image: 115.05
08:05:26 INFO - __call__: Epoch: 3 Average loss per image: 110.29
08:05:56 INFO - __call__: Epoch: 4 Average loss per image: 107.85
08:06:27 INFO - __call__: Epoch: 5 Average loss per image: 106.11
08:06:58 INFO - __call__: Epoch: 6 Average loss per image: 104.86
08:07:30 INFO - __call__: Epoch: 7 Average loss per image: 104.01
08:08:00 INFO - __call__: Epoch: 8 Average loss per image: 103.15
08:08:31 INFO - __call__: Epoch: 9 Average loss per image: 102.45
08:09:01 INFO - __call__: Epoch: 10 Average loss per image: 101.90
08:09:33 INFO - __call__: Epoch: 11 Average loss per image: 101.32
08:10:05 I

In [None]:
# generate visualization
args = vae_trim_viz.parse_arguments()
args.name = name
args.plots = "all"
args.n_rows = 10
args.n_cols = 10
vae_trim_viz.main(args)

### Test model

In [None]:
args = vae_trim.parse_arguments()
args.name = name
# results dir
exp_dir = os.path.join(vae_trim.RES_DIR, args.name)

# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)
metadata = load_metadata(exp_dir)

# load model
model = load_model(exp_dir, is_gpu=not args.no_cuda)

In [None]:
# loss
loss_f = get_loss_f(args.loss,
                    n_data=len(test_loader.dataset),
                    device=device,
                    **vars(args))

# evaluate on testset
storer = defaultdict(list)
for data, _ in tqdm(test_loader, leave=False, disable=args.no_progress_bar):
    data = data.to(device)
    recon_batch, latent_dist, latent_sample = model(data)
    _ = loss_f(data, recon_batch, latent_dist, model.training,
                    storer, latent_sample=latent_sample)    
    losses = {k: sum(v) / len(test_loader) for k, v in storer.items()}
    break
print(losses)