In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
from model_mnist import LeNet5
from visualize import *
import dset_mnist as dset
import foolbox
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
import warnings
warnings.filterwarnings("ignore")
# disentangled vae
sys.path.append('../disentangling-vae')
from collections import defaultdict
import vae_trim, vae_trim_viz
from disvae.utils.modelIO import save_model, load_model, load_metadata
from disvae.models.losses import get_loss_f

In [2]:
from disvae import init_specific_model
from utils.datasets import get_img_size

### BTCVAE-1

In [3]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 6 # total correlation reg
args.attr_lamb = 0 # change of attribute wrt other attributes
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [4]:
# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)

# initialize model
args.img_size = get_img_size(args.dataset)
model = init_specific_model(args.model_type, args.img_size, args.latent_dim).to(device)

# Train
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# loss
L1Loss = torch.nn.L1Loss()
loss_f = get_loss_f(args.loss,
                    n_data=len(train_loader.dataset),
                    device=device,
                    **vars(args))

m = DecoderEncoder(model)

Train model

In [5]:
model.train()
for epoch in range(5):
    epoch_loss = 0
    epoch_reg = 0
    # one epoch
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        batch_size, channel, height, width = inputs.size()
        recon_batch, latent_dist, latent_sample = model(inputs)

        loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                           None, latent_sample=latent_sample)        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        epoch_loss += loss.item()    
        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(labels), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.data.item()), end='')        
    mean_epoch_loss = epoch_loss / len(train_loader)
    mean_epoch_reg = epoch_reg / len(train_loader)
    print('\nAverage loss: {} - Average reg: {}'.format(mean_epoch_loss, 0))

Average loss: 45.49810778001732 - Average reg: 0
Average loss: 12.368510833935442 - Average reg: 0
Average loss: 7.7154602180920175 - Average reg: 0
Average loss: 4.973543276918977 - Average reg: 0
Average loss: 3.2386133157368153 - Average reg: 0


In [6]:
model.eval()
epoch_loss = 0
epoch_reg = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):     
    inputs = inputs.to(device)
    labels = labels.to(device)    
    batch_size, channel, height, width = inputs.size()
    recon_batch, latent_dist, latent_sample = model(inputs)

    loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                       None, latent_sample=latent_sample)          
    # penalize change in one attribute wrt the other attributes
    s = deepcopy(latent_dist[0].detach())
    s = s.requires_grad_(True)
    s_output = m(s)
    attr_reg = 0
    for i in range(model.latent_dim):
        col_idx = np.arange(model.latent_dim)!=i
        gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                        retain_graph=True, create_graph=True, only_inputs=True)[0]
        gradient_pairwise = gradients[:,col_idx]
        attr_reg += L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
    epoch_loss += loss.item()
    epoch_reg += attr_reg.item()  
print(epoch_loss/len(train_loader), epoch_reg/len(train_loader))

4.056267093747918 0.582081064105288


### BTCVAE-2

In [7]:
from disvae import init_specific_model
from utils.datasets import get_img_size

In [8]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 18 # total correlation reg
args.attr_lamb = 0 # change of attribute wrt other attributes
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [9]:
# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)

# initialize model
args.img_size = get_img_size(args.dataset)
model = init_specific_model(args.model_type, args.img_size, args.latent_dim).to(device)

# Train
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# loss
L1Loss = torch.nn.L1Loss()
loss_f = get_loss_f(args.loss,
                    n_data=len(train_loader.dataset),
                    device=device,
                    **vars(args))

m = DecoderEncoder(model)

Train model

In [10]:
model.train()
for epoch in range(5):
    epoch_loss = 0
    epoch_reg = 0
    # one epoch
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        batch_size, channel, height, width = inputs.size()
        recon_batch, latent_dist, latent_sample = model(inputs)

        loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                           None, latent_sample=latent_sample)        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        epoch_loss += loss.item()    
        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(labels), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.data.item()), end='')        
    mean_epoch_loss = epoch_loss / len(train_loader)
    mean_epoch_reg = epoch_reg / len(train_loader)
    print('\nAverage loss: {} - Average reg: {}'.format(mean_epoch_loss, 0))

