In [357]:
import torch

import numpy as np
from math import gamma

def random_choice_full(input, n_samples, number_of_gausses):
        from torch import multinomial, ones
        if n_samples * number_of_gausses < input.shape[0]:
            replacement = False
        else:
            replacement = True
        idx = multinomial(ones(input.shape[0]), n_samples * number_of_gausses, replacement=replacement)
        sampled = input[idx].reshape(number_of_gausses, n_samples, -1)
        return torch.mean(sampled, axis=1)

def provide_weights_for_x(x, how = None, device = None):
    dim = x.shape[1]
    scale = (1 / dim)

    if how == "gauss":
        from torch.distributions import MultivariateNormal
        
        sampled_points = random_choice_full(x, dim, dim)
        cov_mat = (scale * torch.eye(dim)).repeat(dim, 1, 1)
        mvn = MultivariateNormal(loc=sampled_points.to(device), covariance_matrix=cov_mat.to(device))
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
        
    elif how == "sqrt":
        weight_vector = 1 / torch.sqrt(1 + x.reshape(-1, 1, dim).to(device)**2)
        
    elif how == "log":
        weight_vector = torch.log(1 + x.reshape(-1, 1, dim).to(device)**2)
        
    elif how == "TStudent":
        from torch.distributions.studentT import StudentT
        
        mvn = StudentT(df=1, loc=x.mean(0), scale=scale)
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
        # to trzeba poprawić ?!
    elif how == "Cauchy":
        from torch.distributions.cauchy import Cauchy
        
        mvn = Cauchy(loc=x.mean(0), scale=1)
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
    elif how == "Gumbel":
        from torch.distributions.gumbel import Gumbel
        
        mvn = Gumbel(loc=x.mean(0), scale=1)
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
    elif how == "Laplace":
        from torch.distributions.laplace import Laplace
        
        mvn = Laplace(loc=x.mean(0), scale=1)
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
    return weight_vector
    
class WICA(object):
    def __init__(self):
        self.number_of_gausses = 5
        self.z_dim = 5
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def wica_loss(self, z, latent_normalization=False, how="gauss"):
        if latent_normalization:
            x = (z - z.mean(dim=1, keepdim=True)) / z.std(dim=1, keepdim=True)
        else:
            x = z
        dim = self.z_dim if self.z_dim is not None else x.shape[1]


        weight_vector = provide_weights_for_x(
                x=x, 
                how=how,
                device = self.device
        )
                
        print(weight_vector.shape)
        
        sum_of_weights = torch.sum(weight_vector, axis=0)
        weight_sum = torch.sum(x * weight_vector.T.reshape(self.number_of_gausses, -1, 1), axis=1)
        weight_mean = weight_sum / sum_of_weights.reshape(-1, 1)

        xm = x - weight_mean.reshape(self.number_of_gausses, 1, -1)
        wxm = xm * weight_vector.T.reshape(self.number_of_gausses, -1, 1)

        wcov = (wxm.permute(0, 2, 1).matmul(xm)) / sum_of_weights.reshape(-1, 1, 1)

        diag = torch.diagonal(wcov ** 2, dim1=1, dim2=2)
        diag_pow_plus = diag.reshape(diag.shape[0], diag.shape[1], -1) + diag.reshape(diag.shape[0], -1, diag.shape[1])

        tmp = (2 * wcov ** 2 / diag_pow_plus)

        triu = torch.triu(tmp, diagonal=1)
        normalize = 2.0 / (dim * (dim - 1))
        cost = torch.sum(normalize * triu) / self.number_of_gausses
        return cost

In [358]:
wica = WICA()

In [359]:
t = torch.sin(torch.randn(100)).resize(20,5)
t.mean(0)

tensor([-0.0815, -0.1086,  0.2232, -0.0168, -0.0041])

In [363]:
for i in range(10):
    print(wica.wica_loss(t, latent_normalization=False, how="TStudent"))

torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)
torch.Size([20, 1, 5])
tensor(0.0530)


In [308]:
for i in range(10):
    print(wica.wica_loss(t, latent_normalization=True, how="sqrt"))

torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)
torch.Size([20, 1, 5])
tensor(0.1065)


In [364]:
import numpy as np
import pandas as pd
import torch

import sys  
sys.path.insert(0, '/Users/andrzej/Personal/Projects/disentanglement-pytorch')
from models.vae import VAEModel

In [368]:
from architectures import encoders, decoders

encoder_name = "SimpleGaussianConv64"
decoder_name = "SimpleConv64"

encoder = getattr(encoders, encoder_name)
decoder = getattr(decoders, decoder_name)

model = VAEModel(encoder(8, 1, 64), decoder(8, 1, 64)).to(torch.device('cpu'))

checkpoint = torch.load('/Users/andrzej/Personal/results/multiple-wicas/last', map_location=torch.device('cpu'))

model.load_state_dict(checkpoint['model_states']['G'])
model.eval()

