In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from copy import deepcopy
from model_mnist import LeNet5
from visualize import *
import dset_mnist as dset
import foolbox
sys.path.append('../trim')
from transforms_torch import transform_bandpass, tensor_t_augment, batch_fftshift2d, batch_ifftshift2d
from trim import *
from util import *
from attributions import *
from captum.attr import *
import warnings
warnings.filterwarnings("ignore")
# disentangled vae
sys.path.append('../disentangling-vae')
from collections import defaultdict
import vae_trim, vae_trim_viz
from disvae.utils.modelIO import save_model, load_model, load_metadata
from disvae.models.losses import get_loss_f

In [2]:
from disvae import init_specific_model
from utils.datasets import get_img_size

### BTCVAE-Attr-Pen-1

In [3]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 0 # total correlation reg
args.attr_lamb = 0.001 # change of attribute wrt other attributes
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [4]:
# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)

# initialize model
args.img_size = get_img_size(args.dataset)
model = init_specific_model(args.model_type, args.img_size, args.latent_dim).to(device)

# Train
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# loss
L1Loss = torch.nn.L1Loss()
loss_f = get_loss_f(args.loss,
                    n_data=len(train_loader.dataset),
                    device=device,
                    **vars(args))

m = DecoderEncoder(model)

Train model

In [5]:
model.train()
for epoch in range(5):
    epoch_loss = 0
    epoch_reg = 0
    # one epoch
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        batch_size, channel, height, width = inputs.size()
        recon_batch, latent_dist, latent_sample = model(inputs)

        loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                           None, latent_sample=latent_sample)        
        # penalize change in one attribute wrt the other attributes
        if args.attr_lamb > 0:
            s = deepcopy(latent_dist[0].detach())
            s = s.requires_grad_(True)
            s_output = m(s)
            attr_reg = 0
            for i in range(model.latent_dim):
                col_idx = np.arange(model.latent_dim)!=i
                gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                                retain_graph=True, create_graph=True, only_inputs=True)[0]
                gradient_pairwise = gradients[:,col_idx]
                attr_reg += args.attr_lamb * L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
        loss += attr_reg
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        epoch_loss += loss.item()
        epoch_reg += attr_reg.item()      
        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(labels), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.data.item()), end='')        
    mean_epoch_loss = epoch_loss / len(train_loader)
    mean_epoch_reg = epoch_reg / len(train_loader)
    print('\nAverage loss: {} - Average reg: {}'.format(mean_epoch_loss, mean_epoch_reg))

Average loss: 225.70974871344657 - Average reg: 0.0010709765868987253
Average loss: 182.90000163975046 - Average reg: 0.0011176248453358916
Average loss: 178.3071426196393 - Average reg: 0.001090005140946205
Average loss: 175.87265733666004 - Average reg: 0.0010700983435674104
Average loss: 174.0766552109708 - Average reg: 0.0010565968333626353


In [6]:
model.eval()
epoch_loss = 0
epoch_reg = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):     
    inputs = inputs.to(device)
    labels = labels.to(device)    
    batch_size, channel, height, width = inputs.size()
    recon_batch, latent_dist, latent_sample = model(inputs)

    loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                       None, latent_sample=latent_sample)          
    # penalize change in one attribute wrt the other attributes
    s = deepcopy(latent_dist[0].detach())
    s = s.requires_grad_(True)
    s_output = m(s)
    attr_reg = 0
    for i in range(model.latent_dim):
        col_idx = np.arange(model.latent_dim)!=i
        gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                        retain_graph=True, create_graph=True, only_inputs=True)[0]
        gradient_pairwise = gradients[:,col_idx]
        attr_reg += L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
    epoch_loss += loss.item()
    epoch_reg += attr_reg.item()  
print(epoch_loss/len(train_loader), epoch_reg/len(train_loader))

173.11531616731494 1.0372125168344868


### BTCVAE-Attr-Pen-2

In [7]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 0 # total correlation reg
args.attr_lamb = 0.1 # change of attribute wrt other attributes
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [8]:
# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)

# initialize model
args.img_size = get_img_size(args.dataset)
model = init_specific_model(args.model_type, args.img_size, args.latent_dim).to(device)

# Train
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# loss
L1Loss = torch.nn.L1Loss()
loss_f = get_loss_f(args.loss,
                    n_data=len(train_loader.dataset),
                    device=device,
                    **vars(args))

m = DecoderEncoder(model)

Train model

