In [None]:
NP_model = pickle.load(gzip.open('data/NP_score.pkl.gz'))
SA_model = {i[j]: float(i[0]) for i in pickle.load(gzip.open('data/SA_score.pkl.gz')) for j in range(1, len(i))}


class MolecularMetrics(object):

    @staticmethod
    def _avoid_sanitization_error(op):
        try:
            return op()
        except ValueError:
            return None

    @staticmethod
    def remap(x, x_min, x_max):
        return (x - x_min) / (x_max - x_min)

    @staticmethod
    def valid_lambda(x):
        return x is not None and Chem.MolToSmiles(x) != ''

    @staticmethod
    def valid_lambda_special(x):
        s = Chem.MolToSmiles(x) if x is not None else ''
        return x is not None and '*' not in s and '.' not in s and s != ''

    @staticmethod
    def valid_scores(mols):
        return np.array(list(map(MolecularMetrics.valid_lambda_special, mols)), dtype=np.float32)

    @staticmethod
    def valid_filter(mols):
        return list(filter(MolecularMetrics.valid_lambda, mols))

    @staticmethod
    def valid_total_score(mols):
        return np.array(list(map(MolecularMetrics.valid_lambda, mols)), dtype=np.float32).mean()

    @staticmethod
    def novel_scores(mols, data):
        return np.array(
            list(map(lambda x: MolecularMetrics.valid_lambda(x) and Chem.MolToSmiles(x) not in data.smiles, mols)))

    @staticmethod
    def novel_filter(mols, data):
        return list(filter(lambda x: MolecularMetrics.valid_lambda(x) and Chem.MolToSmiles(x) not in data.smiles, mols))

    @staticmethod
    def novel_total_score(mols, data):
        return MolecularMetrics.novel_scores(MolecularMetrics.valid_filter(mols), data).mean()

    @staticmethod
    def unique_scores(mols):
        smiles = list(map(lambda x: Chem.MolToSmiles(x) if MolecularMetrics.valid_lambda(x) else '', mols))
        return np.clip(
            0.75 + np.array(list(map(lambda x: 1 / smiles.count(x) if x != '' else 0, smiles)), dtype=np.float32), 0, 1)

    @staticmethod
    def unique_total_score(mols):
        v = MolecularMetrics.valid_filter(mols)
        s = set(map(lambda x: Chem.MolToSmiles(x), v))
        return 0 if len(v) == 0 else len(s) / len(v)

    # @staticmethod
    # def novel_and_unique_total_score(mols, data):
    #     return ((MolecularMetrics.unique_scores(mols) == 1).astype(float) * MolecularMetrics.novel_scores(mols,
    #                                                                                                       data)).sum()
    #
    # @staticmethod
    # def reconstruction_scores(data, model, session, sample=False):
    #
    #     m0, _, _, a, x, _, f, _, _ = data.next_validation_batch()
    #     feed_dict = {model.edges_labels: a, model.nodes_labels: x, model.node_features: f, model.training: False}
    #
    #     try:
    #         feed_dict.update({model.variational: False})
    #     except AttributeError:
    #         pass
    #
    #     n, e = session.run([model.nodes_gumbel_argmax, model.edges_gumbel_argmax] if sample else [
    #         model.nodes_argmax, model.edges_argmax], feed_dict=feed_dict)
    #
    #     n, e = np.argmax(n, axis=-1), np.argmax(e, axis=-1)
    #
    #     m1 = [data.matrices2mol(n_, e_, strict=True) for n_, e_ in zip(n, e)]
    #
    #     return np.mean([float(Chem.MolToSmiles(m0_) == Chem.MolToSmiles(m1_)) if m1_ is not None else 0
    #             for m0_, m1_ in zip(m0, m1)])

    @staticmethod
    def natural_product_scores(mols, norm=False):

        # calculating the score
        scores = [sum(NP_model.get(bit, 0)
                      for bit in Chem.rdMolDescriptors.GetMorganFingerprint(mol,
                                                                            2).GetNonzeroElements()) / float(
            mol.GetNumAtoms()) if mol is not None else None
                  for mol in mols]

        # preventing score explosion for exotic molecules
        scores = list(map(lambda score: score if score is None else (
            4 + math.log10(score - 4 + 1) if score > 4 else (
                -4 - math.log10(-4 - score + 1) if score < -4 else score)), scores))

        scores = np.array(list(map(lambda x: -4 if x is None else x, scores)))
        scores = np.clip(MolecularMetrics.remap(scores, -3, 1), 0.0, 1.0) if norm else scores

        return scores

    @staticmethod
    def quantitative_estimation_druglikeness_scores(mols, norm=False):
        return np.array(list(map(lambda x: 0 if x is None else x, [
            MolecularMetrics._avoid_sanitization_error(lambda: QED.qed(mol)) if mol is not None else None for mol in
            mols])))

    @staticmethod
    def water_octanol_partition_coefficient_scores(mols, norm=False):
        scores = [MolecularMetrics._avoid_sanitization_error(lambda: Crippen.MolLogP(mol)) if mol is not None else None
                  for mol in mols]
        scores = np.array(list(map(lambda x: -3 if x is None else x, scores)))
        scores = np.clip(MolecularMetrics.remap(scores, -2.12178879609, 6.0429063424), 0.0, 1.0) if norm else scores

        return scores

    @staticmethod
    def _compute_SAS(mol):
        fp = Chem.rdMolDescriptors.GetMorganFingerprint(mol, 2)
        fps = fp.GetNonzeroElements()
        score1 = 0.
        nf = 0
        # for bitId, v in fps.items():
        for bitId, v in fps.items():
            nf += v
            sfp = bitId
            score1 += SA_model.get(sfp, -4) * v
        score1 /= nf

        # features score
        nAtoms = mol.GetNumAtoms()
        nChiralCenters = len(Chem.FindMolChiralCenters(
            mol, includeUnassigned=True))
        ri = mol.GetRingInfo()
        nSpiro = Chem.rdMolDescriptors.CalcNumSpiroAtoms(mol)
        nBridgeheads = Chem.rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
        nMacrocycles = 0
        for x in ri.AtomRings():
            if len(x) > 8:
                nMacrocycles += 1

        sizePenalty = nAtoms ** 1.005 - nAtoms
        stereoPenalty = math.log10(nChiralCenters + 1)
        spiroPenalty = math.log10(nSpiro + 1)
        bridgePenalty = math.log10(nBridgeheads + 1)
        macrocyclePenalty = 0.

        # ---------------------------------------
        # This differs from the paper, which defines:
        #  macrocyclePenalty = math.log10(nMacrocycles+1)
        # This form generates better results when 2 or more macrocycles are present
        if nMacrocycles > 0:
            macrocyclePenalty = math.log10(2)

        score2 = 0. - sizePenalty - stereoPenalty - \
                 spiroPenalty - bridgePenalty - macrocyclePenalty

        # correction for the fingerprint density
        # not in the original publication, added in version 1.1
        # to make highly symmetrical molecules easier to synthetise
        score3 = 0.
        if nAtoms > len(fps):
            score3 = math.log(float(nAtoms) / len(fps)) * .5

        sascore = score1 + score2 + score3

        # need to transform "raw" value into scale between 1 and 10
        min = -4.0
        max = 2.5
        sascore = 11. - (sascore - min + 1) / (max - min) * 9.
        # smooth the 10-end
        if sascore > 8.:
            sascore = 8. + math.log(sascore + 1. - 9.)
        if sascore > 10.:
            sascore = 10.0
        elif sascore < 1.:
            sascore = 1.0

        return sascore

    @staticmethod
    def synthetic_accessibility_score_scores(mols, norm=False):
        scores = [MolecularMetrics._compute_SAS(mol) if mol is not None else None for mol in mols]
        scores = np.array(list(map(lambda x: 10 if x is None else x, scores)))
        scores = np.clip(MolecularMetrics.remap(scores, 5, 1.5), 0.0, 1.0) if norm else scores

        return scores

    @staticmethod
    def diversity_scores(mols, data):
        rand_mols = np.random.choice(data.data, 100)
        fps = [Chem.rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, 4, nBits=2048) for mol in rand_mols]

        scores = np.array(
            list(map(lambda x: MolecularMetrics.__compute_diversity(x, fps) if x is not None else 0, mols)))
        scores = np.clip(MolecularMetrics.remap(scores, 0.9, 0.945), 0.0, 1.0)

        return scores

    @staticmethod
    def __compute_diversity(mol, fps):
        ref_fps = Chem.rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, 4, nBits=2048)
        dist = DataStructs.BulkTanimotoSimilarity(ref_fps, fps, returnDistance=True)
        score = np.mean(dist)
        return score

    @staticmethod
    def drugcandidate_scores(mols, data):

        scores = (MolecularMetrics.constant_bump(
            MolecularMetrics.water_octanol_partition_coefficient_scores(mols, norm=True), 0.210,
            0.945) + MolecularMetrics.synthetic_accessibility_score_scores(mols,
                                                                           norm=True) + MolecularMetrics.novel_scores(
            mols, data) + (1 - MolecularMetrics.novel_scores(mols, data)) * 0.3) / 4

        return scores

    @staticmethod
    def constant_bump(x, x_low, x_high, decay=0.025):
        return np.select(condlist=[x <= x_low, x >= x_high],
                         choicelist=[np.exp(- (x - x_low) ** 2 / decay),
                                     np.exp(- (x - x_high) ** 2 / decay)],
                         default=np.ones_like(x))

