In [1]:
import sys
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.distributions.categorical import Categorical
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
from torch.utils.data import DataLoader

import numpy as np
import pickle as pkl

from tqdm.notebook import tqdm
import networkx as nx

from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit import RDLogger   


%matplotlib notebook
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.figure_factory as ff

from IPython.display import SVG

from collections import defaultdict
import multiprocessing as mp
from statistics import stdev, mean
from pathlib import Path
import pickle as pkl
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

from collections import namedtuple

In [2]:
NUM_CPU = 3
SEED = 3287450

In [3]:
torch.manual_seed(SEED)
torch.cuda.empty_cache()
sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})
np.set_printoptions(threshold=sys.maxsize)
RDLogger.DisableLog('rdApp.*') 

In [4]:
def save(thing, path):
    with open(path, 'wb') as fp:
        pkl.dump(thing, fp)
    print(f'saved [{type(thing)}] to {path}')

In [5]:
def load(path):
    with open(path, 'rb') as fp:
        thing = pkl.load(fp)
    return thing

# Data

In [6]:
!ls ./data/lipinski

active_lipinski.smi		 gcpn_smiles.500.smi
all_lipinski.smi		 lstm_smiles.500.smi
druglike_lipinski_100k.smi	 qed_df_full.pickle
druglike_lipinski_10k.smi	 qed_df.pickle
druglike_lipinski_1k.smi	 qed_df_small.activity.pickle
druglike_lipinski_50k.test.smi	 results
druglike_lipinski_50k.train.smi  very_active_lipinski.smi
druglike_lipinski.smi


In [7]:
GO_TOKEN = 'G'
END_TOKEN = '\n'
PAD_TOKEN = 'A'

DATA_PATH = Path('./data/lipinski/')

In [8]:
def load_smiles(path, test_split=0.15):
    all_smiles = []
    with open(path) as fp:
        for line in tqdm(fp.readlines()):
            all_smiles += [GO_TOKEN + line]
    size = len(all_smiles)
    split = int(size * test_split)
    return all_smiles[split:] ,  all_smiles[:split]

In [9]:
P1_SMILES = load_smiles(DATA_PATH / 'druglike_lipinski.smi')
P2_SMILES = load_smiles(DATA_PATH / 'active_lipinski.smi', test_split=0.1)
P3_SMILES = load_smiles(DATA_PATH / 'very_active_lipinski.smi', test_split=0.1)

HBox(children=(FloatProgress(value=0.0, max=622802.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1084.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=482.0), HTML(value='')))




In [10]:
ALL_SMILES = P1_SMILES[0] + P1_SMILES[1] + P2_SMILES[0] + P2_SMILES[1] + P3_SMILES[0] + P3_SMILES[1]
len(ALL_SMILES)

624368

In [11]:
alphabet = sorted(list(set(''.join(ALL_SMILES+[PAD_TOKEN]))))
NUM_SYM = len(alphabet)

In [12]:
SYM_TO_ID = {s:i for i,s in enumerate(alphabet)}
ID_TO_SYM = {i:s for s,i in SYM_TO_ID.items()}

In [13]:
SYM_TO_ID

{'\n': 0,
 '#': 1,
 '%': 2,
 '(': 3,
 ')': 4,
 '+': 5,
 '-': 6,
 '/': 7,
 '0': 8,
 '1': 9,
 '2': 10,
 '3': 11,
 '4': 12,
 '5': 13,
 '6': 14,
 '7': 15,
 '8': 16,
 '9': 17,
 '=': 18,
 '@': 19,
 'A': 20,
 'B': 21,
 'C': 22,
 'F': 23,
 'G': 24,
 'H': 25,
 'I': 26,
 'N': 27,
 'O': 28,
 'P': 29,
 'S': 30,
 '[': 31,
 '\\': 32,
 ']': 33,
 'c': 34,
 'l': 35,
 'n': 36,
 'o': 37,
 'p': 38,
 'r': 39,
 's': 40}

In [14]:
# MAX_ATOMS = max(map(lambda s: Chem.MolFromSmiles(s).GetNumAtoms(), D.keys()))
MAX_SYM = max(map(len, ALL_SMILES)) # GO_TOKEN and END_TOKEN already considered 

In [15]:
MAX_SYM

136