In [9]:
model.train()
for epoch in range(5):
    epoch_loss = 0
    epoch_reg = 0
    # one epoch
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        batch_size, channel, height, width = inputs.size()
        recon_batch, latent_dist, latent_sample = model(inputs)

        loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                           None, latent_sample=latent_sample)        
        # penalize change in one attribute wrt the other attributes
        if args.attr_lamb > 0:
            s = deepcopy(latent_dist[0].detach())
            s = s.requires_grad_(True)
            s_output = m(s)
            attr_reg = 0
            for i in range(model.latent_dim):
                col_idx = np.arange(model.latent_dim)!=i
                gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                                retain_graph=True, create_graph=True, only_inputs=True)[0]
                gradient_pairwise = gradients[:,col_idx]
                attr_reg += args.attr_lamb * L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
        loss += attr_reg
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        epoch_loss += loss.item()
        epoch_reg += attr_reg.item()      
        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(labels), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.data.item()), end='')        
    mean_epoch_loss = epoch_loss / len(train_loader)
    mean_epoch_reg = epoch_reg / len(train_loader)
    print('\nAverage loss: {} - Average reg: {}'.format(mean_epoch_loss, mean_epoch_reg))

Average loss: 227.0549631667798 - Average reg: 0.10116995455248397
Average loss: 182.547677029933 - Average reg: 0.1074968228922851
Average loss: 177.87978973063323 - Average reg: 0.10524751053753692
Average loss: 175.2153774660025 - Average reg: 0.10432412872500003
Average loss: 173.47549109875774 - Average reg: 0.10305916701457393


In [10]:
model.eval()
epoch_loss = 0
epoch_reg = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):     
    inputs = inputs.to(device)
    labels = labels.to(device)    
    batch_size, channel, height, width = inputs.size()
    recon_batch, latent_dist, latent_sample = model(inputs)

    loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                       None, latent_sample=latent_sample)          
    # penalize change in one attribute wrt the other attributes
    s = deepcopy(latent_dist[0].detach())
    s = s.requires_grad_(True)
    s_output = m(s)
    attr_reg = 0
    for i in range(model.latent_dim):
        col_idx = np.arange(model.latent_dim)!=i
        gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                        retain_graph=True, create_graph=True, only_inputs=True)[0]
        gradient_pairwise = gradients[:,col_idx]
        attr_reg += L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
    epoch_loss += loss.item()
    epoch_reg += attr_reg.item()  
print(epoch_loss/len(train_loader), epoch_reg/len(train_loader))

173.6711936251187 1.0244821412985259


### BTCVAE-Attr-Pen-3

In [11]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 0 # total correlation reg
args.attr_lamb = 50 # change of attribute wrt other attributes
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [12]:
# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)

# initialize model
args.img_size = get_img_size(args.dataset)
model = init_specific_model(args.model_type, args.img_size, args.latent_dim).to(device)

# Train
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# loss
L1Loss = torch.nn.L1Loss()
loss_f = get_loss_f(args.loss,
                    n_data=len(train_loader.dataset),
                    device=device,
                    **vars(args))

m = DecoderEncoder(model)

Train model

In [13]:
model.train()
for epoch in range(5):
    epoch_loss = 0
    epoch_reg = 0
    # one epoch
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        batch_size, channel, height, width = inputs.size()
        recon_batch, latent_dist, latent_sample = model(inputs)

        loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                           None, latent_sample=latent_sample)        
        # penalize change in one attribute wrt the other attributes
        if args.attr_lamb > 0:
            s = deepcopy(latent_dist[0].detach())
            s = s.requires_grad_(True)
            s_output = m(s)
            attr_reg = 0
            for i in range(model.latent_dim):
                col_idx = np.arange(model.latent_dim)!=i
                gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                                retain_graph=True, create_graph=True, only_inputs=True)[0]
                gradient_pairwise = gradients[:,col_idx]
                attr_reg += args.attr_lamb * L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
        loss += attr_reg
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        epoch_loss += loss.item()
        epoch_reg += attr_reg.item()      
        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(labels), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.data.item()), end='')        
    mean_epoch_loss = epoch_loss / len(train_loader)
    mean_epoch_reg = epoch_reg / len(train_loader)
    print('\nAverage loss: {} - Average reg: {}'.format(mean_epoch_loss, mean_epoch_reg))

Average loss: 260.4856673446037 - Average reg: 12.300184852532995
Average loss: 210.9487006506686 - Average reg: 12.4859372781538
Average loss: 203.65534685212157 - Average reg: 11.08942149290398
Average loss: 200.07705140266336 - Average reg: 10.401542523268189
Average loss: 198.0178859360945 - Average reg: 9.98408572861889


