## Variational Recurrent Network (VRNN)

Implementation based on Chung's *A Recurrent Latent Variable Model for Sequential Data* [arXiv:1506.02216v6].

### 1. Network design

There are three types of layers: input (x), hidden(h) and latent(z). We can compare VRNN sided by side with RNN to see how it works in generation phase.

- RNN: $h_o + x_o -> h_1 + x_1 -> h_2 + x_2 -> ...$
- VRNN: with $ h_o \left\{
\begin{array}{ll}
      h_o -> z_1 \\
      z_1 + h_o -> x_1\\
      z_1 + x_1 + h_o -> h_1 \\
\end{array} 
\right .$ 
with $ h_1 \left\{
\begin{array}{ll}
      h_1 -> z_2 \\
      z_2 + h_1 -> x_2\\
      z_2 + x_2 + h_1 -> h_2 \\
\end{array} 
\right .$

It is clearer to see how it works in the code blocks below. This loop is used to generate new text when the network is properly trained. x is wanted output, h is deterministic hidden state, and z is latent state (stochastic hidden state). Both h and z are changing with repect to time.

### 2. Training

The VRNN above contains three components, a latent layer genreator $h_o -> z_1$, a decoder net to get $x_1$, and a recurrent net to get $h_1$ for the next cycle.

The training objective is to make sure $x_0$ is realistic. To do that, an encoder layer is added to transform $x_1 + h_0 -> z_1$. Then the decoder should transform $z_1 + h_o -> x_1$ correctly. This implies a cross-entropy loss in the "tiny shakespear" or MSE in image reconstruction.

Another loose end is  $h_o -> z_1$. Statistically, $x_1 + h_0 -> z_1$ should be the same as $h_o -> z_1$, if $x_1$ is sampled randomly. This constraint is formularize as a KL divergence between the two.

>#### KL Divergence of Multivariate Normal Distribution
>![](https://wikimedia.org/api/rest_v1/media/math/render/svg/8dad333d8c5fc46358036ced5ab8e5d22bae708c)

Now putting everything together for one training cycle.

$\left\{
\begin{array}{ll}
      h_o -> z_{1,prior} \\
      x_1 + h_o -> z_{1,infer}\\
      z_1 <- sampling N(z_{1,infer})\\
      z_1 + h_o -> x_{1,reconstruct}\\
      z_1 + x_1 + h_o -> h_1 \\
\end{array} 
\right . $
=>
$
\left\{
\begin{array}{ll}
      loss\_latent = DL(z_{1,infer} | z_{1,prior}) \\
      loss\_reconstruct = x_1 - x_{1,reconstruct} \\
\end{array} 
\right .
$


In [6]:
import torch
from torch import nn, optim
from torch.autograd import Variable

class VRNNCell(nn.Module):
    def __init__(self):
        super(VRNNCell,self).__init__()
        self.phi_x = nn.Sequential(nn.Embedding(128,64), nn.Linear(64,64), nn.ELU())
        self.encoder = nn.Linear(128,64*2) # output hyperparameters
        self.phi_z = nn.Sequential(nn.Linear(64,64), nn.ELU())
        self.decoder = nn.Linear(128,128) # logits
        self.prior = nn.Linear(64,64*2) # output hyperparameters
        self.rnn = nn.GRUCell(128,64)
    def forward(self, x, hidden):
        x = self.phi_x(x)
        # 1. h => z
        z_prior = self.prior(hidden)
        # 2. x + h => z
        z_infer = self.encoder(torch.cat([x,hidden], dim=1))
        # sampling
        z = Variable(torch.randn(x.size(0),64))*z_infer[:,64:].exp()+z_infer[:,:64]
        z = self.phi_z(z)
        # 3. h + z => x
        x_out = self.decoder(torch.cat([hidden, z], dim=1))
        # 4. x + z => h
        hidden_next = self.rnn(torch.cat([x,z], dim=1),hidden)
        return x_out, hidden_next, z_prior, z_infer
    def calculate_loss(self, x, hidden):
        x_out, hidden_next, z_prior, z_infer = self.forward(x, hidden)
        # 1. logistic regression loss
        loss1 = nn.functional.cross_entropy(x_out, x) 
        # 2. KL Divergence between Multivariate Gaussian
        mu_infer, log_sigma_infer = z_infer[:,:64], z_infer[:,64:]
        mu_prior, log_sigma_prior = z_prior[:,:64], z_prior[:,64:]
        loss2 = (2*(log_sigma_infer-log_sigma_prior)).exp() \
                + ((mu_infer-mu_prior)/log_sigma_prior.exp())**2 \
                - 2*(log_sigma_infer-log_sigma_prior) - 1
        loss2 = 0.5*loss2.sum(dim=1).mean()
        return loss1, loss2, hidden_next
    def generate(self, hidden=None, temperature=None):
        if hidden is None:
            hidden=Variable(torch.zeros(1,64))
        if temperature is None:
            temperature = 0.8
        # 1. h => z
        z_prior = self.prior(hidden)
        # sampling
        z = Variable(torch.randn(z_prior.size(0),64))*z_prior[:,64:].exp()+z_prior[:,:64]
        z = self.phi_z(z)
        # 2. h + z => x
        x_out = self.decoder(torch.cat([hidden, z], dim=1))
        print (x_out.shape)
        # sampling
        x_sample = x = x_out.div(temperature).exp().multinomial(1).squeeze()
        x = self.phi_x(x)
        print (x.size())
        # 3. x + z => h
        hidden_next = self.rnn(torch.cat([x,z], dim=1),hidden)
        return x_sample, hidden_next
    def generate_text(self, hidden=None,temperature=None, n=100):
        res = []
        hidden = None
        for _ in range(n):
            x_sample, hidden = self.generate(hidden,temperature)
            res.append(chr(x_sample.data[0]))
        return "".join(res)
        
# Test
net = VRNNCell()
x = Variable(torch.LongTensor([12,13,14]))
hidden = Variable(torch.rand(3,64))
output, hidden_next, z_infer, z_prior = net(x, hidden)
loss1, loss2, _ = net.calculate_loss(x, hidden)
loss1, loss2
hidden = Variable(torch.zeros(1,64))
net.generate_text()

torch.Size([1, 128])
torch.Size([64])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

##  Download tiny shakspear text

In [3]:
from six.moves.urllib import request
url = "https://raw.githubusercontent.com/jcjohnson/torch-rnn/master/data/tiny-shakespeare.txt"
text = request.urlopen(url).read().decode()

print('-----SAMPLE----\n')
print(text[:100])

-----SAMPLE----

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


### A convinient function to sample text

In [4]:
import numpy as np

def batch_generator(seq_size=300, batch_size=64):
    cap = len(text) - seq_size*batch_size
    while True:
        idx = np.random.randint(0, cap, batch_size)
        res = []
        for _ in range(seq_size):
            batch = torch.LongTensor([ord(text[i]) for i in idx])
            res.append(batch)
            idx += 1
        yield res

g = batch_generator()
batch = next(g)

## Model Training

In [8]:
net = VRNNCell()
max_epoch = 2000
optimizer = optim.Adam(net.parameters(), lr=0.001)
g = batch_generator()

hidden = Variable(torch.zeros(64,64)) #batch_size x hidden_size
for epoch in range(max_epoch):
    batch = next(g)
    loss_seq = 0
    loss1_seq, loss2_seq = 0, 0
    optimizer.zero_grad()
    for x in batch:
        loss1, loss2, hidden = net.calculate_loss(Variable(x),hidden)
        loss1_seq += loss1.data[0]
        loss2_seq += loss2.data[0]
        loss_seq = loss_seq + loss1+loss2
    loss_seq.backward()
    optimizer.step()
    hidden.detach_()
    if epoch%100==0:
        print('>> epoch {}, loss {:12.4f}, decoder loss {:12.4f}, latent loss {:12.4f}'.format(epoch, loss_seq.data[0], loss1_seq, loss2_seq))
        print(net.generate_text())
        print()
        

>> epoch 0, loss    3769.2000, decoder loss    1464.9570, latent loss    2304.2424
]]ZA}~.m\#1Z5r>UK!{nxm	/ d9TB/A wstZ"y;FB9Evh;

>> epoch 100, loss    1007.5905, decoder loss     984.4189, latent loss      23.1717
 rt  cos
 u  Btutrs,lh treotiu  ri rihrsrathtgc apr kr heoeeewoset  nal'e  niai uochatyoe dec te.es

>> epoch 200, loss     930.7681, decoder loss     911.5832, latent loss      19.1846
scd.ORSo plm py ld e moencu arsh iiae eyuio li nntnir twr,tt h le cewee un  nmo lhcri r we utocee  e

>> epoch 300, loss     795.6132, decoder loss     778.8964, latent loss      16.7170
zUxitrel so bo sat om hho he tos otans thtos
Oere, I
AI unnt fey.C Yon arl nilil mithe,

BNL:
TYor, 

>> epoch 400, loss     739.1819, decoder loss     724.5921, latent loss      14.5900
:

ASESSE:
Tive theaveress aw thal he care hees the anl ouues spid at erucerte ho ere es the wat sde

>> epoch 500, loss     705.3881, decoder loss     693.6972, latent loss      11.6912
pzit wivy thes ciun keit u

## Evaluation

In [12]:
sample = net.generate_text(n=1000, temperature=1)
print(sample)

`omak, wha lating
To thing matheds now:
Your, fich's mad pother you with thouss the deedh! goust I, hest, seably the were thee co, preatt goor his mat start pean the poose not 'ere, as and for that I great a cring wer.

KINO KINGBRAV:
Bese retuble not whirs,
With my heake! who at his yeoth.

Sist starl'd sullancen'd and bece breour there things.
Sconte to ctret.

PRINGER:
OL RUMERTE RIRI IP LARIENIZ:
Beiolt, you to Mripching a will inting,
And the me thou read onaidion
And king a's for old somee thee for speak eim'p calf
The live eavert stish
Tis conhal of my wairggred most swexferous frome.

VINGER:
Not you lay my disge,
We not: the rueselly with it hightens my, will an my foochorr me
but hash proied our nir is how, woul malay with lethantolt and is inge:
Had thy monk-tich hap,
Thimbrisuegetreve, like tous accounce; the were on and trust thoy if peeccon.

COMEON:
Yet a peave. Preathed that in soned; what shave nongle.

RICHENRIUS:
Forther,
And that the be thy chill with wogen thighter

## Comments

- Denifinitely train longer to get better results. 
- Keep in mind the rnn kernel only has 1 layer, with 64 neurons.
- Seems no need to tune temperature here. temperature = 0.8 generates a lot of obscure spelling. temperature = 1 works fine.