In [16]:
def encode(smiles):
    """
    encode(simles): 
        - takes in a variable length smiles string (up to MAX_SYM) 
          and outputs a fixed size vector (MAX_SYM by NUM_SYM)
        
    """
    x = np.zeros((MAX_SYM, NUM_SYM))
    x_n = len(smiles)
    for i, sym in enumerate(smiles):
        x[i, SYM_TO_ID[sym]] = 1
    x[x_n:, SYM_TO_ID[PAD_TOKEN]] = 1
    return x, x_n

In [17]:
def decode(x):
    assert x.shape[1] == NUM_SYM
    smiles = ''
    for i in range(x.shape[0]):
        topi = np.argmax(x[i, :])
        smiles += ID_TO_SYM[topi]
    return smiles 

def decode_valid(x):
    s = decode(x)
    if s[0] == 'G':
        s = s[1:]
    return  s[:s.find(END_TOKEN)]

In [18]:
decode(encode(ALL_SMILES[0])[0])

'GCn1c(SCC(=O)Nc2ccccc2)nnc1c3ccc4nonc4c3\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'

In [19]:
decode_valid(encode(ALL_SMILES[0])[0])

'Cn1c(SCC(=O)Nc2ccccc2)nnc1c3ccc4nonc4c3'

In [20]:
class SMILESDataset(torch.utils.data.Dataset):
    
    def __init__(self, smiles):
        self.all_smiles = smiles
        self.size = len(smiles)

    def __len__(self):
        return self.size
    
    def __getitem__(self, i):
        """
        __getitem__(self, i):
            - return 
        """
        x_i, x_n = encode(self.all_smiles[i])
        
        y_i = x_i[1:].copy()
        x_i = x_i[:-1].copy()
        
        return x_i, y_i, x_n
    

In [21]:
def make_phase(data, pid, lr, interval, ):
    Phase = namedtuple('Phase', ['pid','traindata', 'testdata', 'lr', 'test_interval'])
    phase = Phase(traindata=SMILESDataset(data[0]),
                  testdata=SMILESDataset(data[1]),
                  lr=lr,
                  pid=pid, # phase id
                  test_interval=interval)
    
    return phase
    

In [22]:
LR_BASE = 3e-3
P1 = make_phase(P1_SMILES, 1, 3e-3/2, 500)
P2 = make_phase(P2_SMILES, 2,3e-4, 100)
P3 = make_phase(P3_SMILES, 3, 3e-4/2, 50)

In [23]:
xi, yi, xlen = P1.traindata[4]

In [24]:
decode(xi), decode(yi)

('GCOc1cc(O)c2C(=O)O[C@@H](C)CCC[C@@H](O)[C@@H](O)[C@@H](O)C\\C=C\\c2c1\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA',
 'COc1cc(O)c2C(=O)O[C@@H](C)CCC[C@@H](O)[C@@H](O)[C@@H](O)C\\C=C\\c2c1\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA')

In [25]:
def count_valid(smiles):
    """
    count_valid(smiles):
        - returns % of valid smiles
    """
    invalid = 0
    
    for m in map(Chem.MolFromSmiles, smiles):
        invalid += m is None
        
    return 1 - (invalid/len(smiles))

# Model

In [26]:
# plain LSTM model
class GeneratorLSTM(nn.Module):
    def __init__(self, input_size, output_size, num_layers, hidden_size, embedding_size, bidirectional):
        super().__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        
        self.lstm = nn.LSTM(input_size=embedding_size, 
                            hidden_size=hidden_size, 
                            num_layers=num_layers, 
                            batch_first=True,
                            dropout=0.15,
                            bidirectional=bidirectional)
        
        self.input_module = nn.Sequential(nn.Linear(input_size, 256),
                                         nn.ReLU(),
                                         nn.Dropout(0.10),
                                         nn.Linear(256, 512),
                                         nn.ReLU(),
                                         nn.Dropout(0.10),
                                         nn.Linear(512, embedding_size),
                                         nn.ReLU())
        lstm_output_shape = hidden_size if not bidirectional else hidden_size * 2
        self.output_module = nn.Sequential(nn.Linear(lstm_output_shape , 256),
                                           nn.Dropout(0.10),
                                           nn.ReLU(),
                                           nn.Linear(256, output_size))

        self.hidden = None 

    def init_hidden(self, batch_size, cuda=True):
        if self.bidirectional:
            ht = Variable(torch.zeros(self.num_layers*2, batch_size, self.hidden_size))
            ct = Variable(torch.zeros(self.num_layers*2, batch_size, self.hidden_size))
        else:
            ht = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))
            ct = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))
            

        if cuda:
            return ht.cuda(), ct.cuda()
        else:
            return ht, ct
            


    def forward(self, input_raw, pack=False, input_lens=None):
        """
        forward(self, input_raw, state):
             - input_raw = (bs, x_{i,t})
             - state = (ht, ct)
        """
        input_ = self.input_module(input_raw)
        
        if pack:
            input_ = pack_padded_sequence(input_, input_lens, batch_first=True)
            
        input_, self.hidden = self.lstm(input_, self.hidden)
        
        if pack: 
            input_ = pad_packed_sequence(input_, batch_first=True)[0]

            
        input_ = self.output_module(input_)
        return input_

In [27]:
model = GeneratorLSTM(input_size = NUM_SYM, 
                      output_size = NUM_SYM, 
                      num_layers = 3, 
                      hidden_size = 512, 
                      embedding_size = 512,
                      bidirectional=False)

In [28]:
model.cuda()

GeneratorLSTM(
  (lstm): LSTM(512, 512, num_layers=3, batch_first=True, dropout=0.15)
  (input_module): Sequential(
    (0): Linear(in_features=41, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
  )
  (output_module): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=41, bias=True)
  )
)

In [29]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

GeneratorLSTM(
  (lstm): LSTM(512, 512, num_layers=3, batch_first=True, dropout=0.15)
  (input_module): Sequential(
    (0): Linear(in_features=41, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
  )
  (output_module): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=41, bias=True)
  )
)

In [30]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'model has [{count_parameters(model):,}] trainable params')

model has [6,850,601] trainable params


## Train

In [31]:
def train_one_epoch(model, model_optim, lr_sched,
                    train_dataloader, test_dataloader, 
                    test_interval, save_path, epoch):
    
    criterion = nn.NLLLoss()
    loss_history = []
    total = len(train_dataloader)
    
    test_loss_history = defaultdict(list)
    new_smiles = defaultdict(list)
    
    
    for batch_idx, batch in tqdm(enumerate(train_dataloader), total=total):
        
        if batch_idx % test_interval == 0:
            new_smiles[batch_idx] = generate(model,)
            
            smiles_valid_batch = generate(model, test_samples=100, sample_f=softmax_temp_sample, v=False)
            validity = count_valid(smiles_valid_batch)
            
            test_loss_history[batch_idx] = test(model, test_dataloader)
            # we can test later
            
            print(f'[batch num: {batch_idx}] sampled; [{validity:.3f}%] valid smiles sampled')
            
            if batch_idx == 0:
                print_loss = 'nan'
            else:
                print_loss = f'{loss.data.item():.4f}'

            
            path = save_path / Path(f'{epoch}-epoch-{batch_idx}-batch-{print_loss}-loss')
            path.mkdir(exist_ok=True)
            
            torch.save(model.state_dict(), str(path / 'model_dict.torch'))
            save(new_smiles[batch_idx], str(path / 'generated_smiles.list' ))
            save(test_loss_history[batch_idx] , str(path / 'test_loss.item' ))
        
        model.train()
        model.zero_grad()
    
        x_batch, y_batch, batch_lens = batch
        
        batch_size = x_batch.size(0)
        max_len = int(max(batch_lens).item())
        
        x_batch = x_batch[:, 0:max_len, :]
        y_batch = y_batch[:, 0:max_len, :]

        # sort input
        batch_len_sorted, sort_index = torch.sort(batch_lens, 0, descending=True)
        batch_len_sorted = batch_len_sorted.numpy().tolist()
        
        x_batch = torch.index_select(x_batch, 0, sort_index)
        y_batch = torch.index_select(y_batch, 0, sort_index)

        x_batch = Variable(x_batch.float()).cuda()
        y_batch = Variable(y_batch.float()).cuda()

        # init state
        model.hidden = model.init_hidden(batch_size=x_batch.size(0))
        try:
            y_pred = model(x_batch, pack=True, input_lens=batch_len_sorted)
        except Exception as e:
            print(f'[ERROR] got exception {e}')
            print(f'[ERROR] skipping batch...')
            continue
        
        y_pred = F.log_softmax(y_pred.view(-1, NUM_SYM), dim=-1)
        _, y_batch = y_batch.topk(1, dim=-1)
        y_batch = y_batch.view(-1)
        loss = criterion(y_pred, y_batch)
        
        loss.backward()
        model_optim.step()
        lr_sched.step()

        
        if batch_idx % 10 == 0 :
            print(f'[batch num: {batch_idx}] loss: {loss.data.item():.4f}')
            
    path = save_path / Path(f'{epoch}-epoch-END-batch')
    path.mkdir(exist_ok=True)

    torch.save(model.state_dict(), str(path / 'model_dict.torch'))

                
    return loss_history, test_loss_history, new_smiles

