<a href="https://colab.research.google.com/github/aksub99/molecular-vae/blob/master/Molecular_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import gzip
import pandas
import h5py
import numpy as np
from __future__ import print_function
import argparse
import os
import h5py
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn import model_selection

In [0]:
def one_hot_array(i, n):
    return map(int, [ix == i for ix in xrange(n)])

def one_hot_index(vec, charset):
    return map(charset.index, vec)

def from_one_hot_array(vec):
    oh = np.where(vec == 1)
    if oh[0].shape == (0, ):
        return None
    return int(oh[0][0])

def decode_smiles_from_indexes(vec, charset):
    return "".join(map(lambda x: charset[x], vec)).strip()

def load_dataset(filename, split = True):
    h5f = h5py.File(filename, 'r')
    if split:
        data_train = h5f['data_train'][:]
    else:
        data_train = None
    data_test = h5f['data_test'][:]
    charset =  h5f['charset'][:]
    h5f.close()
    if split:
        return (data_train, data_test, charset)
    else:
        return (data_test, charset)


In [0]:
class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        self.conv_1 = nn.Conv1d(120, 9, kernel_size=9)
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)
        self.linear_0 = nn.Linear(70, 435)
        self.linear_1 = nn.Linear(435, 292)
        self.linear_2 = nn.Linear(435, 292)

        self.linear_3 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, 33)
        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def encode(self, x):
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))
        return self.linear_1(x), self.linear_2(x)

    def sampling(self, z_mean, z_logvar):
        epsilon = 1e-2 * torch.randn_like(z_logvar)
        return torch.exp(0.5 * z_logvar) * epsilon + z_mean

    def decode(self, z):
        z = F.selu(self.linear_3(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
        output, hn = self.gru(z)
        out_reshape = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_reshape), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        z_mean, z_logvar = self.encode(x)
        z = self.sampling(z_mean, z_logvar)
        return self.decode(z), z_mean, z_logvar

In [4]:
!rm -R 'molecular-vae'
!git clone https://github.com/aksub99/molecular-vae.git
import zipfile
zip_ref = zipfile.ZipFile('molecular-vae/data/processed.zip', 'r')
zip_ref.extractall('molecular-vae/data/')
zip_ref.close()


Cloning into 'molecular-vae'...
remote: Enumerating objects: 42, done.[K
remote: Counting objects:   2% (1/42)   [Kremote: Counting objects:   4% (2/42)   [Kremote: Counting objects:   7% (3/42)   [Kremote: Counting objects:   9% (4/42)   [Kremote: Counting objects:  11% (5/42)   [Kremote: Counting objects:  14% (6/42)   [Kremote: Counting objects:  16% (7/42)   [Kremote: Counting objects:  19% (8/42)   [Kremote: Counting objects:  21% (9/42)   [Kremote: Counting objects:  23% (10/42)   [Kremote: Counting objects:  26% (11/42)   [Kremote: Counting objects:  28% (12/42)   [Kremote: Counting objects:  30% (13/42)   [Kremote: Counting objects:  33% (14/42)   [Kremote: Counting objects:  35% (15/42)   [Kremote: Counting objects:  38% (16/42)   [Kremote: Counting objects:  40% (17/42)   [Kremote: Counting objects:  42% (18/42)   [Kremote: Counting objects:  45% (19/42)   [Kremote: Counting objects:  47% (20/42)   [Kremote: Counting objects:  50% (21/

In [5]:
def vae_loss(x_decoded_mean, x, z_mean, z_logvar):
    xent_loss = F.binary_cross_entropy(x_decoded_mean, x, size_average=False)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return xent_loss + kl_loss

data_train, data_test, charset = load_dataset('molecular-vae/data/processed.h5')
data_train = torch.utils.data.TensorDataset(torch.from_numpy(data_train))
train_loader = torch.utils.data.DataLoader(data_train, batch_size=250, shuffle=True)

torch.manual_seed(42)

epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MolecularVAE().to(device)
optimizer = optim.Adam(model.parameters())

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data[0].to(device)
        optimizer.zero_grad()
        output, mean, logvar = model(data)
        loss = vae_loss(output, data, mean, logvar)
        loss.backward()
        train_loss += loss
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'{epoch} / {batch_idx}\t{loss:.4f}')
    print('train', train_loss / len(train_loader.dataset))
    return train_loss / len(train_loader.dataset)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)



1 / 0	133752.1094
1 / 100	31996.6543
train tensor(150.1170, device='cuda:0', grad_fn=<DivBackward0>)
2 / 0	29997.5488
2 / 100	31999.7266
train tensor(121.1428, device='cuda:0', grad_fn=<DivBackward0>)
3 / 0	30086.0234
3 / 100	30490.2559
train tensor(119.5334, device='cuda:0', grad_fn=<DivBackward0>)
4 / 0	29008.4883
4 / 100	28547.3418
train tensor(117.5161, device='cuda:0', grad_fn=<DivBackward0>)
5 / 0	30265.7793
5 / 100	28809.0742
train tensor(115.3262, device='cuda:0', grad_fn=<DivBackward0>)
6 / 0	29764.8027
6 / 100	28289.2422
train tensor(116.9878, device='cuda:0', grad_fn=<DivBackward0>)
7 / 0	28438.0820


KeyboardInterrupt: ignored