In [None]:
# default_exp layers
# all_slow

# Layers

> Pytorch model layers

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export
from mrl.imports import *
from mrl.torch_imports import *
from mrl.torch_core import *

In [None]:
# export

class Linear(nn.Module):
    def __init__(self, d_in, d_out, act=True, bn=False, dropout=0., **lin_kwargs):
        super().__init__()
        
        layers = [nn.Linear(d_in, d_out, **lin_kwargs)]
        
        if bn:
            layers.append(nn.BatchNorm1d(d_out))
            
        if act:
            layers.append(nn.ReLU())
            
        if dropout>0.:
            layers.append(nn.Dropout(p=dropout))
            
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)
    
class Conv(nn.Module):
    def __init__(self, d_in, d_out, ks=3, stride=1, padding=None, ndim=2, 
                 act=True, bn=False, dropout=0., **conv_kwargs):
        super().__init__()
        
        if padding is None:
            padding = (ks-1)//2
            
        if ndim==1:
            conv_func = nn.Conv1d
            bn_func = nn.BatchNorm1d
        elif ndim==2:
            conv_func = nn.Conv2d
            bn_func = nn.BatchNorm2d
        else:
            conv_func = nn.Conv3d
            bn_func = nn.BatchNorm3d
        
        layers = [conv_func(d_in, d_out, ks, stride, padding=padding, **conv_kwargs)]
        
        if bn:
            layers.append(bn_func(d_out))
            
        if act:
            layers.append(nn.ReLU())
            
        if dropout>0.:
            layers.append(nn.Dropout(p=dropout))
            
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)
    
class Conv1d(Conv):
    def __init__(self, d_in, d_out, ks=3, stride=1, padding=None, 
                 act=True, bn=False, dropout=0., **conv_kwargs):
        super().__init__(d_in, d_out, ks, stride, padding, ndim=1, 
                 act=act, bn=bn, dropout=dropout, **conv_kwargs)
        
class Conv2d(Conv):
    def __init__(self, d_in, d_out, ks=3, stride=1, padding=None, 
                 act=True, bn=False, dropout=0., **conv_kwargs):
        super().__init__(d_in, d_out, ks, stride, padding, ndim=2, 
                 act=act, bn=bn, dropout=dropout, **conv_kwargs)
        
class Conv3d(Conv):
    def __init__(self, d_in, d_out, ks=3, stride=1, padding=None, 
                 act=True, bn=False, dropout=0., **conv_kwargs):
        super().__init__(d_in, d_out, ks, stride, padding, ndim=3, 
                 act=act, bn=bn, dropout=dropout, **conv_kwargs)

In [None]:
layer = Linear(128, 64, bn=True, dropout=0.5)
_ = layer(torch.randn(16,128))

In [None]:
# export

class SphericalDistribution(torch.distributions.Distribution):
    def __init__(self, loc, scale, validate_args=False):
        super().__init__(loc.shape[0], validate_args=validate_args)
        self.dim = loc.shape[0]
        self.loc = loc
        self.scale = scale
        self.dist = Normal(self.loc, self.scale)
        
    def sample(self, n):
        s = self.dist.sample(n)
        s = F.normalize(s, p=2, dim=-1)
        return s
    
    def rsample(self, n):
        s = self.dist.rsample(n)
        s = F.normalize(s, p=2, dim=-1)
        return s
    
    def __repr__(self):
        return f'Spherical(loc: {self.loc.size()}, scale: {self.scale.size()})'

class Prior(nn.Module):
    def __init__(self):
        super().__init__()
        
    def get_dist(self):
        raise NotImplementedError
        
    def log_prob(self, x):
        raise NotImplementedError
    
    def sample(self, n):
        if type(n)==int:
            n = [n]
        return self.get_dist().sample(n)
    
    def rsample(self, n):
        if type(n)==int:
            n = [n]
        return self.get_dist().rsample(n)
    
    
class NormalPrior(Prior):
    def __init__(self, loc, log_scale, trainable=True):
        super().__init__()
        if trainable:
            loc = nn.Parameter(loc)
            log_scale = nn.Parameter(log_scale)
        self.loc = loc
        self.log_scale = log_scale
        self.trainable = trainable
        
    def get_dist(self):
        return Normal(self.loc, self.log_scale.exp())
    
    def log_prob(self, x):
        var = self.log_scale.exp().pow(2)
        return -((x - self.loc) ** 2) / (2 * var) - self.log_scale - math.log(math.sqrt(2 * math.pi))
    
class SphericalPrior(NormalPrior):
    def __init__(self, loc, log_scale, trainable=True):
        super().__init__(loc, log_scale, trainable)
        
    def get_dist(self):
        return SphericalDistribution(self.loc, self.log_scale.exp())

In [None]:
p = NormalPrior(torch.zeros((64,)), torch.zeros((64,)), trainable=True)
assert p.rsample(5).requires_grad
assert not p.sample(5).requires_grad

p = SphericalPrior(torch.zeros((2,)), torch.zeros((2,)), trainable=True)
assert p.rsample(5).requires_grad
assert not p.sample(5).requires_grad

In [None]:
# export