VAEModel(
  (encoder): SimpleGaussianConv64(
    (main): Sequential(
      (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): ReLU(inplace=True)
      (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Flatten3D()
      (13): Linear(in_features=256, out_features=16, bias=True)
    )
  )
  (decoder): SimpleConv64(
    (main): Sequential(
      (0): Unsqueeze3D()
      (1): Conv2d(8, 256, kernel_size=(1, 1), stride=(2, 2))
      (2): ReLU(inplace=True)
      (3): ConvTra

In [387]:
from common.data_loader import get_dataloader

train_loader = get_dataloader('dsprites_full', '/Users/andrzej/Personal/Projects/data/test_dsets', 3,
                              123, num_workers=1, pin_memory=True, image_size=64, 
                              include_labels=None, shuffle=True, droplast=False)

In [381]:
x, y = next(iter(train_loader))
model(x)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0.

In [384]:
import torchvision.utils
def visualize_recon(input_image, recon_image):
        input_image = torchvision.utils.make_grid(input_image)
        recon_image = torchvision.utils.make_grid(recon_image)

        white_line = torch.ones((3, input_image.size(1), 10)).to('cpu')
        samples = torch.cat([input_image, white_line, recon_image], dim=2)

        torchvision.utils.save_image(samples, "test.png")


In [385]:
visualize_recon(x,model(x))

In [422]:
from common.utils import grid2gif, get_data_for_visualization, prepare_data_for_visualization
import os

z_dim = 8
l_dim = 0
traverse_z = True
traverse_c = False
num_labels = 0
image_size = 64
num_channels = train_loader.dataset.num_channels()

def set_z(z, latent_id, val):
    z[:, latent_id] = val

def encode_deterministic(**kwargs):
    images = kwargs['images']
    if len(images.size()) == 3:
        images = images.unsqueeze(0)
    return model.encode(images)

def decode_deterministic(**kwargs):
    latent = kwargs['latent']
    if len(latent.size()) == 1:
        latent = latent.unsqueeze(0)
    return model.decode(latent)

def visualize_traverse(limit: tuple, spacing, data=None, test=False):
    interp_values = torch.arange(limit[0], limit[1]+spacing, spacing)
    num_cols = interp_values.size(0)

    sample_images_dict, sample_labels_dict = prepare_data_for_visualization(next(iter(train_loader)))
    encodings = dict()

    for key in sample_images_dict.keys():
        encodings[key], _  = encode_deterministic(images=sample_images_dict[key], labels=sample_labels_dict[key])

    gifs = []
    for key in encodings:
        latent_orig = encodings[key]
        label_orig = sample_labels_dict[key]
        print('latent_orig: {}, label_orig: {}'.format(latent_orig, label_orig))
        samples = []

        # encode original on the first row
        sample = decode_deterministic(latent=latent_orig, labels=label_orig)
        for _ in interp_values:
            samples.append(sample)
        for zid in range(z_dim):
            for val in interp_values:
                latent = latent_orig
                latent[:, zid] = val
                set_z(latent, zid, val)
                sample = decode_deterministic(latent=latent, labels=label_orig)

                samples.append(sample)
                gifs.append(sample)
                    
        samples = torch.cat(samples, dim=0).cpu()
        samples = torchvision.utils.make_grid(samples, nrow=num_cols)
        
        file_name = os.path.join(".", '{}_{}.{}'.format("traverse", key, "png"))
        torchvision.utils.save_image(samples, file_name)
        
    total_rows = num_labels * l_dim + \
                 z_dim * int(traverse_z) + \
                 num_labels * int(traverse_c)
    gifs = torch.cat(gifs)
    gifs = gifs.view(len(encodings), total_rows, num_cols,
                     num_channels, image_size, image_size).transpose(1, 2)
    for i, key in enumerate(encodings.keys()):
        for j, val in enumerate(interp_values):
            file_name = \
                os.path.join('.', '{}_{}_{}.{}'.format('tmp', key, str(j).zfill(2), '.png'))
            torchvision.utils.save_image(tensor=gifs[i][j].cpu(),
                                         filename=file_name,
                                         nrow=total_rows, pad_value=1)
            
        file_name = os.path.join('.', '{}_{}.{}'.format('traverse', key, 'gif'))

        grid2gif(str(os.path.join('.', '{}_{}*.{}').format('tmp', key, 'png')),
                 file_name, delay=10)

        # Delete temp image files
        for j, val in enumerate(interp_values):
            os.remove(
                os.path.join('.', '{}_{}_{}.{}'.format('tmp', key, str(j).zfill(2), '.png')))
    return samples

In [426]:
min_ = -3
max_ = 3
spacing_ = 0.1
samples = visualize_traverse(limit=(min_,max_), spacing=spacing_)

latent_orig: tensor([[-0.0428, -0.0960,  0.0179, -1.3545,  1.4628,  1.5551,  0.7217, -0.6405]],
       grad_fn=<SliceBackward>), label_orig: tensor([0])
latent_orig: tensor([[ 0.0060, -0.0571,  0.0738,  1.2968,  0.0695, -0.4443, -0.0866, -1.1218]],
       grad_fn=<SliceBackward>), label_orig: tensor([0])
latent_orig: tensor([[-0.0310, -0.0670,  0.0620, -0.7422, -1.1748, -0.1942,  1.5837, -1.1176]],
       grad_fn=<SliceBackward>), label_orig: tensor([0])
