In [5]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os,sys
opj = os.path.join
from tqdm import tqdm
import acd
from random import randint
from copy import deepcopy
from model import init_specific_model
from losses import Loss
from dset import *
from training import Trainer
from utils import *
import pickle as pkl
import itertools

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# seed
random.seed(p.seed)
np.random.seed(p.seed)
torch.manual_seed(p.seed)

# GET DATALOADERS
(train_loader, train_latents), (test_loader, test_latents) = define_dataloaders(p)

# PREPARES MODEL
model = init_specific_model(orig_dim=p.orig_dim, latent_dim=p.latent_dim, hidden_dim=p.hidden_dim)
model = model.to(device)

# TRAINS
optimizer = torch.optim.Adam(model.parameters(), lr=p.lr)
beta = p.beta
attr = p.attr
alpha = p.alpha
gamma = p.gamma
tc = p.tc
num_epochs = p.num_epochs

loss_f = Loss(beta=beta, attr=attr, alpha=alpha, gamma=gamma, tc=tc, is_mss=True)
trainer = Trainer(model, optimizer, loss_f, device=device)

In [None]:
trainer(train_loader, test_loader, epochs=num_epochs)

In [None]:
rec_loss, kl_loss, mi_loss, tc_loss, dw_kl_loss, attr_loss = calc_losses(model, test_loader, loss_f)
s.reconstruction_loss = rec_loss
s.kl_normal_loss = kl_loss
s.total_correlation = tc_loss
s.mutual_information = mi_loss
s.dimensionwise_kl_loss = dw_kl_loss
s.attribution_loss = attr_loss
s.disentanglement_metric = calc_disentangle_metric(model, test_loader).mean()
s.net = model

In [None]:
plt.hist(calc_disentangle_metric(model, test_loader))

In [None]:
s.disentanglement_metric

In [None]:
# EVALUATE TEST DATA
data = test_loader.dataset.data.to(device)
recon_data, latent_dist, latent_sample = model(data)

plot_2d_samples(latent_sample.detach().cpu())
plt.title('Estimated latent variables')
plt.show()

In [None]:
plot_2d_samples(data.detach().cpu()[:,:2])
plot_2d_samples(recon_data.detach().cpu()[:,:2])
plt.title('Original and reconstructed data after projection')
plt.show()

In [None]:
ind = 5000
plot_2d_samples(data.detach().cpu()[:,:2])
plot_2d_samples(recon_data.detach().cpu()[:,:2])
plt.title('Original and reconstructed data after projection')
pt = data[ind:ind+1][:,:2]
plt.annotate("x", pt[0], size=15)
plt.show()

# GET TRAVERSAL
decoded_traversal = traversals(model, data=data[ind:ind+1], n_latents=4)
# PROJECTION to FIRST TWO COORDINATES
decoded_traversal0 = decoded_traversal[:,:2]

In [None]:
plot_2d_samples(decoded_traversal0[:100])

In [None]:
plot_2d_samples(decoded_traversal0[100:200])

In [None]:
plot_2d_samples(decoded_traversal0[200:300])

In [None]:
plot_2d_samples(decoded_traversal0[300:400])