In [24]:
import torch
from torch import nn, optim
from torch.autograd import Variable

device = torch.device('cuda')

class VRNN(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=64, embed_dim=64, enc_param_dim=64):
        super(VRNN, self).__init__()
        
        self.latent_dim = enc_param_dim
        self.hidden_dim = hidden_dim
        
        self.phi_x = nn.Sequential(
                                    nn.Embedding(input_dim, embed_dim),
                                    nn.Linear(embed_dim, embed_dim),
                                    nn.ELU()
                                  )
        
        self.encoder = nn.Linear(hidden_dim + embed_dim,
                                 enc_param_dim + enc_param_dim)
        
        self.phi_z = nn.Sequential(
                                    nn.Linear(enc_param_dim, enc_param_dim),
                                    nn.ELU()
                                  )
        
        self.decoder = nn.Linear(hidden_dim + enc_param_dim,
                                 input_dim)
        
        self.prior = nn.Linear(hidden_dim,
                               enc_param_dim + enc_param_dim)
        
        self.rnn = nn.GRUCell(embed_dim + enc_param_dim,
                              hidden_dim)
        
    def forward(self, x, hidden):

        x = self.phi_x(x)
        
        z_prior = self.prior(hidden)
        
        z_infer = self.encoder(torch.cat([x,hidden], dim=1))
        z_mu = z_infer[:, :self.latent_dim]
        z_sd = z_infer[:, self.latent_dim:]

        ## Reparameterized sample
        z = z_mu + ( Variable(torch.randn(x.size(0),64)).to(device) * z_sd.exp() )
        
        z = self.phi_z(z)
        
        x_out = self.decoder(torch.cat([hidden, z], dim=1))
        
        hidden_next = self.rnn(torch.cat([x,z], dim=1),hidden)
        
        return x_out, hidden_next, z_prior, z_infer
    
    def calculate_loss(self, x, hidden):
        
        x_out, hidden_next, z_prior, z_infer = self.forward(x, hidden)
        
        # 1. logistic regression loss
        loss1 = nn.functional.cross_entropy(x_out, x) 
        
        # 2. KL Divergence between Multivariate Gaussian
        mu_infer, log_sigma_infer = z_infer[:,:64], z_infer[:,64:]
        mu_prior, log_sigma_prior = z_prior[:,:64], z_prior[:,64:]
        
        loss2 = (2*(log_sigma_infer-log_sigma_prior)).exp() \
                + ((mu_infer-mu_prior)/log_sigma_prior.exp())**2 \
                - 2*(log_sigma_infer-log_sigma_prior) - 1
        
        loss2 = 0.5*loss2.sum(dim=1).mean()
        
        return loss1, loss2, hidden_next
    
    def generate(self, hidden=None, temperature=None):
        
        if hidden is None:
            hidden=Variable(torch.zeros(1,self.hidden_dim)).to(device)
        if temperature is None:
            temperature = 0.8
            

        z_prior = self.prior(hidden)
        z_mu = z_prior[:, :self.latent_dim]
        z_sd = z_prior[:, self.latent_dim:]
        
        z =  z_mu + ( Variable(torch.randn(z_prior.size(0),self.latent_dim)).to(device) * z_sd.exp() )
        
        z = self.phi_z(z)
        
        x_out = self.decoder(torch.cat([hidden, z], dim=1))
        
        x_sample = x_out.div(temperature).exp().multinomial(1).squeeze()
        
        x = self.phi_x(x_sample)
        
        rnn_inp = torch.cat([x,z.squeeze()], dim=0).unsqueeze(0)
        
        hidden_next = self.rnn(rnn_inp, hidden)
        
        return x_sample, hidden_next
    
    def generate_text(self, temperature=None, n=100):
        res = []
        hidden = None
        for _ in range(n):
            x_sample, hidden = self.generate(hidden,temperature)
            res.append(chr(x_sample.item()))
        return "".join(res)

In [25]:
net = VRNN().cuda()

x = Variable(torch.LongTensor([12,13,14])).to(device)
hidden = Variable(torch.rand(3,64)).to(device)

output, hidden_next, z_infer, z_prior = net(x, hidden)

loss1, loss2, _ = net.calculate_loss(x, hidden)

net.generate_text()

'Jni}\nn`je\x03L\x16#\x04\x05\\!^^\'\x18CDX\x14jSNQgs!Q>!Qzn\x1dS3H7\x175\x0c8wQ*+#EtcPy\rWTM\x7f\x0f8\x18rh.1h,2k\x1180#wO"Lz\x02\r`iSIf]\x14WP&h\x0f69\x17}'

In [28]:
import numpy as np
from six.moves.urllib import request

url = "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt"
text = request.urlopen(url).read().decode()

def batch_generator(seq_size=300, batch_size=64):
    cap = len(text) - seq_size*batch_size
    while True:
        idx = np.random.randint(0, cap, batch_size)
        res = []
        for _ in range(seq_size):
            batch = torch.LongTensor([ord(text[i]) for i in idx]).to(device)
            res.append(batch)
            idx += 1
        yield res

In [29]:
max_epoch = 2000
hidden_dim = 64
batch_size = 64
hidden = Variable(torch.zeros(batch_size, hidden_dim)).to(device)

optimizer = optim.Adam(net.parameters(), lr=0.001)
data_gen = batch_generator()

for epoch in range(max_epoch):
    batch = next(data_gen)
    loss_seq = 0
    loss1_seq, loss2_seq = 0, 0
    optimizer.zero_grad()
    
    for x in batch:
        loss1, loss2, hidden = net.calculate_loss( Variable(x), hidden )
        loss1_seq += loss1.item()
        loss2_seq += loss2.item()
        loss_seq = loss_seq + loss1 + loss2
    loss_seq.backward()
    optimizer.step()
    hidden.detach_()
    
    if epoch%100==0:
        print('Epoch - {} : Loss = {}, Decoder_Loss = {}, Latent_Loss = {} \n'.format(epoch,
                                                                                      loss_seq.item(),
                                                                                      loss1_seq,
                                                                                      loss2_seq))
        print(net.generate_text())
        print('---'*20)

Epoch - 0 : Loss = 3617.65332031, Decoder_Loss = 1463.41964531, Latent_Loss = 2154.23372746 

+[1U^ s!FzV2{~o*],c|_uU'2(B7JRvzY6lSa>ytxZ1VB>Xa
D5L]xNalb<)zoA$<Ye|#M5sM
------------------------------------------------------------
Epoch - 100 : Loss = 1007.88867188, Decoder_Loss = 989.24618578, Latent_Loss = 18.6429289468 

6ut, adpuyneo caroldm  , s 
 o   te
moi  ooh ehwtite to ,an ryteir f 
d,  atoa,eoe ednireu  to di e
------------------------------------------------------------
Epoch - 200 : Loss = 984.95135498, Decoder_Loss = 972.042239666, Latent_Loss = 12.9090968557 

BKc.Aupet g ed,  feer aolv  mearF
lnftcnsh dtn
e  hoa  a
esoetnuo  
  t tetaat et na rl Khaah   e .e
------------------------------------------------------------
Epoch - 300 : Loss = 889.150817871, Decoder_Loss = 874.83273077, Latent_Loss = 14.3177695833 

lbiulpcuead aa ae wh inolotI
ETaa mthe ohdsHotdrre m; t lo;R

TIIBea tifhen,

D
Bmt
C
YSWoouaod iobi
--------------------------------------