In [32]:
def make_run_id():
    return  np.random.choice(['cheetah', 'jaguar', 'wombat', 'cobra', 
                               'croc', 'panda', 'dragon', 'seal', 'spider', 'lizard',
                               'gorilla', 'koala', 'blackbear', 'grizzly', 'zebra',
                               'hippo'])

In [33]:
make_run_id()

'dragon'

In [34]:
def transfer_learning(model, phases, epochs, batch_size=32, run_id = None):
    """
    transfer_learning(model, phases):
        - model: pytorch model
        - phases: [Datasets] for datapaths
    """
    history = defaultdict(dict)
    
    run_id = run_id if run_id is not None else make_run_id()
    
    results_dir = DATA_PATH / 'results' / f'{run_id}-run'
    results_dir.mkdir(exist_ok=False)
    save(SYM_TO_ID, results_dir / 'SYM_TO_ID')
    save(ID_TO_SYM, results_dir / 'ID_TO_SYM')

    print(f'[TRANSFER_LEARNER] saving all progress to [{results_dir}]')
    
    for idx, phase in enumerate(phases):
        print(f'[TRANSFER_LEARNER] STARTING PHASE [{phase.pid}]')
        
        dataloader = DataLoader(phase.traindata, 
                                batch_size=batch_size, 
                                num_workers=NUM_CPU, 
                                shuffle=True)
        
        testloader = DataLoader(phase.testdata, 
                                batch_size=1000, 
                                num_workers=NUM_CPU, 
                                shuffle=True)
        
        model_optim = optim.Adam(list(model.parameters()), lr=phase.lr)
        lr_sched = CosineAnnealingLR(model_optim, 500)
        
        phase_dir = results_dir / f'phase-{phase.pid}'
        phase_dir.mkdir(exist_ok=True)
        
        for epoch_idx in range(epochs[idx]):
            print(f'[TRANSFER_LEARNER] PHASE [{phase.pid}] EPOCH [{epoch_idx}]')
            history[phase.pid][epoch_idx] = train_one_epoch(model, model_optim, lr_sched,
                                                   dataloader, testloader,
                                                   phase.test_interval, phase_dir, 
                                                   epoch_idx)
            
    return history
            
            
            
    

## Testing and Generation

In [35]:
def softmax_temp_sample(y_t, temperature = 1.0):
    prediction_vector = F.softmax(y_t / temperature, dim=-1)
    x_index_t = torch.multinomial(prediction_vector, 1)[:, 0]
    return x_index_t

In [36]:
def topk_sample(y_t):
    _, pred_idx =  y_t.topk(1, dim=-1)
    return pred_idx

In [37]:
def generate(model, test_samples=5, sample_f=softmax_temp_sample, v=True):
    with torch.no_grad():
        model.eval()

        x = torch.zeros(test_samples, MAX_SYM, NUM_SYM).cuda()

        x[:, 0, SYM_TO_ID[GO_TOKEN]] = 1
        
        for i in range(MAX_SYM-1):
            model.hidden = model.init_hidden(batch_size=test_samples)
            
            pred = model(x, pack=True, input_lens=np.ones(test_samples)*(i+1))
            pred_idx = sample_f(pred[:,i,:])
            temp = torch.zeros(test_samples, MAX_SYM, NUM_SYM).cuda()
            for j in range(test_samples):
                temp[j, i+1, pred_idx[j]] = 1  
            x.add_(temp)

        if v: print('\n',10*'-' + 'GENERATED SMILES STRINGS' + 10*'-')
        smiles = []

        for j in range(test_samples):
            s = decode_valid(x[j].cpu().numpy())
            smiles += [s]

            if v: print(s)

        return smiles
        
    

