In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

! cp /content/gdrive/MyDrive/protein.txt /content
! cp /content/gdrive/MyDrive/smile.txt /content

Mounted at /content/gdrive


In [None]:
protein_txt = open("protein.txt", encoding='utf8').read().split('\n')
smile_txt = open("smile.txt", encoding='utf8').read().split('\n')

In [None]:
protein_txt.pop()
smile_txt.pop()

''

In [None]:
smile_txt

['O = C ( N 1 C C 2 C C 2 ( c 2 c [ n H ] c 3 n c c c c 2 3 ) C 1 ) C 1 ( c 2 c c c ( F ) c c 2 ) C C 1',
 'C n 1 n c c 2 c ( - c 3 c c c n 4 n c ( N c 5 c c c 6 c ( c 5 ) C C N ( C ( = O ) O C ( C ) ( C ) C ) C C 6 ) n c 3 4 ) c c c c 2 1',
 'C n 1 n c c 2 c ( - c 3 c c c n 4 n c ( N c 5 c c c 6 c ( c 5 ) C C N ( C ( = O ) O C ( C ) ( C ) C ) C C 6 ) n c 3 4 ) c c c c 2 1',
 'c 1 c c c ( N c 2 n c n c 3 n [ n H ] c ( N c 4 c c c c c 4 ) c 2 3 ) c c 1',
 'C l c 1 c c c c ( N c 2 [ n H ] n c 3 n c n c ( N c 4 c c c c c 4 ) c 2 3 ) c 1',
 'C l c 1 c c c c ( N c 2 n c n c 3 n [ n H ] c ( N c 4 c c c c ( C l ) c 4 ) c 2 3 ) c 1',
 'C O c 1 c c c ( N c 2 [ n H ] n c 3 n c n c ( N c 4 c c c c ( C l ) c 4 ) c 2 3 ) c c 1',
 'O c 1 c c c ( N c 2 [ n H ] n c 3 n c n c ( N c 4 c c c c ( C l ) c 4 ) c 2 3 ) c c 1',
 'C O c 1 c c c c ( N c 2 [ n H ] n c 3 n c n c ( N c 4 c c c c ( C l ) c 4 ) c 2 3 ) c 1',
 'O c 1 c c c c ( N c 2 [ n H ] n c 3 n c n c ( N c 4 c c c c ( C l ) c 4 ) c 2 3 ) c 1',
 '

In [None]:
def remove_lines(max_len):
    items_to_remove = []
    for idx in range(len(protein_txt)):
        if len(protein_txt[idx].replace(" ", "")) > max_len:
            items_to_remove.append(idx)
    padding = 0
    for idx in items_to_remove:
        protein_txt.pop(idx - padding)
        smile_txt.pop(idx - padding)
        padding += 1

remove_lines(1224)

In [None]:
tokenize = lambda x : x.split()

In [None]:
from torchtext.vocab import build_vocab_from_iterator

def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenize(text)

def get_vocab(train_datapipe):
    vocab = build_vocab_from_iterator(yield_tokens(train_datapipe),
                                     specials=['<UNK>', '<PAD>', '<SOS>', '<EOS>'],
                                     max_tokens=20000)
    vocab.set_default_index(vocab['<UNK>'])
    return vocab

protein_voc = get_vocab(protein_txt)
smile_voc = get_vocab(smile_txt)

smile_voc.get_itos()

['<UNK>',
 '<PAD>',
 '<SOS>',
 '<EOS>',
 'c',
 'C',
 '(',
 ')',
 '1',
 'O',
 '2',
 'N',
 'n',
 '=',
 '3',
 '@',
 '[',
 ']',
 'H',
 'F',
 '4',
 '-',
 'l',
 '5',
 'S',
 '#',
 '/',
 's',
 'o',
 '6',
 'B',
 '\\',
 'r',
 '+',
 'P',
 'I',
 '7',
 '.',
 '8',
 'i',
 'e',
 'a',
 '9',
 'R',
 'u',
 'V',
 'L',
 'b',
 'K',
 'Z',
 't']

In [None]:
import torch
torch.save(protein_voc, 'protein-vocab.pt')
torch.save(smile_voc, 'smiles-vocab.pt')

In [None]:
import torch
import numpy as np


def data_process():
    data = []
    for (protein, smile) in zip(protein_txt, smile_txt):
        protein_tensor_ = np.array([protein_voc[token] for token in tokenize(protein)]).astype(np.int64)
        smile_tensor_ = np.array([smile_voc[token] for token in tokenize(smile)]).astype(np.int64)
        data.append([protein_tensor_, smile_tensor_])
    return data

data = data_process()

In [None]:
SOS_token = np.array([2])
EOS_token = np.array([3])


def generate_batch(data, batch_size, padding_token):
    batches = []
    d = data
    for idx in range(0, len(data), batch_size):
        # We make sure we dont get the last bit if its not batch_size size
        if idx + batch_size-1 < len(data):
            for i in range(2):
                max_batch_length = 0

                # Get longest sentence in batch
                for seq_pack in data[idx : idx + batch_size]:
                    if len(seq_pack[i]) > max_batch_length:
                        max_batch_length = len(seq_pack[i])

                # Append X padding tokens until it reaches the max length
                for seq_idx in range(batch_size):
                    remaining_length = max_batch_length - len(data[idx + seq_idx][i])
                    padding = [padding_token] * remaining_length
                    data[idx + seq_idx][i] = np.concatenate((SOS_token, data[idx + seq_idx][i], padding, EOS_token))

            batches.append(np.array(d[idx : idx + batch_size]))

    print(f"{len(batches)} batches of size {batch_size}")

    return batches

train_batch = generate_batch(data[0:800000], 14, 1)
val_batch = generate_batch(data[800000:], 14, 1)

  batches.append(np.array(d[idx : idx + batch_size]))


57142 batches of size 14
7110 batches of size 14


In [None]:
import torch
import torch.nn as nn

import random
import math
import numpy as np
import matplotlib.pyplot as plt

class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()

        self.dropout = nn.Dropout(dropout_p)

        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)

        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)

        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)

        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)

    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