In [5]:
def gradient_penalty( y, x):
        weight = torch.ones(y.size())
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1))
        return torch.mean((dydx_l2norm - 1) ** 2)

def label2onehot(labels, dim):
        """Convert label indices to one-hot vectors."""
        out = torch.zeros(list(labels.size()) + [dim])
        out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.)
        return out

def sample_z(batch_size):
        return np.random.normal(0, 1, size=(batch_size, 32))

def postprocess(inputs, method = "soft_gumbel", temperature=1.):
        def listify(x):
            return x if type(x) == list or type(x) == tuple else [x]

        def delistify(x):
            return x if len(x) > 1 else x[0]

        if method == 'soft_gumbel':
            softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1, e_logits.size(-1))
                                        / temperature, hard=False).view(e_logits.size())
                       for e_logits in listify(inputs)]
        elif method == 'hard_gumbel':
            softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1, e_logits.size(-1))
                                        / temperature, hard=True).view(e_logits.size())
                       for e_logits in listify(inputs)]
        else:
            softmax = [F.softmax(e_logits / temperature, -1)
                       for e_logits in listify(inputs)]

        return [delistify(e) for e in (softmax)]

In [6]:
data = SparseMolecularDataset()
data.load("qm9_5k.sparsedataset")

b_dim = data.bond_num_types
m_dim = data.atom_num_types


