In [1]:
import os
import re
import shutil
import random

import math
import pickle
import joblib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.utils.data as data
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch import optim
import torch.nn.functional as F
from torchvision import models, transforms, datasets, utils
from tqdm import tqdm
from tqdm import tqdm_notebook
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score
import itertools
import copy

from transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertModel.from_pretrained("Rostlab/prot_bert")

pd.set_option('display.max_columns', 1000)

%load_ext autoreload
%autoreload 2

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/361 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

In [2]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

In [None]:
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2

PAD = '<pad>'
SOS = '<sos>'
EOS = '<eos>'

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Using gpu: %s ' % torch.cuda.is_available())

In [None]:
with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/pairs_train_tensors.pk', 'rb') as file:
    pairs_train_tensors = pickle.load(file)

with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/pairs_test_tensors.pk', 'rb') as file:
    pairs_test_tensors = pickle.load(file)

with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/pairs_val_tensors.pk', 'rb') as file:
    pairs_val_tensors = pickle.load(file)

In [None]:
with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/symbole_to_index_encoding.pk', 'rb') as file:
    symb2index = pickle.load(file)

with open('drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/index_to_symbol_decoding.pk', 'rb') as file:
    index2symb = pickle.load(file)

### Evaluation functions

In [None]:
def evaluate(model, pair, max_length=136):
    model.eval()
    with torch.no_grad():
        output_sequence = model(pair[0].view(1,-1), pair[1].view(1,-1))

        decoded_atom = []

        for di in range(max_length):
            topv, topi = torch.topk(output_sequence[:,di,:], 1, 1)

            if topi.item() == EOS_IDX:
                break
            elif topi.item() in [PAD_IDX,SOS_IDX]:
                continue
            else:
                decoded_atom.append(index2symb[topi.item()])

        return ''.join(decoded_atom)

In [None]:
def evaluate_idx_train(model, model_output, max_length=224):
    model.eval()
    with torch.no_grad():
        output_sequence = []
        for di in range(max_length):
            topv, topi = torch.topk(model_output[di,:], 1, 0)
            if topi.item() == EOS_IDX:
                output_sequence.append(topi.item())
                break
            else:
                output_sequence.append(topi.item())

        while len(output_sequence) < max_length:
            output_sequence.append(PAD_IDX)

        return output_sequence

In [None]:
def evaluateRandomly(model, n=10):
    for i in range(n):
        pair = random.choice(test_batch)
        print('Amino acid target:', ''.join([index2symb[i.item()] for i in pair[0] if i.item() not in [PAD_IDX,SOS_IDX,EOS_IDX]]))
        print('SMILE ligand:', ''.join([index2symb[i.item()] for i in pair[1] if i.item() not in [PAD_IDX,SOS_IDX,EOS_IDX]]))

        decoded_output = evaluate(model, pair)
        print('Pred SMILE ligand:', decoded_output)

In [None]:
def compute_test_metrics(model, data_loader_test, batch_size, out_vocab_size, criterion, sample_nb_natch):
    model.eval()
    with torch.no_grad():
        loss = 0
        total_batches = len(data_loader_test)
        selected_batch_indices = random.sample(range(total_batches), sample_nb_natch)
        accuracy_test = []
        f1score_test = []
        for i, data in enumerate(data_loader_test):
            if i in selected_batch_indices:
              in_data, out_data = data

              output = model(in_data.view(batch_size,-1), out_data.view(batch_size,-1))
              loss = criterion(output.contiguous().view(-1, out_vocab_size), out_data.contiguous().view(-1))

              for j in range(batch_size):
                  accuracy_test.append(accuracy_score(evaluate_idx_train(model, output[j,:,:], max_length=224),
                                                       out_data[j,:].cpu().detach().numpy()))
                  f1score_test.append(f1_score(evaluate_idx_train(model, output[j,:], max_length=224),
                                            out_data[j,:,:].cpu().detach().numpy(), average='weighted'))

        accuracy = np.mean(accuracy_test)
        f1score = np.mean(f1score_test)

    return loss, accuracy, f1score