class Conditional_LSTM(nn.Module):
    def __init__(self, d_embedding, d_hidden, d_output, d_latent, n_layers,
                 condition_hidden=True, condition_output=True,
                 bidir=False, dropout=0., batch_first=True):
        super().__init__()
        
        self.d_embedding = d_embedding
        self.d_hidden = d_hidden
        self.d_output = d_output
        self.n_layers = n_layers
        self.bidir = bidir
        self.n_dir = 1 if not bidir else 2
        self.batch_first = batch_first
        self.condition_hidden = condition_hidden
        self.condition_output = condition_output
        
        self.lstms = []
        self.hidden_sizes = []
        
        for l in range(n_layers):
            if l==0:
                input_size = d_embedding if not self.condition_output else d_embedding+d_latent
            else:
                input_size = d_hidden
                
            output_size = d_output if l==n_layers-1 else d_hidden
            output_size = output_size // self.n_dir
            
            hidden_size = (self.n_dir, 1, output_size)
            self.hidden_sizes.append(hidden_size)
            
            lstm = nn.LSTM(input_size, output_size, 1, batch_first=batch_first, 
                           dropout=dropout, bidirectional=bidir)
            self.lstms.append(lstm)
            
        self.lstms = nn.ModuleList(self.lstms)
        
        if self.condition_hidden:
            to_hidden = []
            for size in self.hidden_sizes:
                ndir, _, dim = size
                to_hidden.append(nn.Linear(d_latent, ndir*dim*2))
                
            self.to_hidden = nn.ModuleList(to_hidden)
        
    def forward(self, x, z, hiddens=None):
        
        bs = x.shape[0] if self.batch_first else x.shape[1]
        sl = x.shape[1] if self.batch_first else x.shape[0]
        
        if self.condition_output:
            if self.batch_first:
                z_ = z.unsqueeze(1).repeat(1,sl,1)
            else:
                z_ = z.unsqueeze(0).repeat(sl,1,1)
                
            x = torch.cat([x, z_], -1)

        if hiddens is None:
            if self.condition_hidden:
                hiddens = self.latent_to_hidden(z)
                
            else:
                hiddens = self.get_new_hidden(bs)
            
            hiddens = to_device(hiddens, x.device)
            
        new_hiddens = []
        for i, lstm in enumerate(self.lstms):
            x, (h,c) = lstm(x, hiddens[i])
            new_hiddens.append((h.detach(), c.detach()))
            
        return x, new_hiddens
    
    def latent_to_hidden(self, z):
        hiddens = []
        for layer in self.to_hidden:
            h = layer(z)
            h,c = torch.chunk(h, 2, dim=-1)
            bs, _ = h.shape
            h = h.contiguous().reshape(bs, self.n_dir, -1).permute(1,0,2)
            c = c.contiguous().reshape(bs, self.n_dir, -1).permute(1,0,2)
            hiddens.append((h,c))
            
        return hiddens
            
    def get_new_hidden(self, bs):
        hiddens = []
        for hs in self.hidden_sizes:
            h = torch.zeros(hs).repeat(1,bs,1)
            c = torch.zeros(hs).repeat(1,bs,1)
            hiddens.append((h,c))
        
        return hiddens
    
    def mixup_hiddens(self, hiddens):
        new_hiddens = []
        for item in hiddens:
            h,c = item
            shuffle = to_device(torch.randperm(h.shape[1]))
            h = h[:,shuffle]
            c = c[:,shuffle]
            new_hiddens.append((h,c))
        return new_hiddens
    
class LSTM(Conditional_LSTM):
    def __init__(self, d_embedding, d_hidden, d_output, n_layers, 
                 bidir=False, dropout=0., batch_first=True):
        super().__init__(d_embedding, d_hidden, d_output, 0, n_layers,
                 condition_hidden=False, condition_output=False,
                 bidir=bidir, dropout=dropout, batch_first=batch_first)

    def forward(self, x, hiddens=None):
        
        x, new_hiddens = super().forward(x, None, hiddens)
            
        return x, new_hiddens

In [None]:
d_embedding=64
d_hidden=128
d_latent = 32
n_layers = 2

l1 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
                     condition_hidden=True, condition_output=True, 
                     bidir=False, batch_first=True)

l2 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
                     condition_hidden=True, condition_output=True, 
                     bidir=True, batch_first=True)

l3 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
                     condition_hidden=False, condition_output=True, 
                     bidir=False, batch_first=True)

l4 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
                     condition_hidden=True, condition_output=False, 
                     bidir=False, batch_first=True)

l5 = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
                     condition_hidden=False, condition_output=False, 
                     bidir=True, batch_first=True)

bs = 12
x = torch.randn((bs, 21, d_embedding))
z = torch.randn((bs, d_latent))

_ = l1(x,z)
_ = l1(x,z, l1.latent_to_hidden(z))

_ = l2(x,z)
_ = l2(x,z, l2.latent_to_hidden(z))

_ = l3(x,z)
_ = l3(x,z, l3.get_new_hidden(bs))

_ = l4(x,z)
_ = l4(x,z, l4.get_new_hidden(bs))

_ = l5(x,z)
_ = l5(x,None)
_ = l5(x,z, l5.get_new_hidden(bs))

In [None]:
l1 = LSTM(d_embedding, d_hidden, d_embedding, n_layers, bidir=False, batch_first=True)

l2 = LSTM(d_embedding, d_hidden, d_embedding, n_layers, bidir=True, batch_first=True)

_ = l1(x)
_ = l1(x, l1.get_new_hidden(bs))

_ = l2(x)
_ = l2(x, l2.get_new_hidden(bs))

In [None]:
# export

class Conditional_LSTM_Block(nn.Module):
    def __init__(self, d_vocab, d_embedding, d_hidden, d_output, d_latent, n_layers,
                 lstm_drop=0., lin_drop=0., bidir=False,
                 condition_hidden=True, condition_output=False):
        super().__init__()
        
        self.embedding = nn.Embedding(d_vocab, d_embedding)
        self.lstm = Conditional_LSTM(d_embedding, d_hidden, d_output, d_latent, n_layers,
                                    condition_hidden=condition_hidden, condition_output=condition_output,
                                     bidir=bidir, dropout=lstm_drop)
        
        self.head = Linear(d_output, d_vocab, act=False, bn=False, dropout=lin_drop)
        
    def forward(self, x, z, hiddens=None):
        x = self.embedding(x)
        x, hiddens = self.lstm(x, z, hiddens)
        x = self.head(x)
        
        return x, hiddens

