In [256]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data_utils
import torch.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data prep

In [399]:
import selfies as sf
data = pd.read_csv('./GRU_data/selfies.csv', header=None, names=['selfies'])
alphabet = sf.get_alphabet_from_selfies(data.selfies)
alphabet.add("[nop]") # [nop] is a special padding symbol
alphabet.add("[start]")
alphabet.add("[end]")
alphabet = list(sorted(alphabet))
#pad_to_len = max(sf.len_selfies(s) for s in data.selfies) + 5
pad_to_len = 128
symbol_to_idx = {s: i for i, s in enumerate(alphabet)}
idx2char = {i: s for i, s in enumerate(alphabet)}

In [400]:
len(alphabet)

42

In [401]:
from torch.utils.data import Dataset, DataLoader

class GRUDataset(Dataset):
    def __init__(self, smiles_fp, selfies, vectorizer):
        self.smiles_fp = pd.read_csv(smiles_fp, sep=',', nrows=1000)
        self.selfies = pd.read_csv(selfies, nrows=1000)
        self.X = self.prepare_X(self.smiles_fp)
        self.X = np.array([self.reconstruct_fp(fp) for fp in self.X])
        self.y = self.prepare_y(self.selfies)
    def __len__(self):
        return len(self.smiles_fp)
    def __getitem__(self, idx):
        raw_selfie = self.y[idx][0]
        vectorized_selfie = vectorizer.vectorize(raw_selfie)
        return torch.from_numpy(self.X[idx]).float(), torch.from_numpy(vectorized_selfie).float()

    @staticmethod
    def prepare_X(smiles_fp):
        fps = smiles_fp.fps.apply(eval).apply(lambda x: np.array(x, dtype=int))
        return fps
    @staticmethod
    def prepare_y(selfies):
        return selfies.values
    @staticmethod
    def reconstruct_fp(fp, length=4860):
        fp_rec = np.zeros(length)
        fp_rec[fp] = 1
        return fp_rec

In [402]:
import re
class SELFIESVectorizer:
    def __init__(self, alphabet, pad_to_len):
        self.alphabet = alphabet
        self.pad_to_len = pad_to_len
        self.char2idx = {s: i for i, s in enumerate(alphabet)}
        self.idx2char = {i: s for i, s in enumerate(alphabet)}
    def vectorize(self, selfie):
        ''' Vectorize a list of SMILES strings to a numpy array of shape (len(smiles), embed, len(charset))'''
        X = np.zeros((self.pad_to_len, len(self.alphabet)))
        splited = ['[start]'] + self.split_selfi(selfie) + ['[end]'] + ['[nop]'] * (self.pad_to_len - len(self.split_selfi(selfie)) - 2)
        for i, char in enumerate(splited):
            X[i, self.char2idx[char]] = 1
        return X
    def devectorize(self, ohe):
        ''' Devectorize a numpy array of shape (len(smiles), embed, len(charset)) to a list of SMILES strings'''
        selfie_str = ''
        for j in range(self.pad_to_len):
            char = self.idx2char[np.argmax(ohe[j])]
            if char == '[start]':
                continue
            elif char == '[end]':
                break
            else:
                selfie_str += char
        return selfie_str

    def split_selfi(self, selfie):
        pattern = r'(\[[^\[\]]*\])'
        return re.findall(pattern, selfie)

In [403]:
vectorizer = SELFIESVectorizer(alphabet, pad_to_len)
dataset = GRUDataset('GRU_data/chembl_klek.csv', 'GRU_data/selfies.csv', vectorizer)

In [404]:
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [405]:
batch_size = 64
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, drop_last=True)
next(iter(train_loader))[0].shape

torch.Size([64, 4860])

# NN architecture

In [406]:
class EncoderNet(nn.Module):
    def __init__(self, fp_size, encoding_size):
        super(EncoderNet, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(fp_size, 2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, 256)
        self.fc5 = nn.Linear(256, encoding_size)

    def forward(self, x):
        out = self.relu(self.fc1(x))
        out = self.relu(self.fc2(out))
        out = self.relu(self.fc3(out))
        out = self.relu(self.fc4(out))
        out = self.relu(self.fc5(out))
        return out

class DecoderNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, drop_prob):
        super(DecoderNet, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(drop_prob)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, input, hidden):
        input = input.unsqueeze(0)  # Add a time step dimension
        output, hidden = self.gru(input, hidden)
        output = self.fc(output)
        output = self.dropout(output)
        output = self.softmax(output)
        return output.squeeze(0), hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size)

