## Step three: Train the autoencoder and encode input data

Again, you can use `help` to see how to use the module

`>>> help(vamb.encode)`

    Help on module vamb.encode in vamb:

    NAME
        vamb.encode - Encode a depths matrix and a tnf matrix to latent representation.

    DESCRIPTION
        Creates a variational autoencoder in PyTorch and tries to represent the depths
        and tnf in the latent space under gaussian noise.

        usage:
        >>> vae, dataloader = trainvae(depths, tnf) # Make & train VAE on Numpy arrays
        >>> latent = vae.encode(dataloader) # Encode to latent representation
        >>> latent.shape
        (183882, 40)
        
    [ lines elided ]
    
---
Aha, so we need to use the `trainvae` function first, then the `VAE.encode` method. You can call the `help` functions on those, but I'm not showing that here.

Training networks always take some time. If you have a GPU and CUDA installed, you can pass `cuda=True` to `encode.trainvae` to train on your GPU for increased speed. With a beefy GPU, this can make quite a difference. I run this on my laptop, so I'll just use my CPU.

Sometimes, you'll want to reuse a VAE you have already trained. For this, I've added the `VAE.save` method of the VAE class, as well as a `VAE.load` method. In this example, I'll write the trained model weights to a file in `/tmp` and show how to reload the VAE again. But remember - a trained VAE only works on the dataset it's been trained on!

In [1]:
# Again, we import stuff
import sys
sys.path.append('/home/jakni/Documents/scripts/')
import vamb

# And load the data we just saved in tutorial part 1 - of course, if this was
# the same notebook, we could have just kept it in memory
with open('/home/jakni/Downloads/example/rpkms.npz', 'rb') as file:
    rpkms = vamb.vambtools.read_npz(file)
    
with open('/home/jakni/Downloads/example/tnfs.npz', 'rb') as file:
    tnfs = vamb.vambtools.read_npz(file)

In [2]:
# I'm training just 5 epochs for this demonstration.
# When actually using the VAE, 400 epochs are suitable
with open('/tmp/model.pt', 'wb') as modelfile:
    vae, dataloader = vamb.encode.trainvae(rpkms, tnfs, nepochs=5, modelfile=modelfile, verbose=True)

	Capacity: 3000
	MSE ratio: 0.25
	CUDA: False
	N latent: 100
	N hidden: 325, 325, 325
	N contigs: 39551
	N samples: 6
	N epochs: 5
	Batch size: 128

	Epoch: 1	Loss: 0.7304685	CE: 0.2104285	MSE: 0.8073147	KLD: 0.4471146
	Epoch: 2	Loss: 0.4845564	CE: 0.1317251	MSE: 0.6136941	KLD: 0.9170021
	Epoch: 3	Loss: 0.4078121	CE: 0.1094842	MSE: 0.5295789	KLD: 1.3440626
	Epoch: 4	Loss: 0.3748378	CE: 0.1021385	MSE: 0.4709753	KLD: 1.7203019
	Epoch: 5	Loss: 0.3550885	CE: 0.0985572	MSE: 0.4275323	KLD: 2.0380203


---
The VAE encodes the high-dimensional (n_samples + 136 features) input data in a lower dimensional space (nlatent features). When training, it learns an encoding scheme, with which it encodes the input data to a series of normal distributions, and a decoding scheme, in which it uses one value sampled from each normal distribution to reconstruct the input data given the latent representation while influenced by gaussian noise.

The theory here is that the latent representation is a more efficient encoding of the input data. If the input data for the contigs indeed do fall into bins, an efficient encoding would be to simply encode the bin they belong to, then use the "bin identity" to reconstruct the data. We force it to encode to *distributions* rather than single values because this makes it more robust - it will not as easily overfit to interpret slightly different values as being very distinct if there is an intrinsic noise in each encoding.

### The loss function

The loss of the VAE consists of three major terms:

* Cross entropy (CE) measures the dissimilarity of the reconstructed abundances to observed abundances. This penalizes a failure to reconstruct the abundances accurately.
* Mean squared error (MSE) measures the dissimilary of reconstructed versus observed TNF. This penalizes failure to reconstruct TNF accurately.
* Kullback-Leibler divergence (KLD) measures the dissimilarity between the encoded distributions and the standard gaussian distribution N(0, 1). This penalizes learning.

All three terms are important. CE and MSE is necessary, because we believe the VAE can only learn to effectively reconstruct the input if it learns to encode the signal from the input into the latent layers. In other words, these terms incentivize the network to learn something. KLD is necessary because we care that the encoding is *efficient*, viz. it is contained in as little information as possible. The entire point of encoding is to encode a majority of the signal while shedding the noise, and this is only achieved if we place contrains on how much the network is allowed to learn. Without KLD, the network can theoretically learn an infinitely complex encoding, and the network will learn to encode both noise and signal.

In `encode.py`, the loss function is written as:

$\alpha \cdot CE + \beta \cdot MSE + \gamma \cdot KLD$

Where both CE, MSE and KLD is calculated as means over the vectors for which they are defined. The constants $\alpha$, $\beta$ and $\gamma$ are subject to the following constrains:

1. As the learning rate is fixed and optimized for a specific gradient, this means the total loss $\alpha \cdot CE + \beta \cdot MSE + \gamma \cdot KLD$ should sum to a constant, lest the training become ustable. In our code, we want it to sum to 1.

