In [12]:
from ast import Mult
import json, os, pickle, torch, logging, typing, numpy as np, glob, argparse
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizerFast, BertForSequenceClassification
from transformers.models.bert.modeling_bert import SequenceClassifierOutput
from src.classes.datasets import IMDBDataset
from src.utils.pickleUtils import pdump, pload
from src.proecssing import correct_count
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm



def evaluateModel(
  dataset_name: str,
  batch_size: int,
  epoch_num: int,
  use_margin_loss: bool,
  lambda_weight: float,
  use_cache: bool
):
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  DATASET_NAME = dataset_name
  DATASET_PATH = f"datasets/{dataset_name}/base"
  BATCH_SIZE = batch_size
  EPOCH_NUM = epoch_num
  USE_MARGIN_LOSS = use_margin_loss
  LAMBDA_WEIGHT = lambda_weight
  USE_ENCODING_CACHE = use_cache
  OUTPUT_PATH = f"checkpoints/{DATASET_NAME}/model"
  TOPK_NUM = 4
  # memAvailable = psutil.virtual_memory().available
  # estimatedMemConsumed = os.path.getsize(os.path.join(DATASET_PATH, "train_set.pickle.blosc")) * 3
  # USE_PINNED_MEMORY = True if (args.use_pinned_memory & (memAvailable > estimatedMemConsumed)) == 1 else False


  def loadTestData():
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    test_set = pload(os.path.join(DATASET_PATH, 'test_set'))
    test_texts = test_set['text'].tolist()
    test_labels = test_set['label'].tolist()
    test_encodings = tokenizer(test_texts, padding=True, truncation=True)
    pdump(test_encodings, os.path.join(DATASET_PATH, 'test_encodings'))
    test_dataset = IMDBDataset(labels=test_labels, encodings=test_encodings)
    test_loader = DataLoader(
      test_dataset,
      batch_size=BATCH_SIZE,
      shuffle=True,
      persistent_workers=False,
      pin_memory=False
    )
    return test_loader
  
  

  test_loader: DataLoader = loadTestData()
  num_labels = -1
  if len(test_loader.dataset[0]["labels"].shape) == 1:
    num_labels = test_loader.dataset[0]["labels"].shape[0]
  else:
    print("Invalid label shape")
    exit()

  torch.cuda.empty_cache()
  model: torch.nn.Module = BertForSequenceClassification.from_pretrained(os.path.join(OUTPUT_PATH, "best_epoch"), num_labels=num_labels) #type: ignore
  # model: BertForSequenceClassification = torch.nn.DataParallel(model) #type:ignore #! only use w/ distributed
  model.eval()
  model.to(device)




  metrics:MetricCollection = MetricCollection([
    MulticlassAccuracy(num_classes=num_labels)
  ]
  ).to(device)



  with torch.no_grad():
    for batch in tqdm(test_loader):

      input_ids = batch["input_ids"].to(device)
      attention_mask = batch["attention_mask"].to(device)
      token_type_ids = batch["token_type_ids"].to(device)
      true_labels =  batch["labels"].to(device)




      outputs = model.forward(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids
      )

      logits = outputs[0]
      
      pred_labels = logits.sign().relu()




      metrics(preds=pred_labels, target=true_labels)

  metrics(preds=torch.tensor([0, 1]).to(device), target=torch.tensor([1, 0]).to(device))
  print(metrics.compute())


evaluateModel(
  dataset_name="imdb",
  batch_size=2,
  epoch_num=1,
  use_margin_loss=True,
  lambda_weight=0.2,
  use_cache=True  
  )


  8%|▊         | 5/63 [00:00<00:01, 43.10it/s]

tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')


 24%|██▍       | 15/63 [00:00<00:01, 44.21it/s]

tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')


 40%|███▉      | 25/63 [00:00<00:00, 44.29it/s]

tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')


 56%|█████▌    | 35/63 [00:00<00:00, 44.40it/s]

tensor([[1., 0.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')


 71%|███████▏  | 45/63 [00:01<00:00, 44.31it/s]

tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')


 79%|███████▉  | 50/63 [00:01<00:00, 44.09it/s]

tensor([[1., 0.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')


100%|██████████| 63/63 [00:01<00:00, 44.58it/s]

tensor([[1., 0.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.],
        [0., 1.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[1., 0.],
        [1., 0.]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 1.]], device='cuda:0')
tensor([[0., 0.]], device='cuda:0')
{'MulticlassAccuracy': tensor(0.4960, device='cuda:0')}