In [38]:
def test(model, testloader):
    with torch.no_grad():

        model.eval()

        criterion = nn.NLLLoss()
        total_loss = 0

        for batch in tqdm(testloader,total=len(testloader)):

            x_batch, y_batch, batch_lens = batch

            batch_size = x_batch.size(0)
            max_len = int(max(batch_lens).item())

            x_batch = x_batch[:, 0:max_len, :]
            y_batch = y_batch[:, 0:max_len, :]

            # sort input
            batch_len_sorted, sort_index = torch.sort(batch_lens, 0, descending=True)
            batch_len_sorted = batch_len_sorted.numpy().tolist()

            x_batch = torch.index_select(x_batch, 0, sort_index)
            y_batch = torch.index_select(y_batch, 0, sort_index)

            x_batch = Variable(x_batch.float()).cuda()
            y_batch = Variable(y_batch.float()).cuda()

            # init state
            model.hidden = model.init_hidden(batch_size=x_batch.size(0))

            y_pred = model(x_batch, pack=True, input_lens=batch_len_sorted)

            y_pred = y_pred.view(-1, NUM_SYM)
            _, y_batch = y_batch.topk(1, dim=-1)
            y_batch = y_batch.view(-1)

            total_loss += criterion(y_pred, y_batch)

        return total_loss.data.item() / len(testloader)
    

## Results

In [None]:
history = transfer_learning(model, phases=[P1, P2, P3], epochs=[2, 2, 2])

saved [<class 'dict'>] to data/lipinski/results/zebra-run/SYM_TO_ID
saved [<class 'dict'>] to data/lipinski/results/zebra-run/ID_TO_SYM
[TRANSFER_LEARNER] saving all progress to [data/lipinski/results/zebra-run]
[TRANSFER_LEARNER] STARTING PHASE [1]
[TRANSFER_LEARNER] PHASE [1] EPOCH [0]


HBox(children=(FloatProgress(value=0.0, max=16544.0), HTML(value='')))


 ----------GENERATED SMILES STRINGS----------
\FoHs9rC5[0o=op49%o7Oo2(=@rpC#P7=4cn5p+2pcG9GPN/pO1sG[Fl2rs\/=ls(O01l]-1@F3I#\B[[]\8onOpoS4=0]3)GA@7[s[072@B=S6rppB1r-+7p03[SOo0HAsBF[
5c66I4r5]noAA4lPGC[H]-2-7537=4-0\n96OB%c)=1o@OG\G@)s7PApG@C//75\+sCFS\l1=C9(4p5N-F8N%@3Ap(Sr5[)5#p@%p
84o2O4B----6n]A)+#S
7BPp=53rA[/46]8//Bp8//0(551#A/%S0]N4PG89\
8Ir33F)(G)76)8oBG)F\2lFO/F+r


HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 0] sampled; [0.050%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-0-batch-nan-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-0-batch-nan-loss/test_loss.item
[batch num: 0] loss: 3.7025
[batch num: 10] loss: 3.0124
[batch num: 20] loss: 2.8118
[batch num: 30] loss: 2.6824
[batch num: 40] loss: 2.5676
[batch num: 50] loss: 2.3619
[batch num: 60] loss: 1.9394
[batch num: 70] loss: 1.9654
[batch num: 80] loss: 1.3544
[batch num: 90] loss: 1.2640
[batch num: 100] loss: 1.4756
[batch num: 110] loss: 1.4155
[batch num: 120] loss: 0.9922
[batch num: 130] loss: 1.2545
[batch num: 140] loss: 1.2949
[batch num: 150] loss: 1.2264
[batch num: 160] loss: 1.0949
[batch num: 170] loss: 1.0276
[batch num: 180] loss: 1.2213
[batch num: 190] loss: 1.0423
[batch num: 200] loss: 0.7628
[batch num: 210] loss: 0.7687
[batch num: 220] loss: 1.0216
[batch num: 230] loss: 0.9241
[batch num: 240]

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 500] sampled; [0.000%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-500-batch-0.9426-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-500-batch-0.9426-loss/test_loss.item
[batch num: 500] loss: 0.9236
[batch num: 510] loss: 0.7213
[batch num: 520] loss: 0.9427
[batch num: 530] loss: 0.8451
[batch num: 540] loss: 0.7914
[batch num: 550] loss: 0.8685
[batch num: 560] loss: 0.7428
[batch num: 570] loss: 0.8778
[batch num: 580] loss: 0.8734
[batch num: 590] loss: 0.9460
[batch num: 600] loss: 0.7148
[batch num: 610] loss: 0.8984
[batch num: 620] loss: 0.7510
[batch num: 630] loss: 0.9479
[batch num: 640] loss: 0.6416
[batch num: 650] loss: 0.9046
[batch num: 660] loss: 0.9202
[batch num: 670] loss: 0.7619
[batch num: 680] loss: 1.0050
[batch num: 690] loss: 0.7953
[batch num: 700] loss: 0.9173
[batch num: 710] loss: 0.6275
[batch num: 720] loss: 0.7771
[batch num: 730] loss: 

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 1000] sampled; [0.010%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-1000-batch-0.8034-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-1000-batch-0.8034-loss/test_loss.item
[batch num: 1000] loss: 0.8474
[batch num: 1010] loss: 0.7972
[batch num: 1020] loss: 0.7590
[batch num: 1030] loss: 0.7705
[batch num: 1040] loss: 0.6459
[batch num: 1050] loss: 0.8453
[batch num: 1060] loss: 0.8374
[batch num: 1070] loss: 0.6984
[batch num: 1080] loss: 0.6287
[batch num: 1090] loss: 0.7114
[batch num: 1100] loss: 0.7659
[batch num: 1110] loss: 0.6996
[batch num: 1120] loss: 0.7093
[batch num: 1130] loss: 0.7525
[batch num: 1140] loss: 0.7357
[batch num: 1150] loss: 0.8469
[batch num: 1160] loss: 0.6841
[batch num: 1170] loss: 0.7275
[batch num: 1180] loss: 0.5487
[batch num: 1190] loss: 0.6366
[batch num: 1200] loss: 0.5407
[batch num: 1210] loss: 0.4991
[batch num: 1220] loss: 0.54

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 1500] sampled; [0.080%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-1500-batch-0.5716-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-1500-batch-0.5716-loss/test_loss.item
[batch num: 1500] loss: 0.6623
[batch num: 1510] loss: 0.5997
[batch num: 1520] loss: 0.4185
[batch num: 1530] loss: 0.6382
[batch num: 1540] loss: 0.6415
[batch num: 1550] loss: 0.4935
[batch num: 1560] loss: 0.6518
[batch num: 1570] loss: 0.5776
[batch num: 1580] loss: 0.6263
[batch num: 1590] loss: 0.6716
[batch num: 1600] loss: 0.5616
[batch num: 1610] loss: 0.4897
[batch num: 1620] loss: 0.6189
[batch num: 1630] loss: 0.6226
[batch num: 1640] loss: 0.5493
[batch num: 1650] loss: 0.5888
[batch num: 1660] loss: 0.6481
[batch num: 1670] loss: 0.5458
[batch num: 1680] loss: 0.5704
[batch num: 1690] loss: 0.6662
[batch num: 1700] loss: 0.6174
[batch num: 1710] loss: 0.6706
[batch num: 1720] loss: 0.67

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 2000] sampled; [0.130%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-2000-batch-0.6841-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-2000-batch-0.6841-loss/test_loss.item
[batch num: 2000] loss: 0.5795
[batch num: 2010] loss: 0.6549
[batch num: 2020] loss: 0.5868
[batch num: 2030] loss: 0.4524
[batch num: 2040] loss: 0.6880
[batch num: 2050] loss: 0.6145

[batch num: 2500] sampled; [0.260%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-2500-batch-0.4767-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-2500-batch-0.4767-loss/test_loss.item
[batch num: 2500] loss: 0.5374
[batch num: 2510] loss: 0.5921
[batch num: 2520] loss: 0.5049
[batch num: 2530] loss: 0.4835
[batch num: 2540] loss: 0.4902
[batch num: 2550] loss: 0.4847
[batch num: 2560] loss: 0.5465
[batch num: 2570]

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 3000] sampled; [0.310%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-3000-batch-0.4448-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-3000-batch-0.4448-loss/test_loss.item
[batch num: 3000] loss: 0.6637
[batch num: 3010] loss: 0.4952
[batch num: 3020] loss: 0.5398
[batch num: 3030] loss: 0.5140
[batch num: 3040] loss: 0.5381
[batch num: 3050] loss: 0.6028
[batch num: 3060] loss: 0.3804
[batch num: 3070] loss: 0.5301
[batch num: 3080] loss: 0.4654
[batch num: 3090] loss: 0.4530
[batch num: 3100] loss: 0.5403
[batch num: 3110] loss: 0.6468
[batch num: 3120] loss: 0.5510
[batch num: 3130] loss: 0.5131
[batch num: 3140] loss: 0.4732
[batch num: 3150] loss: 0.5384
[batch num: 3160] loss: 0.3777
[batch num: 3170] loss: 0.5057
[batch num: 3180] loss: 0.6362
[batch num: 3190] loss: 0.5927
[batch num: 3200] loss: 0.5909
[batch num: 3210] loss: 0.5481
[batch num: 3220] loss: 0.50

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 3500] sampled; [0.420%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-3500-batch-0.4965-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-3500-batch-0.4965-loss/test_loss.item
[batch num: 3500] loss: 0.3804
[batch num: 3510] loss: 0.4142
[batch num: 3520] loss: 0.6229
[batch num: 3530] loss: 0.4767
[batch num: 3540] loss: 0.5589
[batch num: 3550] loss: 0.4675
[batch num: 3560] loss: 0.5100
[batch num: 3570] loss: 0.5733
[batch num: 3580] loss: 0.4524
[batch num: 3590] loss: 0.4806
[batch num: 3600] loss: 0.5092
[batch num: 3610] loss: 0.5473


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




[batch num: 5500] sampled; [0.590%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-5500-batch-0.4805-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-5500-batch-0.4805-loss/test_loss.item
[batch num: 5500] loss: 0.3295
[batch num: 5510] loss: 0.4641
[batch num: 5520] loss: 0.4435
[batch num: 5530] loss: 0.4942
[batch num: 5540] loss: 0.5344
[batch num: 5550] loss: 0.6026
[batch num: 5560] loss: 0.4049
[batch num: 5570] loss: 0.5411
[batch num: 5580] loss: 0.4610
[batch num: 5590] loss: 0.5416
[batch num: 5600] loss: 0.4629
[batch num: 5610] loss: 0.5062
[batch num: 5620] loss: 0.3946
[batch num: 5630] loss: 0.5387
[batch num: 5640] loss: 0.5485
[batch num: 5650] loss: 0.4739
[batch num: 5660] loss: 0.5222
[batch num: 5670] loss: 0.5441
[batch num: 5680] loss: 0.4798
[batch num: 5690] loss: 0.5574
[batch num: 5700] loss: 0.4938
[batch num: 5710] loss: 0.4107
[batch num: 5720] loss: 0.43

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 6000] sampled; [0.590%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-6000-batch-0.4984-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-6000-batch-0.4984-loss/test_loss.item
[batch num: 6000] loss: 0.5503
[batch num: 6010] loss: 0.4409
[batch num: 6020] loss: 0.4949
[batch num: 6030] loss: 0.3744
[batch num: 6040] loss: 0.3784
[batch num: 6050] loss: 0.5230
[batch num: 6060] loss: 0.5164
[batch num: 6070] loss: 0.4509
[batch num: 6080] loss: 0.4778
[batch num: 6090] loss: 0.4912
[batch num: 6100] loss: 0.5798
[batch num: 6110] loss: 0.3188
[batch num: 6120] loss: 0.4531
[batch num: 6130] loss: 0.4343
[batch num: 6140] loss: 0.4416
[batch num: 6150] loss: 0.4722
[batch num: 6160] loss: 0.5055
[batch num: 6170] loss: 0.5568
[batch num: 6180] loss: 0.4600
[batch num: 6190] loss: 0.3782
[batch num: 6200] loss: 0.4783
[batch num: 6210] loss: 0.4660
[batch num: 6220] loss: 0.39

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 6500] sampled; [0.600%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-6500-batch-0.4577-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-6500-batch-0.4577-loss/test_loss.item
[batch num: 6500] loss: 0.5199
[batch num: 6510] loss: 0.4046
[batch num: 6520] loss: 0.3975
[batch num: 6530] loss: 0.5062
[batch num: 6540] loss: 0.5050
[batch num: 6550] loss: 0.4368
[batch num: 6560] loss: 0.3573
[batch num: 6570] loss: 0.3663
[batch num: 6580] loss: 0.3607
[batch num: 6590] loss: 0.4704
[batch num: 6600] loss: 0.5523
[batch num: 6610] loss: 0.4669
[batch num: 6620] loss: 0.3023
[batch num: 6630] loss: 0.4851
[batch num: 6640] loss: 0.4309
[batch num: 6650] loss: 0.4222
[batch num: 6660] loss: 0.5142
[batch num: 6670] loss: 0.4086
[batch num: 6680] loss: 0.4643


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[batch num: 8440] loss: 0.4876
[batch num: 8450] loss: 0.4251
[batch num: 8460] loss: 0.3993
[batch num: 8470] loss: 0.4616
[batch num: 8480] loss: 0.4740
[batch num: 8490] loss: 0.4419

 ----------GENERATED SMILES STRINGS----------
Cc1nnc[NS(=O)(=O)c2cc(F)cc(c2)c3ccc(cc3)C(F)(F)F)n1C(C)C4CCNC4=O
CCCCN1C(=O)NC(=O)c2cc(NC=N)c(C(=O)c3ccc(Cl)cc3)cc12
C[C@H]1COCCN1S(=O)(=O)c2nnc(NC(=O)c3cc4CC(C)CCn4n3)n2
C[C@H]1CC[C@H]2OC(=O)[C@@]34CCCC[C@]5(C)[C@]3(CCN3CCC4)CC[C@H]4CC[C@]2(C)C
Nc1cnc(nc1NC2CC(=O)NC2)c3nc4ccc(COc5ccc(F)cc5F)c4C(=O)c34


HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 8500] sampled; [0.730%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-8500-batch-0.4463-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-8500-batch-0.4463-loss/test_loss.item
[batch num: 8500] loss: 0.4904
[batch num: 8510] loss: 0.5157
[batch num: 8520] loss: 0.4327
[batch num: 8530] loss: 0.3678
[batch num: 8540] loss: 0.4531
[batch num: 8550] loss: 0.4099
[batch num: 8560] loss: 0.3441
[batch num: 8570] loss: 0.4668
[batch num: 8580] loss: 0.4635
[batch num: 8590] loss: 0.4113
[batch num: 8600] loss: 0.4396
[batch num: 8610] loss: 0.3923
[batch num: 8620] loss: 0.4453
[batch num: 8630] loss: 0.4750
[batch num: 8640] loss: 0.3430
[batch num: 8650] loss: 0.4433
[batch num: 8660] loss: 0.4160
[batch num: 8670] loss: 0.3586
[batch num: 8680] loss: 0.5194
[batch num: 8690] loss: 0.4560
[batch num: 8700] loss: 0.4477
[batch num: 8710] loss: 0.4249
[batch num: 8720] loss: 0.33

HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


[batch num: 9000] sampled; [0.650%] valid smiles sampled
saved [<class 'list'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-9000-batch-0.4953-loss/generated_smiles.list
saved [<class 'float'>] to data/lipinski/results/zebra-run/phase-1/0-epoch-9000-batch-0.4953-loss/test_loss.item
[batch num: 9000] loss: 0.4453
[batch num: 9010] loss: 0.4814
[batch num: 9020] loss: 0.4705
[batch num: 9030] loss: 0.4758
[batch num: 9040] loss: 0.4147
[batch num: 9050] loss: 0.4416
[batch num: 9060] loss: 0.4985
[batch num: 9070] loss: 0.5556
[batch num: 9080] loss: 0.4795
[batch num: 9090] loss: 0.4606
[batch num: 9100] loss: 0.4838
[batch num: 9110] loss: 0.4323
[batch num: 9120] loss: 0.4906
[batch num: 9130] loss: 0.4657
[batch num: 9140] loss: 0.4973
[batch num: 9150] loss: 0.4578
[batch num: 9160] loss: 0.3402


In [None]:
def plot_history(history, l=-1):
    if l < 0:
        l = len(history)
    fig = px.line(x=np.arange(l), y=history[:l], labels={'x':'batch number', 'y':'binary cross-entropy loss'})
    return fig

In [None]:
history[2]

In [None]:
!ls data/lipinski/results/panda-run/phase-0/0-epoch-3500-batch-0.4833-loss

In [None]:
load('data/lipinski/results/panda-run/phase-0/0-epoch-3500-batch-0.4833-loss/generated_smiles.dict')

In [None]:
load_path = './data/lipinski/results/panda-run/phase-0/0-epoch-3500-batch-0.4833-loss/model_dict.torch'

In [None]:
model.load_state_dict(torch.load(load_path))

In [None]:
newsmiles = generate(model, test_samples=500, sample_f=softmax_temp_sample)
validity = count_valid(newsmiles)
print(f'validity is at [{validity:.2f}%]')

In [None]:
transfer_learning(model, [P2, P3], [2, 2] )