In [1]:
# General Imports
import os
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader,random_split, Dataset
from utils.fixes import global_seed
import warnings
global_seed(42)
%matplotlib inline

Global seed set to 42


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
smifile = "GDB17.50000000LLnoSR.smi"
data = pd.read_csv(smifile, delimiter = "\t", names = ["smiles"])
data = data.sample(n=100000).reset_index(drop=True)

In [4]:
data.smiles[5]

'CC1C2NC(=O)C(C)C(N)=NC2(C)CC1(C)C'

In [5]:
# max len smiles
MAX_SMILES_LEN = max([len(smile) for smile in data.smiles]) + 5
print(MAX_SMILES_LEN)

51


In [6]:
multi_char = ["Cl", "Br", "Si"]
single_char = ['#', ')', '(', '+', '-', '/', '1', '3', '2', '5', '4', '7', '6', '8', '=', '@', 'C', 'B', 'F', 'I', 'H', 'O', 'N', 'S', '[', ']', '\\', 'c', 'l', 'o', 'n', 'p', 's', 'r']

In [7]:
import re

In [12]:
# create tokenizer for smiles strings
class SMILESTokenizer:
    def __init__(self, multi_char=["Cl", "Br", "Si"], start='?', end='E', max_len=60):
        self.multi_char = multi_char
        self.single_char = ['#', ')', '(', '+', '-', '/', '1', '3', '2', '5', '4', '7', '6', '8', '=', '@', 'C', 'B', 'F', 'I', 'H', 'O', 'N', 'S', '[', ']', '\\', 'c', 'l', 'o', 'n', 'p', 's', 'r']
        self.multi_pattern = self._generate_regex(multi_char)
        self.start = start
        self.end = end
        self.max_len = max_len
        self.vocab = self.single_char + self.multi_char
        self_pad = '<pad>'
        self.char2idx = {start: 1, end: 2, self_pad: 0}
        self.char2idx.update({char: idx + 3 for idx, char in enumerate(self.vocab)})
        self.idx2char = {idx: char for char, idx in self.char2idx.items()}
        self.vocab.extend([start])
    def tokenize(self, smiles):
        if len(smiles) > self.max_len:
            warnings.warn(f"SMILES string is longer than {self.max_len -1} characters. Skipping...")
            return None
        smiles = self.start + smiles + self.end
        split = re.split(self.multi_pattern, smiles)
        out = []
        for x in split:
            if x in self.multi_char:
                out.append(x)
                continue
            if x is None:
                continue
            for y in x:
                out.append(y)
        without_pad = [self.char2idx[x] for x in out]
        return without_pad + [0] * (self.max_len - len(without_pad))
    def detokenize(self, tokens, remove_start_end=True, remove_padding=True):
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.tolist()
        raw_string = "".join([self.idx2char[x] for x in tokens])
        if remove_padding:
            raw_string = raw_string.replace("<pad>", "")
        if remove_start_end:
            raw_string = raw_string.replace(self.start, "").replace(self.end, "")
        return raw_string

    def _generate_regex(self, multi_char):
        grouped = [f"({x})" for x in multi_char]
        multi_pattern = "|".join(grouped)
        return multi_pattern


In [13]:
tokenizer = SMILESTokenizer(max_len=50)
encoded = tokenizer.tokenize('CCCNBr')

In [14]:
encoded

[1, 19, 19, 19, 25, 38, 2]

In [16]:
tokenizer.detokenize(encoded)

'CCCNBr'

In [18]:
data.smiles[6000] == tokenizer.detokenize(tokenizer.tokenize(data.smiles[6000]))

[1, 19, 19, 25, 2]

In [20]:
class SMILESDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=50):
        self.smiles = data.smiles.to_list()
        self.tokenizer = tokenizer(max_len=max_len)
        self.max_len = max_len
    def __len__(self):
        return len(self.smiles)
    def __getitem__(self, idx):
        smiles_raw = self.smiles[idx]
        encoded = self.tokenizer.tokenize(smiles_raw)
        return torch.tensor(encoded, dtype=torch.long)

tensor([[ 1, 19, 19, 19, 25, 38,  2],
        [ 1, 19, 19, 25,  2,  0,  0]])

