In [None]:
# default_exp g_models.lstm_lm

# LSTM Language Model

> LSTM-based language model

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export

from mrl.imports import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *

In [None]:
# export

class LSTM_LM(nn.Module):
    def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, 
                 lstm_drop=0., lin_drop=0., bos_idx=0, bidir=False, tie_weights=False):
        super().__init__()
        
        self.block = LSTM_Block(d_vocab, d_embedding, d_hidden, d_embedding, n_layers,
                                lstm_drop=lstm_drop, lin_drop=lin_drop, bidir=bidir)
        self.bos_idx = bos_idx
        
        if tie_weights:
            self.block.embedding.weight = self.block.head.weight
        
    def forward(self, x, hiddens=None):
        output, hiddens, pre_output = self.block(x, hiddens)
        return output
    
    def sample(self, bs, sl, temperature=1., multinomial=True):
        
        preds = idxs = to_device(torch.tensor([self.bos_idx]*bs).long().unsqueeze(-1))
        lps = []

        hiddens = None
        
        for i in range(sl):
            x, hiddens = self.block(idxs, hiddens)
            x.div_(temperature)
            
            idxs, lp = x_to_preds(x, multinomial=multinomial)
            
            lps.append(lp)            
            preds = torch.cat([preds, idxs], -1)
            
        return preds[:, 1:], torch.cat(lps,-1)
    
    def sample_no_grad(self, bs, sl, temperature=1., multinomial=True):
        with torch.no_grad():
            return self.sample(bs, sl, temperature=temperature, multinomial=multinomial)
        
    def get_lps(self, x, y, temperature=1.):
        x = self.forward(x)
        x.div_(temperature)
        
        lps = F.log_softmax(x, -1)
        lps = lps.gather(2, y.unsqueeze(-1)).squeeze(-1)
        
        return lps, x

In [None]:
lm = LSTM_LM(32, 64, 256, 2)
ints = torch.randint(0, 31, (16, 10))
x = ints[:,:-1]
y = ints[:,1:]
out = lm(x)
lp,_ = lm.get_lps(x,y)
_ = lm.sample(8, 10)

In [None]:
# export

class Conditional_LSTM_LM(Encoder_Decoder):
    def __init__(self, encoder, d_vocab, d_embedding, d_hidden, d_latent, n_layers,
                 lstm_drop=0., lin_drop=0., bidir=False,
                 condition_hidden=True, condition_output=False, bos_idx=0, prior=None):
        
        transition = Norm_Transition(d_latent)
        
        decoder = Conditional_LSTM_Block(d_vocab, d_embedding, d_hidden, d_embedding,
                                d_latent, n_layers, lstm_drop=lstm_drop, lin_drop=lin_drop, 
                                condition_hidden=condition_hidden, condition_output=condition_output)
        
        if prior is None:
            prior = SphericalPrior(torch.zeros((encoder.d_latent)), torch.zeros((encoder.d_latent)), 
                                trainable=False)
        
        super().__init__(encoder, decoder, transition, prior)
        
        self.bos_idx = bos_idx
        
    def forward(self, x, condition, hiddens=None):
        z = self.encoder(condition)
        z = self.transition(z)
        x, hiddens = self.decoder(x, z, hiddens)
        return x
    

    def sample(self, bs, sl, z=None, temperature=1., multinomial=True):
        
        if z is None:
            if self.prior is not None:
                z = to_device(self.prior.rsample([bs]))
            else:
                z = to_device(torch.randn((bs, self.encoder.d_latent)))
                z = self.transition(z)
        else:
            bs = z.shape[0]
        
        preds = idxs = to_device(torch.tensor([self.bos_idx]*bs).long().unsqueeze(-1))
        lps = []

        hiddens = self.decoder.lstm.latent_to_hidden(z)
        
        for i in range(sl):
            
            x, hiddens = self.decoder(idxs,z,hiddens)
            x.div_(temperature)
            
            idxs, lp = x_to_preds(x, multinomial=multinomial)
            
            lps.append(lp)            
            preds = torch.cat([preds, idxs], -1)
            
        return preds[:, 1:], torch.cat(lps,-1)
    
    def sample_no_grad(self, bs, sl, z=None, temperature=1., multinomial=True):
        with torch.no_grad():
            return self.sample(bs, sl, z=z, temperature=temperature, multinomial=multinomial)
        
    def get_lps(self, x, y, temperature=1.):
        x, c = x
        z = self.transition(self.encoder(c))
        x,_ = self.decoder(x, z)
        
        x.div_(temperature)
        
        lps = F.log_softmax(x, -1)
        lps = lps.gather(2, y.unsqueeze(-1)).squeeze(-1)
        
        if self.prior.trainable:
            prior_lps = self.prior.log_prob(z).mean(-1, keepdim=True)
            prior_lps = torch.zeros(prior_lps.shape).float().to(prior_lps.device) + prior_lps - prior_lps.detach()
            lps += prior_lps
        
        return lps, x
        
    def set_prior_from_latent(self, z, logvar, trainable=False):
        z = z.detach()
        logvar = logvar.detach()
        self.prior = SphericalPrior(z, logvar, trainable)
        
    def set_prior_from_encoder(self, condition, logvar, trainable=False):
        assert condition.shape[0]==1
        z = self.transition(self.encoder(condition))
        z = z.squeeze(0)
        self.set_prior_from_latent(z, logvar, trainable)

In [None]:
encoder = MLP_Encoder(128, [64, 32], 16, [0.1, 0.1])

lm = Conditional_LSTM_LM(encoder, 32, 64, 128, 16, 2)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

condition = torch.randn((8,128))

_ = lm(x, condition)

_ = lm.get_lps([x,condition],y)

_ = lm.sample(3, 80)

In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()