In [7]:
def reward(mols,metric ="validity,qed"):
        rr = 1.
        for m in ('logp,sas,qed,unique' if metric == 'all' else metric).split(','):

            if m == 'np':
                rr *= MolecularMetrics.natural_product_scores(mols, norm=True)
            elif m == 'logp':
                rr *= MolecularMetrics.water_octanol_partition_coefficient_scores(mols, norm=True)
            elif m == 'sas':
                rr *= MolecularMetrics.synthetic_accessibility_score_scores(mols, norm=True)
            elif m == 'qed':
                rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores(mols, norm=True)
            elif m == 'novelty':
                rr *= MolecularMetrics.novel_scores(mols, data)
            elif m == 'dc':
                rr *= MolecularMetrics.drugcandidate_scores(mols, data)
            elif m == 'unique':
                rr *= MolecularMetrics.unique_scores(mols)
            elif m == 'diversity':
                rr *= MolecularMetrics.diversity_scores(mols, data)
            elif m == 'validity':
                rr *= MolecularMetrics.valid_scores(mols)
            else:
                raise RuntimeError('{} is not defined as a metric'.format(m))

        return rr.reshape(-1, 1)
    
def get_reward(n_hat, e_hat, method = "soft_gumbel"):
        (edges_hard, nodes_hard) = postprocess((e_hat, n_hat), method)
        edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1]
        mols = [data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
                for e_, n_ in zip(edges_hard, nodes_hard)]
        rewards = torch.from_numpy(reward(mols))
        return rewards