class LSTM_Block(nn.Module):
    def __init__(self, d_vocab, d_embedding, d_hidden, d_output, n_layers,
                 lstm_drop=0., lin_drop=0., bidir=False):
        super().__init__()
        
        self.embedding = nn.Embedding(d_vocab, d_embedding)
        self.lstm = LSTM(d_embedding, d_hidden, d_output, n_layers,
                                     bidir=bidir, dropout=lstm_drop)
        
        self.head = nn.Linear(d_output, d_vocab)
        self.head_drop = nn.Dropout(lin_drop)
        
    def forward(self, x, hiddens=None):
        x = self.embedding(x)
        x, hiddens = self.lstm(x, hiddens)
        x = self.head_drop(self.head(x))
        
        return x, hiddens


In [None]:
# export

class LSTM_LM(nn.Module):
    def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, 
                 lstm_drop=0., lin_drop=0., bos_idx=0, bidir=False, tie_weights=False):
        super().__init__()
        
        self.block = LSTM_Block(d_vocab, d_embedding, d_hidden, d_embedding, n_layers,
                                lstm_drop=lstm_drop, lin_drop=lin_drop, bidir=bidir)
        self.bos_idx = bos_idx
        
        if tie_weights:
            self.block.embedding.weight = self.block.head.weight
        
    def forward(self, x, hiddens=None):
        x, hiddens = self.block(x, hiddens)
        return x
    
    def sample(self, bs, sl, temperature=1., multinomial=True):
        
        preds = idxs = to_device(torch.tensor([self.bos_idx]*bs).long().unsqueeze(-1))
        lps = []

        hiddens = None
        
        for i in range(sl):
            x, hiddens = self.block(idxs, hiddens)
            x.div_(temperature)
            
            idxs, lp = x_to_preds(x, multinomial=multinomial)
            
            lps.append(lp)            
            preds = torch.cat([preds, idxs], -1)
            
        return preds[:, 1:], torch.cat(lps,-1)
    
    def sample_no_grad(self, bs, sl, temperature=1., multinomial=True):
        with torch.no_grad():
            return self.sample(bs, sl, temperature=temperature, multinomial=multinomial)
        
    def get_lps(self, x, y, temperature=1.):
        x = self.forward(x)
        x.div_(temperature)
        
        lps = F.log_softmax(x, -1)
        lps = lps.gather(2, y.unsqueeze(-1)).squeeze(-1)
        
        return lps

In [None]:
lm = LSTM_LM(32, 64, 256, 2)
ints = torch.randint(0, 31, (16, 10))
x = ints[:,:-1]
y = ints[:,1:]
out = lm(x)
lp = lm.get_lps(x,y)
_ = lm.sample(8, 10)

In [None]:
# export
class Encoder(nn.Module):
    def __init__(self, d_latent):
        super().__init__()
        self.d_latent = d_latent

class LSTM_Encoder(Encoder):
    def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, d_latent, dropout=0.):
        super().__init__(d_latent)
        self.embedding = nn.Embedding(d_vocab, d_embedding)
        self.lstm = LSTM(d_embedding, d_hidden, d_hidden, n_layers, 
                                 bidir=True, batch_first=True, dropout=dropout)
        self.head = nn.Linear(d_hidden*2, d_latent)
        
    def forward(self, x):
        x = self.embedding(x)
        x, hiddens = self.lstm(x)
        hidden = torch.cat(list(torch.cat(hiddens[-1], -1)), -1) # concatenate hidden/cell states of last layer
        latent = self.head(hidden)
        return latent
    
class MLP_Encoder(Encoder):
    def __init__(self, d_in, dims, d_latent, dropouts):
        super().__init__(d_latent)
        
        dims = [d_in]+dims
        
        acts = [True]*(len(dims)-1)
        bns = [True]*(len(dims)-1)
        layers = [Linear(d_in, d_out, act=a, bn=b, dropout=p)
                 for d_in, d_out, a, b, p in zip(dims[:-1], dims[1:], acts, bns, dropouts)]
        layers.append(nn.Linear(dims[-1], d_latent))
        
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        x = self.layers(x)
        return x
    
    
class Conv_Encoder(Encoder):
    def __init__(self, d_vocab, d_embedding, d_latent, filters, kernel_sizes, strides, dropouts):
        super().__init__(d_latent)
        
        self.embedding = nn.Embedding(d_vocab, d_embedding)
        
        filters = [d_embedding] + filters
        
        convs = [Conv1d(filters[i], filters[i+1], ks=kernel_sizes[i],
                        stride=strides[i], act=True, bn=True, dropout=dropouts[i])
                    for i in range(len(filters)-1)]

        self.convs = nn.Sequential(*convs)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(filters[-1], d_latent)
        
    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0,2,1)
        x = self.convs(x)
        x = self.pool(x).squeeze(-1)
        x = self.head(x)
        return x

In [None]:
d_latent = 128
l = LSTM_Encoder(32, 64, 128, 2, 128)
assert l(torch.randint(0,31, (10,15))).shape[-1] == d_latent

m = MLP_Encoder(128, [64, 32, 16], d_latent, [0.1, 0.1, 0.1])
assert m(torch.randn(8,128)).shape[-1] == d_latent

c = Conv_Encoder(32, 64, d_latent, [32, 16], [7,7], [2,2], [0.1, 0.1])
assert c(torch.randint(0,31, (10,15))).shape[-1] == d_latent

In [None]:
# export

class VAE_Transition(nn.Module):
    def __init__(self, d_latent):
        super().__init__()
        
        self.d_latent = d_latent
        self.transition = nn.Linear(d_latent, d_latent*2)
        
    def forward(self, x, z_scale=1.):
        mu, logvar = self.get_stats(x)
        z = z_scale*torch.randn(mu.shape).to(mu.device)
        z = mu + z*torch.exp(0.5*logvar)
        kl_loss = 0.5 * (logvar.exp() + mu.pow(2) - 1 - logvar).sum(1).mean()
        return z, kl_loss
    
    def get_stats(self, x):
        mu, logvar = torch.chunk(self.transition(x), 2, dim=-1)
        return mu, logvar
    