In [None]:
def evaluate_symb(transformer, pair, max_length=136):
    transformer.eval()
    with torch.no_grad():
        output_sequence = transformer(pair[0].view(1,-1), pair[1].view(1,-1))

        decoded_atom = []

        for di in range(max_length):
            topv, topi = torch.topk(output_sequence[:,di,:], 1, 1)

            if topi.item() == EOS_IDX:
                break
            elif topi.item() in [PAD_IDX,SOS_IDX]:
                continue
            else:
                decoded_atom.append(index2symb[topi.item()])

            decoder_input = topi.squeeze().detach()

        return ''.join(decoded_atom)

### Pack into batches

In [None]:
len(pairs_train_tensors)

In [None]:
len(pairs_test_tensors)

In [None]:
len(pairs_val_tensors)

In [None]:
pairs_train_tensors_input = [pair[0] for pair in pairs_train_tensors[:len(pairs_train_tensors)//3]]
pairs_train_tensors_output = [pair[1] for pair in pairs_train_tensors[:len(pairs_train_tensors)//3]]

pairs_test_tensors_input = [pair[0] for pair in pairs_test_tensors[:len(pairs_test_tensors)//3]]
pairs_test_tensors_output = [pair[1] for pair in pairs_test_tensors[:len(pairs_test_tensors)//3]]

pairs_val_tensors_input = [pair[0] for pair in pairs_val_tensors[:len(pairs_val_tensors)//3]]
pairs_val_tensors_output = [pair[1] for pair in pairs_val_tensors[:len(pairs_val_tensors)//3]]

In [None]:
len(pairs_train_tensors_input)

In [None]:
input_batch = pad_sequence(pairs_train_tensors_input + pairs_test_tensors_input + pairs_val_tensors_input,
                                 padding_value=0,
                                 batch_first=True)
output_batch = pad_sequence(pairs_train_tensors_output + pairs_test_tensors_output + pairs_val_tensors_output,
                            padding_value=0,
                            batch_first=True)

In [None]:
input_train_batch = input_batch[:len(pairs_train_tensors_input)]
input_test_batch = input_batch[len(pairs_train_tensors_input):len(pairs_train_tensors_input)+len(pairs_test_tensors_input)]
input_val_batch = input_batch[len(pairs_train_tensors_input)+len(pairs_test_tensors_input):]

output_train_batch = output_batch[:len(pairs_train_tensors_output)]
output_test_batch = output_batch[len(pairs_train_tensors_output):len(pairs_train_tensors_input)+len(pairs_test_tensors_output)]
output_val_batch = output_batch[len(pairs_train_tensors_input)+len(pairs_test_tensors_output):]

In [None]:
input_train_batch.size()

In [None]:
output_train_batch.size()

In [None]:
input_test_batch.size()

In [None]:
train_batch = [[input_train_batch[index].to(device), output_train_batch[index].to(device)] for index in range(len(input_train_batch))]
test_batch = [[input_test_batch[index].to(device), output_test_batch[index].to(device)] for index in range(len(input_test_batch))]
val_batch = [[input_val_batch[index].to(device), output_val_batch[index].to(device)] for index in range(len(input_val_batch))]

In [None]:
len(test_batch)

In [None]:
len(random.choice(test_batch)[0])

In [None]:
max_seq_length = len(test_batch[0][0])

In [None]:
train_dataloader = DataLoader(train_batch, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_batch, batch_size=2, shuffle=True)

## Model

### Encoder layer: Prot Bert

In [None]:
prot_bert = BertModel.from_pretrained("Rostlab/prot_bert")

In [None]:
prot_bert

In [None]:
next(iter(train_dataloader))[0]

In [None]:
next(iter(train_dataloader))[0].size()

In [None]:
next(iter(train_dataloader))[0].size()

In [None]:
input_train_batch[0,:,:].size()

In [None]:
output = prot_bert(input_train_batch[0,:,:])
output

In [None]:
output.keys()

In [None]:
output['pooler_output']

In [None]:
output['pooler_output'].size()

### Decoder layer

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value

        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        # print('size x:',x.size())
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        # print('size Q:',Q.size())
        # print('size W_q(Q):',self.W_q(Q).size())
        Q = self.split_heads(self.W_q(Q))
        # print('size K:',K.size())
        # print('size W_k(K):',self.W_k(K).size())
        K = self.split_heads(self.W_k(K))
        # print('size V:',V.size())
        # print('size W_v(V):',self.W_v(V).size())
        V = self.split_heads(self.W_v(V))

        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [None]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        # print('First MultiHeadAttention call')
        # print('x size:', x.size())
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        # print('Second MultiHeadAttention call')
        # print('enc_output size:', enc_output.size())
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [None]:
class ProtBertPersonalizedModel(nn.Module):
    def __init__(self, encoder_model, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(ProtBertPersonalizedModel, self).__init__()
        self.encoder_model = encoder_model.to(device)
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model).to(device)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model).to(device)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length).to(device)

        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]).to(device)

        self.fc = nn.Linear(d_model, tgt_vocab_size).to(device)
        self.dropout = nn.Dropout(dropout).to(device)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool().to(device)
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.encoder_model(src)['last_hidden_state']
        # print('ouput embedded',src_embedded.size())
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, src_embedded, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

### Training

In [None]:
# Fix prot bert weights

for param in prot_bert.parameters():
    param.requires_grad = False

In [None]:
in_vocab_size = len(symb2index)
out_vocab_size = len(symb2index)
d_model = 1024 # BERT dimension
num_heads = 2
num_layers = 2
batch_size = 2
d_ff = 48
max_seq_length = max_seq_length
dropout = 0.1

fine_tuned_model = ProtBertPersonalizedModel(prot_bert, in_vocab_size, out_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

In [None]:
!pip install wandb -qU
# Log in to your W&B account
import wandb

wandb.login()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m27.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.1/296.1 kB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25h

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(fine_tuned_model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

wandb.init(
    project="Protein-Specific-Drug-Generation",
    name=f"prot_bert_fine_tune_1",
    config={
        "learning_rate": 0.0001,
        "batch_size": batch_size,
        "architecture": "prot_bert_fine_tuned",
        "epochs": 2,
    })

fine_tuned_model.train()

for epoch in range(2):
    target_train_list = []
    output_train_list = []
    for i, data in tqdm(enumerate(train_dataloader)):
        in_data, out_data = data
        optimizer.zero_grad()
        output = fine_tuned_model(in_data.view(batch_size,-1), out_data.view(batch_size,-1))
        loss = criterion(output.contiguous().view(-1, out_vocab_size), out_data.contiguous().view(-1))
        loss.backward()
        optimizer.step()

        # Metrics computation
        if i % 20 == 0:
          accuracy_train = []
          f1score_train = []
          for i in range(batch_size):
            accuracy_train.append(accuracy_score(evaluate_idx_train(fine_tuned_model, output[i,:], max_length=224),
                                                 out_data[i,:,:].cpu().detach().numpy()))
            f1score_train.append(f1_score(evaluate_idx_train(fine_tuned_model, output[i,:], max_length=224),
                                          out_data[i,:,:].cpu().detach().numpy(), average='weighted'))

          loss_test, accuracy_test, f1_score_test = compute_test_metrics(fine_tuned_model, test_dataloader, batch_size, out_vocab_size, criterion, sample_nb_natch=1)

          wandb.log({"acc_train": np.mean(accuracy_train), "loss_train": loss,
                    "f1_score_train": np.mean(f1score_train),
                    "acc_test": accuracy_test, "loss_test": loss_test,
                    "f1_score_test": f1_score_test})

    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

[34m[1mwandb[0m: Currently logged in as: [33moriane-cavrois[0m. Use [1m`wandb login --relogin`[0m to force relogin


56243it [10:06:27,  1.55it/s]


Epoch: 1, Loss: 0.0


56243it [10:07:45,  1.54it/s]


Epoch: 2, Loss: 0.0


In [None]:
# criterion = nn.CrossEntropyLoss(ignore_index=0)
# optimizer = optim.Adam(fine_tuned_model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
# batch_size = 2

# fine_tuned_model.train()

# for epoch in range(2):
#     for data in tqdm(train_dataloader):
#         in_data, out_data = data
#         optimizer.zero_grad()
#         output = fine_tuned_model(in_data.view(batch_size,-1), out_data.view(batch_size,-1))
#         loss = criterion(output.contiguous().view(-1, out_vocab_size), out_data.contiguous().view(-1))
#         loss.backward()
#         optimizer.step()
#     print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

In [None]:
evaluateRandomly(fine_tuned_model)

Amino acid target: MSGTKLEDSPPCRNWSSASELNETQEPFLNPTDYDDEEFLRYLWREYLHPKEYEWVLIAGYIIVFVVALIGNVLVCVAVWKNHHMRTVTNYFIVNLSLADVLVTITCLPATLVVDITETWFFGQSLCKVIPYLQTVSVSVSVLTLSCIALDRWYAICHPLMFKSTAKRARNSIVIIWIVSCIIMIPQAIVMECSTVFPGLANKTTLFTVCDERWGGEIYPKMYHICFFLVTYMAPLCLMVLAYLQIFRKLWCRQIPGTSSVVQRKWKPLQPVSQPRGPGQPTKSRMSAVAAEIKQIRARRKTARMLMIVLLVFAICYLPISILNVLKRVFGMFAHTEDRETVYAWFTFSHWLVYANSAANPIIYNFLSGKFREEFKAAFSCCCLGVHHRQEDRLTRGRTSTESRKSLTTQISNFDNISKLSEQVVLTSISTLPAANGAGPLQNW
SMILE ligand: CC(C)(O)C(=O)N1CCC[C@H](NS(C)(=O)=O)[C@@H]1CO[C@H]1CC[C@H](CC1)c1ccccc1
Pred SMILE ligand: CC(C)(O)C(=O)N1CCC[C@H](NS(C)(=O)=O)[C@@H]1CO[C@H]1CC[C@H](CC1)c1ccccc1
Amino acid target: MEVQLGLGRVYPRPPSKTYRGAFQNLFQSVREVIQNPGPRHPEAASAAPPGASLLLLQQQQQQQQQQQQQQQQQQQQQQQETSPRQQQQQQGEDGSPQAHRRGPTGYLVLDEEQQPSQPQSALECHPERGCVPEPGAAVAASKGLPQQLPAPPDEDDSAAPSTLSLLGPTFPGLSSCSADLKDILSEASTMQLLQQQQQEAVSEGSSSGRAREASGAPTSSKDNYLGGTSTISDNAKELCKAVSVSMGLGVEALEHLSPGEQLRGDCMYAPLLGVPPAVRPTPCAPLAECKGSLLDDSAGKSTEDTAEYSPFKGGYTKGLEGESLGCSGSAAAGSSGTLE

In [None]:
torch.save(fine_tuned_model, 'drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/prot_bert_fine_tune1')

In [None]:
model_pkl_file = "drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/prot_bert_fine_tune1.pkl"

with open(model_pkl_file, 'wb') as file:
    pickle.dump(fine_tuned_model, file)

## Load and test the final model

In [None]:
model_pkl_file = "drive/MyDrive/Colab Notebooks/Protein-Specific-Drug-Generation/prot_bert_fine_tune1.pkl"

# Loading the transformer model from the pickle file
with open(model_pkl_file, 'rb') as file:
    fine_tuned_model = pickle.load(file)

In [None]:
in_vocab_size = len(symb2index)
out_vocab_size = len(symb2index)
d_model = 24
num_heads = 2
num_layers = 2
d_ff = 48
max_seq_length = max_seq_length
dropout = 0.1
batch_size = 2

criterion = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
len(test_dataloader)

8752

In [None]:
loss_test, accuracy_test, f1score_test = compute_test_metrics(fine_tuned_model, test_dataloader, batch_size, out_vocab_size, criterion, len(test_dataloader))

In [None]:
accuracy_test

0.9999987247812744

In [None]:
f1score_test

0.9999988097958562

### Infer the sequences for the test set

In [None]:
fine_tuned_model.eval()
with torch.no_grad():
    input_sequences = []
    targets_sequences = []
    outputs_sequences = []
    total_batches = len(test_dataloader)
    for i, data in tqdm(enumerate(test_dataloader)):
        in_data, out_data = data
        for i in range(in_data.size()[0]):
          input_sequences.append(''.join([index2symb[i.item()] for i in in_data[i,:,:] if i.item() not in [PAD_IDX,SOS_IDX,EOS_IDX]]))

          decoded_output = evaluate_symb(fine_tuned_model, (in_data[i,:,:], out_data[i,:,:]))

          targets_sequences.append(''.join([index2symb[i.item()] for i in out_data[i,:,:] if i.item() not in [PAD_IDX,SOS_IDX,EOS_IDX]]))
          outputs_sequences.append(decoded_output)

1514it [32:24,  1.28s/it]

In [None]:
targets_sequences[0]

In [None]:
outputs_sequences[0]