In [8]:
def gradient_penalty(y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size())
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

def get_gen_mols( n_hat, e_hat):
        (edges_hard, nodes_hard) = postprocess((e_hat, n_hat))
        edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1]
        mols = [data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
                for e_, n_ in zip(edges_hard, nodes_hard)]
        return mols

def all_scores(mols, data, norm=False, reconstruction=False):
    m0 = {k: list(filter(lambda e: e is not None, v)) for k, v in {
        'NP': MolecularMetrics.natural_product_scores(mols, norm=norm),
        'QED': MolecularMetrics.quantitative_estimation_druglikeness_scores(mols),
        'Solute': MolecularMetrics.water_octanol_partition_coefficient_scores(mols, norm=norm),
        'SA': MolecularMetrics.synthetic_accessibility_score_scores(mols, norm=norm),
        'diverse': MolecularMetrics.diversity_scores(mols, data),
        'drugcand': MolecularMetrics.drugcandidate_scores(mols, data)}.items()}

    m1 = {'valid': MolecularMetrics.valid_total_score(mols) * 100,
          'unique': MolecularMetrics.unique_total_score(mols) * 100,
          'novel': MolecularMetrics.novel_total_score(mols, data) * 100}

    return m0, m1

def random_string(string_len=3):
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(string_len))

def save_mol_img(mols, f_name='tmp.png', is_test=False):
    orig_f_name = f_name
    for a_mol in mols:
        try:
            if Chem.MolToSmiles(a_mol) is not None:
                print('Generating molecule')

                if is_test:
                    f_name = orig_f_name
                    f_split = f_name.split('.')
                    f_split[-1] = random_string() + '.' + f_split[-1]
                    f_name = ''.join(f_split)

                rdkit.Chem.Draw.MolToFile(a_mol, f_name)
                a_smi = Chem.MolToSmiles(a_mol)
                mol_graph = read_smiles(a_smi)

                break

                # if not is_test:
                #     break
        except:
            continue

In [9]:
from models import *

In [10]:
G =  Generator()
D = MolGANDiscriminator()
R = MolGANDiscriminator()

G_optim = torch.optim.Adam(G.parameters(),lr=1e-3)
D_optim = torch.optim.Adam(D.parameters(),lr=1e-3)
R_optim = torch.optim.Adam(R.parameters(),lr=1e-3)

def reset_grad():
    G_optim.zero_grad()
    D_optim.zero_grad()
    R_optim.zero_grad()