In [21]:
dataset = SMILESDataset(data, SMILESTokenizer, max_len=MAX_SMILES_LEN)

PackedSequence(data=tensor([ 1,  1, 19, 19, 19, 19, 19, 25, 25,  2, 38,  2]), batch_sizes=tensor([2, 2, 2, 2, 2, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [22]:
tokenizer = SMILESTokenizer(max_len=MAX_SMILES_LEN)

<function torchtext.datasets.multi30k.Multi30k(root: str = '.data', split: Union[Tuple[str], str] = ('train', 'valid', 'test'), language_pair: Tuple[str] = ('de', 'en'))>

In [43]:
dataset[7]

tensor([ 1, 19, 19, 19,  5, 19,  3, 19,  4, 19,  5, 19,  4, 19, 19,  9, 17, 25,
        25, 19,  5, 25, 19, 19, 17, 24,  4, 17, 25,  9,  2,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [44]:
tokenizer.detokenize(dataset[7]) == data.smiles[7]

True

In [45]:
# create dataloader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [4]:
import torch.nn as nn
import torch
example_2D_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(f'Shape of 2D tensor: {example_2D_tensor.shape}')
print(f'Shape after view: {example_2D_tensor.view(1, 1, -1).shape}')
print(example_2D_tensor.view(1, 1, -1))

Shape of 2D tensor: torch.Size([2, 3])
Shape after view: torch.Size([1, 1, 6])
tensor([[[1, 2, 3, 4, 5, 6]]])


In [78]:
import torch.functional as F
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.encoder = nn.GRU(embedding_size, hidden_size, batch_first=True)
        self.hidden_size = hidden_size
    def forward(self, x):
        x = self.embedding(x)
        x, hidden = self.encoder(x)
        return hidden[-1]

class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, batch_size):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.decoder = nn.GRU(embedding_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.activation = nn.LogSoftmax(dim=2)
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.start = torch.tensor([1])
    def forward(self, hidden):
        x = self.start.repeat(self.batch_size, 1)
        x = self.embedding(x)
        x, hidden = self.decoder(x, hidden)
        x = self.fc(x)
        x = self.activation(x)
        return x, hidden


In [79]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = Encoder(vocab_size=32, embedding_size=32, hidden_size=16)
dummy_input = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])
h = encoder(dummy_input)

In [80]:
h.shape

torch.Size([1, 16])

In [81]:
h

tensor([[-0.7786,  0.0748, -0.5652,  0.5272,  0.8104,  0.0347, -0.2540,  0.0094,
          0.3040,  0.1887, -0.0108,  0.2584,  0.3977,  0.1729,  0.7098, -0.0173]],
       grad_fn=<SelectBackward0>)

In [82]:
dummy_batch = dummy_input.repeat(5, 1)
dummy_batch

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16]])

In [83]:
encoder(dummy_batch)

tensor([[-0.7786,  0.0748, -0.5652,  0.5272,  0.8104,  0.0347, -0.2540,  0.0094,
          0.3040,  0.1887, -0.0108,  0.2584,  0.3977,  0.1729,  0.7098, -0.0173],
        [-0.7786,  0.0748, -0.5652,  0.5272,  0.8104,  0.0347, -0.2540,  0.0094,
          0.3040,  0.1887, -0.0108,  0.2584,  0.3977,  0.1729,  0.7098, -0.0173],
        [-0.7786,  0.0748, -0.5652,  0.5272,  0.8104,  0.0347, -0.2540,  0.0094,
          0.3040,  0.1887, -0.0108,  0.2584,  0.3977,  0.1729,  0.7098, -0.0173],
        [-0.7786,  0.0748, -0.5652,  0.5272,  0.8104,  0.0347, -0.2540,  0.0094,
          0.3040,  0.1887, -0.0108,  0.2584,  0.3977,  0.1729,  0.7098, -0.0173],
        [-0.7786,  0.0748, -0.5652,  0.5272,  0.8104,  0.0347, -0.2540,  0.0094,
          0.3040,  0.1887, -0.0108,  0.2584,  0.3977,  0.1729,  0.7098, -0.0173]],
       grad_fn=<SelectBackward0>)