In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
from model_mnist import LeNet5
from visualize import *
import dset_mnist as dset
import foolbox
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
import warnings
warnings.filterwarnings("ignore")
# disentangled vae
sys.path.append('../disentangling-vae')
from collections import defaultdict
import vae_trim, vae_trim_viz
from disvae.utils.modelIO import save_model, load_model, load_metadata
from disvae.models.losses import get_loss_f

### BTCVAE-Attr-Pen

In [2]:
from disvae import init_specific_model
from utils.datasets import get_img_size

In [3]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 6 # total correlation reg
args.attr_lamb = 0 # change of attribute wrt other attributes
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [4]:
# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)

# initialize model
args.img_size = get_img_size(args.dataset)
model = init_specific_model(args.model_type, args.img_size, args.latent_dim).to(device)

# Train
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# loss
L1Loss = torch.nn.L1Loss()
# L1Loss = torch.nn.MSELoss()
loss_f = get_loss_f(args.loss,
                    n_data=len(train_loader.dataset),
                    device=device,
                    **vars(args))

# saliency map
saliency = InputXGradient(DecoderEncoder(model))

Train model

In [5]:
model.train()
for epoch in range(10):
    epoch_loss = 0
    epoch_reg = 0
    # one epoch
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        batch_size, channel, height, width = inputs.size()
        recon_batch, latent_dist, latent_sample = model(inputs)

        loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                           None, latent_sample=latent_sample)        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        epoch_loss += loss.item()    
        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(labels), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.data.item()), end='')        
    mean_epoch_loss = epoch_loss / len(train_loader)
    mean_epoch_reg = epoch_reg / len(train_loader)
    print('\nAverage loss: {} - Average reg: {}'.format(mean_epoch_loss, 0))

Average loss: 45.67456041673607 - Average reg: 0
Average loss: 12.729268006932761 - Average reg: 0
Average loss: 8.173371142161681 - Average reg: 0
Average loss: 5.3074500077823075 - Average reg: 0
Average loss: 3.670491330405034 - Average reg: 0
Average loss: 2.3702857336764143 - Average reg: 0
Average loss: 1.344481185555204 - Average reg: 0
Average loss: 0.5583316664705907 - Average reg: 0
Average loss: -0.08227057548474147 - Average reg: 0
Average loss: -0.7664251429185684 - Average reg: 0


In [14]:
model.eval()
epoch_loss = 0
epoch_reg = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):     
    inputs = inputs.to(device)
    labels = labels.to(device)    
    batch_size, channel, height, width = inputs.size()
    recon_batch, latent_dist, latent_sample = model(inputs)    
    
    loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                       None, latent_sample=latent_sample)     
    # penalize change in one attribute wrt the other attributes
    reg = 0
    s = deepcopy(latent_dist[0].detach())
    for i in range(model.latent_dim):
        col_idx = np.arange(model.latent_dim)!=i
#         attributions = torch.div(saliency.attribute(s, target=i),s)[:,col_idx]
        attributions = saliency.attribute(s, target=i)[:,col_idx]
        reg += L1Loss(attributions, torch.zeros_like(attributions))
    epoch_loss += loss.item()
    epoch_reg += reg.item()  
print(epoch_loss/len(train_loader), epoch_reg/len(train_loader))

-1.7869764543545525 0.4399063771785195