class Norm_Transition(nn.Module):
    def __init__(self, d_latent, p=2):
        super().__init__()
        self.d_latent = d_latent
        self.p = p
        
    def forward(self, x):
        x = F.normalize(x, p=self.p, dim=-1)
        return x

class PT_Transition(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x

In [None]:
t = Norm_Transition(128, p=2)
x = torch.randn((8, 128))
assert torch.allclose(t(x).pow(2).sum(-1), torch.ones(x.shape[0]).float())

In [None]:
# export

class Encoder_Decoder(nn.Module):
    def __init__(self, encoder, decoder, transition=None, prior=None):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
        if transition is None:
            transition = PT_Transition()
            
        self.transition = transition
        
        if prior is None:
            prior = NormalPrior(torch.zeros((encoder.d_latent)), torch.zeros((encoder.d_latent)), 
                                trainable=False)
        
        self.prior = prior
        
    def forward(self, x, decoder_input=None):
        if decoder_input is None:
            decoder_input = x
            
        z = self.encoder(x)
        z = self.transition(x)
        output = self.decoder(decoder_input, z)
        return output
    
    def set_prior(self, prior):
        self.prior = prior

In [None]:
# export 

class VAE(Encoder_Decoder):
    def __init__(self, encoder, decoder, prior=None, bos_idx=0):
        transition = VAE_Transition(encoder.d_latent)
        super().__init__(encoder, decoder, transition, prior)
            
        self.bos_idx = bos_idx
        
    def forward(self, x, decoder_input=None):
        
        z = self.encoder(x)
        z, kl_loss = self.transition(z)
            
        if decoder_input is None:
            decoder_input = x
            
        output, hiddens = self.decoder(decoder_input, z)
        return output, kl_loss
    

    def sample(self, bs, sl, z=None, temperature=1., multinomial=True):
        
        preds = idxs = to_device(torch.tensor([self.bos_idx]*bs).long().unsqueeze(-1))
        lps = []
        
        if z is None:
            z = to_device(self.prior.rsample([bs]))
        
        hiddens = None
        
        for i in range(sl):
            x, hiddens = self.decoder(idxs, z, hiddens)
            x.div_(temperature)
            
            idxs, lp = x_to_preds(x, multinomial=multinomial)
            
            lps.append(lp)            
            preds = torch.cat([preds, idxs], -1)
            
        return preds[:, 1:], torch.cat(lps,-1)
    
    def sample_no_grad(self, bs, sl, z=None, temperature=1., multinomial=True):
        with torch.no_grad():
            return self.sample(bs, sl, z=z, temperature=temperature, multinomial=multinomial)
        
    def get_lps(self, x, y, temperature=1., z=None):

        if type(x)==list:
            z,_ = self.transition(self.encoder(x[0]))
            x,_ = self.decoder(x[1], z)
        else:
            z,_ = self.transition(self.encoder(x))
            x,_ = self.decoder(x,z)
        
        x.div_(temperature)
        
        lps = F.log_softmax(x, -1)
        lps = lps.gather(2, y.unsqueeze(-1)).squeeze(-1)
        
        if self.prior.trainable:
            prior_lps = self.prior.log_prob(z).mean(-1, keepdim=True)
            prior_lps = torch.zeros(prior_lps.shape).float() + prior_lps - prior_lps.detach()
            lps += prior_lps
        
        return lps
    
    def set_prior_from_stats(self, mu, logvar, trainable=False):
        mu = mu.detach()
        logvar = logvar.detach()
        self.prior = NormalPrior(mu, logvar, trainable)
        
    def set_prior_from_latent(self, z, trainable=False):
        mu, logvar = self.transition.get_stats(z)
        self.set_prior_from_stats(mu, logvar, trainable)
        
    def set_prior_from_encoder(self, x, trainable=False):
        assert x.shape[0]==1, "Must set prior from a single input"
        z = self.encoder(x)
        z = z.squeeze(0)
        self.set_prior_from_latent(z, trainable)

In [None]:
encoder = LSTM_Encoder(32, 64, 128, 2, 128)
decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
                                condition_hidden=True, condition_output=True)
vae = VAE(encoder, decoder)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

_ = vae(x)

decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
                                condition_hidden=False, condition_output=True)
vae = VAE(encoder, decoder)

_ = vae(x)

decoder = Conditional_LSTM_Block(32, 64, 128, 64, 128, 2,
                                condition_hidden=True, condition_output=False)
vae = VAE(encoder, decoder)

_ = vae(x)

_ = vae.sample(8, 16)

z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)

_ = vae.get_lps(x,y)

vae.set_prior_from_encoder(x[0].unsqueeze(0));

In [None]:
# export

class LSTM_VAE(VAE):
    def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, d_latent,
                enc_drop=0., dec_drop=0., condition_hidden=True, condition_output=True,
                prior=None, bos_idx=0):
        
        encoder = LSTM_Encoder(d_vocab, d_embedding, d_hidden, 
                               n_layers, d_latent, dropout=enc_drop)
        
        decoder = Conditional_LSTM_Block(d_vocab, d_embedding, d_hidden, d_embedding,
                                d_latent, n_layers, lstm_drop=dec_drop, lin_drop=dec_drop, 
                                condition_hidden=condition_hidden, condition_output=condition_output)
        
        super().__init__(encoder, decoder, prior, bos_idx)

In [None]:
vae = LSTM_VAE(32, 64, 128, 2, 128, condition_hidden=True, condition_output=True)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

_ = vae(x)

