In [1]:
import os
import torch
from transformers import Trainer, TrainingArguments, TrainerCallback
from transformers import BertTokenizerFast, BertForTokenClassification
from transformers import set_seed
from datasets import load_from_disk, load_dataset
from torch.nn import functional as F
import numpy as np

os.environ['CUDA_VISIBLE_DEVICES'] = "0"
torch.cuda.set_device(0)

In [14]:
model = BertForTokenClassification.from_pretrained('test_backup/model/')

In [15]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

In [16]:
dataset = load_from_disk('../data/rocstories/')

In [26]:
from random import shuffle
from sklearn.metrics import accuracy_score
from scipy.stats import kendalltau
from collections import defaultdict
from tqdm.auto import tqdm

metrics = defaultdict(list)

VERBOSE = True
for entry in tqdm(dataset['val'].select(list(range(len(dataset['val'])//100)))):
    sents = [entry[key] for key in sorted([key for key in entry if key.startswith('sentence')], key=lambda x: x[-1])]
    sent_idx = list(range(len(sents)))
    data = list(zip(sents, sent_idx))
    shuffle(data)
    shuffled_sents = [i[0] for i in data]
    shuffled_idx = np.array([i[1] for i in data])
    shuffled_text = ' [CLS] '+ ' [CLS] '.join(shuffled_sents)
    if VERBOSE:
        print(' '.join(sents))
        print()
        print(shuffled_text)
        print('-')
    inputs = tokenizer(shuffled_text, add_special_tokens=False, return_tensors='pt')
    outputs = model(**inputs)
    logits = outputs['logits']
    
    input_ids = inputs['input_ids']
    target_logits = logits[input_ids == tokenizer.cls_token_id]
    predicted_idx = np.argsort(np.argsort(target_logits.reshape(-1).detach().numpy()))
    
    tau, p = kendalltau(shuffled_idx, predicted_idx)
    acc = accuracy_score(shuffled_idx, predicted_idx)
    metrics['tau'].append(tau)
    metrics['acc'].append(acc)
    if VERBOSE:
        print('Acc: ', acc)
        print('Tau: ', tau)
        print('-')
        print('Logits: ', target_logits.reshape(-1).detach().numpy())
        print('Pred: ', predicted_idx)
        print('True: ', shuffled_idx)
        print('\n------\n')

          

  0%|          | 0/58 [00:00<?, ?it/s]

Angie wanted to surprise her kids with a trip to the zoo. She packed up the things they would need for the day. The kids got in the car and she told them she had a surprise. She pulled into the zoo parking lot. The kids were excited to see that they were at the zoo.

 [CLS] She pulled into the zoo parking lot. [CLS] Angie wanted to surprise her kids with a trip to the zoo. [CLS] The kids got in the car and she told them she had a surprise. [CLS] She packed up the things they would need for the day. [CLS] The kids were excited to see that they were at the zoo.
-
Acc:  0.0
Tau:  0.19999999999999998
-
Logits:  [2.1923623 2.15922   2.1486568 2.1654005 2.163315 ]
Pred:  [4 1 0 3 2]
True:  [3 0 2 1 4]

------

The boy had a swimming party. He invited some friends. Everybody played in the pool. One kid hurt himself jumping into the pool. The parents told everyone to go home.

 [CLS] The boy had a swimming party. [CLS] He invited some friends. [CLS] The parents told everyone to go home. [CLS] 

In [22]:
np.mean(metrics['acc'])

0.21086587436332768

In [23]:
np.mean(metrics['tau'])

-0.006451612903225809

In [25]:
metrics['tau']

[0.0,
 0.6,
 -0.39999999999999997,
 -0.39999999999999997,
 -0.39999999999999997,
 -0.19999999999999998,
 0.39999999999999997,
 -0.19999999999999998,
 -0.39999999999999997,
 0.39999999999999997,
 -0.19999999999999998,
 0.6,
 -0.6,
 -0.39999999999999997,
 0.0,
 0.6,
 -0.19999999999999998,
 0.9999999999999999,
 0.39999999999999997,
 -0.19999999999999998,
 -0.19999999999999998,
 -0.19999999999999998,
 0.6,
 0.0,
 0.0,
 -0.6,
 -0.39999999999999997,
 -0.19999999999999998,
 -0.19999999999999998,
 0.6,
 0.39999999999999997,
 0.19999999999999998,
 0.39999999999999997,
 0.6,
 0.0,
 -0.19999999999999998,
 -0.19999999999999998,
 -0.39999999999999997,
 -0.19999999999999998,
 0.39999999999999997,
 0.0,
 -0.19999999999999998,
 0.19999999999999998,
 0.7999999999999999,
 -0.19999999999999998,
 -0.7999999999999999,
 -0.39999999999999997,
 0.0,
 -0.6,
 0.6,
 0.39999999999999997,
 0.39999999999999997,
 -0.39999999999999997,
 -0.39999999999999997,
 -0.7999999999999999,
 0.0,
 0.0,
 0.39999999999999997,
 0.