In [1]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
import numpy as np
import torch 
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from typing import List

In [2]:

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 32),
        )
        self.decoder = nn.Sequential(
            nn.Linear(32, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.Sigmoid()
        )
        
        self.mu = nn.Linear(32, 32)
        self.logstd = nn.Linear(32, 32)
        
    def encode(self, x):
        # x.shape Bx1024
        h = self.encoder(x)
        return self.mu(h), self.logstd(h)

    def reparametrize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = 1e-2 * torch.randn_like(std)
           
            return eps.mul(std).add_(mu)
        else:
            return mu
        

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar


In [3]:
def fingerprint_mols(mols, numBits = 1024):
    """Generates a morgan fingerprint for a list of smiles string.
    :param mols: A smiles string for a molecule.
    :param 2: The radius of the fingerprint.
    :param num_bits: The number of bits to use in the fingerprint.
    :return: A 1-D numpy array containing the morgan fingerprint.
    """
    fps = []
    for mol in mols:
        mol = Chem.MolFromSmiles(mol)
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=numBits)
        fp_arr = np.zeros((1,))
        DataStructs.ConvertToNumpyArray(fp, fp_arr)
        fps.append(fp_arr)
    return fps 


def load_dataset(file_dg, file_smiles):
    dataset = {}
    with open(file_dg) as f_dg:
        for line in f_dg:
            w = line.split()
            name = w[0][w[0].index('_') + 1:].replace('_', ' ')
            dg = float(w[1])
            dataset[name] = {'dg': dg, 'smiles': None}

    smiles = []
    dg = []
    with open(file_smiles) as f_smiles:
        for line in f_smiles:
            w = line.split('\t')
            if len(w) == 3 and w[0] in dataset.keys():
                name = w[0]
                dataset[name]['smiles'] = w[-1].replace('\n', '')
                smiles.append(dataset[name]['smiles'])
                dg.append(dataset[name]['dg'])

    return dataset, smiles, dg

dataset, smiles, dg = load_dataset('./DatabaseOMSDrugs_scores.dat', './DatabaseOMSDrugs.dat')


class fpDataset(Dataset):
    def __init__(self, chemical_space: List[np.array]):
        self.chemical_space = chemical_space
        
    def __getitem__(self, index):
        return torch.from_numpy(self.chemical_space[index]).float()
    def __len__(self):
        #len(dataset)
        return len(self.chemical_space)
 

In [5]:
#supervisor.run_sampling_test(n_iterations=20, smiles=smiles, dg=dg)
#defining the hyperparameters
    

epochs = 100
vae = VAE()
lr = 1e-2

optimizer = torch.optim.Adam(vae.parameters(), lr)
k = fingerprint_mols(smiles)
dataset = fpDataset(k)
train_loader = DataLoader(dataset = dataset, batch_size = 32,
                       shuffle = True)

# defining the loss function 
def loss_function(ỹ, y, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(ỹ, y)
    #KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD 

for epoch in range(1, epochs + 1):
    train_loss = 0
    #losses = []
    for batch_idx, batch in enumerate(train_loader):
        #print(vae(batch)[0].shape)
        data = vae(batch)[0]
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        print(recon_batch, data)

        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss
        optimizer.step()
        #losses.append(loss.detach().item())
        if batch_idx % 100 == 0:
            print(f'epoch :{epoch} \t loss: {loss:.4f}')
        
#supervisor.run_sampling_test(n_iterations=20, smiles=smiles, dg=dg)
#defining the hyperparameters


tensor([[0.5040, 0.4821, 0.4950,  ..., 0.5128, 0.4896, 0.4939],
        [0.5041, 0.4820, 0.4950,  ..., 0.5129, 0.4902, 0.4939],
        [0.5046, 0.4821, 0.4944,  ..., 0.5130, 0.4897, 0.4936],
        ...,
        [0.5043, 0.4824, 0.4949,  ..., 0.5129, 0.4897, 0.4939],
        [0.5041, 0.4823, 0.4952,  ..., 0.5129, 0.4894, 0.4938],
        [0.5040, 0.4820, 0.4949,  ..., 0.5130, 0.4898, 0.4938]],
       grad_fn=<SigmoidBackward0>) tensor([[0.5031, 0.4826, 0.4952,  ..., 0.5125, 0.4893, 0.4945],
        [0.5035, 0.4823, 0.4958,  ..., 0.5126, 0.4896, 0.4945],
        [0.5033, 0.4823, 0.4951,  ..., 0.5133, 0.4896, 0.4939],
        ...,
        [0.5036, 0.4829, 0.4952,  ..., 0.5130, 0.4891, 0.4943],
        [0.5033, 0.4828, 0.4954,  ..., 0.5129, 0.4890, 0.4941],
        [0.5028, 0.4825, 0.4954,  ..., 0.5131, 0.4889, 0.4940]],
       grad_fn=<SigmoidBackward0>)
epoch :1 	 loss: 10.9537
tensor([[0.0717, 0.8713, 0.0539,  ..., 0.1047, 0.3184, 0.9218],
        [0.0720, 0.8715, 0.0547,  ..., 0.1084


KeyboardInterrupt



In [8]:
(recon_batch - data).abs().mean()


tensor(4.6083e-09, grad_fn=<MeanBackward0>)