_ = vae.sample(8, 16)

z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)

_ = vae.get_lps(x,y)

In [None]:
# export

class Conv_VAE(VAE):
    def __init__(self, d_vocab, d_embedding, conv_filters, kernel_sizes, strides, conv_drops,
                 d_hidden, n_layers, d_latent, dec_drop=0., 
                 condition_hidden=True, condition_output=True,
                 prior=None, bos_idx=0):
        
        encoder = Conv_Encoder(d_vocab, d_embedding, d_latent, 
                               conv_filters, kernel_sizes, strides, conv_drops)
        
        decoder = Conditional_LSTM_Block(d_vocab, d_embedding, d_hidden, d_embedding,
                                d_latent, n_layers, lstm_drop=dec_drop, lin_drop=dec_drop, 
                                condition_hidden=condition_hidden, condition_output=condition_output)
        
        super().__init__(encoder, decoder, prior, bos_idx)

In [None]:
vae = Conv_VAE(32, 64, [128, 256], [7,7], [2,2], [0.1,0.1], 128, 2, 128, 
               condition_hidden=False, condition_output=True)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

_ = vae(x)

_ = vae.sample(8, 16)

z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)

_ = vae.get_lps(x,y)

In [None]:
# export
        
class MLP_VAE(VAE):
    def __init__(self, d_vocab, d_embedding, encoder_d_in, encoder_dims, encoder_drops,
                 d_hidden, n_layers, d_latent, dec_drop=0., 
                 condition_hidden=True, condition_output=True,
                 prior=None, bos_idx=0):
        
        encoder = MLP_Encoder(encoder_d_in, encoder_dims, d_latent, encoder_drops)
        
        decoder = Conditional_LSTM_Block(d_vocab, d_embedding, d_hidden, d_embedding,
                                d_latent, n_layers, lstm_drop=dec_drop, lin_drop=dec_drop, 
                                condition_hidden=condition_hidden, condition_output=condition_output)
        
        super().__init__(encoder, decoder, prior, bos_idx)
    
    

In [None]:
vae = MLP_VAE(32, 64, 128, [64, 32], [0.1, 0.1], 128, 2, 128, 
               condition_hidden=False, condition_output=True)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

condition = torch.randn((8,128))


_ = vae(condition, x)

_ = vae.sample(8, 16)

z = vae.prior.sample([8])
_ = vae.sample(8, 16, z=z)

_ = vae.get_lps([condition,x],y)

In [None]:
# export

class Conditional_LSTM_LM(Encoder_Decoder):
    def __init__(self, encoder, d_vocab, d_embedding, d_hidden, d_latent, n_layers,
                 lstm_drop=0., lin_drop=0., bidir=False,
                 condition_hidden=True, condition_output=False, bos_idx=0, prior=None):
        
        transition = Norm_Transition(d_latent)
        
        decoder = Conditional_LSTM_Block(d_vocab, d_embedding, d_hidden, d_embedding,
                                d_latent, n_layers, lstm_drop=lstm_drop, lin_drop=lin_drop, 
                                condition_hidden=condition_hidden, condition_output=condition_output)
        
        if prior is None:
            prior = SphericalPrior(torch.zeros((encoder.d_latent)), torch.zeros((encoder.d_latent)), 
                                trainable=False)
        
        super().__init__(encoder, decoder, transition, prior)
        
        self.bos_idx = bos_idx
        
    def forward(self, x, condition, hiddens=None):
        z = self.encoder(condition)
        z = self.transition(z)
        x, hiddens = self.decoder(x, z, hiddens)
        return x
    

    def sample(self, bs, sl, z=None, temperature=1., multinomial=True):
        
        if z is None:
            if self.prior is not None:
                z = to_device(self.prior.rsample([bs]))
            else:
                z = to_device(torch.randn((bs, self.encoder.d_latent)))
                z = self.transition(z)
        else:
            bs = z.shape[0]
        
        preds = idxs = to_device(torch.tensor([self.bos_idx]*bs).long().unsqueeze(-1))
        lps = []

        hiddens = self.decoder.lstm.latent_to_hidden(z)
        
        for i in range(sl):
            
            x, hiddens = self.decoder(idxs,z,hiddens)
            x.div_(temperature)
            
            idxs, lp = x_to_preds(x, multinomial=multinomial)
            
            lps.append(lp)            
            preds = torch.cat([preds, idxs], -1)
            
        return preds[:, 1:], torch.cat(lps,-1)
    
    def sample_no_grad(self, bs, sl, z=None, temperature=1., multinomial=True):
        with torch.no_grad():
            return self.sample(bs, sl, z=z, temperature=temperature, multinomial=multinomial)
        
    def get_lps(self, x, y, temperature=1.):
        x, c = x
        z = self.transition(self.encoder(c))
        x,_ = self.decoder(x, z)
        
        x.div_(temperature)
        
        lps = F.log_softmax(x, -1)
        lps = lps.gather(2, y.unsqueeze(-1)).squeeze(-1)
        
        if self.prior.trainable:
            prior_lps = self.prior.log_prob(z).mean(-1, keepdim=True)
            prior_lps = torch.zeros(prior_lps.shape).float() + prior_lps - prior_lps.detach()
            lps += prior_lps
        
        return lps
        
    def set_prior_from_latent(self, z, logvar, trainable=False):
        z = z.detach()
        logvar = logvar.detach()
        self.prior = SphericalPrior(z, logvar, trainable)
        
    def set_prior_from_encoder(self, condition, logvar, trainable=False):
        assert condition.shape[0]==1
        z = self.transition(self.encoder(condition))
        z = z.squeeze(0)
        self.set_prior_from_latent(z, logvar, trainable)

In [None]:
encoder = MLP_Encoder(128, [64, 32], 16, [0.1, 0.1])