In [None]:
class Transformer(nn.Module):

    # Constructor
    def __init__(
        self,
        src_tokens,
        trg_tokens,
        dim_model,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        dropout_p,
    ):
        super().__init__()

        # INFO
        self.model_type = "Transformer"
        self.dim_model = dim_model

        # LAYERS
        self.positional_encoder = PositionalEncoding(
            dim_model=dim_model, dropout_p=dropout_p, max_len=5000
        )
        self.input_embedding = nn.Embedding(src_tokens, dim_model)
        self.output_embedding = nn.Embedding(trg_tokens, dim_model)
        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p,
        )
        self.out = nn.Linear(dim_model, trg_tokens)

    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):

        src_emb = self.input_embedding(src) * math.sqrt(self.dim_model)
        tgt_emb = self.output_embedding(tgt) * math.sqrt(self.dim_model)
        src_emb = self.positional_encoder(src_emb)
        tgt_emb = self.positional_encoder(tgt_emb)

        # Transformer blocks - Out size = (sequence length, batch_size, dim_model)
        transformer_out = self.transformer(src_emb.transpose(0, 1), tgt_emb.transpose(0, 1), tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        transformer_out = transformer_out.transpose(0, 1)


        # Output layer - Out size = (batch_size, sequence length, output_vocab_size)
        out = self.out(transformer_out)
        out = out.permute(1,0,2)
        return out

    def get_tgt_mask(self, size) -> torch.tensor:
        # Generates a squeare matrix where the each row allows one word more to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0

        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]

        return mask

    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(
    src_tokens=len(protein_voc), trg_tokens=len(smile_voc), dim_model=256, num_heads=8, num_encoder_layers=6, num_decoder_layers=6, dropout_p=0.1
).to(device)

opt = torch.optim.SGD(model.parameters(), lr=0.01)
#opt = torch.optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

