In [396]:
import torch

import numpy as np
from math import gamma

def random_choice_full(input, n_samples, number_of_gausses):
        from torch import multinomial, ones
        if n_samples * number_of_gausses < input.shape[0]:
            replacement = False
        else:
            replacement = True
        idx = multinomial(ones(input.shape[0]), n_samples * number_of_gausses, replacement=replacement)
        sampled = input[idx].reshape(number_of_gausses, n_samples, -1)
        return torch.mean(sampled, axis=1)
#         return sampled

def provide_weights_for_x(x, how=None, device=None, n_samples=None, times=None):
    dim = x.shape[1]
    
    if n_samples is None:
        n_samples = dim
    if times is None:
        times = dim
        
    scale = (1 / dim)
    sampled_points = random_choice_full(x, n_samples, times)

    if how == "gauss":
        from torch.distributions import MultivariateNormal
        
        cov_mat = (scale * torch.eye(dim)).repeat(dim, 1, 1)
        mvn = MultivariateNormal(loc=sampled_points.to(device), covariance_matrix=cov_mat.to(device))
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
        
    elif how == "sqrt":
        weight_vector = torch.sqrt(1 + sampled_points.reshape(-1, 1, dim).to(device) ** 2) ** (-1)
        
    elif how == "log":
        weight_vector = torch.log(1 + sampled_points.reshape(-1, 1, dim).to(device)**2)
        
    elif how == "TStudent":
        from torch.distributions.studentT import StudentT
        
        mvn = StudentT(df=1, loc=x.mean(0), scale=scale)
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
        # to trzeba poprawić ?!
    elif how == "Cauchy":
        from torch.distributions.cauchy import Cauchy
        
        mvn = Cauchy(loc=x.mean(0), scale=1)
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
    elif how == "Gumbel":
        from torch.distributions.gumbel import Gumbel
        
        mvn = Gumbel(loc=x.mean(0), scale=1)
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
    elif how == "Laplace":
        from torch.distributions.laplace import Laplace
        
        mvn = Laplace(loc=x.mean(0), scale=1)
        weight_vector = torch.exp(mvn.log_prob(x.reshape(-1, 1, dim).to(device)))
    return sampled_points, weight_vector
    
class WICA(object):
    def __init__(self):
        self.number_of_gausses = 10
        self.z_dim = 5
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def wica_loss(self, z, latent_normalization=False, how="gauss"):
        if latent_normalization:
            x = (z - z.mean(dim=1, keepdim=True)) / z.std(dim=1, keepdim=True)
        else:
            x = z
        dim = self.z_dim if self.z_dim is not None else x.shape[1]


        _, weight_vector = provide_weights_for_x(
                x=x, 
                how=how,
                device = self.device
        )
                        
        sum_of_weights = torch.sum(weight_vector, axis=0)
        
        weight_sum = torch.sum(x.reshape(1,x.shape[0], x.shape[1]) * weight_vector, axis=0)

        weight_mean = weight_sum / sum_of_weights

        xm = x - weight_mean
        wxm = torch.sum(xm.reshape(1,xm.shape[0], xm.shape[1]) * weight_vector, axis=0)

        wcov = (wxm.reshape(1 ,wxm.shape[0], wxm.shape[1]).permute(0, 2, 1).matmul(xm)) / sum_of_weights

        diag = torch.diagonal(wcov ** 2, dim1=1, dim2=2)
        diag_pow_plus = diag.reshape(diag.shape[0], diag.shape[1], -1) + diag.reshape(diag.shape[0], -1, diag.shape[1])

        tmp = (2 * wcov ** 2 / diag_pow_plus)

        triu = torch.triu(tmp, diagonal=1)
        normalize = 2.0 / (dim * (dim - 1))
        cost = torch.sum(normalize * triu) / self.number_of_gausses
        return cost

In [399]:
_, weight_vector = provide_weights_for_x(t, 'sqrt')
sum_of_weights = torch.sum(weight_vector, axis=0)

weight_sum = torch.sum(t.reshape(1,t.shape[0], t.shape[1])*weight_vector, axis=0)
weight_mean = weight_sum / sum_of_weights
xm = t-weight_mean
wxm = torch.sum(xm.reshape(1,xm.shape[0], xm.shape[1])*weight_vector, axis=0)

wcov = (wxm.reshape(1 ,wxm.shape[0], wxm.shape[1]).permute(0, 2, 1).matmul(xm)) / sum_of_weights
diag = torch.diagonal(wcov ** 2, dim1=1, dim2=2)

diag_pow_plus = diag.reshape(diag.shape[0], diag.shape[1], -1) + diag.reshape(diag.shape[0], -1, diag.shape[1])
tmp = (2 * wcov ** 2 / diag_pow_plus)