lm = Conditional_LSTM_LM(encoder, 32, 64, 128, 16, 2)

ints = torch.randint(0, 31, (8, 10))
x = ints[:,:-1]
y = ints[:,1:]

condition = torch.randn((8,128))

_ = lm(x, condition)

_ = lm.get_lps([x,condition],y)

_ = lm.sample(3, 80)

In [None]:
# class LSTM(nn.Module):
#     def __init__(self, d_embedding, d_hidden, d_output, n_layers, 
#                  bidir=False, dropout=0., batch_first=True):
#         super().__init__()
        
#         self.d_embedding = d_embedding
#         self.d_hidden = d_hidden
#         self.d_output = d_output
#         self.n_layers = n_layers
#         self.bidir = bidir
#         self.n_dir = 1 if not bidir else 2
#         self.batch_first = batch_first
        
#         self.lstms = []
#         self.hidden_sizes = []
        
#         for l in range(n_layers):
#             input_size = d_embedding if l==0 else d_hidden
#             output_size = d_output if l==n_layers-1 else d_hidden
#             output_size = output_size // self.n_dir
            
#             hidden_size = (self.n_dir, 1, output_size)
#             self.hidden_sizes.append(hidden_size)
            
#             lstm = nn.LSTM(input_size, output_size, 1, batch_first=batch_first, 
#                            dropout=dropout, bidirectional=bidir)
#             self.lstms.append(lstm)
            
#         self.lstms = nn.ModuleList(self.lstms)
        
#     def forward(self, x, hiddens=None):
        
#         bs = x.shape[0] if self.batch_first else x.shape[1]
        
#         if hiddens is None:
#             hiddens = self.get_new_hidden(bs)
#             hiddens = to_device(hiddens, x.device)
            
#         new_hiddens = []
#         for i, lstm in enumerate(self.lstms):
#             x, (h,c) = lstm(x, hiddens[i])
#             new_hiddens.append((h.detach(), c.detach()))
            
#         return x, new_hiddens
            
#     def get_new_hidden(self, bs):
#         hiddens = []
#         for hs in self.hidden_sizes:
#             h = torch.zeros(hs).repeat(1,bs,1)
#             c = torch.zeros(hs).repeat(1,bs,1)
#             hiddens.append((h,c))
        
#         return hiddens

In [None]:
# class Conditional_LSTM(nn.Module):
#     def __init__(self, d_embedding, d_hidden, d_output, d_latent, n_layers,
#                  condition_hidden=True, condition_output=True,
#                  bidir=False, dropout=0., batch_first=True):
#         super().__init__()
        
#         self.d_embedding = d_embedding
#         self.d_hidden = d_hidden
#         self.d_output = d_output
#         self.n_layers = n_layers
#         self.bidir = bidir
#         self.n_dir = 1 if not bidir else 2
#         self.batch_first = batch_first
#         self.condition_hidden = condition_hidden
#         self.condition_output = condition_output
        
#         self.lstms = []
#         self.hidden_sizes = []
        
#         for l in range(n_layers):
#             if l==0:
#                 input_size = d_embedding if not self.condition_output else d_embedding+d_latent
#             else:
#                 input_size = d_hidden
                
#             output_size = d_output if l==n_layers-1 else d_hidden
#             output_size = output_size // self.n_dir
            
#             hidden_size = (self.n_dir, 1, output_size)
#             self.hidden_sizes.append(hidden_size)
            
#             lstm = nn.LSTM(input_size, output_size, 1, batch_first=batch_first, 
#                            dropout=dropout, bidirectional=bidir)
#             self.lstms.append(lstm)
            
#         self.lstms = nn.ModuleList(self.lstms)
        
#         if self.condition_hidden:
#             to_hidden = []
#             for size in self.hidden_sizes:
#                 ndir, _, dim = size
#                 to_hidden.append(Linear(d_latent, ndir*dim*2, act=False, bn=False))
                
#             self.to_hidden = nn.ModuleList(to_hidden)
        
#     def forward(self, x, z, hiddens=None):
        
#         bs = x.shape[0] if self.batch_first else x.shape[1]
#         sl = x.shape[1] if self.batch_first else x.shape[0]
        
#         if self.condition_output:
#             if self.batch_first:
#                 z_ = z.unsqueeze(1).repeat(1,sl,1)
#             else:
#                 z_ = z.unsqueeze(0).repeat(sl,1,1)
                
#             x = torch.cat([x, z_], -1)

#         if hiddens is None:
#             if self.condition_hidden:
#                 hiddens = self.latent_to_hidden(z)
                
#             else:
#                 hiddens = self.get_new_hidden(bs)
            
#             hiddens = to_device(hiddens, x.device)
            
#         new_hiddens = []
#         for i, lstm in enumerate(self.lstms):
#             x, (h,c) = lstm(x, hiddens[i])
#             new_hiddens.append((h.detach(), c.detach()))
            
#         return x, new_hiddens
    
#     def latent_to_hidden(self, z):
#         hiddens = []
#         for layer in self.to_hidden:
#             h = layer(z)
#             h,c = torch.chunk(h, 2, dim=-1)
#             bs, _ = h.shape
#             h = h.contiguous().reshape(bs, self.n_dir, -1).permute(1,0,2)
#             c = c.contiguous().reshape(bs, self.n_dir, -1).permute(1,0,2)
#             hiddens.append((h,c))
            
#         return hiddens
            
#     def get_new_hidden(self, bs):
#         hiddens = []
#         for hs in self.hidden_sizes:
#             h = torch.zeros(hs).repeat(1,bs,1)
#             c = torch.zeros(hs).repeat(1,bs,1)
#             hiddens.append((h,c))
        
#         return hiddens

In [None]:
# class LSTMLM(nn.Module):
#     def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, 
#                  lstm_drop=0., bos_idx=0, bidir=False):
#         super().__init__()
        