Average loss: -386.7244673608971 - Average reg: 0
Average loss: -416.68307049391365 - Average reg: 0
Average loss: -421.13327055648443 - Average reg: 0
Average loss: -423.2155005936938 - Average reg: 0
Average loss: -424.6122833772509 - Average reg: 0


In [11]:
model.eval()
epoch_loss = 0
epoch_reg = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):     
    inputs = inputs.to(device)
    labels = labels.to(device)    
    batch_size, channel, height, width = inputs.size()
    recon_batch, latent_dist, latent_sample = model(inputs)

    loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                       None, latent_sample=latent_sample)          
    # penalize change in one attribute wrt the other attributes
    s = deepcopy(latent_dist[0].detach())
    s = s.requires_grad_(True)
    s_output = m(s)
    attr_reg = 0
    for i in range(model.latent_dim):
        col_idx = np.arange(model.latent_dim)!=i
        gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                        retain_graph=True, create_graph=True, only_inputs=True)[0]
        gradient_pairwise = gradients[:,col_idx]
        attr_reg += L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
    epoch_loss += loss.item()
    epoch_reg += attr_reg.item()  
print(epoch_loss/len(train_loader), epoch_reg/len(train_loader))

-424.3654141293914 0.2639923615019713


### BTCVAE-3

In [12]:
from disvae import init_specific_model
from utils.datasets import get_img_size

In [13]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 50 # total correlation reg
args.attr_lamb = 0 # change of attribute wrt other attributes
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [14]:
# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)

# initialize model
args.img_size = get_img_size(args.dataset)
model = init_specific_model(args.model_type, args.img_size, args.latent_dim).to(device)

# Train
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# loss
L1Loss = torch.nn.L1Loss()
loss_f = get_loss_f(args.loss,
                    n_data=len(train_loader.dataset),
                    device=device,
                    **vars(args))

m = DecoderEncoder(model)

Train model

In [15]:
model.train()
for epoch in range(5):
    epoch_loss = 0
    epoch_reg = 0
    # one epoch
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        batch_size, channel, height, width = inputs.size()
        recon_batch, latent_dist, latent_sample = model(inputs)

        loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                           None, latent_sample=latent_sample)        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        epoch_loss += loss.item()    
        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(labels), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.data.item()), end='')        
    mean_epoch_loss = epoch_loss / len(train_loader)
    mean_epoch_reg = epoch_reg / len(train_loader)
    print('\nAverage loss: {} - Average reg: {}'.format(mean_epoch_loss, 0))

Average loss: -1568.6629865113607 - Average reg: 0
Average loss: -1592.5403731291228 - Average reg: 0
Average loss: -1596.2460808662463 - Average reg: 0
Average loss: -1598.9513557613022 - Average reg: 0
Average loss: -1601.7783611761226 - Average reg: 0


In [16]:
model.eval()
epoch_loss = 0
epoch_reg = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):     
    inputs = inputs.to(device)
    labels = labels.to(device)    
    batch_size, channel, height, width = inputs.size()
    recon_batch, latent_dist, latent_sample = model(inputs)

    loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                       None, latent_sample=latent_sample)          
    # penalize change in one attribute wrt the other attributes
    s = deepcopy(latent_dist[0].detach())
    s = s.requires_grad_(True)
    s_output = m(s)
    attr_reg = 0
    for i in range(model.latent_dim):
        col_idx = np.arange(model.latent_dim)!=i
        gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                        retain_graph=True, create_graph=True, only_inputs=True)[0]
        gradient_pairwise = gradients[:,col_idx]
        attr_reg += L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
    epoch_loss += loss.item()
    epoch_reg += attr_reg.item()  
print(epoch_loss/len(train_loader), epoch_reg/len(train_loader))

-1596.7284784052672 0.2097970805982791