In [None]:
la = 0
la_gp = 10
n_critic = 5
g_lr = 1e-4
d_lr = 1e-4
num_steps = (len(data) // 32)
def train(epoch_i, train_val_test='val',mode = "train"):
    
    
    if epoch_i < 0:
        cur_la = 0
    else:
        cur_la = la
      
    losses = defaultdict(list)
    scores = defaultdict(list)  
    the_step = num_steps
    
    if train_val_test == 'val':
        if mode == 'train':
                the_step = 1
        print('[Validating]')
        
    for a_step in range(the_step):
        if train_val_test == 'val':
            mols, _, _, a, x, _, _, _, _ = data.next_validation_batch()
            z = sample_z(a.shape[0])
        elif train_val_test == 'train':
            
            mols, _, _, a, x, _, _, _, _ = data.next_train_batch(32)
            z = sample_z(32)
            
    a= a.astype(np.int64)
    x= x.astype(np.int64)
    z = sample_z(32)
    a = torch.from_numpy(a).long()         # Adjacency.
    x = torch.from_numpy(x).long()        # Nodes.
    a_tensor = label2onehot(labels =a, dim =b_dim)
    x_tensor = label2onehot(x, m_dim)
    z = torch.from_numpy(z).float()

    cur_step = num_steps * epoch_i + a_step

    logits_real, features_real = D(x_tensor,a_tensor)
    edges_logits, nodes_logits = G(z)

    (edges_hat, nodes_hat) = postprocess((edges_logits, nodes_logits))
    logits_fake, features_fake = D(nodes_hat,edges_hat)

    eps = torch.rand(logits_real.size(0), 1, 1, 1)
    x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True)
    x_int1 = (eps.squeeze(-1) * x_tensor + (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True)
    
    grad0, grad1 = D(x_int1, x_int0)
    grad_penalty = gradient_penalty(grad0, x_int0) + gradient_penalty(grad1, x_int1)
    
    d_loss_real = torch.mean(logits_real)
    d_loss_fake = torch.mean(logits_fake)
    loss_D = -d_loss_real + d_loss_fake + la_gp * grad_penalty
    
    if cur_la > 0:
            losses['l_D/R'].append(d_loss_real.item())
            losses['l_D/F'].append(d_loss_fake.item())
            losses['l_D'].append(loss_D.item())

            # Optimise discriminator.
    if train_val_test == 'train' and cur_step % n_critic != 0 and cur_la > 0:
                reset_grad()
                loss_D.backward()
                G_optim.step()
                
    
    value_logit_real, _ = R(x_tensor, a_tensor)
    value_logit_fake, _ = R(nodes_hat,  edges_hat)
    f_loss = (torch.mean(features_real, 0) - torch.mean(features_fake, 0)) ** 2

    reward_r = torch.from_numpy(reward(mols))
    reward_f = get_reward(nodes_hat, edges_hat)
    print(reward_f)
    loss_G = -logits_fake
    
    loss_V = torch.abs(value_logit_real - reward_r) + torch.abs(value_logit_fake - reward_f)
    loss_RL = -value_logit_fake

    loss_G = torch.mean(loss_G)
    loss_V = torch.mean(loss_V)
    loss_RL = torch.mean(loss_RL)
    losses['l_G'].append(loss_G.item())
    losses['l_RL'].append(loss_RL.item())
    losses['l_V'].append(loss_V.item())

    alpha = torch.abs(loss_G.detach() / loss_RL.detach()).detach()
    train_step_G = cur_la * loss_G + (1 - cur_la) * alpha * loss_RL

    train_step_V = loss_V
    if train_val_test == 'train':
            reset_grad()
    
    if cur_step % n_critic == 0:
        train_step_V.backward()
        R_optim.step()

    if train_val_test == 'val':
        mols = get_gen_mols(nodes_logits, edges_logits)
        m0, m1 = all_scores(mols, data, norm=True)  # 'mols' is output of Fake Reward
        for k, v in m1.items():
                scores[k].append(v)
        for k, v in m0.items():
                scores[k].append(np.array(v)[np.nonzero(v)].mean())
        
        
        mol_f_name = os.path.join("Home", 'mol-{}.png'.format(epoch_i))
        save_mol_img(mols, mol_f_name, is_test=mode == 'test')
    
        is_first = True
        for tag, value in losses.items():
                if is_first:
                    log += "\n{}: {:.2f}".format(tag, np.mean(value))
                    is_first = False
                else:
                    log += ", {}: {:.2f}".format(tag, np.mean(value))
        is_first = True
        for tag, value in scores.items():
                if is_first:
                    log += "\n{}: {:.2f}".format(tag, np.mean(value))
                    is_first = False
                else:
                    log += ", {}: {:.2f}".format(tag, np.mean(value))
        print(log)


  return self._call_impl(*args, **kwargs)


RuntimeError: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 3