#         self.embedding = nn.Embedding(d_vocab, d_embedding)
#         self.lstm = LSTM(d_embedding, d_hidden, d_embedding, n_layers, bidir=bidir, dropout=lstm_drop)
#         self.head = Linear(d_embedding, d_vocab, act=False, bn=False, dropout=0.)
#         self.bos_idx = bos_idx
        
#     def forward(self, x):
#         x = self.embedding(x)
#         x, hiddens = self.lstm(x)
#         self.last_hidden = hiddens
#         x = self.head(x)
#         return x
    
#     def sample(self, bs, sl, temperature=1., multinomial=True):
        
#         preds = idxs = to_device(torch.tensor([self.bos_idx]*bs).long().unsqueeze(-1))
#         lps = []

#         hiddens = None
        
#         for i in range(sl):
#             x = self.embedding(idxs)
#             x, hiddens = self.lstm(x, hiddens)
#             x = self.head(x)
            
#             x.div_(temperature)
            
#             log_probs = F.log_softmax(x, -1).squeeze(1)
#             probs = log_probs.detach().exp()
            
#             if multinomial:
#                 idxs = torch.multinomial(probs, 1)
#             else:
#                 idxs = x.argmax(-1)
                
#             lps.append(torch.gather(log_probs, 1, idxs))
            
#             preds = torch.cat([preds, idxs], -1)
            
#         return preds[:, 1:], torch.cat(lps,-1)
    
#     def sample_no_grad(self, bs, sl, temperature=1., multinomial=True):
#         with torch.no_grad():
#             return self.sample(bs, sl, temperature=temperature, multinomial=multinomial)
        
#     def get_lps(self, x, y, temperature=1.):
#         x = self.forward(x)
#         x.div_(temperature)
        
#         lps = F.log_softmax(x, -1)
#         lps = lps.gather(2, y.unsqueeze(-1)).squeeze(-1)
        
#         return lps

In [None]:
# class Conditional_LSTMLM(nn.Module):
#     def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, mapping, d_latent,
#                  lstm_drop=0., lin_drop=0., bos_idx=0, bidir=False,
#                  condition_hidden=True, condition_output=False, norm_latent=True):
#         super().__init__()
        
#         self.mapping = mapping
#         self.d_latent = d_latent
        
#         self.embedding = nn.Embedding(d_vocab, d_embedding)
#         self.lstm = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, n_layers,
#                                     condition_hidden=condition_hidden, condition_output=condition_output,
#                                      bidir=bidir, dropout=lstm_drop)
        
#         self.head = Linear(d_embedding, d_vocab, act=False, bn=False, dropout=0.)
#         self.bos_idx = bos_idx
#         self.norm_latent = norm_latent
        
#     def forward(self, x, condition):
        
#         z = self.mapping(condition)
#         if self.norm_latent:
#             z = F.normalize(z, p=2, dim=-1)
        
#         x = self.embedding(x)
#         x, hiddens = self.lstm(x,z)
#         self.last_hidden = hiddens
#         x = self.head(x)
#         return x
    
#     def sample(self, bs, sl, z=None, temperature=1., multinomial=True):
        
#         if z is None:
#             z = to_device(F.normalize(torch.randn((bs, self.d_latent)), p=2, dim=-1))
#         else:
#             bs = z.shape[0]
        
#         preds = idxs = to_device(torch.tensor([self.bos_idx]*bs).long().unsqueeze(-1))
#         lps = []

#         hiddens = None
        
#         for i in range(sl):
#             x = self.embedding(idxs)
#             x, hiddens = self.lstm(x, z, hiddens)
#             x = self.head(x)
            
#             x.div_(temperature)
            
#             log_probs = F.log_softmax(x, -1).squeeze(1)
#             probs = log_probs.detach().exp()
            
#             if multinomial:
#                 idxs = torch.multinomial(probs, 1)
#             else:
#                 idxs = x.argmax(-1)
                
#             lps.append(torch.gather(log_probs, 1, idxs))
            
#             preds = torch.cat([preds, idxs], -1)
            
#         return preds[:, 1:], torch.cat(lps,-1)
    
#     def sample_no_grad(self, bs, sl, z=None, temperature=1., multinomial=True):
#         with torch.no_grad():
#             return self.sample(bs, sl, z=z, temperature=temperature, multinomial=multinomial)
        
#     def get_lps(self, x, y, temperature=1.):
#         x = self.forward(x[0], x[1])
#         x.div_(temperature)
        
#         lps = F.log_softmax(x, -1)
#         lps = lps.gather(2, y.unsqueeze(-1)).squeeze(-1)
        
#         return lps

In [None]:
# class VAEEncoder(nn.Module):
#     def __init__(self, d_latent):
#         super().__init__()
#         self.d_latent = d_latent
        
#     def forward(self, x):
#         raise NotImplementedError
        
#     def get_latent(self, mu, logvar, z_scale=1.):
#         z = z_scale*torch.randn(mu.shape).to(mu.device)
#         z = mu + z*torch.exp(0.5*logvar)
#         kl_loss = 0.5 * (logvar.exp() + mu.pow(2) - 1 - logvar).sum(1).mean()
#         return z, kl_loss
        
# class VAELSTMEncoder(VAEEncoder):
#     def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, d_latent, dropout=0.):
#         super().__init__(d_latent)
        
#         self.embedding = nn.Embedding(d_vocab, d_embedding)
#         self.lstm_encoder = LSTM(d_embedding, d_hidden, d_hidden, n_layers, 
#                                  bidir=True, batch_first=True, dropout=dropout)
#         self.transition = nn.Linear(d_hidden*2, d_latent*2)
        
        
#     def forward(self, x, z_scale=1.):
#         x = self.embedding(x)
#         x, hiddens = self.lstm_encoder(x)
#         hidden = torch.cat(list(torch.cat(hiddens[-1], -1)), -1) # concatenate hidden/cell states of last layer
        