In [None]:
def train_loop(model, opt, loss_fn, dataloader):
    model.train()
    total_loss = 0

    for batch in dataloader:
        X, y = batch[:, 0], batch[:, 1]
        X, y = np.vstack(X).astype(np.int64), np.vstack(y).astype(np.int64)
        X, y = torch.tensor(X).to(device), torch.tensor(y).to(device)

        # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
        y_input = y[:,:-1]
        y_expected = y[:,1:]

        # Get mask to mask out the next words
        sequence_length = y_input.size(1)
        tgt_mask = model.get_tgt_mask(sequence_length).to(device)

        # Standard training except we pass in y_input and tgt_mask
        pred = model(X, y_input, tgt_mask)

        # Permute pred to have batch size first again
        pred = pred.permute(1, 2, 0)
        loss = loss_fn(pred, y_expected)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.detach().item()

    return total_loss / len(dataloader)

In [None]:
def validation_loop(model, loss_fn, dataloader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch[:, 0], batch[:, 1]
            X, y = np.vstack(X).astype(np.int64), np.vstack(y).astype(np.int64)
            X, y = torch.tensor(X, dtype=torch.long, device=device), torch.tensor(y, dtype=torch.long, device=device)

            # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
            y_input = y[:,:-1]
            y_expected = y[:,1:]

            # Get mask to mask out the next words
            sequence_length = y_input.size(1)
            tgt_mask = model.get_tgt_mask(sequence_length).to(device)

            # Standard training except we pass in y_input and src_mask
            pred = model(X, y_input, tgt_mask)

            # Permute pred to have batch size first again
            pred = pred.permute(1, 2, 0)
            loss = loss_fn(pred, y_expected)
            total_loss += loss.detach().item()

    return total_loss / len(dataloader)

In [None]:
def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []

    print("Training and validating model")
    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)

        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]

        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]

        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()

    return train_loss_list, validation_loss_list

train_loss_list, validation_loss_list = fit(model, opt, loss_fn, train_batch, val_batch, 1)

Training and validating model
------------------------- Epoch 1 -------------------------
Validation loss: 1.4357



In [None]:
def predict(model, input_sequence, max_length=150, SOS_token=2, EOS_token=3):
    model.eval()

    y_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device)

    num_tokens = len(input_sequence[0])

    for _ in range(max_length):
        # Get source mask
        tgt_mask = model.get_tgt_mask(y_input.size(1)).to(device)

        pred = model(input_sequence, y_input, tgt_mask)

        next_item = pred.topk(1)[1].view(-1)[-1].item() # num with highest probability
        next_item = torch.tensor([[next_item]], device=device)

        # Concatenate previous input with predicted best word
        y_input = torch.cat((y_input, next_item), dim=1)

        # Stop if model predicts end of sentence
        if next_item.view(-1).item() == EOS_token or next_item.view(-1).item() == 1:
            break

    return y_input.view(-1).tolist()


# Here we test some examples to observe how the model predicts
examples = [
    torch.tensor([[2, 20, 5, 21, 12, 5, 5, 7, 15, 7, 10, 5, 8, 7, 11, 7, 9, 4, 8, 11, 12, 5, 11, 4, 15, 4, 13, 14, 15, 4, 13, 21, 21, 21, 11, 7, 7, 15, 11, 7, 7, 13, 8, 10, 18, 4, 18, 18, 8, 4, 4, 4, 12, 18, 18, 18, 18, 18, 12, 18, 10, 18, 4, 4, 12, 6, 8, 16, 18, 10, 18, 20, 8, 17, 4, 14, 13, 18, 20, 18, 6, 18, 4, 18, 8, 20, 12, 10, 8, 4, 4, 6, 12, 10, 18, 18, 18, 8, 4, 4, 8, 10, 8, 18, 10, 4, 8, 18, 18, 13, 18, 8, 18, 8, 7, 8, 13, 20, 13, 13, 8, 18, 18, 4, 11, 11, 4, 13, 9, 10, 15, 13, 9, 13, 8, 13, 6, 7, 6, 5, 14, 8, 7, 10, 18, 10, 4, 18, 8, 16, 4, 4, 5, 10, 5, 6, 14, 10, 15, 14, 11, 14, 17, 9, 10, 17, 20, 5, 7, 5, 13, 20, 11, 10, 4, 23, 19, 14, 6, 6, 20, 20, 14, 5, 4, 15, 18, 5, 5, 11, 11, 4, 5, 9, 14, 5, 11, 5, 19, 10, 19, 14, 4, 11, 9, 6, 18, 15, 6, 10, 15, 15, 16, 11, 4, 13, 10, 14, 6, 5, 8, 11, 17, 4, 10, 7, 13, 5, 13, 4, 10, 18, 10, 7, 6, 8, 13, 13, 5, 5, 11, 4, 4, 13, 13, 10, 15, 9, 17, 7, 7, 14, 5, 16, 10, 10, 13, 21, 16, 8, 7, 14, 8, 5, 5, 7, 5, 5, 5, 5, 11, 9, 5, 9, 11, 5, 5, 11, 17, 17, 9, 3]], dtype=torch.long, device=device)
]

