In [27]:
import transformers
from datasets import load_from_disk

import torch

import numpy as np
import glob
import os
from models.model import ModelClass

# Load checkpoint
checkpoint = glob.glob(os.path.join('models/results','*'))
checkpoint.sort(key=os.path.getmtime)
checkpoint.insert(0,None) # final checkpoint if None, if no actual checkpoints are present
chkpt = checkpoint[-1]

# Load model + tokenizer
Model = ModelClass(chkpt)
model = Model.load_model()

model.eval()


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [20]:
# load eval data
dataset = load_from_disk('data/tokenized/dataset.hf')['test']


In [34]:
from torch.utils.data import DataLoader


def collate_fn(batch):
    return {
        'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
        'token_type_ids': torch.stack([torch.tensor(item['token_type_ids']) for item in batch]),
        'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]),
        'labels': torch.tensor([item['labels'] for item in batch]),
    }

data_loader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)


In [58]:
true_labels = []
all_logits = []
for batch in data_loader:
    
    # Move your batch to the same device as your model
    batch = {k: v.to('cuda') for k, v in batch.items()}
    
    with torch.no_grad():
        outputs = model(**batch)
        # Your outputs include things like loss and logits
        logits = outputs.logits
        
        # Calculate metrics here based on logits and batch['labels']
    true_labels.append(batch['labels'].cpu().numpy())
    all_logits.append(logits.cpu().numpy())

In [59]:
nplogits = np.concatenate(all_logits, axis=0)
predictions = np.argmax(nplogits, axis=1)


In [65]:
predictions.shape

(100,)

In [66]:
true_labels[-1].shape

(4, 4)

In [70]:
true_labels_concat = np.concatenate(true_labels, axis=0).argmax(axis=1)

In [77]:
from sklearn.metrics import classification_report

classification_report(true_labels_concat, predictions, output_dict=True)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'2': {'precision': 0.83,
  'recall': 1.0,
  'f1-score': 0.9071038251366119,
  'support': 83.0},
 '3': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 17.0},
 'accuracy': 0.83,
 'macro avg': {'precision': 0.415,
  'recall': 0.5,
  'f1-score': 0.45355191256830596,
  'support': 100.0},
 'weighted avg': {'precision': 0.6889,
  'recall': 0.83,
  'f1-score': 0.7528961748633879,
  'support': 100.0}}