In [5]:
from torch.utils.data import DataLoader
from moses import CharVocab, StringDataset
from moses.char_rnn import CharRNN
from moses.char_rnn import config as CharRNNConfig
import moses
import torch
from torch import autograd
import rdkit.Chem as chem

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm.auto import tqdm
from torch.nn.utils.rnn import pad_sequence

from moses.interfaces import MosesTrainer
from moses.utils import CharVocab, Logger


def anti_model_rnn_loss(loss, alpha=0.5):
    loss = torch.exp(-1 * loss)
    loss = 1. - loss + 1.e-7
    loss = torch.log2(loss)
    loss = -alpha * loss
    return loss


class CharRNNTrainer(MosesTrainer):

    def __init__(self, config):
        self.config = config

    def _train_epoch(self, model, tqdm_data, criterion, optimizer=None):
        if optimizer is None:
            model.eval()
        else:
            model.train()

        postfix = {'loss': 0,
                   'running_loss': 0}

        for i, ((prevs, nexts, lens), (nprevs, nnexts, nlens)) in enumerate(tqdm_data):
            prevs = prevs.to(model.device)
            nexts = nexts.to(model.device)
            lens = lens.to(model.device)
            
            nprevs = nprevs.to(model.device)
            nnexts = nnexts.to(model.device)
            nlens = nlens.to(model.device)
            
            outputs, _, _ = model(prevs, lens)
            noutputs, _, _ = model(nprevs, nlens)

            loss = criterion(outputs.view(-1, outputs.shape[-1]),
                             nexts.view(-1))
            
            nloss = criterion(noutputs.view(-1, noutputs.shape[-1]),
                             nnexts.view(-1))
            
            nloss = anti_model_rnn_loss(loss)
            loss = torch.mean(loss)
            nloss = torch.mean(nloss)

            if optimizer is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                nloss.backward()
                optimizer.step()

            postfix['loss'] = loss.item() + nloss.item()
            postfix['running_loss'] += (loss.item() + nloss.item() -
                                        postfix['running_loss']) / (i + 1)
            tqdm_data.set_postfix(postfix)

        postfix['mode'] = 'Eval' if optimizer is None else 'Train'
        return postfix

    def _train(self, model, train_loader, negative_loader, val_loader=None, logger=None):
        def get_params():
            return (p for p in model.parameters() if p.requires_grad)

        device = model.device
        criterion = nn.CrossEntropyLoss(reduction='none')
        optimizer = optim.Adam(get_params(), lr=self.config.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              self.config.step_size,
                                              self.config.gamma)

        model.zero_grad()
        for epoch in range(self.config.train_epochs):

            tqdm_data = tqdm(zip(train_loader, negative_loader),
                             desc='Training (epoch #{})'.format(epoch))
            postfix = self._train_epoch(model, tqdm_data, criterion, optimizer)
            if logger is not None:
                logger.append(postfix)
                logger.save(self.config.log_file)

            if val_loader is not None:
                tqdm_data = tqdm(val_loader,
                                 desc='Validation (epoch #{})'.format(epoch))
                postfix = self._train_epoch(model, tqdm_data, criterion)
                if logger is not None:
                    logger.append(postfix)
                    logger.save(self.config.log_file)

            if (self.config.model_save is not None) and \
                    (epoch % self.config.save_frequency == 0):
                model = model.to('cpu')
                torch.save(
                    model.state_dict(),
                    self.config.model_save[:-3]+'_{0:03d}.pt'.format(epoch)
                )
                model = model.to(device)

            scheduler.step()

    def get_vocabulary(self, data):
        return CharVocab.from_data(data)

    def get_collate_fn(self, model):
        device = self.get_collate_device(model)

        def collate(data):
            data.sort(key=len, reverse=True)
            tensors = [model.string2tensor(string, device=device)
                       for string in data]

            pad = model.vocabulary.pad
            prevs = pad_sequence([t[:-1] for t in tensors],
                                 batch_first=True, padding_value=pad)
            nexts = pad_sequence([t[1:] for t in tensors],
                                 batch_first=True, padding_value=pad)
            lens = torch.tensor([len(t) - 1 for t in tensors],
                                dtype=torch.long, device=device)
            return prevs, nexts, lens

        return collate

    def fit(self, model, train_data, negative_data, val_data=None):
        logger = Logger() if self.config.log_file is not None else None

        train_loader = self.get_dataloader(model, train_data, shuffle=True)
        negative_loader = self.get_dataloader(model, negative_data, shuffle=True)
        
        val_loader = None if val_data is None else self.get_dataloader(
            model, val_data, shuffle=False
        )

        self._train(model, train_loader, negative_loader, val_loader, logger)
        return model


In [13]:
train = moses.get_dataset('train')

config = CharRNNConfig.get_config()
config.log_file = None
config.n_workers = 1
config.n_batch = 1
trainer = CharRNNTrainer(config)
crnn = CharRNN(CharVocab.from_data(train), config)
crnn.load_state_dict(torch.load('crnn._060.pt'))
with autograd.detect_anomaly():
    trainer.fit(crnn, train)

  # Remove the CWD from sys.path while we load stuff.


HBox(children=(HTML(value='Training (epoch #0)'), FloatProgress(value=0.0, max=1584663.0), HTML(value='')))

tensor([0.8037, 0.2505, 1.0000, 0.2562, 0.0651, 0.9999, 1.0000, 1.0000, 0.9992,
        0.1919, 1.0000, 0.6166, 0.9505, 0.1513, 1.0000, 0.8437, 0.7702, 0.8421,
        0.4521, 0.3031, 0.4060, 0.1638, 1.0000, 0.9986, 1.0000, 1.0000, 0.9993,
        1.0000, 0.8675, 0.0761, 0.9999, 1.0000], grad_fn=<ExpBackward>)
tensor([1.9630e-01, 7.4952e-01, 3.3842e-07, 7.4379e-01, 9.3492e-01, 6.2327e-05,
        1.0000e-07, 1.0000e-07, 7.7758e-04, 8.0806e-01, 3.7293e-05, 3.8339e-01,
        4.9469e-02, 8.4867e-01, 1.6497e-06, 1.5630e-01, 2.2979e-01, 1.5789e-01,
        5.4791e-01, 6.9692e-01, 5.9403e-01, 8.3618e-01, 1.6497e-06, 1.3759e-03,
        2.8418e-06, 1.0000e-07, 7.0689e-04, 7.7294e-06, 1.3250e-01, 9.2391e-01,
        5.9824e-05, 1.4763e-05], grad_fn=<AddBackward0>)
tensor([ -2.3488,  -0.4160, -21.4947,  -0.4270,  -0.0971, -13.9698, -23.2535,
        -23.2535, -10.3287,  -0.3075, -14.7107,  -1.3831,  -4.3373,  -0.2367,
        -19.2093,  -2.6776,  -2.1216,  -2.6630,  -0.8680,  -0.5209,  -0.751

KeyboardInterrupt: 