In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data_utils
import torch.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data prep

In [2]:
import selfies as sf
data = pd.read_csv('./GRU_data/selfies.csv', header=None, names=['selfies'])
alphabet = sf.get_alphabet_from_selfies(data.selfies)
alphabet.add("[nop]") # [nop] is a special padding symbol
alphabet.add("[start]")
alphabet.add("[end]")
alphabet = list(sorted(alphabet))
#pad_to_len = max(sf.len_selfies(s) for s in data.selfies) + 5
pad_to_len = 128
symbol_to_idx = {s: i for i, s in enumerate(alphabet)}
idx2char = {i: s for i, s in enumerate(alphabet)}

In [3]:
len(alphabet)

42

In [4]:
from torch.utils.data import Dataset, DataLoader

class GRUDataset(Dataset):
    def __init__(self, smiles_fp, selfies, vectorizer):
        self.smiles_fp = pd.read_csv(smiles_fp, sep=',', nrows=10000)
        self.selfies = pd.read_csv(selfies, nrows=10000)
        self.X = self.prepare_X(self.smiles_fp)
        self.X = np.array([self.reconstruct_fp(fp) for fp in self.X])
        self.y = self.prepare_y(self.selfies)
    def __len__(self):
        return len(self.smiles_fp)
    def __getitem__(self, idx):
        raw_selfie = self.y[idx][0]
        vectorized_selfie = vectorizer.vectorize(raw_selfie)
        return torch.from_numpy(self.X[idx]).float(), torch.from_numpy(vectorized_selfie).float()

    @staticmethod
    def prepare_X(smiles_fp):
        fps = smiles_fp.fps.apply(eval).apply(lambda x: np.array(x, dtype=int))
        return fps
    @staticmethod
    def prepare_y(selfies):
        return selfies.values
    @staticmethod
    def reconstruct_fp(fp, length=4860):
        fp_rec = np.zeros(length)
        fp_rec[fp] = 1
        return fp_rec

In [5]:
import re
class SELFIESVectorizer:
    def __init__(self, alphabet, pad_to_len):
        self.alphabet = alphabet
        self.pad_to_len = pad_to_len
        self.char2idx = {s: i for i, s in enumerate(alphabet)}
        self.idx2char = {i: s for i, s in enumerate(alphabet)}
    def vectorize(self, selfie):
        ''' Vectorize a list of SMILES strings to a numpy array of shape (len(smiles), embed, len(charset))'''
        X = np.zeros((self.pad_to_len, len(self.alphabet)))
        splited = ['[start]'] + self.split_selfi(selfie) + ['[end]'] + ['[nop]'] * (self.pad_to_len - len(self.split_selfi(selfie)) - 2)
        for i, char in enumerate(splited):
            X[i, self.char2idx[char]] = 1
        return X
    def devectorize(self, ohe):
        ''' Devectorize a numpy array of shape (len(smiles), embed, len(charset)) to a list of SMILES strings'''
        selfie_str = ''
        for j in range(self.pad_to_len):
            char = self.idx2char[np.argmax(ohe[j])]
            if char == '[start]':
                continue
            elif char == '[end]':
                break
            else:
                selfie_str += char
        return selfie_str

    def split_selfi(self, selfie):
        pattern = r'(\[[^\[\]]*\])'
        return re.findall(pattern, selfie)

In [6]:
vectorizer = SELFIESVectorizer(alphabet, pad_to_len)
dataset = GRUDataset('GRU_data/chembl_klek.csv', 'GRU_data/selfies.csv', vectorizer)

In [7]:
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [8]:
batch_size = 64
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, drop_last=True)
next(iter(train_loader))[0].shape

torch.Size([64, 4860])

# NN architecture

In [63]:
import random

class EncoderNet(nn.Module):
    def __init__(self, fp_size, encoding_size):
        super(EncoderNet, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(fp_size, 2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, 256)
        self.fc5 = nn.Linear(256, encoding_size)

    def forward(self, x):
        out = self.relu(self.fc1(x))
        out = self.relu(self.fc2(out))
        out = self.relu(self.fc3(out))
        out = self.relu(self.fc4(out))
        out = self.relu(self.fc5(out))
        return out

class DecoderNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, drop_prob):
        super(DecoderNet, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, dropout=drop_prob, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=2)
        self.max_len = vectorizer.pad_to_len

    def forward(self, encoded):
        batch_size = encoded.size(0)
        hidden = self.init_hidden(encoded)
        start_vector = vectorizer.vectorize('[start]')
        start_vector = torch.from_numpy(start_vector).float().to(device)
        start_vector = start_vector.unsqueeze(0)
        start_vector = start_vector.repeat(batch_size, 1, 1)
        decoded, hidden = self.gru(start_vector, hidden)
        decoded = self.fc(decoded)
        return decoded

    def init_hidden(self, encoded):
        return encoded.unsqueeze(0).repeat(self.num_layers, 1, 1).to(device)
    