for idx, example in enumerate(examples):
    result = predict(model, example)
    print(f"Example {idx}")
    print(f"Input: {example.view(-1).tolist()[1:-1]}")
    print(f"Continuation: {result[1:-1]}")
    print()

Example 0
Input: [20, 5, 21, 12, 5, 5, 7, 15, 7, 10, 5, 8, 7, 11, 7, 9, 4, 8, 11, 12, 5, 11, 4, 15, 4, 13, 14, 15, 4, 13, 21, 21, 21, 11, 7, 7, 15, 11, 7, 7, 13, 8, 10, 18, 4, 18, 18, 8, 4, 4, 4, 12, 18, 18, 18, 18, 18, 12, 18, 10, 18, 4, 4, 12, 6, 8, 16, 18, 10, 18, 20, 8, 17, 4, 14, 13, 18, 20, 18, 6, 18, 4, 18, 8, 20, 12, 10, 8, 4, 4, 6, 12, 10, 18, 18, 18, 8, 4, 4, 8, 10, 8, 18, 10, 4, 8, 18, 18, 13, 18, 8, 18, 8, 7, 8, 13, 20, 13, 13, 8, 18, 18, 4, 11, 11, 4, 13, 9, 10, 15, 13, 9, 13, 8, 13, 6, 7, 6, 5, 14, 8, 7, 10, 18, 10, 4, 18, 8, 16, 4, 4, 5, 10, 5, 6, 14, 10, 15, 14, 11, 14, 17, 9, 10, 17, 20, 5, 7, 5, 13, 20, 11, 10, 4, 23, 19, 14, 6, 6, 20, 20, 14, 5, 4, 15, 18, 5, 5, 11, 11, 4, 5, 9, 14, 5, 11, 5, 19, 10, 19, 14, 4, 11, 9, 6, 18, 15, 6, 10, 15, 15, 16, 11, 4, 13, 10, 14, 6, 5, 8, 11, 17, 4, 10, 7, 13, 5, 13, 4, 10, 18, 10, 7, 6, 8, 13, 13, 5, 5, 11, 4, 4, 13, 13, 10, 15, 9, 17, 7, 7, 14, 5, 16, 10, 10, 13, 21, 16, 8, 7, 14, 8, 5, 5, 7, 5, 5, 5, 5, 11, 9, 5, 9, 11, 5, 5, 1

In [None]:
def translate(input):
  output = ''
  for word in input:
      output += smile_voc.get_itos()[word]
  return output

translate([5, 5, 6, 5, 7, 6, 5, 7, 4, 8, 4, 4, 4, 6, 11, 4, 10, 12, 4, 4, 6, 11, 4, 14, 4, 4, 4, 6, 5, 22, 7, 4, 6, 5, 22, 7, 4, 4, 14, 7, 12, 4, 10, 7, 4, 4, 8])

'CC(C)(C)c1ccc(Nc2ncc(Nc3ccc(Cl)c(Cl)cc3)nc2)cc1'