triu = torch.triu(tmp, diagonal=1)
normalize = 2.0 / (3 * (3 - 1))
cost = torch.sum(normalize * triu) / 5
cost

tensor(0.0011)

In [400]:
wica = WICA()
t = torch.sin(torch.randn(1200)).resize(300,4)
t

tensor([[ 0.1773, -0.6457, -0.8364, -0.9291],
        [ 0.2543, -0.0167, -0.9255,  0.7730],
        [ 0.9994,  0.2729,  0.0856,  0.6789],
        ...,
        [-0.5104, -0.9665,  0.0550,  0.8703],
        [-0.1197, -0.4765,  0.8096,  0.9986],
        [-0.9210, -0.2407,  0.9671, -0.8304]])

In [325]:
dim = t.shape[1]

sampled_points = random_choice_full(t, 1, 5)
weight_vector = torch.log(1 + sampled_points.reshape(-1, 1, dim)**2)
print("sampled_points: \n", sampled_points)
print("weight_vector: \n", weight_vector)
print("t: \n", t)
# własciwa funkcja
# weight_sum = t.T.matmul(weight_vector)


sampled_points: 
 tensor([[[-0.1189, -0.9492,  0.8771]],

        [[ 0.9893, -0.7897,  0.0857]],

        [[-0.2262, -0.6038, -0.9691]],

        [[-0.3576, -0.4485, -0.9810]],

        [[ 0.8463,  0.6288, -0.0036]]])
weight_vector: 
 tensor([[[1.4048e-02, 6.4240e-01, 5.7057e-01]],

        [[6.8247e-01, 4.8469e-01, 7.3247e-03]],

        [[4.9899e-02, 3.1083e-01, 6.6221e-01]],

        [[1.2031e-01, 1.8329e-01, 6.7419e-01]],

        [[5.4013e-01, 3.3318e-01, 1.3113e-05]]])
