In [1]:
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
# import acd
from random import randint
from copy import deepcopy
import pickle as pkl
import argparse

sys.path.append('../../src/vae')
sys.path.append('../../src/vae/models')
sys.path.append('../../src/dsets/images')
from dset import get_dataloaders
from model import init_specific_model
from losses import get_loss_f
from training import Trainer

sys.path.append('../../lib/trim')
# trim modules
from trim import DecoderEncoder

sys.path.append('../../lib/disentangling-vae')
import main

### Train model

In [2]:
args = main.parse_arguments()
args.dataset = "dsprites"
args.model_type = "Burgess"
args.latent_dim = 10
args.img_size = (1, 64, 64)
args.rec_dist = "bernoulli"
args.reg_anneal = 0
args.beta = 0
args.lamPT = 1
args.lamNN = 0.1
args.lamH = 0
args.lamSP = 1

In [3]:
class p:
    '''Parameters for Gaussian mixture simulation
    '''
    # parameters for generating data
    seed = 13
    dataset = "dsprites"
    
    # parameters for model architecture
    model_type = "Burgess"
    latent_dim = 10 
    img_size = (1, 64, 64)
    
    # parameters for training
    train_batch_size = 64
    test_batch_size = 100
    lr = 1e-4
    rec_dist = "bernoulli"
    reg_anneal = 0
    num_epochs = 100
    
    # hyperparameters for loss
    beta = 0.0
    lamPT = 0.0
    lamNN = 0.0
    lamH = 0.0
    lamSP = 0.0
    
    # parameters for exp
    warm_start = None # which parameter to warm start with respect to
    seq_init = 1      # value of warm_start parameter to start with respect to
    
    # SAVE MODEL
    out_dir = "/home/ubuntu/local-vae/notebooks/ex_dsprites/results" # wooseok's setup
#     out_dir = '/scratch/users/vision/chandan/local-vae' # chandan's setup
    dirname = "vary"
    pid = ''.join(["%s" % randint(0, 9) for num in range(0, 10)])

    def _str(self):
        vals = vars(p)
        return 'beta=' + str(vals['beta']) + '_lamPT=' + str(vals['lamPT']) + '_lamNN=' + str(vals['lamNN']) + '_lamSP=' + str(vals['lamSP']) \
                + '_seed=' + str(vals['seed']) + '_pid=' + vals['pid']
    
    def _dict(self):
        return {attr: val for (attr, val) in vars(self).items()
                 if not attr.startswith('_')}

In [4]:
for arg in vars(args):
    setattr(p, arg, getattr(args, arg))
    
# create dir
out_dir = opj(p.out_dir, p.dirname)
os.makedirs(out_dir, exist_ok=True)  

# seed
random.seed(p.seed)
np.random.seed(p.seed)
torch.manual_seed(p.seed)

# get dataloaders
train_loader = get_dataloaders(p.dataset,
                               batch_size=p.train_batch_size,
                               logger=None)

# prepare model
model = init_specific_model(model_type=p.model_type, 
                            img_size=p.img_size,
                            latent_dim=p.latent_dim,
                            hidden_dim=None).to(device)

# train
optimizer = torch.optim.Adam(model.parameters(), lr=p.lr)
loss_f = get_loss_f(decoder=model.decoder, **vars(p))
trainer = Trainer(model, optimizer, loss_f, device=device)

In [5]:
trainer(train_loader, epochs=p.num_epochs)

====> Epoch: 0 Average train loss: 345.9500
====> Epoch: 1 Average train loss: 159.0704
====> Epoch: 2 Average train loss: 140.5256
====> Epoch: 3 Average train loss: 133.0484
====> Epoch: 4 Average train loss: 130.1644
====> Epoch: 5 Average train loss: 128.5116
====> Epoch: 6 Average train loss: 127.3136
====> Epoch: 7 Average train loss: 126.2052
====> Epoch: 8 Average train loss: 125.6266
====> Epoch: 9 Average train loss: 125.1243
====> Epoch: 10 Average train loss: 124.6516
====> Epoch: 11 Average train loss: 124.4359
====> Epoch: 12 Average train loss: 123.9639
====> Epoch: 13 Average train loss: 123.5313
====> Epoch: 14 Average train loss: 123.3407
====> Epoch: 15 Average train loss: 123.2088
====> Epoch: 16 Average train loss: 122.8320
====> Epoch: 17 Average train loss: 122.6288
====> Epoch: 18 Average train loss: 122.3611
====> Epoch: 19 Average train loss: 122.2580
====> Epoch: 20 Average train loss: 121.9465
====> Epoch: 21 Average train loss: 121.7966
====> Epoch: 22 Aver