## In Theory

The combined model is trained with the triple criterion*:

$\mathcal{L} = \mathcal{L}_{prior} + \mathcal{L}_{llike}^{Dis_l} + \mathcal{L}_{GAN}$

where

$\mathcal{L}_{prior} = D_{KL}(q(z|x) \Vert p(z))$

$\mathcal{L}_{GAN} = \log(Dis(x)) + \log(1-Dis(Dec(z))) + \log(1-Dis(Dec(Enc(x))))$            

$\mathcal{L}_{llike}^{Dis_l} = -\mathbb{E}_{q(z|x)}[\log p(Dis_l (x)|z)]$


---

$x$ is a training sample and $z \sim p(z)$
(Dis = discriminator; Gen = Decoder)

In addition, for $\mathcal{L}_{GAN}$, $x = Dec(z)$ with $z \sim p(z)$

$\mathcal{L}_{GAN}$ is the style error;
$\mathcal{L}_{llike}^{Dis_l}$ is the reconstruction (content)  error

Optimise decoder wrt $\mathcal{L}_{GAN}$

---
\*  The authors do not explicitly do this, see the gradient updates section below. (They seem to use this)

### $\mathcal{L}_{llike}^{Dis_l}$  (learned distance/ Reconstruction Error)

$Dis_l(x)$ is the hidden representation of the $l$ th layer of the discriminator

Uses a Gaussian observation model for $Dis_l(x)$
- Mean: $Dis_l(\tilde x)$ where $\tilde x \sim Dec(z)$ is the sample from the decoder of $x$.
- Identity covariance



## In Practice

- Do not update all network parameters wrt. the combined loss
    * The Discriminator should not minimise $\mathcal{L}_{llike}^{Dis_l}$ (collapses discriminator to 0)
    * Do not backpropagate the error signal from $\mathcal{L}_{GAN}$ to the Encoder

- Use a parameter $\gamma$ to weight the ability to reconstruct vs. fooling the discriminator
    * not applied to the entire model
    * weighting only applied when updating the parameters of the decoder.
    * $\theta_{Dec} \stackrel{+}\leftarrow - \nabla_{\theta_{Dec}} (\gamma \mathcal{L}_{llike}^{Dis_l} - \mathcal{L}_{GAN})$


## Updating gradients

1. $\theta_{Enc} \stackrel{+}\leftarrow - \nabla_{\theta_{Enc}} (\mathcal{L}_{prior} + \mathcal{L}_{llike}^{Dis_l})$
2. $\theta_{Dec} \stackrel{+}\leftarrow - \nabla_{\theta_{Dec}} (\gamma \mathcal{L}_{llike}^{Dis_l} - \mathcal{L}_{GAN})$
3. $\theta_{Dis} \stackrel{+}\leftarrow - \nabla_{\theta_{Dis}} \mathcal{L}_{GAN}$

## Steps

1. After encoder runs, calculate $\mathcal{L}_{prior}$

2. After decoder runs, (and maybe the disc), calculate the layer-wise disc loss,$\mathcal{L}_{llike}^{Dis_l}$

3. After running the random sample through the decoder, calculate $\mathcal{L}_{GAN}$

4. Update the parameters using the gradients (see previous section)

5. Note the algorithm says "until deadline" (we will need to define the deadline)


**(See Figure and Algorithm on top of page 3)**

## Learning Parameters used in the paper

- Trained with RMSProp 
- Learning rate = 3e-4
- Batch size = 64

## Thoughts on VAEGAN_Basic_3 implementation

In VAEGAN class:

```
@staticmethod
    def weighted_bce(outputs, labels):
        mins, _ = labels.min(dim=1)
        mask = mins != -1
        criterion = torch.nn.BCELoss(reduction="sum")
        loss = criterion(torch.squeeze(outputs[mask]), labels[mask])
        weights = 1
        loss = (loss * weights).mean()
        return loss

    
    @staticmethod
    def loss(ten_original, ten_predicted, layer_original, layer_predicted, layer_sampled, labels_original,
             labels_predicted, labels_sampled, mus, variances, aux_out, aux_out_recon, aux_labels):
        """
        :param ten_original: original images
        :param ten_predicted:  predicted images (output of the decoder)
        :param layer_original:  intermediate layer for original (intermediate output of the discriminator)
        :param layer_predicted: intermediate layer for reconstructed (intermediate output of the discriminator)
        :param labels_original: labels for original (output of the discriminator)
        :param labels_predicted: labels for reconstructed (output of the discriminator)
        :param labels_sampled: labels for sampled from gaussian (0,1) (output of the discriminator)
        :param mus: tensor of means
        :param variances: tensor of diagonals of log_variances
        :return:
        """
        
        # reconstruction error, not used for the loss but useful to evaluate quality
        nle = 0.5*(ten_original.view(len(ten_original), -1) - ten_predicted.view(len(ten_predicted), -1)) ** 2
        
        # kl-divergence
        kl = -0.5 * torch.sum(-variances.exp() - torch.pow(mus,2) + variances + 1, 1)
        
        # mse between intermediate layers for both
        mse_1 = torch.sum(1.0*(layer_original - layer_predicted) ** 2, 1) / 2.
        mse_2 = torch.Tensor([0]) #torch.sum(0.5*(layer_original - layer_sampled) ** 2, 1) / 2.
        
        # bce for decoder and discriminator for original,sampled and reconstructed
        # the only excluded is the bce_gen_original
        bce_dis_original = -torch.log(labels_original + 1e-3)
        bce_dis_sampled = -torch.log(1 - labels_sampled + 1e-3)
        bce_dis_recon = -torch.log(1 - labels_predicted + 1e-3)

        bce_gen_sampled = -torch.log(labels_sampled + 1e-3)
        bce_gen_recon = -torch.log(labels_predicted + 1e-3)
        
        aux_loss_original = VaeGan.weighted_bce(aux_out, aux_labels.float())
        
        return nle, kl, mse_1, mse_2, bce_dis_original, bce_dis_sampled, bce_dis_recon, bce_gen_sampled, bce_gen_recon, aux_loss_original  

```

- `nle`: not sure what this is
- The MSE of the intermediate layers are calculated - is this a VAE measure? (These are `mse_1` and `mse_2`)
- The gradient updates and loss calculations are done outside of the loss function - in the training script.
- Do we need to calculate the BCE?
- 

```
# THIS IS THE MOST IMPORTANT PART OF THE CODE
loss_encoder = torch.sum(kl_value)+torch.sum(mse_value_1) # + torch.sum(mse_value_2)
loss_discriminator = torch.sum(bce_dis_original_value) + torch.sum(bce_dis_sampled_value) + torch.sum(bce_dis_predicted_value) + (lambda_aux * aux_loss)
loss_decoder = torch.sum(bce_gen_sampled_value) + torch.sum(bce_gen_predicted_value)
loss_decoder = torch.sum(lambda_mse / 1 * mse_value_1) + ((1.0 - lambda_mse) * loss_decoder)
```

## Links

- https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed
