In [1]:
import sys
sys.path.append('../')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models import *
from mrl.agent import *
from mrl.policy_gradient import PolicyGradient, TRPO, PPO
from mrl.environment import *

  return f(*args, **kwds)


In [2]:
os.environ['ncpus'] = '0'
os.environ['DEFAULT_GPU'] = '1'

In [3]:
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = TextDataset(['CCC'], vocab)

d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
lstm_drop = 0.
lin_drop = 0.
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True

lm_model = LSTM_LM(d_vocab, d_embedding, d_hidden, n_layers,
                lstm_drop, lin_drop, bos_idx, bidir, tie_weights)

lm_model.load_state_dict(torch.load('../nbs/untracked_files/lstm_lm_small.pt'))

loss = CrossEntropy()

In [4]:
agent = GenerativeAgent(lm_model, vocab, loss, ds, opt_kwargs={'lr':1e-4},
                       base_update=0.97)

In [5]:
def scale_sa(sa):
    return (10-sa)/9

In [6]:
template = Template([ValidityFilter(), SingleCompoundFilter()],
                    [QEDFilter(0.5, None, score=1.),
                     SAFilter(None, 3, score=1.)], 
                    fail_score=-1.)

In [7]:
class UpdateBaselineCB(Callback):
    def __init__(self, agent, iters, name):
        super().__init__(order=10, name=name)
        self.agent = agent
        self.iters = iters
        self.num_updates = 0
        
    def after_batch(self):
        log = self.environment.log
        iterations = log.iterations
        if iterations%self.iters == 0 and iterations>0:
            self.agent.update_base_models()
            self.num_updates += 1
            
            
class StatsCallback(Callback):
    def __init__(self, grabname, name, order):
        self.grabname = grabname
        self.name = name
        self.order = order
        
    def setup(self):
        log = self.environment.log
        log.add_metric(f'{self.grabname}_p90')
        log.add_metric(f'{self.grabname}_max')
        
    def after_compute_reward(self):
        log = self.environment.log
        state = self.environment.batch_state
        rewards = state.rewards.detach().cpu().numpy()
        sources = np.array(state.sources)
        
        rewards = rewards[sources==self.grabname]
        
        log.update_metric(f'{self.grabname}_p90', np.percentile(rewards, 90))
        log.update_metric(f'{self.grabname}_max', rewards.max())

class Rollback(Callback):
    def __init__(self, model1, model2, metric, lookback, target, alpha, name):
        super().__init__(order=10, name=name)
        self.model1 = model1
        self.model2 = model2
        self.metric = metric
        self.lookback = lookback
        self.target = target
        self.alpha = alpha
        self.last_rollback = 0
        
    def after_batch(self):
        env = self.environment
        log = env.log
        current_val = np.array(log.metrics[self.metric][-self.lookback:]).mean()
    
        if current_val < self.target and self.last_rollback <= 0:
            print('rollback')
            merge_models(self.model1, self.model2, alpha=self.alpha)
            self.last_rollback = self.lookback
            
        self.last_rollback -= 1
        
class SupevisedCB(Callback):
    def __init__(self, agent, frequency, base_update, percentile, lr, bs):
        super().__init__('supervised', order=1000)
        self.agent = agent
        self.frequency = frequency
        self.base_update = base_update
        self.percentile = percentile
        self.lr = lr
        self.bs = bs
        
    def after_batch(self):
        env = self.environment
        iterations = self.environment.log.iterations
        
        if iterations>0 and iterations%self.frequency==0:
            df = log_to_df(env.log.log, ['samples', 'rewards'])
            df.drop_duplicates(subset='samples', inplace=True)
            df = df[df.rewards>np.percentile(df.rewards.values, self.percentile)]
            
            self.agent.update_dataset_from_inputs(df.samples.values)
            self.agent.train_supervised(self.bs, 1, self.lr)
            
            merge_models(self.agent.base_model, self.agent.model, alpha=self.base_update)
            


class LogSampler(Sampler):
    def __init__(self, sample_name, start_iter, percentile, p_buffer=0.):
        super().__init__(sample_name+'_sample', p_buffer, p_batch=0.)
        self.start_iter = start_iter
        self.percentile = percentile
        self.sample_name = sample_name
        
    def build_buffer(self):
        env = self.environment
        
        iterations = self.environment.log.iterations

        if iterations > self.start_iter:
            df = log_to_df(env.log.log, ['samples', self.sample_name])
            df.drop_duplicates(subset='samples', inplace=True)
            bs = int(env.buffer_size * self.p_buffer)
            if bs > 0:
                
                subset = df[df[self.sample_name]>np.percentile(df[self.sample_name].values, 
                                                               self.percentile)]
                outputs = list(subset.sample(n=min(bs, subset.shape[0])).samples.values)
                env.buffer.add(outputs)
                    
class LogEnumerator(Sampler):
    def __init__(self, sample_name, start_iter, n_samp):
        super().__init__(sample_name+'_enum', p_buffer=0., p_batch=0.)
        self.start_iter = start_iter
        self.n_samp = n_samp
        self.atom_types = ['C', 'N', 'O', 'F', 'S', 'Br', 'Cl', -1]
        self.sample_name = sample_name
        
    def build_buffer(self):
        
        env = self.environment
        
        iterations = self.environment.log.iterations

        if iterations > self.start_iter:
            df = log_to_df(env.log.log, ['samples', self.sample_name])
            df.drop_duplicates(subset='samples', inplace=True)
            bs = int(env.buffer_size * self.p_buffer)
            if bs > 0:
                
                subset = df[df[self.sample_name]>np.percentile(df[self.sample_name].values, 
                                                               self.percentile)]
                sample = list(subset.sample(n=min(bs, subset.shape[0])).samples.values)
                outputs = []
                
                for s in samples:
                    new_smiles = add_atom_combi(s, self.atom_types) + add_bond_combi(s)
                    new_smiles = [i for i in new_smiles if i is not None]
                    new_smiles = [i for i in smiles if not '.' in i]
                    outputs += new_smiles
                        
                env.buffer.add(outputs)
                    
class DatasetSampler(Sampler):
    def __init__(self, n_samples, samples, name):
        super().__init__(name, 0., 0.)
        self.n_samples = n_samples
        self.samples = samples
        
    def build_buffer(self):
        idxs = np.random.randint(0, len(self.samples), self.n_samples)
        samples = [self.samples[i] for i in idxs]
        self.environment.buffer.add(samples)

In [8]:
def log_to_df(log, keys=None):
    batch = 0
    output_dict = defaultdict(list)
    
    if keys is None:
        keys = list(log.keys())
    
    items = log[keys[0]]
    for item in items:
        output_dict['batch'] += [batch]*len(item)
        batch += 1
        
    for key in keys:
        output_dict[key] = flatten_list_of_lists(log[key])
        
    return pd.DataFrame(output_dict)

In [9]:
# update_cb = UpdateBaselineCB(agent, 5, 'base_update')
# stat_cb = StatsCallback('live', 'stat_cb', 10)
# roll_cb = Rollback(agent.model, agent.base_model, 'live_valid', 5, 0.4, 0.5, 'rb')
# reward_cb = LogSampler('rewards', 10, 92, p_buffer=0.03)
enum_cb = LogEnumerator('rewards', 10, 5)
train_cb = SupevisedCB(agent, 10, 1., 90, 5e-5, 64)

In [10]:
# cbs = [update_cb, stat_cb, roll_cb, reward_cb, enum_cb, train_cb]
cbs = [enum_cb, train_cb]

In [11]:
class PredReward(Callback):
    def __init__(self, name, agent, weight=1.):
        super().__init__(order=1)
        self.name = name
        self.weight = weight
        self.agent = agent
        
    def setup(self):
        log = self.environment.log
        log.add_metric(self.name)
        log.add_log(self.name)
        
    def compute_reward(self):
        env = self.environment
        samples = self.batch_state.samples
        with torch.no_grad():
            preds = self.agent.predict_data(samples).squeeze()
        reward = -preds * self.weight
        
        env.log.update_metric(self.name, reward.mean().detach().cpu().numpy())
        self.batch_state.rewards += reward
        self.batch_state[self.name] = reward
        

In [12]:
df = pd.read_csv('../nbs/untracked_files/affinity_data.csv')

r_model = MLP_Encoder(2048, [1024, 512, 256, 128], 1, [0.2, 0.2, 0.2, 0.2])

r_ds = Vec_Prediction_Dataset(df.smiles.values, df.value.values/10, ECFP6)

r_agent = PredictiveAgent(r_model, MSELoss(), r_ds, opt_kwargs={'lr':1e-3})

r_agent.load_weights('../nbs/untracked_files/aff_pred.pt')

r_model.eval();

pred_cb = PredReward('aff', r_agent, weight=10.)

In [13]:
gen_bs = 800
bs = 1000

In [14]:
# agent_cb = GenAgentCallback(agent, 'generative')
agent_cb = AgentCallback(agent, 'generative')

In [15]:
sampler2 = ModelSampler(agent, agent.model, 'live', 0.5, 0., gen_bs, latent=False)
sampler3 = ModelSampler(agent, agent.base_model, 'base', 0.5, 0., gen_bs)

samplers = [sampler2, sampler3]

In [16]:
env = Environment(agent_cb, template, samplers=samplers, reward_cbs=[pred_cb], loss_cbs=[],
                 cbs=cbs)

In [17]:
%%time
env.fit(bs, 90, 600, 4000, 5)

iterations,rewards,mean_reward,valid,diversity,template,aff
0,-7.155,-7.155,1.0,1.0,1.32,-8.475
5,-7.234,-7.158,1.0,1.0,1.326,-8.56
10,-7.073,-7.142,1.0,1.0,1.33,-8.403
15,-7.097,-7.137,1.0,1.0,1.327,-8.424
20,-7.049,-7.121,1.0,1.0,1.351,-8.4
25,-6.974,-7.1,1.0,1.0,1.368,-8.342
30,-7.006,-7.097,1.0,0.999,1.354,-8.36
35,-6.907,-7.058,1.0,0.995,1.346,-8.253
40,-7.02,-7.05,1.0,0.999,1.327,-8.347
45,-6.939,-7.029,1.0,1.0,1.346,-8.285


Epoch,Train Loss,Valid Loss,Time
0,0.32897,0.31814,00:00


Epoch,Train Loss,Valid Loss,Time
0,0.32649,0.33895,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.31553,0.31251,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.35979,0.39169,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.34398,0.34404,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.33802,0.31868,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.32477,0.30576,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.3338,0.33674,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.45444,0.28137,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.29689,0.28591,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.33136,0.27401,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.31221,0.34509,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.30612,0.31199,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.2751,0.31738,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.32258,0.32529,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.28605,0.28617,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.27144,0.31577,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.30265,0.29722,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.37908,0.24814,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.25639,0.28841,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.2845,0.29473,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.29813,0.28022,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.29614,0.31097,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.31067,0.31311,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.29558,0.36611,00:01


Epoch,Train Loss,Valid Loss,Time
0,0.2858,0.30857,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.31701,0.33591,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.31836,0.33173,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.30855,0.31254,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.24547,0.31276,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.305,0.27932,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.2479,0.29679,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.29555,0.30718,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.27466,0.3052,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.2833,0.28407,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.29512,0.25657,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.29551,0.28794,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.31493,0.28883,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.30418,0.24911,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.27413,0.26309,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.26238,0.22378,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.33172,0.29101,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.28132,0.28222,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.24919,0.30995,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.27342,0.24142,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.18077,0.29804,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.26521,0.27234,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.27687,0.21312,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.22046,0.22111,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.22789,0.20662,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.29431,0.24918,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.26029,0.26561,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.22596,0.20709,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.22823,0.24512,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.23015,0.23223,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.2171,0.20355,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.23271,0.20313,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.24226,0.19431,00:02


Epoch,Train Loss,Valid Loss,Time
0,0.23889,0.22618,00:03


CPU times: user 3h 4min 15s, sys: 2min 47s, total: 3h 7min 2s
Wall time: 31min 33s


In [20]:
log_df = log_to_df(env.log.log)