In [14]:
model.eval()
epoch_loss = 0
epoch_reg = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):     
    inputs = inputs.to(device)
    labels = labels.to(device)    
    batch_size, channel, height, width = inputs.size()
    recon_batch, latent_dist, latent_sample = model(inputs)

    loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                       None, latent_sample=latent_sample)          
    # penalize change in one attribute wrt the other attributes
    s = deepcopy(latent_dist[0].detach())
    s = s.requires_grad_(True)
    s_output = m(s)
    attr_reg = 0
    for i in range(model.latent_dim):
        col_idx = np.arange(model.latent_dim)!=i
        gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                        retain_graph=True, create_graph=True, only_inputs=True)[0]
        gradient_pairwise = gradients[:,col_idx]
        attr_reg += L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
    epoch_loss += loss.item()
    epoch_reg += attr_reg.item()  
print(epoch_loss/len(train_loader), epoch_reg/len(train_loader))

187.61602703493034 0.19825433282010846


### BTCVAE-Attr-Pen-4

In [15]:
args = vae_trim.parse_arguments()
args.loss = "btcvae"
args.reg_anneal = 0
args.btcvae_B = 0 # total correlation reg
args.attr_lamb = 200 # change of attribute wrt other attributes
name = args.loss + "_B_" + str(args.btcvae_B) + "_attr_" + str(args.attr_lamb)
args.name = name

In [16]:
# load dataloaders
train_loader, test_loader = dset.load_data(args.batch_size, args.eval_batchsize, device)

# initialize model
args.img_size = get_img_size(args.dataset)
model = init_specific_model(args.model_type, args.img_size, args.latent_dim).to(device)

# Train
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# loss
L1Loss = torch.nn.L1Loss()
loss_f = get_loss_f(args.loss,
                    n_data=len(train_loader.dataset),
                    device=device,
                    **vars(args))

m = DecoderEncoder(model)

Train model

In [17]:
model.train()
for epoch in range(5):
    epoch_loss = 0
    epoch_reg = 0
    # one epoch
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        batch_size, channel, height, width = inputs.size()
        recon_batch, latent_dist, latent_sample = model(inputs)

        loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                           None, latent_sample=latent_sample)        
        # penalize change in one attribute wrt the other attributes
        if args.attr_lamb > 0:
            s = deepcopy(latent_dist[0].detach())
            s = s.requires_grad_(True)
            s_output = m(s)
            attr_reg = 0
            for i in range(model.latent_dim):
                col_idx = np.arange(model.latent_dim)!=i
                gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                                retain_graph=True, create_graph=True, only_inputs=True)[0]
                gradient_pairwise = gradients[:,col_idx]
                attr_reg += args.attr_lamb * L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
        loss += attr_reg
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        epoch_loss += loss.item()
        epoch_reg += attr_reg.item()      
        print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(labels), len(train_loader.dataset),
                   100. * batch_idx / len(train_loader), loss.data.item()), end='')        
    mean_epoch_loss = epoch_loss / len(train_loader)
    mean_epoch_reg = epoch_reg / len(train_loader)
    print('\nAverage loss: {} - Average reg: {}'.format(mean_epoch_loss, mean_epoch_reg))

Average loss: 290.3483241392351 - Average reg: 17.01171578120575
Average loss: 238.6353965873149 - Average reg: 11.175237426879818
Average loss: 226.92110200731486 - Average reg: 8.544843212119552
Average loss: 220.50440168889094 - Average reg: 7.37843575762279
Average loss: 216.28087433772302 - Average reg: 6.5540944284467555


In [18]:
model.eval()
epoch_loss = 0
epoch_reg = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):     
    inputs = inputs.to(device)
    labels = labels.to(device)    
    batch_size, channel, height, width = inputs.size()
    recon_batch, latent_dist, latent_sample = model(inputs)

    loss = loss_f(inputs, recon_batch, latent_dist, model.training,
                       None, latent_sample=latent_sample)          
    # penalize change in one attribute wrt the other attributes
    s = deepcopy(latent_dist[0].detach())
    s = s.requires_grad_(True)
    s_output = m(s)
    attr_reg = 0
    for i in range(model.latent_dim):
        col_idx = np.arange(model.latent_dim)!=i
        gradients = torch.autograd.grad(s_output[:,i], s, grad_outputs=torch.ones_like(s_output[:,i]), 
                                        retain_graph=True, create_graph=True, only_inputs=True)[0]
        gradient_pairwise = gradients[:,col_idx]
        attr_reg += L1Loss(gradient_pairwise, torch.zeros_like(gradient_pairwise))
    epoch_loss += loss.item()
    epoch_reg += attr_reg.item()  
print(epoch_loss/len(train_loader), epoch_reg/len(train_loader))

238.98592105043977 0.030306410858594278
