Run this notebook on Google colab to be able to easily use GPUs to train the transformer model.

In [1]:
import os
import re
import shutil
import random

import math
import pickle
import joblib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.utils.data as data
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch import optim
import torch.nn.functional as F
from torchvision import models, transforms, datasets, utils
# from tqdm import tqdm
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score
import itertools
import copy

pd.set_option('display.max_columns', 1000)

%load_ext autoreload
%autoreload 2

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2

PAD = '<pad>'
SOS = '<sos>'
EOS = '<eos>'

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Using gpu: %s ' % torch.cuda.is_available())

Using gpu: True 


In [5]:
!nvidia-smi

Fri May 31 19:13:27 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8              10W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [6]:
with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/pairs_train_tensors.pk', 'rb') as file:
    pairs_train_tensors = pickle.load(file)

with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/pairs_test_tensors.pk', 'rb') as file:
    pairs_test_tensors = pickle.load(file)

with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/pairs_val_tensors.pk', 'rb') as file:
    pairs_val_tensors = pickle.load(file)

In [7]:
with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/symbole_to_index_encoding.pk', 'rb') as file:
    symb2index = pickle.load(file)

with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/index_to_symbol_decoding.pk', 'rb') as file:
    index2symb = pickle.load(file)

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value

        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [9]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [11]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [12]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [13]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model).to(device)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model).to(device)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length).to(device)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]).to(device)
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]).to(device)

        self.fc = nn.Linear(d_model, tgt_vocab_size).to(device)
        self.dropout = nn.Dropout(dropout).to(device)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool().to(device)
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [14]:
def evaluate_symb(transformer, pair, max_length=136):
    transformer.eval()
    with torch.no_grad():
        output_sequence = transformer(pair[0].view(1,-1), pair[1].view(1,-1))

        decoded_atom = []

        for di in range(max_length):
            topv, topi = torch.topk(output_sequence[:,di,:], 1, 1)

            if topi.item() == EOS_IDX:
                break
            elif topi.item() in [PAD_IDX,SOS_IDX]:
                continue
            else:
                decoded_atom.append(index2symb[topi.item()])

            decoder_input = topi.squeeze().detach()

        return ''.join(decoded_atom)

In [15]:
def evaluate_idx(transformer, pair, max_length=224):
    transformer.eval()
    with torch.no_grad():
        output = transformer(pair[0].view(1,-1), pair[1].view(1,-1))

        output_sequence = []

        for di in range(max_length):
            topv, topi = torch.topk(output[:,di,:], 1, 1)

            if topi.item() == EOS_IDX:
                output_sequence.append(topi.item())
                break
            else:
                output_sequence.append(topi.item())

        while len(output_sequence) < max_length:
            output_sequence.append(PAD_IDX)

        return output_sequence

In [16]:
def evaluate_idx_train(model_output, max_length=224):
    transformer.eval()
    with torch.no_grad():
        output_sequence = []
        for di in range(max_length):
            topv, topi = torch.topk(model_output[di,:], 1, 0)
            if topi.item() == EOS_IDX:
                output_sequence.append(topi.item())
                break
            else:
                output_sequence.append(topi.item())

        while len(output_sequence) < max_length:
            output_sequence.append(PAD_IDX)

        return output_sequence

In [17]:
def evaluateRandomly(transformer, n=10):
    for i in range(n):
        pair = random.choice(test_batch)
        print('Amino acid target:', ''.join([index2symb[i.item()] for i in pair[0] if i.item() not in [PAD_IDX,SOS_IDX,EOS_IDX]]))
        print('SMILE ligand:', ''.join([index2symb[i.item()] for i in pair[1] if i.item() not in [PAD_IDX,SOS_IDX,EOS_IDX]]))

        decoded_output = evaluate_symb(transformer, pair)
        print('Pred SMILE ligand:', decoded_output)

In [18]:
def compute_test_metrics(transformer, data_loader_test, batch_size, out_vocab_size, criterion, sample_nb_natch):
    transformer.eval()
    with torch.no_grad():
        loss = 0
        total_batches = len(data_loader_test)
        selected_batch_indices = random.sample(range(total_batches), sample_nb_natch)
        accuracy_test = []
        f1score_test = []
        for i, data in enumerate(data_loader_test):
            if i in selected_batch_indices:
              in_data, out_data = data

              output = transformer(in_data.view(batch_size,-1), out_data.view(batch_size,-1))
              loss = criterion(output.contiguous().view(-1, out_vocab_size), out_data.contiguous().view(-1))

              for j in range(batch_size):
                  accuracy_test.append(accuracy_score(evaluate_idx_train(output[j,:,:], max_length=224),
                                                       out_data[j,:].cpu().detach().numpy()))
                  f1score_test.append(f1_score(evaluate_idx_train(output[j,:], max_length=224),
                                            out_data[j,:,:].cpu().detach().numpy(), average='weighted'))

        accuracy = np.mean(accuracy_test)
        f1score = np.mean(f1score_test)

    return loss, accuracy, f1score

