In [1]:
import argparse
import torch
import pandas as pd
import json
import numpy as np
from dataset import NorecDataset
from transformers import AdamW, BertForSequenceClassification
import tqdm
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from sklearn.metrics import classification_report
import shap
import scipy as sp

In [6]:
def eval(device, test_loader, num_labels, tokenizer, path_to_model_folder, batch_size, inverted_label_indexer):
    model = BertForSequenceClassification.from_pretrained(path_to_model_folder)
    model = model.to(device)
    model.eval()
    total_test_acc = 0.
    passes = 0
    guesses = []
    golds = []

    def f(x):
        tv = torch.tensor([tokenizer.encode(
            v, padding='max_length', max_length=500, truncation=True) for v in x])
        outputs = model(tv)[0].detach().cpu().numpy()
        scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
        val = sp.special.logit(scores[:, 1])  # use one vs rest logit units
        return val

    explainer = shap.Explainer(f, tokenizer)
    with torch.no_grad():
        for i, batch in enumerate(tqdm.tqdm(test_loader)):
            doc, y = batch
            doc = list(doc)

            shap_values = explainer(doc[0:2], fixed_context=1)
            shap.plots.text(shap_values[0])
            break

In [9]:
device = torch.device("cuda" if False else "cpu")
model_data = torch.load("models/test/test.pt", map_location=device)
tokenizer = BertTokenizer.from_pretrained("NbAiLab/nb-bert-base")
name = "Test"
print(f"Evaluating model: {name}.pt")

test_dataset = NorecDataset(
    "./data/test.csv")
test_loader = DataLoader(
    test_dataset, batch_size=8, shuffle=True)
num_labels = 2

eval(device, test_loader, num_labels,
     tokenizer, "models/test/", 8, {0: "Negative", 1: "Positive"})

Evaluating model: Test.pt


  0%|          | 0/52 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]


Partition explainer:  50%|█████     | 1/2 [00:00<?, ?it/s][A

  0%|          | 0/20 [00:00<?, ?it/s]


Partition explainer: 3it [01:12, 36.01s/it]               [A


  0%|          | 0/52 [01:12<?, ?it/s]