t: 
 tensor([[-0.6933,  0.3938,  0.8895],
        [-0.5661, -0.7051,  0.8036],
        [-0.2262, -0.6038, -0.9691],
        [ 0.3380,  0.9316, -0.2815],
        [ 0.4570,  0.1752,  0.9920],
        [-0.9608, -0.6589,  0.4055],
        [ 0.9893, -0.7897,  0.0857],
        [ 0.8710,  0.9984, -0.5808],
        [ 0.9694,  0.9705, -0.9842],
        [ 0.9922, -0.3587,  0.9572],
        [ 0.0845,  0.7502,  0.9995],
        [-0.9964,  0.6247, -0.3304],
        [ 0.7349, -0.7074,  0.8904],
        [-0.9408, -0.8618,  0.7585

In [408]:
for i in range(10):
    print(1000*wica.wica_loss(t, latent_normalization=True, how="sqrt"))

tensor(0.0920)
tensor(0.1662)
tensor(0.4097)
tensor(0.3237)
tensor(0.0496)
tensor(0.1337)
tensor(0.2578)
tensor(0.7654)
tensor(0.2746)
tensor(1.8804)


In [406]:
print(t.reshape(1,10,2)[0])
print(weight_vector)

t.reshape(1,10,2) * weight_vector
# t.reshape(1,5,2).matmul(weight_vector)

RuntimeError: shape '[1, 10, 2]' is invalid for input of size 1200

In [315]:
for i in range(10):
    print(wica.wica_loss(t, latent_normalization=True, how="sqrt"))

UnboundLocalError: local variable 'sampled_points' referenced before assignment

In [413]:
import numpy as np
import pandas as pd
import torch

import sys  
sys.path.insert(0, '/Users/andrzej/Personal/Projects/disentanglement-pytorch')
from models.ae import AEModel

In [425]:
from architectures import encoders, decoders

encoder_name = "SimpleConv64"
decoder_name = "SimpleConv64"

encoder = getattr(encoders, encoder_name)
decoder = getattr(decoders, decoder_name)

model = AEModel(encoder(8, 1, 64), decoder(8, 1, 64)).to(torch.device('cpu'))

checkpoint = torch.load('/Users/andrzej/Personal/results/sqrt-first-batch-wica/last', map_location=torch.device('cpu'))

model.load_state_dict(checkpoint['model_states']['G'])
model.eval()

AEModel(
  (encoder): SimpleConv64(
    (main): Sequential(
      (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): ReLU(inplace=True)
      (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Flatten3D()
      (13): Linear(in_features=256, out_features=8, bias=True)
    )
  )
  (decoder): SimpleConv64(
    (main): Sequential(
      (0): Unsqueeze3D()
      (1): Conv2d(8, 256, kernel_size=(1, 1), stride=(2, 2))
      (2): ReLU(inplace=True)
      (3): ConvTranspose2d(2

In [439]:
from common.data_loader import get_dataloader

train_loader = get_dataloader('dsprites_full', '/Users/andrzej/Personal/Projects/data/test_dsets', 3,
                              123, num_workers=1, pin_memory=True, image_size=64, 
                              include_labels=None, shuffle=True, droplast=False)

In [433]:
x, y = next(iter(train_loader))
a,b = model(x)

In [420]:
import torchvision.utils
def visualize_recon(input_image, recon_image):
        input_image = torchvision.utils.make_grid(input_image)
        recon_image = torchvision.utils.make_grid(recon_image)

        white_line = torch.ones((3, input_image.size(1), 10)).to('cpu')
        samples = torch.cat([input_image, white_line, recon_image], dim=2)

        torchvision.utils.save_image(samples, "test.png")


In [434]:
visualize_recon(x,a)

In [437]:
from common.utils import grid2gif, get_data_for_visualization, prepare_data_for_visualization
import os

z_dim = 8
l_dim = 0
traverse_z = True
traverse_c = False
num_labels = 0
image_size = 64
num_channels = train_loader.dataset.num_channels()

def set_z(z, latent_id, val):
    z[:, latent_id] = val

def encode_deterministic(**kwargs):
    images = kwargs['images']
    if len(images.size()) == 3:
        images = images.unsqueeze(0)
    return model.encode(images)

def decode_deterministic(**kwargs):
    latent = kwargs['latent']
    if len(latent.size()) == 1:
        latent = latent.unsqueeze(0)
    return model.decode(latent)

def visualize_traverse(limit: tuple, spacing, data=None, test=False):
    interp_values = torch.arange(limit[0], limit[1]+spacing, spacing)
    num_cols = interp_values.size(0)

    sample_images_dict, sample_labels_dict = prepare_data_for_visualization(next(iter(train_loader)))
    encodings = dict()

    for key in sample_images_dict.keys():
        encodings[key]  = encode_deterministic(images=sample_images_dict[key], labels=sample_labels_dict[key])

    gifs = []
    for key in encodings:
        latent_orig = encodings[key]
        label_orig = sample_labels_dict[key]
        print('latent_orig: {}, label_orig: {}'.format(latent_orig, label_orig))
        samples = []

        # encode original on the first row
        sample = decode_deterministic(latent=latent_orig, labels=label_orig)
        for _ in interp_values:
            samples.append(sample)
        for zid in range(z_dim):
            for val in interp_values:
                latent = latent_orig
                latent[:, zid] = val
                set_z(latent, zid, val)
                sample = decode_deterministic(latent=latent, labels=label_orig)

                samples.append(sample)
                gifs.append(sample)
                    
        samples = torch.cat(samples, dim=0).cpu()
        samples = torchvision.utils.make_grid(samples, nrow=num_cols)
        
        file_name = os.path.join(".", '{}_{}.{}'.format("traverse", key, "png"))
        torchvision.utils.save_image(samples, file_name)
        
    total_rows = num_labels * l_dim + \
                 z_dim * int(traverse_z) + \
                 num_labels * int(traverse_c)
    gifs = torch.cat(gifs)
    gifs = gifs.view(len(encodings), total_rows, num_cols,
                     num_channels, image_size, image_size).transpose(1, 2)
    for i, key in enumerate(encodings.keys()):
        for j, val in enumerate(interp_values):
            file_name = \
                os.path.join('.', '{}_{}_{}.{}'.format('tmp', key, str(j).zfill(2), '.png'))
            torchvision.utils.save_image(tensor=gifs[i][j].cpu(),
                                         filename=file_name,
                                         nrow=total_rows, pad_value=1)
            
        file_name = os.path.join('.', '{}_{}.{}'.format('traverse', key, 'gif'))

        grid2gif(str(os.path.join('.', '{}_{}*.{}').format('tmp', key, 'png')),
                 file_name, delay=10)

        # Delete temp image files
        for j, val in enumerate(interp_values):
            os.remove(
                os.path.join('.', '{}_{}_{}.{}'.format('tmp', key, str(j).zfill(2), '.png')))
    return samples

In [442]:
min_ = -50
max_ = 50
spacing_ = 5
samples = visualize_traverse(limit=(min_,max_), spacing=spacing_)

latent_orig: tensor([[-18.7481,   2.6798,  16.7035,  40.3148,  39.0097,  20.5979,  57.7464,
         -64.9600]], grad_fn=<AddmmBackward>), label_orig: tensor([0])
latent_orig: tensor([[-59.9391,  72.4069,  -4.9910,   9.8507,  15.5894, -51.6062, -55.8661,
           8.1131]], grad_fn=<AddmmBackward>), label_orig: tensor([0])
latent_orig: tensor([[ 47.1662,  -7.9580,  11.1962, -23.2260,  53.1356,  52.4977,  54.3462,
         -66.5915]], grad_fn=<AddmmBackward>), label_orig: tensor([0])
