<a href="https://colab.research.google.com/github/RyanBalshaw/VAEs-with-Conditioning/blob/main/VAEs_with_clustering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

VaDE versus CURL:


---


The differences are minor, you can easily show that the two are the same. Write this up at some point for reference.

VaDE objective function:
\begin{equation}
L_{VaDE} = \mathbb{E}_{q(\mathbf{z}\vert\mathbf{x})}\left[ \log p(\mathbf{x}\vert\mathbf{z}) \right] - \mathbb{E}_{q(\mathbf{y}\vert\mathbf{x})}\left[ KL(q(\mathbf{z}\vert\mathbf{x}))\Vert p(\mathbf{z}\vert \mathbf{y}))\right] - KL(q(\mathbf{y}\vert\mathbf{x})\Vert p(\mathbf{y}))
\end{equation}

CURL objective function:
\begin{equation}
L_{CURL} = \mathbb{E}_{q(\mathbf{z}\vert\mathbf{x}, \mathbf{y})q(\mathbf{y}\vert\mathbf{x})}\left[ \log p(\mathbf{x}\vert\mathbf{z}) \right] - \mathbb{E}_{q(\mathbf{y}\vert\mathbf{x})}\left[ KL(q(\mathbf{z}\vert\mathbf{x}, \mathbf{y}))\Vert p(\mathbf{z}\vert \mathbf{y}))\right] - KL(q(\mathbf{y}\Vert\mathbf{x})\vert p(\mathbf{y}))
\end{equation}

Important notes:
- Decoder: $p(\mathbf{x}\vert\mathbf{z}) \sim \mathcal{N}(\mathbf{x}\vert \mathbf{\mu}(\mathbf{z}), \mathbf{\sigma}^2(\mathbf{z})\mathbf{I})$
- Encoder: $q_{VaDE}(\mathbf{z}\vert \mathbf{x}) \sim \mathcal{N}(\mathbf{z}\vert \mathbf{\mu}_{\mathbf{z}}(\mathbf{x}), \mathbf{\sigma}_{\mathbf{z}}^2(\mathbf{x})\mathbf{I})$ OR $q_{CURL}(\mathbf{z}\vert \mathbf{x}, \mathbf{y}) \sim \mathcal{N}(\mathbf{z}\vert \mathbf{\mu}_{\mathbf{z}}(\mathbf{x}, \mathbf{y}), \mathbf{\sigma}_{\mathbf{z}}^2(\mathbf{x}, \mathbf{y})\mathbf{I})$
- Prior $p(\mathbf{z}\vert\mathbf{y}) \sim \mathcal{N}(\mathbf{z}\vert \mathbf{\mu}_{\mathbf{z}}(\mathbf{y}), \mathbf{\sigma}_{\mathbf{z}}^2(\mathbf{y})\mathbf{I}))$

- Prior $p(\mathbf{y}) \sim Cat(\mathbf{\pi})$ = $\prod_{k=1}^{K}\pi_k^{z_k}$

The final component is to define $q(\mathbf{y}\vert\mathbf{x})$. VaDE and CURL take vastly different approaches:

for VaDE:
$q(\mathbf{y}\vert\mathbf{x}) = p(\mathbf{y}\vert\mathbf{z}) = \frac{p(\mathbf{y})p(\mathbf{z}\vert\mathbf{y})}{\sum_{i=1}^Kp(\mathbf{y})p(\mathbf{z}\vert\mathbf{y})},$

where this equation also features in linear mixture models. This term is the posterior probability for y given an observation.

For CURL:
$q(\mathbf{y}\vert\mathbf{x})$ is part of the encoder, with a softmax 'task inference' head. I like this formulation a little less as I am not convinced that it works well.

Why do I say this? Well I noted one potentially problematic area in how CURL estimates the 'categorical regulariser'. From their code, they take a batch and take the average of the argmax of the labels $\mathbf{y}$, essentially estimating the batch class likelihood. The problem here is that it is not given that a batch will have equal samples from each 'hidden class', so I am not sure how useful this will be when there are unequal spread in the classes. Maybe you need to perform some pre-training inference for the prior $p(\mathbf{y})$.

It is natural, for the continuous distribution $q(\mathbf{z}\vert\cdots)$, to take a Monte Carlo estimate. However, the addition of the distribution categorial distribution, any expectation over $q(\mathbf{y}\vert\mathbf{x})$ then becomes a summation over K indices of the term in the expectation weighted by $q(\mathbf{y}=i\vert\mathbf{x})$.

Furthermore, for the middle KL divergence term, both distributions in the KL divergence are Gaussian and thus the KL divergence becomes:
$KL(\mathcal{N}_0\Vert\mathcal{N}_1) = \frac{1}{2}\left( tr(\Sigma_1^{-1}\Sigma_0) + (\mathbf{\mu}_1 - \mathbf{\mu}_0)^T\Sigma_1^{-1}(\mathbf{\mu}_1 - \mathbf{\mu}_0) - k + \log\left(\frac{det\Sigma_1}{det\Sigma_2}\right) \right)$,

where $\mathcal{N}_0\sim\mathcal{N}(\mathbf{\mu}_0, \Sigma_0)$, $\mathcal{N}_1\sim\mathcal{N}(\mathbf{\mu}_1, \Sigma_1)$ and $k$ is the dimensionality of the space covered by the distribution. If $\Sigma_1$ and $\Sigma_2$ are parametrised as diagonal covariance distributions $\Sigma_0 = \mathbf{\sigma_0^2}\mathbf{I}$ and $\Sigma_1 = \mathbf{\sigma_1^2}\mathbf{I}$ then