### Pack into batches

In [21]:
len(pairs_train_tensors)

337459

In [22]:
len(pairs_test_tensors)

52514

In [23]:
len(pairs_val_tensors)

72065

In [25]:
# Sample the data sets to reduce the training time
pairs_train_tensors_sample = random.sample(pairs_train_tensors, int(len(pairs_train_tensors) / 2))
pairs_test_tensors_sample = random.sample(pairs_test_tensors, int(len(pairs_test_tensors) / 2))
pairs_val_tensors_sample = random.sample(pairs_val_tensors, int(len(pairs_val_tensors) / 2))

In [26]:
pairs_train_tensors_input = [pair[0].to(device) for pair in pairs_train_tensors_sample]
pairs_train_tensors_output = [pair[1].to(device) for pair in pairs_train_tensors_sample]

pairs_test_tensors_input = [pair[0].to(device) for pair in pairs_test_tensors_sample]
pairs_test_tensors_output = [pair[1].to(device) for pair in pairs_test_tensors_sample]

pairs_val_tensors_input = [pair[0].to(device) for pair in pairs_val_tensors_sample]
pairs_val_tensors_output = [pair[1].to(device) for pair in pairs_val_tensors_sample]

In [27]:
len(pairs_train_tensors_input)

168729

In [28]:
input_batch = pad_sequence(pairs_train_tensors_input + pairs_test_tensors_input + pairs_val_tensors_input,
                                 padding_value=0,
                                 batch_first=True)
output_batch = pad_sequence(pairs_train_tensors_output + pairs_test_tensors_output + pairs_val_tensors_output,
                            padding_value=0,
                            batch_first=True)

In [29]:
input_train_batch = input_batch[:len(pairs_train_tensors_input)]
input_test_batch = input_batch[len(pairs_train_tensors_input):len(pairs_train_tensors_input)+len(pairs_test_tensors_input)]
input_val_batch = input_batch[len(pairs_train_tensors_input)+len(pairs_test_tensors_input):]

output_train_batch = output_batch[:len(pairs_train_tensors_output)]
output_test_batch = output_batch[len(pairs_train_tensors_output):len(pairs_train_tensors_input)+len(pairs_test_tensors_output)]
output_val_batch = output_batch[len(pairs_train_tensors_input)+len(pairs_test_tensors_output):]

In [30]:
input_train_batch.size()

torch.Size([168729, 2027, 1])

In [31]:
output_train_batch.size()

torch.Size([168729, 224, 1])

In [32]:
input_test_batch.size()

torch.Size([26257, 2027, 1])

In [33]:
train_batch = [[input_train_batch[index].to(device), output_train_batch[index].to(device)] for index in range(len(input_train_batch))]
test_batch = [[input_test_batch[index].to(device), output_test_batch[index].to(device)] for index in range(len(input_test_batch))]
val_batch = [[input_val_batch[index].to(device), output_val_batch[index].to(device)] for index in range(len(input_val_batch))]

In [34]:
len(test_batch)

26257

In [35]:
len(random.choice(test_batch)[0])

2027

In [36]:
max_seq_length = len(test_batch[0][0])

In [43]:
train_dataloader = DataLoader(train_batch, batch_size=16, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_batch, batch_size=16, shuffle=True, drop_last=True)

### Training

In [37]:
!pip install wandb -qU

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m71.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m289.0/289.0 kB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [38]:
# Log in to your W&B account
import wandb

In [39]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [40]:
in_vocab_size = len(symb2index)
out_vocab_size = len(symb2index)
d_model = 24
num_heads = 2
num_layers = 2
d_ff = 48
max_seq_length = max_seq_length
dropout = 0.1
batch_size = 16