#         mu, logvar = torch.chunk(self.transition(hidden), 2, dim=-1)
#         z, kl_loss = self.get_latent(mu, logvar, z_scale)
        
#         return z, kl_loss
              
# class VAEConvEncoder(VAEEncoder):
#     def __init__(self, d_vocab, d_embedding, kernel_size, n_layers, d_latent, dropout=0.):
#         super().__init__(d_latent)
    
#         self.embedding = nn.Embedding(d_vocab, d_embedding)

#         convs = []
#         input_size = d_embedding
#         for i in range(n_layers):
#             convs.append(Conv1d(input_size, input_size*2, ks=kernel_size, stride=2, 
#                                 act=True, bn=True, dropout=dropout))
#             input_size = input_size*2

#         self.convs = nn.Sequential(*convs)
#         self.pool = nn.AdaptiveAvgPool1d(1)
#         self.transition = nn.Linear(input_size, d_latent*2)
    
#     def forward(self, x, z_scale=1.):
#         x = self.embedding(x)
#         x = x.permute(0,2,1)
#         x = self.convs(x)
#         x = self.pool(x).squeeze(-1)
        
#         mu, logvar = torch.chunk(self.transition(x), 2, dim=-1)
#         z, kl_loss = self.get_latent(mu, logvar, z_scale)
        
#         return z, kl_loss
              
# class VAELinEncoder(VAEEncoder):
#     def __init__(self, d_input, n_layers, d_latent, dropout=0.):
#         super().__init__(d_latent)
    
#         lins = []
#         input_size = d_input
#         for i in range(n_layers):
#             lins.append(Linear(input_size, input_size//2, act=True, bn=True, dropout=dropout))
#             input_size = input_size//2
            
#         self.layers = nn.Sequential(*lins)
#         self.transition = nn.Linear(input_size, d_latent*2)
    
#     def forward(self, x, z_scale=1.):
#         x = self.layers(x)
        
#         mu, logvar = torch.chunk(self.transition(x), 2, dim=-1)
#         z, kl_loss = self.get_latent(mu, logvar, z_scale)
        
#         return z, kl_loss

In [None]:
# class VAEDecoder(nn.Module):
#     def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, d_latent,
#                 condition_hidden=True, condition_output=True):
#         super().__init__()
        
#         self.embedding = nn.Embedding(d_vocab, d_embedding)
#         self.decoder = Conditional_LSTM(d_embedding, d_hidden, d_embedding, d_latent, 3, 
#                                     condition_hidden=condition_hidden, condition_output=condition_output, 
#                                     bidir=False, batch_first=True)
        
#         self.head = Linear(d_embedding, d_vocab, act=False, bn=False, dropout=0.)
        
#     def forward(self, x, z, hiddens=None):
#         bs, sl = x.shape
#         x = self.embedding(x)
        
#         decoded, hiddens = self.decoder(x, z, hiddens)
#         output = self.head(decoded)
        
#         return output, hiddens

In [None]:
# class VAE(nn.Module):
#     def __init__(self, encoder, decoder, prior=None, bos_idx=0):
#         super().__init__()
        
#         self.encoder = encoder
#         self.decoder = decoder
#         if prior is None:
#             prior = Normal(torch.zeros((encoder.d_latent)), torch.ones((encoder.d_latent)))
#         self.prior = prior
#         self.bos_idx = bos_idx
        
#     def forward(self, x, decoder_input=None):
#         z, kl_loss = self.encoder(x)
        
#         if decoder_input is None:
#             decoder_input = x
            
#         output, hiddens = self.decoder(decoder_input, z)
#         return output, kl_loss
    
#     def sample(self, bs, sl, z=None, temperature=1., multinomial=True):
        
#         preds = idxs = to_device(torch.tensor([self.bos_idx]*bs).long().unsqueeze(-1))
#         lps = []
        
#         if z is None:
#             z = to_device(self.prior.sample([bs]))
        
#         hiddens = None
        
#         for i in range(sl):
#             x, hiddens = self.decoder(idxs, z, hiddens)
#             x.div_(temperature)
            
#             log_probs = F.log_softmax(x, -1).squeeze(1)
#             probs = log_probs.detach().exp()
            
#             if multinomial:
#                 idxs = torch.multinomial(probs, 1)
#             else:
#                 idxs = x.argmax(-1)
                
#             lps.append(torch.gather(log_probs, 1, idxs))
            
#             preds = torch.cat([preds, idxs], -1)
            
#         return preds[:, 1:], torch.cat(lps,-1)
    
#     def sample_no_grad(self, bs, sl, z=None, temperature=1., multinomial=True):
#         with torch.no_grad():
#             return self.sample(bs, sl, z=z, temperature=temperature, multinomial=multinomial)
        
#     def get_lps(self, x, y, temperature=1., z=None):

#         if type(x)==list:
#             x,_ = self.forward(x[0], decoder_input=x[1])
#         else:
#             x,_ = self.forward(x)
        
#         x.div_(temperature)
        
#         lps = F.log_softmax(x, -1)
#         lps = lps.gather(2, y.unsqueeze(-1)).squeeze(-1)
        
#         return lps


In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_chem.ipynb.
Converted 02_template.filters.ipynb.
Converted 03_template.template.ipynb.
Converted 04_template.blocks.ipynb.
Converted 05_torch_core.ipynb.
Converted 06_layers.ipynb.
Converted 07_dataloaders.ipynb.
Converted index.ipynb.
Converted template.overview.ipynb.
Converted tutorials.ipynb.
Converted tutorials.structure_enumeration.ipynb.
Converted tutorials.template.advanced.ipynb.
Converted tutorials.template.beginner.ipynb.
Converted tutorials.template.intermediate.ipynb.