$KL(\mathcal{N}_0\Vert\mathcal{N}_1) = \frac{1}{2}\left( \sum_{i} \frac{\sigma^2_{0,i}}{\sigma^2_{1,i}} + \sum_i\left(\frac{\mu_{1,i} - \mu_{0,i})^2}{\sigma^2_{1,i}}\right)  - k + \sum_i \log\left(\frac{\sigma^2_{1, i}}{\sigma^2_{0, i}}\right)  \right)$,
$KL(\mathcal{N}_0\Vert\mathcal{N}_1) = \frac{1}{2}\sum_{i}\left(  \frac{\sigma^2_{0,i}}{\sigma^2_{1,i}} + \left(\frac{\mu_{1,i} - \mu_{0,i})^2}{\sigma^2_{1,i}}\right)  - 1 +  \log\left(\frac{\sigma^2_{1, i}}{\sigma^2_{0, i}}\right)  \right)$,

Finally, the final KL divergence term can be expanded as follows:
$=\sum_{k=1}^K q(\mathbf{y}=k\vert\mathbf{x}) \log \left( \frac{q(\mathbf{y}=k\vert\mathbf{x})}{p(\mathbf{y}=k)} \right)$


Let's now focus on the final KL term $KL(q(\mathbf{y}\vert\mathbf{x})\Vert p(\mathbf{y}))$. Since we know that this regularises the posterior conditional probability (i.e. the conditional probability given a sample from $\mathbf{x}$, which is actually a $\mathbf{z}$ if you think about where the MoG lies), we need a method to evaluate the KL term. The expansion of the term is straightforward:

$KL(q(\mathbf{y}\vert\mathbf{x})\Vert p(\mathbf{y})) = \mathbb{E}_{q(\mathbf{y}\vert\mathbf{x}}[q(\mathbf{y}\vert\mathbf{x})\left( \log \frac{q(\mathbf{y}\vert\mathbf{x})}{p(\mathbf{y})} \right)]$

and since y is discrete:

$KL(q(\mathbf{y}\vert\mathbf{x})\Vert p(\mathbf{y})) = \sum_{k=1}^{K}q(\mathbf{y}=k\vert\mathbf{x})\left( \log \frac{p(\mathbf{y}=k\vert\mathbf{x})}{p(\mathbf{y}=k)} \right) = \sum_{k=1}^{K}q(\mathbf{y}=k\vert\mathbf{x})\left( \log \frac{p(\mathbf{y}=k\vert\mathbf{x})}{\pi_k} \right)$.

So, what do CURL and VaDE do? VaDE tries to estimate $p(\mathbf{y}=k\vert\mathbf{x})$ for each $\mathbf{x}$ while CURL uses a batch-estimated posterior and effectively directly parametrises $p(\mathbf{y}\vert\mathbf{x}) = \prod_{k=1}^{K} \gamma_{k}^{y_k}$, where $\gamma_{k}$ is the batch estimated class likelihood. 

Personally, I like what VaDE does more, but it is something to test out.

I just realised there is an alternative derivation, one which allows one to use a cross-entropy term. Let's dissect the KL term even more: since we know $\pi_k$ is a constant scalar, we can separate the terms nicely:

$ KL(q(\mathbf{y}\vert\mathbf{x})\Vert p(\mathbf{y})) = \sum_{k=1}^{K}q(\mathbf{y}=k\vert\mathbf{x})\left( \log p(\mathbf{y}=k\vert\mathbf{x}) \right) -  \sum_{k=1}^{K}q(\mathbf{y}=k\vert\mathbf{x})\log \left( pi_k \right) $

where the term on the right can be grouped with the $q(\mathbf{y}\vert\mathbf{x})$ terms in the previous objective functions. This leaves us with $\sum_{k=1}^{K}q(\mathbf{y}=k\vert\mathbf{x})\left( \log p(\mathbf{y}=k\vert\mathbf{x}) \right)$. What do we do with this? Well if we think about it a little, this is simply the entropy $\mathcal{H}(q(\mathbf{y}\vert\mathbf{x}))$, which is a cross-entropy loss. However, the difference is that we do not have labels and the labels, although they should be 1-of-K, will not be. Thus, it is better to leave it as the entropy. I believe Tensorflows softmax_cross_entropy_with_logits is better suited.

At the end of the day, I think simply taking the loss at face value is the way to proceed, monitoring $\mathcal{H}(q(\mathbf{y}\vert\mathbf{x}))$ will be useful as this will tell us if any information is encoded into $\mathbf{y}$, or if it is just left as is by the model.





In [None]:
def apply_MLP_to_source(source,
                        num_layer,
                        num_segment = None,
                        iter4condthresh = 10000,
                        cond_thresh_ratio = 0.25,
                        layer_name_base = 'ip',
                        save_layer_data = False,
                        Arange=None,
                        nonlinear_type = 'ReLU',
                        negative_slope = 0.2,
                        random_seed=0):
    """Generate MLP and Apply it to source signal.
    Args:
        source: source signals. 2D ndarray [num_comp, num_data]
        num_layer: number of layers
        num_segment: (option) number of segments (only used to modulate random_seed)
        iter4condthresh: (option) number of random iteration to decide the threshold of condition number of mixing matrices
        cond_thresh_ratio: (option) percentile of condition number to decide its threshold
        layer_name_base: (option) layer name
        save_layer_data: (option) if true, save activities of all layers
        Arange: (option) range of value of mixing matrices
        nonlinear_type: (option) type of nonlinearity
        negative_slope: (option) parameter of leaky-ReLU
        random_seed: (option) random seed
    Returns:
        mixedsig: sensor signals. 2D ndarray [num_comp, num_data]
        mixlayer: parameters of mixing layers
    """
    if Arange is None:
        Arange = [-1, 1]
    #print("Generating sensor signal...")

    # Subfuction to normalize mixing matrix
    def l2normalize(Amat, axis=0):
        # axis: 0=column-normalization, 1=row-normalization
        l2norm = np.sqrt(np.sum(Amat*Amat,axis))
        Amat = Amat / l2norm
        return Amat

    # Initialize random generator
    np.random.seed(random_seed)

    # To change random_seed based on num_layer and num_segment
    for i in range(num_layer):
        np.random.rand()

    if num_segment is not None:
        for i in range(num_segment):
            np.random.rand()

    num_comp = source.shape[0]

    # Determine condThresh ------------------------------------
    condList = np.zeros([iter4condthresh])
    
    for i in range(iter4condthresh):
        A = np.random.uniform(Arange[0],Arange[1],[num_comp,num_comp])
        A = l2normalize(A, axis=0)
        condList[i] = np.linalg.cond(A)

    condList.sort() # Ascending order
    condThresh = condList[int(iter4condthresh * cond_thresh_ratio)]
    #print("    cond thresh: {0:f}".format(condThresh))

    # Generate mixed signal -----------------------------------
    mixedsig = source.copy()
    mixlayer = []
    for ln in range(num_layer-1,-1,-1):

        # Apply nonlinearity ----------------------------------
        if ln < num_layer-1: # No nolinearity for the first layer (source signal)
            if nonlinear_type == "ReLU": # Leaky-ReLU
                mixedsig[mixedsig<0] = negative_slope * mixedsig[mixedsig<0]
            else:
                raise ValueError

        # Generate mixing matrix ------------------------------
        condA = condThresh + 1
        while condA > condThresh:
            A = np.random.uniform(Arange[0], Arange[1], [num_comp, num_comp])
            A = l2normalize(A)  # Normalize (column)
            condA = np.linalg.cond(A)
            #print("    L{0:d}: cond={1:f}".format(ln, condA))
        # Bias
        b = np.zeros([num_comp]).reshape([1,-1]).T

        # Apply bias and mixing matrix ------------------------
        mixedsig = mixedsig + b
        mixedsig = np.dot(A, mixedsig)

        # Storege ---------------------------------------------
        layername = layer_name_base + str(ln+1)
        mixlayer.append({"name":layername, "A":A.copy(), "b":b.copy()})
        # Storege data
        if save_layer_data:
            mixlayer[-1]["x"] = mixedsig.copy()

    return mixedsig, mixlayer

In [None]:
# Datasets
class mixing_MLP(nn.Module): #Depreciated, no longer used.
    def __init__(self, n_size, n_layers):
        super(mixing_MLP, self).__init__()
        
        self.layers = []
        self.activation = nn.LeakyReLU(negative_slope = 0.5)
        
        for i in range(n_layers):
            self.layers.append(nn.Linear(n_size, n_size))
            self.layers.append(self.activation)
        
        #self.layers.pop(-1)
        
        self.model = nn.Sequential(*self.layers)
        self.model.apply(self.init_weights)
    
    @staticmethod
    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
    
    def forward(self, x):
        with torch.no_grad():
            return self.model(x)
        
class iVAE_datasets(object):
    
    def __init__(self, n, M, Lsegments, k, batch_size = 64, randomise = True, random_seed = False, mod_flag = False, mix_L = 1, Gauss_source = True, seed = True):
        """
        n = size of latent space
        M = no. classes
        Lsegments = no. samples per class
        k = no. of prior parameters
            k = 1: variance Gaussian
            k = 2: mean and variance gaussian
        mod_flag = case where one signal has mean modulation and the other doesn't
        """
        self.latent_size = n
        self.no_classes = M
        self.no_samples = Lsegments
        self.k = k
        self.batch_size = batch_size #specifies the batch size
        self.randomise = randomise #Specifies that sample must be obtained randomly (not uniformly)
        self.random_seed = random_seed #If random_seed = True - specifies that a random sample is required and the counter is not increased!
        self.mod_flag = mod_flag
        self.mix_L = mix_L
        self.Gauss_source = Gauss_source

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        #Define mixing model (Unused)
        #self.mixing_model = mixing_MLP(self.latent_size, 1)
        #self.mixing_model.to(self.device)
        #print(self.mixing_model)
        
        if self.k == 1 and not self.mod_flag:
            #print("\nVariance modulated sources.\n")
            self.mu_centers = np.zeros((self.no_classes, self.latent_size))
            
        elif self.k == 2 and not self.mod_flag:
            #print("\nMean and variance modulated sources.\n")
            if seed:
                np.random.seed(2**13 + 4)

            self.mu_centers = np.random.rand(self.no_classes, self.latent_size) * 10 - 5         # in range (-5, 5)
            #self.mu_centers += np.sign(self.mu_centers) * 0.5 #Shift centers a outwards a little
       
        else:
            self.mu_centers = np.zeros((self.no_classes, self.latent_size))
            
            list_range = np.arange(0, self.no_classes, 1)
            np.random.shuffle(list_range) #random permutation gamma(u)
            
            a = np.random.random() * 10 - 5
            
            self.mu_centers[:, 1] =  a * list_range
            
        if not hasattr(self, "std_centers"):
            self.std_centers = np.random.rand(self.no_classes, self.latent_size) * 2.5 + 0.5      # in range (0.5, 3)
        
        #Make the sample labels
        self.sample_labels = []
        for i in range(self.no_classes):
            self.sample_labels += [i] * self.no_samples
        self.sample_labels = np.array(self.sample_labels)

        self.data = torch.from_numpy(self.sample()).to(self.device)
        
        #Normalise      
        self.data_mean = torch.mean(self.data, dim = 0)
        self.data_std = torch.std(self.data, dim = 0)
        
        #self.data = (self.data - self.data_mean) / self.data_std

        mixed_data, mix_layer = apply_MLP_to_source(self.data.cpu().numpy().T, self.mix_L, num_segment = self.no_classes)
        mixed_data = torch.from_numpy(mixed_data.T).float()

        self.mix_layer = mix_layer 

        if self.mod_flag:
            self.mixed_data = torch.hstack((mixed_data[:, [0]], self.data[:, [1]]))
        
        else:
            self.mixed_data = mixed_data
        
        #Add noise
        self.mixed_data += torch.randn_like(self.mixed_data) * 0.01
        
        self.data_tuples = list(zip(self.data, self.sample_labels)) #list of tuples
        self.mixed_data_tuples = list(zip(self.mixed_data, self.sample_labels)) #list of tuples

        #shuffle mixed_data
        self.shuffled_data_index = np.arange(0, self.mixed_data.size(0), 1, dtype = int)

        if self.random_seed:
            np.random.shuffle(self.shuffled_data_index)

        #Convert self.sample_labels to torch.tensor
        self.sample_labels = torch.from_numpy(self.sample_labels)
    
    def sample(self):
        selected_centers = self.sample_labels
        
        latent_sample = self.mu_centers[selected_centers, :]

        if self.Gauss_source:
            latent_sample += np.random.randn(len(selected_centers), self.latent_size) * self.std_centers[selected_centers, :]
        elif not self.Gauss_source and self.latent_size == 2:
            s1 = np.random.laplace(loc = 0, scale = self.std_centers[selected_centers, 0]).reshape(-1, 1)
            s2 = np.random.laplace(loc = 0, scale = self.std_centers[selected_centers, 1]).reshape(-1, 1)
            latent_sample += np.hstack((s1, s2))
        return latent_sample.astype(np.float32)
    
    #turn the class into an iterator
    def __iter__(self):
        
        self.iter_cnt = 0 #initialises the iterator
        return self #returns the iterator object
    
    def __next__(self):

        if not self.random_seed:
            start = self.iter_cnt * self.batch_size
            end = start + self.batch_size

            index = self.shuffled_data_index[start:end]

            if end <= len(self.mixed_data_tuples):

                self.iter_cnt += 1
                
                data = self.mixed_data[index, :]
                labels = self.sample_labels[index]
                
                return data, labels

            else:
                self.iter_cnt= 0
                raise StopIteration

        else:
              print("Random sampler is not implemented.")
              raise SystemExit

# Objective functions

In [None]:
class GaussianLoss(nn.Module):
    def __init__(self):
        super(GaussianLoss, self).__init__()
    
    def forward(self, x, x_recon, sum_vals = True):

        if isinstance(x_recon, tuple):
            #Learnt a variance on the output
            mu_recon, var_recon = x_recon

        else:
            #No learnt variance on the output
            mu_recon = x_recon
            var_recon = torch.ones_like(x_recon).requires_grad_(False)
        
        error = x - mu_recon

        B, N = x.size()
        #Assuming diagonalised covariance:
        gauss_log_loss = torch.mul(error.pow(2), 1/(2 * var_recon + 1e-12)) #2x100 error vector is needed to do normal multiplication
        gauss_log_loss += 1/2 * torch.log(var_recon + 1e-12)

        #Sum over dimensionality
        gauss_log_loss = torch.sum(gauss_log_loss, dim = 1, keepdim = True)
        
        if sum_vals:
            gauss_log_loss +=  torch.sum(gauss_log_loss)

        return gauss_log_loss #Unnormalised


class KL_divergence(nn.Module):
    def __init__(self, std_normal = False):
        super(KL_divergence, self).__init__()

        self.std_normal = std_normal #A flag to check whether the loss

    def forward(self, mu_0, var_0, mu_1 = None, var_1 = None):

        if self.std_normal:
            mu_1 = torch.zeros_like(mu_0).requires_grad_(False)
            var_1 = torch.ones_like(var_0).requires_grad_(False)

        #perform everything elementwise and then 
        Dkl = var_0 / var_1 + ((mu_1 - mu_0)**2) / var_1 - 1 + torch.log(var_1 / var_0)

        #Sum over dimensionality
        Dkl = 0.5 * torch.sum(Dkl, dim = 1, keepdim = True)

        return Dkl #Unnormalised

class discrete_KL_divergence(nn.Module):
    def __init__(self, no_classes):
        super(discrete_KL_divergence, self).__init__()

        self.no_classes = no_classes
        self.uniform_prior = uniform_prior


    def forward(self, class_prob, prior_prob = None):

        assert class_prob.size(1) == self.no_classes, "There is a mis-match between the pre-defined number of classes and the number of classes given to the discrete KL divergence."

        if prior_prob is None:
          prior_prob = torch.ones(self.no_classes) / self.no_classes

        #perform everything elementwise and then 
        Dkl = class_prob * (torch.log(class_prob) - torch.log(self.prior_prob))

        return Dkl #Unnormalised

class VAE_loss(nn.Module):
    #No ability to learn a variance, variance is controlled by the noise distribution for iVAE
    def __init__(self, loss_name = "L2", gamma = 1, beta = 1, std_normal = False):
        super(VAE_loss, self).__init__()

        self.gamma = gamma
        self.beta = beta

        self.loss_name = loss_name

        if self.loss_name.lower() == "l2":
            self.recon_loss = GaussianLoss()
        
        elif self.loss_name.lower() == "l1":
            self.recon_loss = nn.L1Loss(size_average = False) #Turn off averaging by size as you do it by yourself at the end.
        
        else:
            print("Unknown loss entered.")
            raise SystemExit
        

        self.kl_loss = KL_divergence(std_normal)
    
    def forward(self, x, recon_x, mu_0, var_0, mu_1 = None, var_1 = None):
        
        B, N = x.size()

        if isinstance(recon_x, tuple) and self.loss_name.lower() == "l1": #Check if it is a tuple, will be this by default when it is fed in.
            recon_x = recon_x[0]

        Lrecon = self.recon_loss(x, recon_x)
        Lkl = self.kl_loss(mu_0, var_0, mu_1, var_1) #Need to normalise with same values reconstruction loss (Pytorch does this automatically unless you specify)
        
        Ltotal = torch.sum(self.gamma * Lrecon + self.beta * Lkl)

        #Normalise value
        Lrecon /= (B * N)
        Lkl_continuous /= (B * N)
        Ltotal /= (B * N)

        return Ltotal, Lrecon, Lkl

class MoG_VAE_loss(nn.Module):
    #No ability to learn a variance, variance is controlled by the noise distribution for iVAE
    def __init__(self, no_classes, loss_name = "L2", gamma = 1, beta = 1, alpha = 1):
        super(MoG_VAE_loss, self).__init__()

        self.no_classes = no_classes

        self.gamma = gamma #Reconstruction loss KL parameter
        self.beta = beta #Continuous KL parameter
        self.alpha = alpha #Categorial KL parameter

        self.loss_name = loss_name

        if self.loss_name.lower() == "l2":
            self.recon_loss = GaussianLoss()
        
        elif self.loss_name.lower() == "l1":
            self.recon_loss = nn.L1Loss()
        
        else:
            print("Unknown loss entered.")
            raise SystemExit
        
        self.kl_loss = KL_divergence(False) #Never use a standard VAE case
        self.discrete_kl_loss = discrete_KL_divergence(self.no_classes, uniform_prior = True) #assume a uniform prior
    
    def forward(self, x, recon_x, mu_0, var_0, mu_1, var_1, q_y_G_x, prior_prob = None, CURL = False):
        
        #You need to expand the input data by no_classes and then reshape it!

        kB, N = x.size()
        B = kB // self.no_classes

        if isinstance(recon_x, tuple) and self.loss_name.lower() == "l1": #Check if it is a tuple, will be this by default when it is fed in.
            recon_x = recon_x[0]

        #Reconstruction loss
        Lrecon = self.recon_loss(x, recon_x, False)
        Lrecon = Lrecon.reshape(B, self.no_classes)

        if CURL:
            Lrecon *= q_y_G_x #weight by categorical likelihood
        
        Lrecon = torch.sum(Lrecon)  #Sum all values

        #continuous KL divergence loss
        Lkl_continuous = -1 * self.kl_loss(mu_0, var_0, mu_1, var_1) #Need to normalise with same values reconstruction loss (Pytorch does this automatically unless you specify)
        Lkl_continuous = Lkl_continuous.reshape(B, self.no_classes)

        Lkl_continuous *= q_y_G_x #weight by categorical likelihood
        Lkl_continuous = torch.sum(Lkl_continuous) #Sum all values

        #discrete KL divergence loss
        Lkl_discrete = -1 * self.discrete_kl_loss(q_y_G_x, prior_prob)
        Lkl_discrete = torch.sum(Lkl_discrete)  #Sum all values

        Ltotal = self.gamma * Lrecon + self.beta * Lkl_continuous + self.beta * Lkl_discrete

        #Normalise value (FIX FOR kB vs B)
        Lrecon /= (kB * N)
        Lkl_continuous /= (kB * N)
        Lkl_discrete /= (kB * N)
        Ltotal /= (kB * N)

        return Ltotal, Lrecon, Lkl_continuous, Lkl_discrete

"""
#TESTING THE FUNCTIONS
no_samples = 512
no_classes = 3
a = torch.randn(no_samples, 5)
b = a + torch.randn_like(a) * 0.1
mu_0 = torch.ones_like(a)
mu_1 = torch.zeros_like(a)
var_0 = torch.ones_like(a)
var_1 = torch.ones_like(a)

a = torch.repeat_interleave(a, no_classes, dim = 0)
b = torch.repeat_interleave(b, no_classes, dim = 0)
mu_0 = torch.repeat_interleave(mu_0, no_classes, dim = 0)
mu_1 = torch.repeat_interleave(mu_1, no_classes, dim = 0)
var_0 = torch.repeat_interleave(var_0, no_classes, dim = 0)
var_1 = torch.repeat_interleave(var_1, no_classes, dim = 0)

loss = MoG_VAE_loss(no_classes)
q_y_G_x = torch.ones(no_samples, no_classes) / no_classes

print(loss(a, b, mu_0, var_0, mu_1, var_1, q_y_G_x, False))
print(loss(a, b, mu_0, var_0, mu_1, var_1, q_y_G_x, True))
print(torch.mean((b - a)**2)/2, torch.mean((b - a)**2)/(2 * no_classes))
"""

'\n#TESTING THE FUNCTIONS\nno_samples = 512\nno_classes = 3\na = torch.randn(no_samples, 5)\nb = a + torch.randn_like(a) * 0.1\nmu_0 = torch.ones_like(a)\nmu_1 = torch.zeros_like(a)\nvar_0 = torch.ones_like(a)\nvar_1 = torch.ones_like(a)\n\na = torch.repeat_interleave(a, no_classes, dim = 0)\nb = torch.repeat_interleave(b, no_classes, dim = 0)\nmu_0 = torch.repeat_interleave(mu_0, no_classes, dim = 0)\nmu_1 = torch.repeat_interleave(mu_1, no_classes, dim = 0)\nvar_0 = torch.repeat_interleave(var_0, no_classes, dim = 0)\nvar_1 = torch.repeat_interleave(var_1, no_classes, dim = 0)\n\nloss = MoG_VAE_loss(no_classes)\nq_y_G_x = torch.ones(no_samples, no_classes) / no_classes\n\nprint(loss(a, b, mu_0, var_0, mu_1, var_1, q_y_G_x, False))\nprint(loss(a, b, mu_0, var_0, mu_1, var_1, q_y_G_x, True))\nprint(torch.mean((b - a)**2)/2, torch.mean((b - a)**2)/(2 * no_classes))\n'

# Networks

In [None]:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn

#Implement similarly to the manner I had previously
#Dict to define layers
#Checks for FF and Convolution
#Add in ability to have variance generating component in decoder (unused at this point)

class Unflatten(nn.Module):
    def __init__(self, ModelDict):
        super(Unflatten, self).__init__()
        self.ModelDict = ModelDict
        
    def forward(self, input_tensor):
        
        First_no_channels = self.ModelDict["channels"][0]

        input_tensor = input_tensor.view(-1, First_no_channels, int(input_tensor.size(1) / First_no_channels))
        
        return input_tensor

class Flatten(nn.Module): #Same name as tensorflow tf.keras.Flatten()
    def __init__(self, DisDict):
        super(Flatten, self).__init__()
        self.DisDict = DisDict
        
    def forward(self, input_tensor):

        input_tensor = input_tensor.view(input_tensor.size(0), -1)
        
        return input_tensor

class Encoder(nn.Module):
    def __init__(self, latent_size, Usize, data_size, encode_dict):
        super(Encoder, self).__init__()

        self.latent_size = latent_size
        self.Usize = Usize
        self.data_size = data_size
        self.encode_dict = encode_dict
        self.activation = nn.LeakyReLU(0.1)
        self.var_activation = nn.Softplus()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        #Check if it is a standard VAE through Usize
        if self.Usize == 0:
            print("You are using a MoG VAE but have set the class size to zero... why?")
            raise SystemExit

        else:
            self.standard_flag = False

        self.layers = [] #Initialise layers 

        if self.encode_dict["conv_flag"]:
            
            for i in range(len(self.encode_dict["channels"]) - 1):

                #append the layer
                self.layers.append( nn.Conv1d(in_channels = self.encode_dict["channels"][i], out_channels = self.encode_dict["channels"][i + 1], kernel_size = self.encode_dict["kernel_size"][i], stride = self.encode_dict["stride"][i], padding = self.encode_dict["padding"][i]) )
                #append the activation function
                self.layers.append(self.activation)
            
            #append the transform to take the nn.linear to a convolutional layer
            self.layers.append(Flatten(self.encode_dict))
      
        for i in range(len(self.encode_dict["ff_layers"]) - 2):
            #append the layer
            self.layers.append(nn.Linear(in_features = self.encode_dict["ff_layers"][i], out_features = self.encode_dict["ff_layers"][i + 1], bias = True))
            #append the activation function
            self.layers.append(self.activation)

        self.layers.pop(-1)
        self.encode_net = nn.Sequential(*self.layers) #hidden representation that gets fed into predicting the label and latent 
        self.y_layer = nn.Sequential(nn.Linear(self.encode_dict["ff_layers"][-2], self.Usize), nn.Softmax())

        self.mu_layer = nn.Linear(self.encode_dict["ff_layers"][-2] + self.Usize, self.encode_dict["ff_layers"][-1])
        self.var_layer = nn.Sequential(nn.Linear(self.encode_dict["ff_layers"][-2] + self.Usize, self.encode_dict["ff_layers"][-1]), self.var_activation)
        
        self.encode_net.apply(self.init_weights)
        self.mu_layer.apply(self.init_weights)
        self.var_layer.apply(self.init_weights)
    
    @staticmethod
    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            #m.bias.data.fill_(0.01)
  
    def forward(self, x, cont_input = None, train_flag = True, mode_labels = True):
        
        #Always stack as [x, conditional labels]
        #mode_labels specifies whether to use softmax outputs for forward generation or whether they are one-hot encoded
        #When you train, you specify the class exactly (you do not use self.y_layer to compute y|x, but rather expand the data to account for L labels and recompute z)

        if cont_input is not None and not self.standard_flag:
            x_input = torch.hstack((x, cont_input))

        else:
            x_input = x

        encode = self.encode_net(x_input)
        y_x = self.y_layer(encode)
        y_labels = torch.argmax(y_x)

        if not self.CURL_flag:
            mu_z = self.mu_layer(encode)
            var_z = self.var_layer(encode)
            return mu_z, var_z, y_labels

        if train_flag: 
            
            with torch.no_grad():
                x_input = torch.repeat_interleave(x_input, self.Usize, dim = 0)
                labels = torch.repeat_interleave(torch.arange(self.Usize), (x_input.size(0),)) #Repeat labels for each sample

        if mode_labels and self.CURL_flag: #CURL adds conditions Z on X and Y, specifies that you 
            u_input = F.one_hot(labels, num_classes = self.Usize)
            encode = torch.hstack((encode, u_input))

        elif not mode_labels and self.CURL_flag: 
            encode = torch.hstack((encode, y_x))

        mu_z = self.mu_layer(encode)
        var_z = self.var_layer(encode)

        return mu_z, var_z, y_x #y_x is a probability (technically)
        
class Decoder(nn.Module):
    def __init__(self, latent_size, Usize, data_size, decode_dict, var_flag = False):
        super(Decoder, self).__init__()

        self.latent_size = latent_size
        self.Usize = Usize
        self.data_size = data_size
        self.decode_dict = decode_dict
        self.var_flag = var_flag
        self.activation = nn.LeakyReLU(0.1)    
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.layers = [] #Initialise layers 
        
        if not self.decode_dict["conv_flag"]:
            for i in range(len(self.decode_dict["ff_layers"]) - 2):
                #append the layer
                self.layers.append(nn.Linear(in_features = self.decode_dict["ff_layers"][i], out_features = self.decode_dict["ff_layers"][i + 1], bias = True))
                #append the activation function
                self.layers.append(self.activation)
            
            self.layers.pop(-1) #remove the final activation for linear outputs
    
            self.decode_net = nn.Sequential(*self.layers)
            self.gen_layer = nn.Linear(self.decode_dict["ff_layers"][-2], self.decode_dict["ff_layers"][-1])
            
            if self.var_flag:
                self.var_layer = nn.Sequential(nn.Linear(self.decode_dict["ff_layers"][-2], self.decode_dict["ff_layers"][-1]), nn.Softplus())
                #self.var_layer.apply(self.init_weights)
        
         
        else:
            for i in range(len(self.decode_dict["ff_layers"]) - 1):
                #append the layer
                self.layers.append(nn.Linear(in_features = self.decode_dict["ff_layers"][i], out_features = self.decode_dict["ff_layers"][i + 1], bias = True))
                #append the activation function
                self.layers.append(self.activation)
        
            #append the transform to take the nn.linear to a convolutional layer
            self.layers.append(Unflatten(self.decode_dict))
            
            for i in range(len(self.decode_dict["channels"]) - 2):

                #append the layer
                self.layers.append( nn.ConvTranspose1d(in_channels = self.decode_dict["channels"][i], out_channels = self.decode_dict["channels"][i + 1], kernel_size = self.decode_dict["kernel_size"][i], stride = self.decode_dict["stride"][i], padding = self.decode_dict["padding"][i]) )
                #append the activation function
                self.layers.append(self.activation)
        
            self.layers.pop(-1) #remove the final activation for linear outputs
    
            self.decode_net = nn.Sequential(*self.layers)
            self.gen_layer = nn.ConvTranspose1d(in_channels = self.decode_dict["channels"][-2], out_channels = self.decode_dict["channels"][-1], kernel_size = self.decode_dict["kernel_size"][-1], stride = self.decode_dict["stride"][-1], padding = self.decode_dict["padding"][-1])

            #self.decode_net.apply(self.init_weights)
            #self.gen_layer.apply(self.init_weights)

            if self.var_flag:
                self.var_layer = nn.Sequential(nn.ConvTranspose1d(in_channels = self.decode_dict["channels"][-2], out_channels = self.decode_dict["channels"][-1], kernel_size = self.decode_dict["kernel_size"][-1], stride = self.decode_dict["stride"][-1], padding = self.decode_dict["padding"][-1])
                                               , nn.Softplus())
                #self.var_layer.apply(self.init_weights)
        

    @staticmethod
    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            #m.bias.data.fill_(0.01)

    @staticmethod
    def reparametrisation_trick(mu_data, var_data):
        with torch.no_grad():
            eta = torch.randn_like(mu_data)

        return mu_data + eta * torch.sqrt(var_data)

    def forward(self, mu_latent, var_latent):

        z_latent = self.reparametrisation_trick(mu_latent, var_latent)

        decode_out = self.decode_net(z_latent)

        x_out = self.gen_layer(decode_out)

        if self.var_flag:
            var_out = self.var_layer(decode_out)
            
        else:
            var_out = torch.ones_like(x_out).requires_grad_(False)
        
        if self.decode_dict["conv_flag"]:
            x_out = x_out.squeeze(1)
            var_out = var_out.squeeze(1)
            
        return x_out, var_out

        
  
class ConditionalPrior(nn.Module):
    #Can adapt to have parametric densities... (only a mean and variance parameter depending on the class)
    def __init__(self, latent_size, Usize, data_size, prior_dict, continuous_prior = True):
        super(ConditionalPrior, self).__init__()

        self.latent_size = latent_size
        self.Usize = Usize
        self.data_size = data_size
        self.prior_dict = prior_dict
        self.continuous_prior = continuous_prior

        self.activation = nn.LeakyReLU(0.1)
        self.var_activation = nn.Softplus()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        #Check if it is a standard VAE, if so, set continuous_prior to False and then set distribution to N(0, I)
        if self.Usize == 0:
            self.continuous_prior = False

        self.layers = [] #Initialise layers 
        
        if self.continuous_prior:
            #Define model - essentially another generator but with only FF layers, by design

            for i in range(len(self.prior_dict["ff_layers"]) - 2):
                #append the layer
                self.layers.append(nn.Linear(in_features = self.prior_dict["ff_layers"][i], out_features = self.prior_dict["ff_layers"][i + 1], bias = True))
                #append the activation function
                self.layers.append(self.activation)
          
            self.layers.pop(-1)
            self.prior_net = nn.Sequential(*self.layers)
            self.prior_mu = nn.Linear(self.prior_dict["ff_layers"][-2], self.prior_dict["ff_layers"][-1])
            self.prior_var = nn.Linear(self.prior_dict["ff_layers"][-2], self.prior_dict["ff_layers"][-1])

            #self.prior_net.apply(self.init_weights)
            #self.prior_mu.apply(self.init_weights)
            #self.prior_var.apply(self.init_weights)
      
        else:
            #Lambda functions that just return the mean and variance parameters at all the class locations of interest!

            self.prior_net = lambda U: U
            
            
            if self.Usize == 0:
                #self._prior_mu_ = nn.parameter.Parameter(torch.Tensor(1, self.latent_size))
                #self._prior_var_ = nn.parameter.Parameter(torch.Tensor(1, self.latent_size))
                self.register_parameter(name='_prior_mu_', param=torch.nn.Parameter(torch.Tensor(1, self.latent_size)))
                self.register_parameter(name='_prior_var_', param=torch.nn.Parameter(torch.Tensor(1, self.latent_size)))

            else:
                self.register_parameter(name='_prior_mu_', param=torch.nn.Parameter(torch.Tensor(self.Usize, self.latent_size)))
                self.register_parameter(name='_prior_var_', param=torch.nn.Parameter(torch.Tensor(self.Usize, self.latent_size)))
                #self._prior_mu_ = nn.parameter.Parameter(torch.Tensor(self.Usize, self.latent_size))
                #self._prior_var_ = nn.parameter.Parameter(torch.Tensor(self.Usize, self.latent_size))#torch.ones(self.Usize, self.latent_size).to(self.device)#

                self.prior_mu = lambda U: self._prior_mu_[torch.argmax(U, dim = 1), :]
                self.prior_var = lambda U: self._prior_var_[torch.argmax(U, dim = 1), :]

            with torch.no_grad(): #initialise parameters
                if self.Usize == 0:
                    #Set to N(0, I)
                    self._prior_mu_.fill_(0)
                    self._prior_var_.fill_(1)
                    #Turn off gradient flag
                    self._prior_mu_.requires_grad_(False)
                    self._prior_var_.requires_grad_(False)
                    
                else:
                    self._prior_mu_.normal_(0, 0.1)
                    self._prior_var_.normal_(0, 0.1)

    @staticmethod
    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            #m.bias.data.fill_(0.01)
    
    def one_hot_encode(self, labels):

        with torch.no_grad():
            label_mat = torch.zeros(labels.size(0), self.Usize)
            label_mat[range(labels.size(0)), labels] = 1

            return label_mat
    
    def forward(self, labels = None, cont_input = None):
        #Always stack as [continuous, discrete]
        if self.Usize == 0:
            return self._prior_mu_, self._prior_var_

        else:
            #if self.continuous_prior:
            #    u_input = self.one_hot_encode(labels)
            
            #else:
            u_input = labels
            
            
            if cont_input is not None:
                u_input = torch.hstack((cont_input, u_input))

            prior_net = self.prior_net(u_input)
            mu = self.prior_mu(prior_net)
            var = self.var_activation(self.prior_var(prior_net)) 
            
            return mu, var

class VAE_model(nn.Module):
    def __init__(self, input_size, latent_size, U_size = None, EncodeDict = None, DecodeDict = None, PriorDict = None, var_decode = False, continuous_prior = True):
        super(VAE_model, self).__init__()
        
        self.input_size = input_size
        self.latent_size = latent_size
        self.U_size = U_size
        self.encode_dict = EncodeDict
        self.decode_dict = DecodeDict
        self.prior_dict = PriorDict
        self.var_decode = var_decode
        self.continuous_prior = continuous_prior

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.model_HI_names = ["HI_1"]
        self.model_HI_names_pretty = [r"$NLL_{recon}$"]
        
        #self.U_size
        self.encoder = Encoder(self.latent_size, 0, self.input_size, self.encode_dict)
        self.decoder = Decoder(self.latent_size, self.U_size, self.input_size, self.decode_dict, var_flag = var_decode)
        self.prior = ConditionalPrior(self.latent_size, self.U_size, self.input_size, self.prior_dict, continuous_prior = self.continuous_prior)

        if self.U_size == 0:
            print("\nInitialising a normal VAE!\n")
            self.standard_flag = True
        
        else:
            self.standard_flag = False 
  
    def train(self):
        self.encoder.train()
        self.decoder.train()
        self.prior.train()
    
    def eval(self):
        self.encoder.eval()
        self.decoder.eval()
        self.prior.eval()

    def to(self, device):
        self.encoder.to(device)
        self.decoder.to(device)
        self.prior.to(device)
    
    def one_hot_encode(self, labels):

        with torch.no_grad():
            label_mat = torch.zeros(labels.size(0), self.Usize)
            label_mat[range(labels.size(0)), labels] = 1

            return label_mat

    def compute_HIs(self, x, labels = None, cont_input = None): #Only useful if you are performing anomaly detection (specific to another project)
        with torch.no_grad():

            mu_latent, var_latent = self.encoder(x, labels, cont_input)

            x_recon1, var_decoder =  self.decoder(mu_latent, var_latent) 
            HI_1 = (1 / x.shape[1]) * torch.sum((x - x_recon1)**2 / (var_decoder), dim = 1) 

            return HI_1, mu_latent

class VAE_optimiser(object):
    def __init__(self, model, Params):
        ls = list(model.encoder.parameters()) + list(model.decoder.parameters()) + list(model.prior.parameters())
        self.VAE_opt = torch.optim.Adam(ls, lr = Params.learning_rate)
    
    def step(self):
        self.VAE_opt.step()

    def zero_grad(self):
        self.VAE_opt.zero_grad()

In [None]:
torch.tile(torch.arange(5), (32,)) #Repeat labels for each sample

tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
        4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2,
        3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1,
        2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0,
        1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
        0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
        4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])

In [None]:
a = torch.randint(0, 4, (10, 4))
print(a)
torch.argmax(a, dim = 1)

tensor([[3, 0, 2, 1],
        [3, 0, 1, 0],
        [1, 0, 1, 0],
        [2, 0, 1, 2],
        [3, 2, 1, 2],
        [3, 0, 0, 0],
        [3, 0, 2, 2],
        [2, 2, 0, 3],
        [2, 1, 2, 3],
        [0, 1, 0, 1]])


tensor([0, 0, 0, 0, 0, 0, 0, 3, 3, 1])

In [9]:
import scipy.optimize as optimize

no_classes = 5
x = np.ones(no_classes) * 0.01


def cost(x):
  u = np.ones(no_classes) / no_classes
  return -1 * np.sum(x * np.log(u / x))

def constraint(x):
  return np.sum(x) - 1

opt_dict = optimize.minimize(cost, x, constraints = {'type':'eq', 'fun': constraint})

print(opt_dict['x'])
print(1/6, 1/0.07357588)

[0.2 0.2 0.2 0.2 0.2]
0.16666666666666666 13.591410663385883


In [3]:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)

print(target, output)

tensor([3, 0, 3]) tensor(2.3567, grad_fn=<NllLossBackward>)
