In [9]:
import torch

In [10]:
class VADE:
    def __init__(self, batch_size=3, latent_dim=2, n_centroid=6):
        self.batch_size = batch_size
        self.latent_dim = latent_dim
        self.n_centroid = n_centroid
        self.theta_p = torch.ones(self.n_centroid) / self.n_centroid
        self.u_p = torch.zeros((self.latent_dim, self.n_centroid))
        self.lambda_p = torch.ones((self.latent_dim, self.n_centroid))
        return

    def get_gamma(self, temp_z: torch.Tensor):
        t_z_shape = temp_z.shape
        z_temp = temp_z.unsqueeze(2).expand(t_z_shape[0], t_z_shape[1], self.n_centroid)

        u_p_temp = self.u_p.unsqueeze(0).expand(t_z_shape[0], -1, -1)
        lambda_p_temp = self.lambda_p.unsqueeze(0).expand(t_z_shape[0], -1, -1)
        theta_p_temp = self.theta_p.reshape(1, 1, self.theta_p.shape[0]) * torch.ones(
            t_z_shape[0], t_z_shape[1], self.n_centroid
        )

        temp_p_c_z = (
            torch.exp(
                torch.sum(
                    (
                        torch.log(theta_p_temp)
                        - 0.5 * torch.log(2 * torch.pi * lambda_p_temp)
                        - torch.square(z_temp - u_p_temp) / (2 * lambda_p_temp)
                    ),
                    dim=1,
                )
            )
            + 1e-10
        )
        return temp_p_c_z / torch.sum(temp_p_c_z, dim=-1, keepdim=True)

    # TODO: finish it someday
    def vade_loss(self, z: torch.Tensor, x: torch.Tensor, x_decoded: torch.Tensor):
        z_temp = z.unsqueeze(2).expand(z.shape[0], z.shape[1], self.n_centroid)
        u_p_temp = self.u_p.unsqueeze(0).expand(z.shape[0], -1, -1)
        lambda_p_temp = self.lambda_p.unsqueeze(0).expand(z.shape[0], -1, -1)
        theta_p_temp = self.theta_p.reshape(1, 1, self.theta_p.shape[0]) * torch.ones(
            z.shape[0], z.shape[1], n_centroid
        )
        return 0

    # WIP, absolutely not tested
    def vae_loss(self, x, x_decoded_mean, z_mean, z_log_var):
        z_temp = z.unsqueeze(2).expand(z.shape[0], z.shape[1], self.n_centroid)
        #z_mean_temp = z_mean.unsqueeze(2).expand(z_mean.shape[0], z_mean.shape[1], self.n_centroid)
        #z_log_var_temp = z_log_var.unsqueeze(2).expand(z_log_var.shape[0], z_log_var.shape[1], self.n_centroid)
        u_p_temp = self.u_p.unsqueeze(0).expand(z.shape[0], -1, -1)
        lambda_p_temp = self.lambda_p.unsqueeze(0).expand(z.shape[0], -1, -1)
        theta_p_temp = self.theta_p.reshape(1, 1, self.theta_p.shape[0]) * torch.ones(z.shape[0], z.shape[1], self.n_centroid)
        p_c_z=torch.exp(torch.sum((torch.log(theta_p_temp)-0.5*torch.log(2*torch.pi*lambda_p_temp)-torch.square(z_temp-u_p_temp)/(2*lambda_p_temp)),axis=1))+1e-10

        gamma=p_c_z/torch.sum(p_c_z,axis=-1,keepdim=True)
        gamma_t=gamma.repeat(z.shape[1])
        
        if self.datatype == 'sigmoid':
            func = torch.nn.functional.binary_cross_entropy
        else:
            func = torch.nn.functional.mse_loss

        loss=self.alpha*self.original_dim * func(x, x_decoded_mean)\
        +torch.sum(0.5*gamma_t*(self.latent_dim*torch.log(torch.pi*2)+torch.log(lambda_p_temp)+torch.exp(z_log_var_t)/lambda_p_temp+torch.square(z_mean_t-u_p_temp)/lambda_p_temp),axis=(1,2))\
        -0.5*torch.sum(z_log_var+1,axis=-1)\
        -torch.sum(torch.log(K.repeat_elements(self.theta_p.dimshuffle('x',0),self.batch_size,0))*gamma,axis=-1)\
        +torch.sum(torch.log(gamma)*gamma,axis=-1)
        return loss

In [11]:
test_tensor = torch.tensor([[1,2],[4,5],[7,8]])
test_tensor

tensor([[1, 2],
        [4, 5],
        [7, 8]])

In [12]:
vade = VADE()
vade.get_gamma(test_tensor)

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])

# Code below wasn't changed

In [13]:
batch_size=3
latent_dim=2
n_centroid=6

In [14]:
def gmmpara_init(latent_dim: int,n_centroid: int):

    theta_init=torch.ones(n_centroid)/n_centroid
    u_init=torch.zeros((latent_dim,n_centroid))
    lambda_init=torch.ones((latent_dim,n_centroid))

    return theta_init, u_init, lambda_init

In [15]:
def get_gamma(temp_z: torch.Tensor, u_p: torch.Tensor, lambda_p:torch.Tensor, theta_p: torch.Tensor, n_centroid: int):
    z_temp = temp_z.unsqueeze(2).expand(temp_z.shape[0], temp_z.shape[1], n_centroid)

    u_p_temp = u_p.unsqueeze(0).expand(temp_z.shape[0], -1, -1)
    lambda_p_temp = lambda_p.unsqueeze(0).expand(temp_z.shape[0], -1, -1)
    theta_p_temp = theta_p.reshape(1,1,theta_p.shape[0])*torch.ones(temp_z.shape[0], temp_z.shape[1], n_centroid)

    temp_p_c_z = torch.exp(torch.sum((torch.log(theta_p_temp)-0.5*torch.log(2*torch.pi*lambda_p_temp)-torch.square(z_temp-u_p_temp)/(2*lambda_p_temp)),dim=1))+1e-10
    return temp_p_c_z/torch.sum(temp_p_c_z,dim=-1,keepdim=True)

In [16]:
#TODO: finish it someday
def vade_loss(z: torch.Tensor, x: torch.Tensor, x_decoded: torch.Tensor, u_p: torch.Tensor, lambda_p: torch.Tensor, theta_p: torch.Tensor, n_centroid: int):
    z_temp = z.unsqueeze(2).expand(z.shape[0], z.shape[1], n_centroid)


    u_p_temp = u_p.unsqueeze(0).expand(z.shape[0], -1, -1)
    lambda_p_temp = lambda_p.unsqueeze(0).expand(z.shape[0], -1, -1)
    theta_p_temp = theta_p.reshape(1,1,theta_p.shape[0])*torch.ones(z.shape[0], z.shape[1], n_centroid)

In [17]:
theta_init, u_init, lambda_init = gmmpara_init(latent_dim,n_centroid)

In [18]:
test_tensor = torch.tensor([[1,2],[4,5],[7,8]])
test_tensor

tensor([[1, 2],
        [4, 5],
        [7, 8]])

In [19]:
get_gamma(test_tensor,u_init,lambda_init,theta_init,n_centroid)

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])