In [None]:
# default_exp layers
# all_slow

# Layers

> Pytorch model layers

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export
from mrl.imports import *
from mrl.torch_imports import *

In [None]:
# export

class Linear(nn.Module):
    def __init__(self, d_in, d_out, act=True, bn=False, dropout=0.):
        super().__init__()
        
        layers = [nn.Linear(d_in, d_out)]
        
        if bn:
            layers.append(nn.BatchNorm1d(d_out))
            
        if act:
            layers.append(nn.ReLU())
            
        if dropout>0.:
            layers.append(nn.Dropout(p=dropout))
            
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)

In [None]:
# export

class LSTM(nn.Module):
    def __init__(self, d_embedding, d_hidden, d_output, n_layers, 
                 bidir=False, dropout=0., batch_first=True):
        super().__init__()
        
        self.d_embedding = d_embedding
        self.d_hidden = d_hidden
        self.d_output = d_output
        self.n_layers = n_layers
        self.bidir = bidir
        self.n_dir = 1 if not bidir else 2
        self.batch_first = batch_first
        
        self.lstms = []
        self.hidden_sizes = []
        
        for l in range(n_layers):
            input_size = d_embedding if l==0 else d_hidden
            output_size = d_output if l==n_layers-1 else d_hidden
            
            hidden_size = (self.n_dir, 1, output_size)
            self.hidden_sizes.append(hidden_size)
            
            lstm = nn.LSTM(input_size, output_size, 1, batch_first=batch_first, 
                           dropout=dropout, bidirectional=bidir)
            self.lstms.append(lstm)
            
        self.lstms = nn.ModuleList(self.lstms)
        
    def forward(self, x, hiddens=None):
        
        bs = x.shape[0] if self.batch_first else x.shape[1]
        
        if hiddens is None:
            hiddens = self.get_new_hidden(bs)
            hiddens = [(i[0].to(x.device), i[1].to(x.device)) for i in hiddens]
            
        new_hiddens = []
        for i, lstm in enumerate(self.lstms):
            x, (h,c) = lstm(x, hiddens[i])
            new_hiddens.append((h.detach(), c.detach()))
            
        return x, new_hiddens
            
    def get_new_hidden(self, bs):
        hiddens = []
        for hs in self.hidden_sizes:
            h = torch.zeros(hs).repeat(1,bs,1)
            c = torch.zeros(hs).repeat(1,bs,1)
            hiddens.append((h,c))
        
        return hiddens

In [None]:
# export

class LSTMLM(nn.Module):
    def __init__(self, d_vocab, d_embedding, d_hidden, n_layers, pad_idx, 
                 lstm_drop=0., bos_idx=0, bidir=False):
        super().__init__()
        
        self.embedding = nn.Embedding(d_vocab, d_embedding)
        self.lstm = LSTM(d_embedding, d_hidden, d_embedding, n_layers, bidir=bidir, dropout=lstm_drop)
        self.head = Linear(d_embedding, d_vocab, act=False, bn=False, dropout=0.)
        self.bos_idx = bos_idx
        
    def forward(self, x):
        x = self.embedding(x)
        x, hiddens = self.lstm(x)
        self.last_hidden = hiddens
        x = self.head(x)
        return x
    
    def sample(self, bs, sl, multinomial=True):
        
        preds = idxs = torch.tensor([vocab.stoi['bos']]*bs).long().unsqueeze(-1) # todo - cuda
        lps = []

        hiddens = self.lstm.get_new_hidden(bs)
        
        for i in range(sl):
            x = self.embedding(idxs)
            x, hiddens = self.lstm(x, hiddens)
            x = self.head(x)

            log_probs = F.log_softmax(x, -1).squeeze(1)
            probs = log_probs.detach().exp()
            
            if multinomial:
                idxs = torch.multinomial(probs, 1)
            else:
                idxs = x.argmax(-1)
                
            lps.append(torch.gather(log_probs, 1, idxs))
            
            preds = torch.cat([preds, idxs], -1)
            
        return preds[:, 1:], torch.cat(lps,-1)
    
    def sample_no_grad(self, bs, sl, multinomial=True):
        with torch.no_grad():
            return self.sample(bs, sl, multinomial=multinomial)

In [None]:
from mrl.dataloaders import *

In [None]:
from mrl.chem import *

  return f(*args, **kwds)


In [None]:
from mrl.core import *

In [None]:
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

In [None]:
model = LSTMLM(len(vocab.itos), 256, 1024, 3, vocab.stoi['pad'])

In [None]:
model.load_state_dict(torch.load('untracked_files/smiles_lm.pt'))

<All keys matched successfully>

In [None]:
smiles = []
for i in range(400):
        
    preds, lps = model.sample(100,100)
    s = [vocab.reconstruct(i) for i in preds]
    smiles.append(s)

In [None]:
len(smiles)

437

In [None]:
len(smiles[0])

100

In [None]:
smiles = flatten_list_of_lists(smiles)

In [None]:
len(smiles)

43700

In [None]:
mols = to_mols(smiles)

In [None]:
len([i for i in mols if i is not None])/len(smiles)

0.9962471395881007

In [None]:
len(set(smiles))/len(smiles)

0.9999313501144165

In [None]:
lps.sum(-1)

tensor([-20.8896, -21.9627, -24.3883, -23.2155, -21.5423, -21.4593, -23.0183,
        -20.6185, -20.4622, -21.0812, -23.2870, -20.9693, -20.7390, -25.3426,
        -21.4910, -27.4981, -22.5912, -20.7221, -21.7060, -21.6088, -21.0723,
        -20.8114, -21.7904, -22.0713, -20.7259, -24.6049, -25.1872, -21.4254,
        -20.9317, -20.7213, -21.4634, -20.2982, -20.8087, -21.3563, -21.1142,
        -21.6108, -20.3180, -21.1286, -21.5118, -20.9709, -21.2088, -22.0678,
        -25.4424, -20.6760, -22.2405, -20.4261, -21.4985, -21.9339, -22.6726,
        -22.4086, -20.2079, -20.9287, -22.3963, -21.3464, -22.2943, -23.4055,
        -20.6830, -20.9470, -21.4593, -21.3202, -22.2658, -20.9025, -20.7081,
        -21.4945, -23.9573, -22.0222, -21.7784, -20.7909, -24.5765, -20.3598,
        -20.1857, -21.5615, -19.6921, -21.4191, -22.1406, -20.0879, -22.7940,
        -20.5922, -20.5122, -21.3518, -22.0540, -20.4816, -21.8322, -20.9540,
        -20.7582, -21.0134, -20.5224, -22.3031, -20.6800, -20.28

In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_chem.ipynb.
Converted 02_template.filters.ipynb.
Converted 03_template.template.ipynb.
Converted 04_template.blocks.ipynb.
Converted 05_torch_core.ipynb.
Converted 06_layers.ipynb.
Converted 07_dataloaders.ipynb.
Converted index.ipynb.
Converted template.overview.ipynb.
Converted tutorials.ipynb.
Converted tutorials.structure_enumeration.ipynb.
Converted tutorials.template.advanced.ipynb.
Converted tutorials.template.beginner.ipynb.
Converted tutorials.template.intermediate.ipynb.
