In [58]:
import torch
import pandas as pd
import numpy as np

from transformers import BertForMaskedLM, BertConfig
import sys
sys.path.append('../')
from data.dataloader import AminoAcidTokenizer, ProteinDataset
from args import parse_args
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from data.utils import load_train_data, load_test_data
from train import LossCheckpointer

In [3]:
DEVICE = 'cuda'

MAX_SEQ_LEN = 100
HIDDEN_SIZE = 768
NUM_ATTENTION_HEADS = 12
NUM_HIDDEN_LAYERS = 12
INTERMEDIATE_SIZE = 3072

BEST_MODEL_PATH = '../checkpoints/BERTMLM_maxSeq100-2024-04-07 20:02/best_model.pt'

In [5]:
    # Initialize tokenizer
tokenizer = AminoAcidTokenizer(max_seq_length=MAX_SEQ_LEN)

config = BertConfig(vocab_size=len(tokenizer.vocab), 
                    hidden_size=HIDDEN_SIZE, 
                    num_hidden_layers=NUM_HIDDEN_LAYERS, 
                    num_attention_heads=NUM_ATTENTION_HEADS, 
                    intermediate_size=INTERMEDIATE_SIZE)

model = BertForMaskedLM(config).to(DEVICE)

best_checkpoint = torch.load(BEST_MODEL_PATH)
    
model.load_state_dict(best_checkpoint['model_state_dict'])

model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(24, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)

In [6]:
test_sequences, test_labels, test_masks = load_test_data(MAX_SEQ_LEN, 
                                                         '../dataset/splitted', 
                                                         prefix='test', 
                                                         device=DEVICE)
test_dataset = TensorDataset(test_sequences, test_masks, test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [61]:
def calculate_accuracy(model, dataloader, mask_token_id=3):
    model.eval()
    total_correct = 0
    total_masked = 0  # Keep track of the total number of masked tokens
    
    with torch.no_grad():
        for seq, mask, labels in tqdm(dataloader):
            outputs = model(input_ids=seq, attention_mask=mask, labels=labels)
            logits = outputs.logits  # Assuming this is [batch_size, sequence_length, num_classes]
            
            masked_positions = seq == mask_token_id  # Find positions of the mask token
            masked_logits = logits[masked_positions.unsqueeze(-1).expand_as(logits)].view(-1, logits.size(-1))  # Filter logits at masked positions
            masked_targets = labels[masked_positions]  # Filter targets at masked positions
            
            _, predicted_labels = torch.max(masked_logits, dim=1)  # Get the predicted labels
            correct_predictions = (predicted_labels == masked_targets).float()  # Find correct predictions
            
            total_correct += correct_predictions.sum().item()  # Update total correct predictions
            total_masked += masked_targets.size(0)  # Update total number of masked tokens

    accuracy = total_correct / total_masked if total_masked > 0 else 0  # Compute overall accuracy
    return accuracy

In [46]:
model.eval()
total_test_loss = 0
with torch.no_grad():
    for batch in test_dataloader:
        inputs, attention_mask, targets = batch
        inputs, attention_mask, targets = inputs.to(DEVICE), attention_mask.to(DEVICE), targets.to(DEVICE)

        outputs = model(inputs, attention_mask=attention_mask, labels=targets)
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        
        break


attention_mask

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')

In [47]:
inputs.shape

torch.Size([32, 100])

In [48]:
logits[0][0]

tensor([ 0.1756, 10.8809, -0.6601, -0.0181, -1.0608, -0.0556,  0.3588,  0.0249,
        -0.2296, -1.8968, -0.3943, -0.6026, -0.2352, -0.5743, -0.4372, -0.6825,
        -0.7557,  0.3318, -1.2046, -0.9446, -1.4121,  0.5507, -0.5083, -0.5819],
       device='cuda:0')

In [49]:
predictions = torch.argmax(logits, dim=-1)
predictions

tensor([[ 1, 14, 19,  ...,  0,  0,  0],
        [ 1, 14, 17,  ...,  0,  0,  0],
        [ 1, 14,  4,  ..., 20, 10,  2],
        ...,
        [ 1, 14,  7,  ..., 12,  7,  2],
        [ 1, 14, 15,  ..., 11, 16,  2],
        [ 1, 14, 20,  ..., 23, 20,  2]], device='cuda:0')

In [87]:
masked_positions = inputs == 3
masked_logits = logits[masked_positions.unsqueeze(-1).expand_as(logits)].view(-1, logits.size(-1))
masked_targets = targets[masked_positions]
_, predicted_labels = torch.max(masked_logits, dim=1)
correct_predictions = (predicted_labels == masked_targets).float() 
accuracy = correct_predictions.mean()
print(f'Accuracy for masked values: {accuracy.item()}')



Accuracy for masked values: 0.09808611869812012


In [62]:
calculate_accuracy(model, test_dataloader)

100%|██████████| 836/836 [00:39<00:00, 21.30it/s]


0.11420816721850575

In [104]:
tokenizer.special_tokens['[MASK]']

3