class Autoencoder(nn.Module):
    def __init__(self, input_size=4860, encoding_size=128, decoding_size=128, output_size=42, num_layers=2, drop_prob=0.1):
        super(Autoencoder, self).__init__()
        self.encoder = EncoderNet(input_size, encoding_size)
        self.decoder = DecoderNet(output_size, decoding_size, output_size, num_layers, drop_prob)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [64]:
encoder = EncoderNet(4860, 128).to(device)
decoder = DecoderNet(42, 128, 42, 2, 0.1).to(device)
test_batch = next(iter(train_loader))[0].to(device)

In [65]:
print(f'Test batch shape: {test_batch.shape}')
encoded = encoder(test_batch)
print(f'Encoded shape: {encoded.shape}')
print(f'Encoded: {encoded[0, :5]}')

Test batch shape: torch.Size([64, 4860])
Encoded shape: torch.Size([64, 128])
Encoded: tensor([0.0000, 0.0000, 0.0364, 0.0000, 0.0000], device='cuda:0',
       grad_fn=<SliceBackward0>)


In [66]:
decoded = decoder(encoded)
print(f'Decoded shape: {decoded.shape}')

Decoded shape: torch.Size([64, 128, 42])


In [78]:
model = Autoencoder(4860, 128, 128, 42, 2, 0.1).to(device)

In [79]:
with torch.no_grad():
    encoded, decoded = model(test_batch)
    print(f'Encoded shape: {encoded.shape}')
    print(f'Decoded shape: {decoded.shape}')

Encoded shape: torch.Size([64, 128])
Decoded shape: torch.Size([64, 128, 42])


# Training

In [80]:
def train(autoencoder, dataloader, num_epochs=10, device=device):
    learning_rate = 0.001

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(autoencoder.parameters(), lr=learning_rate)

    # Training loop
    print('Training started')
    for epoch in range(num_epochs):
        start_time = time.time()
        avg_loss = 0.
        counter = 0
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            encoded, decoded = autoencoder(x)
            loss = criterion(decoded, y)
            loss.backward()
            optimizer.step()
            
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
        
    return autoencoder

In [81]:
# input_size = 4860
# encoding_size = 128
# decoding_size = 128
# output_size = 42
# num_layers = 2

# model = Autoencoder(input_size=input_size, encoding_size=encoding_size, decoding_size=decoding_size,
#                     output_size=output_size, num_layers=num_layers, drop_prob=0.2).to(device)


In [82]:
model = train(autoencoder=model, dataloader=train_loader, num_epochs=100)

Training started
Epoch [1/100], Loss: 13.1353
Epoch [2/100], Loss: 13.0825
Epoch [3/100], Loss: 13.0945
Epoch [4/100], Loss: 12.9774
Epoch [5/100], Loss: 12.8989
Epoch [6/100], Loss: 13.0180
Epoch [7/100], Loss: 13.1643
Epoch [8/100], Loss: 12.9321
Epoch [9/100], Loss: 13.0046
Epoch [10/100], Loss: 12.9978
Epoch [11/100], Loss: 12.9316
Epoch [12/100], Loss: 12.9728
Epoch [13/100], Loss: 13.0060
Epoch [14/100], Loss: 12.9378
Epoch [15/100], Loss: 12.9461
Epoch [16/100], Loss: 12.9096
Epoch [17/100], Loss: 12.9425
Epoch [18/100], Loss: 12.9138
Epoch [19/100], Loss: 12.8969
Epoch [20/100], Loss: 12.9271
Epoch [21/100], Loss: 12.9969
Epoch [22/100], Loss: 12.8832
Epoch [23/100], Loss: 12.9227
Epoch [24/100], Loss: 12.9055
Epoch [25/100], Loss: 12.8463
Epoch [26/100], Loss: 12.9162
Epoch [27/100], Loss: 12.8398
Epoch [28/100], Loss: 12.9060
Epoch [29/100], Loss: 12.8507
Epoch [30/100], Loss: 12.8669
Epoch [31/100], Loss: 12.8170
Epoch [32/100], Loss: 12.7917
Epoch [33/100], Loss: 12.8517
Ep

In [83]:
x = next(iter(train_loader))[0]
x = x.to(device)
x.size()

torch.Size([64, 4860])

In [84]:
encoded, decoded = model(x)

# GRU output to SELFIES

In [85]:
decoded_indices = torch.argmax(decoded.cpu(), dim=2)
decoded_indices = decoded_indices.numpy()

In [86]:
selfies = []
for i in decoded_indices:
    vectorized = []
    for number in decoded_indices[0]:
        v = np.zeros(128)
        v[number] = 1
        vectorized.append(v)
    devectorized = vectorizer.devectorize(vectorized)
    selfies.append(devectorized)
selfies

['[\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\O][\\O][\\O][=Ring2]',
 '[\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\O][\\O][\\O][=Ring2]',
 '[\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\O][\\O][\\O][=Ring2]',
 '[\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N][\\N