In [1]:
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
# import acd
from random import randint
from copy import deepcopy
import pickle as pkl
import argparse

sys.path.append('../../lib/disentangling-vae')
import main

sys.path.append('../../src/vae')
sys.path.append('../../src/vae/models')
sys.path.append('../../src/dsets/images')
from dset import get_dataloaders
from model import init_specific_model
from losses import get_loss_f
from training import Trainer

sys.path.append('../../lib/trim')
# trim modules
from trim import DecoderEncoder

### Train model

In [2]:
args = main.parse_arguments()
args.dataset = "dsprites"
args.model_type = "Burgess"
args.latent_dim = 10
args.img_size = (1, 64, 64)
args.rec_dist = "bernoulli"
args.reg_anneal = 0
args.beta = 0
args.lamPT = 1
args.lamNN = 0.1
args.lamH = 0
args.lamSP = 1

In [3]:
class p:
    '''Parameters for Gaussian mixture simulation
    '''
    # parameters for generating data
    seed = 13
    dataset = "dsprites"
    
    # parameters for model architecture
    model_type = "Burgess"
    latent_dim = 10 
    img_size = (1, 64, 64)
    
    # parameters for training
    train_batch_size = 64
    test_batch_size = 100
    lr = 1e-4
    rec_dist = "bernoulli"
    reg_anneal = 0
    num_epochs = 100
    
    # hyperparameters for loss
    beta = 0.0
    lamPT = 0.0
    lamNN = 0.0
    lamH = 0.0
    lamSP = 0.0
    
    # parameters for exp
    warm_start = None # which parameter to warm start with respect to
    seq_init = 1      # value of warm_start parameter to start with respect to
    
    # SAVE MODEL
    out_dir = "/home/ubuntu/local-vae/notebooks/ex_dsprites/results" # wooseok's setup
#     out_dir = '/scratch/users/vision/chandan/local-vae' # chandan's setup
    dirname = "vary"
    pid = ''.join(["%s" % randint(0, 9) for num in range(0, 10)])

    def _str(self):
        vals = vars(p)
        return 'beta=' + str(vals['beta']) + '_lamPT=' + str(vals['lamPT']) + '_lamNN=' + str(vals['lamNN']) + '_lamSP=' + str(vals['lamSP']) \
                + '_seed=' + str(vals['seed']) + '_pid=' + vals['pid']
    
    def _dict(self):
        return {attr: val for (attr, val) in vars(self).items()
                 if not attr.startswith('_')}

In [4]:
class s:
    '''Parameters to save
    '''
    def _dict(self):
        return {attr: val for (attr, val) in vars(self).items()
                 if not attr.startswith('_')}
    
    
# calculate losses
def calc_losses(model, data_loader, loss_f):
    """
    Tests the model for one epoch.

    Parameters
    ----------
    data_loader: torch.utils.data.DataLoader

    loss_f: loss object

    Return
    ------
    """    
    model.eval()
    rec_loss = 0
    kl_loss = 0
    pt_loss = 0
    nn_loss = 0
    h_loss = 0
    sp_loss = 0

    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)
        recon_data, latent_dist, latent_sample = model(data)
        latent_map = DecoderEncoder(model, use_residuals=True)
        latent_output = latent_map(latent_sample, data)
        _ = loss_f(data, recon_data, latent_dist, model.training, storer=None,
                   latent_sample=latent_sample, latent_output=latent_output, n_data=None)           
        rec_loss += loss_f.rec_loss.item()
        kl_loss += loss_f.kl_loss.item()
        pt_loss += loss_f.pt_loss.item() if type(loss_f.pt_loss) == torch.Tensor else 0
        nn_loss += loss_f.nearest_neighbor_loss.item()if type(loss_f.nearest_neighbor_loss) == torch.Tensor else 0        
        h_loss += loss_f.hessian_loss.item()if type(loss_f.hessian_loss) == torch.Tensor else 0        
        sp_loss += loss_f.sp_loss.item()if type(loss_f.sp_loss) == torch.Tensor else 0                

    n_batch = batch_idx + 1
    rec_loss /= n_batch
    kl_loss /= n_batch
    pt_loss /= n_batch
    nn_loss /= n_batch
    h_loss /= n_batch
    sp_loss /= n_batch

    return (rec_loss, kl_loss, pt_loss, nn_loss, h_loss, sp_loss)
        

In [5]:
for arg in vars(args):
    setattr(p, arg, getattr(args, arg))
    
# create dir
out_dir = opj(p.out_dir, p.dirname)
os.makedirs(out_dir, exist_ok=True)  

# seed
random.seed(p.seed)
np.random.seed(p.seed)
torch.manual_seed(p.seed)

# get dataloaders
train_loader = get_dataloaders(p.dataset,
                               batch_size=p.train_batch_size,
                               logger=None)

# prepare model
model = init_specific_model(model_type=p.model_type, 
                            img_size=p.img_size,
                            latent_dim=p.latent_dim,
                            hidden_dim=None).to(device)

# train
optimizer = torch.optim.Adam(model.parameters(), lr=p.lr)
loss_f = get_loss_f(decoder=model.decoder, **vars(p))
trainer = Trainer(model, optimizer, loss_f, device=device)

In [6]:
# trainer(train_loader, epochs=p.num_epochs)

In [7]:
# calculate losses
print('calculating losses and metric...')    
rec_loss, kl_loss, pt_loss, nn_loss, h_loss, sp_loss = calc_losses(model, train_loader, loss_f)
s.reconstruction_loss = rec_loss
s.kl_normal_loss = kl_loss
s.pt_local_independence_loss = pt_loss
s.nearest_neighbor_loss = nn_loss
s.hessian_loss = h_loss
s.sparsity_loss = sp_loss
# s.disentanglement_metric = calc_disentangle_metric(model, test_loader).mean().item()
s.net = model    

calculating losses and metric...


In [8]:
print(s.reconstruction_loss, s.kl_normal_loss, s.pt_local_independence_loss, s.nearest_neighbor_loss, s.hessian_loss, s.sparsity_loss)

2464.6376222398544 0.009973166474891413 8.981220646673036e-06 0.03372489919032281 0.0 0.024180983006954194