2. The amount of information the network can learn depends on the ratio $\alpha \cdot CE + \beta \cdot MSE \cdot (\gamma \cdot KLD)^{-1} = R$. Because we want our network to learn a *fixed* amount of stuff, KLD can be treated as a constant, and so we let the user define a constant `capacity` and constrain $\alpha$, $\beta$ and $\gamma$ such that $capacity = R \cdot KLD = \alpha \cdot CE + \beta \cdot MSE \cdot \gamma^{-1}$.

3. The relative ratio $\beta \cdot MSE \cdot (\alpha \cdot CE + \beta \cdot MSE)^{-1}$ controls the incentive to learn to reconstruct TNF as opposed to abundances. This is user defined and called `mseratio` in the code.

Now comes a problem. We want to set $\alpha$, $\beta$ and $\gamma$ such that the above equations are satisfied, but we can't know *beforehand* what the CE, KLD or MSE is. And, in any rate, these values changes across the training run.

What we do is to set $\alpha$, $\beta$ and $\gamma$ relative to what CE and MSE would be in a totally *naive*, network which had *no knowledge* of the input dataset. This represents the state of the network before *any* learning is done. What would such a network predict? By the nature of our normalization of the means are 0 for the TNF values and 1/n for abundances, so a null model predicts 0 as TNF and, abundance as being $[n^{-1}, n^{-1} ... n^{-1}]^{T}$. This would result in a CE of $ln(n) * n^{-1}$ and an expected MSE of 1. 

Importantly, these values are rather close to the starting (i.e. untrained) values of a VAE. The KLD of a naive network with the static predictions of above is unfortunately undefined (because such a network could have *any* values of the latent layer and still produce the given outputs). However, we noticed that for reasonable values of $\alpha$, $\beta$ and $\gamma$, it's the case that $\alpha \cdot CE + \beta \cdot MSE >> \gamma \cdot KLD$. Therefore, we modify constraint 1:

1. $\alpha \cdot CE + \beta \cdot MSE = 1$

which from constraint 2 implies that:

2. $\gamma = capacity^{-1}$

From the above constrains follows:

1. $\alpha = n \cdot (1 - msefactor) * ln(n)^{-1}$

2. $\beta = msefactor$

3. $\gamma = capacity^{-1}$

We can see the KL-divergence rises as it learns the dataset and the latent layer drifts away from its prior. At some point, it will begin to overfit too much, and the penalty associated with KL-divergence outweighs the CE and MSE losses. At this point, the KL will stall, and then fall. This point depends on `capacity` and the complexity of the dataset.

Okay, so now we have the trained `vae` and the `dataloader`. Let's feed the dataloader to the VAE in order to get the latent representation:

---

In [3]:
# No need to pass gpu=True to the encode function to encode on GPU
# If you trained the VAE on GPU, it already resides there
latent = vae.encode(dataloader)

print(latent.shape)

(39551, 40)


---
That's 39551 contigs each represented by the (non-noisy) value of 40 latent neurons.

Now we need to cluster this. That's for the next notebook, so again, I'll save the results.

---

In [4]:
with open('/home/jakni/Downloads/example/latent.npz', 'wb') as file:
    vamb.vambtools.write_npz(file, latent)

---
Alright, let me show how to load the trained VAE given the model file we made above.

I want to **show** that we get the same network back that we trained, so let's try to feed it the same data twice.

---

In [5]:
import torch

# Manually create the first mini-batch without randomization
rpkms_in = torch.Tensor(rpkms[:128]).reshape((128, -1))
tnfs_in = torch.Tensor(tnfs[:128]).reshape((128, -1))

In [6]:
# Calling the VAE as a function encodes and decodes the arguments,
# returning the outputs and the two distribution layers
depths_out, tnf_out, mu, logsigma = vae(rpkms_in, tnfs_in)
print(mu[0])

tensor([-0.6341,  0.8006, -1.2837,  2.2479,  3.0842,  2.0184,  0.3096,
        -0.7320,  1.7008, -0.8898,  2.0501,  0.4636, -0.8683,  0.4024,
        -1.2859,  0.3301,  0.8071, -1.0957,  0.4424, -1.1223, -0.7120,
        -2.1790, -1.9727, -0.8413, -2.6715, -1.0463, -1.6019, -1.8441,
         1.7171,  0.3378, -0.8309,  1.1683,  2.1508, -2.0515, -0.3983,
        -0.4869, -1.5248, -1.6428,  2.3076, -1.4227])


In [8]:
# Now, delete the VAE
del vae

# And reload it:
# We need to manually specify whether it should use GPU or not
# And whether the network show begin in training or evaluation mode.
vae = vamb.encode.VAE.load('/tmp/model.pt', cuda=False, evaluate=True)
depths_out, tnf_out, mu, logsigma = vae(rpkms_in, tnfs_in)
print(mu[0])

tensor([-0.6341,  0.8006, -1.2837,  2.2479,  3.0842,  2.0184,  0.3096,
        -0.7320,  1.7008, -0.8898,  2.0501,  0.4636, -0.8683,  0.4024,
        -1.2859,  0.3301,  0.8071, -1.0957,  0.4424, -1.1223, -0.7120,
        -2.1790, -1.9727, -0.8413, -2.6715, -1.0463, -1.6019, -1.8441,
         1.7171,  0.3378, -0.8309,  1.1683,  2.1508, -2.0515, -0.3983,
        -0.4869, -1.5248, -1.6428,  2.3076, -1.4227])


---
We get the same values back, meaning the saved network is the same as the loaded network!