transformer = Transformer(in_vocab_size, out_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

In [44]:
in_data, out_data = next(iter(train_dataloader))
output = transformer(in_data.view(batch_size,-1), out_data.view(batch_size,-1))

In [45]:
in_data.size()

torch.Size([16, 2027, 1])

In [46]:
in_data.view(batch_size,-1).size()

torch.Size([16, 2027])

In [47]:
len(train_dataloader)

10545

In [48]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

wandb.init(
    project="Protein-Specific-Drug-Generation",
    name=f"transformer_1",
    config={
        "learning_rate": 0.0001,
        "batch_size": batch_size,
        "architecture": "transformer",
        "epochs": 2,
    })

transformer.train()

for epoch in range(2):
    target_train_list = []
    output_train_list = []
    for i, data in tqdm(enumerate(train_dataloader)):
        in_data, out_data = data
        optimizer.zero_grad()
        output = transformer(in_data.view(batch_size,-1), out_data.view(batch_size,-1))
        loss = criterion(output.contiguous().view(-1, out_vocab_size), out_data.contiguous().view(-1))
        loss.backward()
        optimizer.step()

        # Metrics computation
        if i % 20 == 0:
          accuracy_train = []
          f1score_train = []
          for i in range(batch_size):
            accuracy_train.append(accuracy_score(evaluate_idx_train(output[i,:], max_length=224),
                                                 out_data[i,:,:].cpu().detach().numpy()))
            f1score_train.append(f1_score(evaluate_idx_train(output[i,:], max_length=224),
                                          out_data[i,:,:].cpu().detach().numpy(), average='weighted'))

          loss_test, accuracy_test, f1_score_test = compute_test_metrics(transformer, test_dataloader, batch_size, out_vocab_size, criterion, sample_nb_natch=1)

          wandb.log({"acc_train": np.mean(accuracy_train), "loss_train": loss,
                    "f1_score_train": np.mean(f1score_train),
                    "acc_test": accuracy_test, "loss_test": loss_test,
                    "f1_score_test": f1_score_test})

    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

[34m[1mwandb[0m: Currently logged in as: [33moriane-cavrois[0m. Use [1m`wandb login --relogin`[0m to force relogin


0it [00:00, ?it/s]

Epoch: 1, Loss: 2.195278284489177e-06


0it [00:00, ?it/s]

Epoch: 2, Loss: 1.1915673212570255e-06


In [49]:
evaluateRandomly(transformer)

Amino acid target: MGRRPQLRLVKALLLLGLNPVSTSLQDQRCENLSLTSNVSGLQCNASVDLIGTCWPRSPAGQLVVRPCPAFFYGVRYNTTNNGYRECLANGSWAARVNYSECQEILNEEKKSKVHYHVAVIINYLGHCISLVALLVAFVLFLRLRSIRCLRNIIHWNLISAFILRNATWFVVQLTVSPEVHQSNVAWCRLVTAAYNYFHVTNFFWMFGEGCYLHTAIVLTYSTDRLRKWMFVCIGWGVPFPIIVAWAIGKLHYDNEKCWFGKRPGVYTDYIYQGPMILVLLINFIFLFNIVRILMTKLRASTTSETIQYRKAVKATLVLLPLLGITYMLFFVNPGEDEVSRVVFIYFNSFLESFQGFFVSVFYCFLNSEVRSAIRKRWRRWQDKHSIRARVARAMSIPTSPTRVSFHSIKQSTAV
SMILE ligand: COC[C@H](C1CC1)n1cc(Cl)nc(Nc2cc(C#N)c(OC(F)F)nc2C)c1=O
Pred SMILE ligand: COC[C@H](C1CC1)n1cc(Cl)nc(Nc2cc(C#N)c(OC(F)F)nc2C)c1=O
Amino acid target: MSGTKLEDSPPCRNWSSASELNETQEPFLNPTDYDDEEFLRYLWREYLHPKEYEWVLIAGYIIVFVVALIGNVLVCVAVWKNHHMRTVTNYFIVNLSLADVLVTITCLPATLVVDITETWFFGQSLCKVIPYLQTVSVSVSVLTLSCIALDRWYAICHPLMFKSTAKRARNSIVIIWIVSCIIMIPQAIVMECSTVFPGLANKTTLFTVCDERWGGEIYPKMYHICFFLVTYMAPLCLMVLAYLQIFRKLWCRQIPGTSSVVQRKWKPLQPVSQPRGPGQPTKSRMSAVAAEIKQIRARRKTARMLMIVLLVFAICYLPISILNVLKRVFGMFAHTEDRETVYAWFTFSHWLVYANSAANPIIYNFLSGKFREEFKAAFSCCCLGVHHRQEDRLTRGRTSTES

In [50]:
torch.save(transformer, 'drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/transformer_init1')

In [52]:
model_pkl_file = "drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/transformer_init1_pk.pkl"

with open(model_pkl_file, 'wb') as file:
    pickle.dump(transformer, file)

In [61]:
model_pkl_file = "drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/transformer_init1_pk.pkl"

# Loading the transformer model from the pickle file
with open(model_pkl_file, 'rb') as file:
    transformer_init = pickle.load(file)

In [62]:
in_vocab_size = len(symb2index)
out_vocab_size = len(symb2index)
d_model = 24
num_heads = 2
num_layers = 2
d_ff = 48
max_seq_length = max_seq_length
dropout = 0.1
batch_size = 16

criterion = nn.CrossEntropyLoss(ignore_index=0)

In [63]:
len(test_dataloader)

1641

In [64]:
loss_test, accuracy_test, f1score_test = compute_test_metrics(transformer_init, test_dataloader, 16, out_vocab_size, criterion, len(test_dataloader))

In [65]:
accuracy_test

0.9999991498541829

In [66]:
f1score_test

0.9999991498541829

### Infer the sequences for the test set

In [67]:
transformer_init.eval()
with torch.no_grad():
    input_sequences = []
    targets_sequences = []
    outputs_sequences = []
    total_batches = len(test_dataloader)
    for i, data in tqdm(enumerate(test_dataloader)):
        in_data, out_data = data
        for i in range(in_data.size()[0]):
          input_sequences.append(''.join([index2symb[i.item()] for i in in_data[i,:,:] if i.item() not in [PAD_IDX,SOS_IDX,EOS_IDX]]))

          decoded_output = evaluate_symb(transformer_init, (in_data[i,:,:], out_data[i,:,:]))

          targets_sequences.append(''.join([index2symb[i.item()] for i in out_data[i,:,:] if i.item() not in [PAD_IDX,SOS_IDX,EOS_IDX]]))
          outputs_sequences.append(decoded_output)

0it [00:00, ?it/s]

In [68]:
targets_sequences[0]

'C[C@H](NC(=O)c1ccc(N2CC(CC#N)(C2)n2cc(cn2)-c2ncnc3[nH]ccc23)c(F)c1)C1CC1'

In [69]:
outputs_sequences[0]

'C[C@H](NC(=O)c1ccc(N2CC(CC#N)(C2)n2cc(cn2)-c2ncnc3[nH]ccc23)c(F)c1)C1CC1'

In [70]:
transformer_init.eval()
with torch.no_grad():
    targets_sequences_index = []
    outputs_sequences_index = []
    accuracy_list = []
    accuracy_mean_list = []
    total_batches = len(test_dataloader)
    for i, data in tqdm(enumerate(test_dataloader)):
        in_data, out_data = data
        for i in range(in_data.size()[0]):
          topv, topi = torch.topk(output[i,:,:], 1)
          targets_sequences_index.append(out_data[i,:,:].cpu().detach().numpy())
          outputs_sequences_index.append(evaluate_idx(transformer_init, (in_data[i,:,:], out_data[i,:,:])))
          accuracy_list.append(accuracy_score(evaluate_idx(transformer_init, (in_data[i,:,:], out_data[i,:,:])),
                                              out_data[i,:,:].cpu().detach().numpy()))

0it [00:00, ?it/s]

In [73]:
targets_sequences_index[0]

array([[ 1],
       [20],
       [30],
       [20],
       [36],
       [32],
       [24],
       [25],
       [13],
       [20],
       [25],
       [27],
       [28],
       [26],
       [13],
       [25],
       [20],
       [24],
       [27],
       [28],
       [26],
       [23],
       [24],
       [23],
       [23],
       [23],
       [25],
       [20],
       [42],
       [13],
       [26],
       [23],
       [25],
       [23],
       [24],
       [26],
       [20],
       [25],
       [ 5],
       [26],
       [25],
       [ 5],
       [26],
       [ 5],
       [26],
       [23],
       [24],
       [23],
       [23],
       [23],
       [25],
       [28],
       [26],
       [23],
       [23],
       [24],
       [ 2],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],
       [ 0],

In [74]:
outputs_sequences_index[0]

[1,
 20,
 30,
 20,
 36,
 32,
 24,
 25,
 13,
 20,
 25,
 27,
 28,
 26,
 13,
 25,
 20,
 24,
 27,
 28,
 26,
 23,
 24,
 23,
 23,
 23,
 25,
 20,
 42,
 13,
 26,
 23,
 25,
 23,
 24,
 26,
 20,
 25,
 5,
 26,
 25,
 5,
 26,
 5,
 26,
 23,
 24,
 23,
 23,
 23,
 25,
 28,
 26,
 23,
 23,
 24,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]