class Autoencoder(nn.Module):
    def __init__(self, input_size, encoding_size, decoding_size, output_size, drop_prob=0.2):
        super(Autoencoder, self).__init__()
        self.encoder = EncoderNet(input_size, encoding_size)
        self.decoder = DecoderNet(encoding_size, decoding_size, output_size, drop_prob)

    def forward(self, x):
        encoded = self.encoder(x)
        hidden = self.decoder.init_hidden(x.size(0))
        decoded_sequences = []
        for _ in range(encoded.size(1)):
            output, hidden = self.decoder(encoded, hidden)
            decoded_sequences.append(output)
        decoded_sequences = torch.stack(decoded_sequences, dim=1)
        return encoded, decoded_sequences

# Training

In [414]:
def train(autoencoder, dataloader):
    num_epochs = 40
    learning_rate = 0.01

    criterion = nn.BCELoss()
    optimizer = optim.Adam(autoencoder.parameters(), lr=learning_rate)

    # Training loop
    print('Training started')
    for epoch in range(num_epochs):
        start_time = time.time()
        avg_loss = 0.
        counter = 0
        
        for x, y in dataloader:
            optimizer.zero_grad()
            encoded, decoded = autoencoder(x)
            loss = criterion(decoded, y)
            loss.backward()
            optimizer.step()
            
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
        
    return autoencoder

In [415]:
input_size = 4860
encoding_size = 128
decoding_size = 128
output_size = 42

model = Autoencoder(input_size=input_size, encoding_size=encoding_size, decoding_size=decoding_size,
                    output_size=output_size)

x, _ = next(iter(train_loader))
encoded, decoded = autoencoder(x)

# Print the encoded and decoded sequences
print("Encoded sequence shape:", encoded.shape)
print("Decoded sequences shape:", decoded.shape)

Encoded sequence shape: torch.Size([64, 128])
Decoded sequences shape: torch.Size([64, 128, 42])


In [416]:
model = train(autoencoder=model, dataloader=train_loader)

Training started
Epoch [1/40], Loss: 0.0566
Epoch [2/40], Loss: 0.0530
Epoch [3/40], Loss: 0.0457
Epoch [4/40], Loss: 0.0448
Epoch [5/40], Loss: 0.0450
Epoch [6/40], Loss: 0.0424
Epoch [7/40], Loss: 0.0416
Epoch [8/40], Loss: 0.0424
Epoch [9/40], Loss: 0.0422
Epoch [10/40], Loss: 0.0423
Epoch [11/40], Loss: 0.0400
Epoch [12/40], Loss: 0.0409
Epoch [13/40], Loss: 0.0420
Epoch [14/40], Loss: 0.0416
Epoch [15/40], Loss: 0.0410
Epoch [16/40], Loss: 0.0394
Epoch [17/40], Loss: 0.0399
Epoch [18/40], Loss: 0.0415
Epoch [19/40], Loss: 0.0402
Epoch [20/40], Loss: 0.0409
Epoch [21/40], Loss: 0.0387
Epoch [22/40], Loss: 0.0416
Epoch [23/40], Loss: 0.0402
Epoch [24/40], Loss: 0.0403
Epoch [25/40], Loss: 0.0412
Epoch [26/40], Loss: 0.0371
Epoch [27/40], Loss: 0.0393
Epoch [28/40], Loss: 0.0382
Epoch [29/40], Loss: 0.0395
Epoch [30/40], Loss: 0.0393
Epoch [31/40], Loss: 0.0407
Epoch [32/40], Loss: 0.0415
Epoch [33/40], Loss: 0.0390
Epoch [34/40], Loss: 0.0376
Epoch [35/40], Loss: 0.0400
Epoch [36/40

In [417]:
def random_fp():
    fp = np.random.rand(4860)
    fp = fp > 0.5
    fp = torch.from_numpy(fp)
    fp = torch.unsqueeze(fp, 0).to(device)
    fp = fp.float()
    return(fp)

In [418]:
encoded, decoded = model(random_fp())

# GRU output to SELFIES

In [422]:
decoded_indices = torch.argmax(decoded, dim=2)
decoded_indices = decoded_indices.numpy()

In [423]:
selfies = []
for i in decoded_indices:
    vectorized = []
    for number in decoded_indices[0]:
        v = np.zeros(128)
        v[number] = 1
        selfies_vectorized.append(v)
    devectorized = vectorizer.devectorize(selfies_vectorized)
    selfies.append(devectorized)
selfies

['[\\O][\\O][\\O][\\O][\\O][\\O][\\O][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1][Ring1]']