In [None]:
import json
import random
from transformers import AutoTokenizer
import numpy as np
from allennlp.common.util import import_module_and_submodules as import_submodules
from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
from scipy.spatial import distance

import sys
import os
sys.path.append(os.path.abspath('..'))

import_submodules("allennlp_lib")

DATASET="bios"
MODEL_NAME="roberta-large"
model_path=f"../experiments/models/{DATASET}/{MODEL_NAME}"

archive = load_archive(model_path + '/model.tar.gz')
print(archive.config)
archive.config['model']['output_hidden_states'] = True
model = archive.model
model._output_hidden_states = True
predictor = Predictor.from_archive(archive, 'jsonl_predictor')

tok = AutoTokenizer.from_pretrained("roberta-large")

with open(model_path + "/label2index.json", "r") as f:
    label2index = json.load(f)
    index2label = {label2index[k]: k for k in label2index}
label2index


In [None]:
def all_masks(tokenized_text):
    # https://stackoverflow.com/questions/1482308/how-to-get-all-subsets-of-a-set-powerset
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    masks = [1 << i for i in range(x)]
    #     for i in range(1 << x):  # empty and full sets included here
    for i in range(1, 1 << x - 1):
        yield [ss for mask, ss in zip(masks, s) if i & mask]
        
from itertools import permutations
def all_masks_length(tokenized_text, max_length):
    s = list(range(len(tokenized_text)))
    for j in range(max_length):
        for i in permutations(s, j):
            yield i
        
def all_consecutive_masks(tokenized_text, max_length = -1):
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    for i in range(x):
        for j in range(i+1, x):
            mask = s[:i] + s[j:]
            if max_length > 0:
                if j - i >= max_length:
                    yield mask
            else:
                yield mask
                
def all_consecutive_masks2(tokenized_text, max_length = -1):
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    for i in range(x+1):
        for j in range(i+1, x+1):
            mask = s[i:j]
            if max_length > 0:
                if j - i <= max_length:
                    yield mask
            else:
                yield mask
            

In [None]:

from scipy.special import softmax


dev_path = f"../data/{DATASET}/dev.jsonl"

with open(dev_path) as f:
    dev_data = [json.loads(line) for line in f if line.strip() if line.strip()]

dev_labels = np.array([label2index[d['label']] for d in dev_data])

excount = 0

for ex in dev_data[:3]:

    foil = ex['label']
    
    out = predictor.predict_json(ex)
    encoded_orig = out['encoded_representations']

    fact = out['label']
    if fact == foil:
        print("Model was not wrong on this example. Skipping")
        continue
    excount += 1
    if excount > 30:
        break
    print('Predicted: ', fact)
    print('Label:', foil)

    ex['text'] = ex['text'].split()

    tok.convert_tokens_to_string(out['tokens'])

    masks2 = list(all_consecutive_masks2(ex['text'], max_length=2))

    encoded = []
    mask_mapping = []

    for m2_i, m2 in enumerate(masks2):
        masked2 = list(ex['text'])
        for i in m2:
            masked2[i] = '<mask>'
        masked2 = ' '.join(masked2)

        masked_ex = {
            "text": masked2
        }

        masked_out = predictor.predict_json(masked_ex)

        encoded.append(masked_out['encoded_representations'])
        mask_mapping.append(m2_i)

    encoded = np.array(encoded)

    for foil_idx in [-100, label2index[foil]]:

        fact_idx = label2index[fact]

        if foil_idx == fact_idx:
            continue

        print("=========")    
        print('foil:', index2label[foil_idx] if foil_idx >= 0 else 'none')
        num_classifiers = 100

        classifier_w = np.load(f"{model_path}/w.npy")
        classifier_b = np.load(f"{model_path}/b.npy")

        if foil_idx >= 0:
            u = classifier_w[fact_idx] - classifier_w[foil_idx]
            contrastive_projection = np.outer(u, u) / np.dot(u, u)

        z_all = encoded_orig 
        z_h = encoded 
        if foil_idx >= 0:
            z_all_row = encoded_orig @ contrastive_projection
            z_h_row = encoded @ contrastive_projection
            
            prediction_probabilities = softmax(z_all_row @ classifier_w.T + classifier_b)
            prediction_probabilities = np.tile(prediction_probabilities, (z_h_row.shape[0], 1))

            prediction_probabilities_del = softmax(z_h_row @ classifier_w.T + classifier_b, axis=1)


            p = prediction_probabilities[:, [fact_idx, foil_idx]]
            q = prediction_probabilities_del[:, [fact_idx, foil_idx]]

            p = p / p.sum(axis=1).reshape(-1, 1)
            q = q / q.sum(axis=1).reshape(-1, 1)
            distances = (p[:, 0] - q[:, 0])
        else:
            z_all_row = encoded_orig
            z_h_row = encoded
            prediction_probabilities = softmax(z_all_row @ classifier_w.T + classifier_b)
            prediction_probabilities = np.tile(prediction_probabilities, (z_h_row.shape[0], 1))

            prediction_probabilities_del = softmax(z_h_row @ classifier_w.T + classifier_b, axis=1)

            p = prediction_probabilities[:, fact_idx]
            q = prediction_probabilities_del[:, fact_idx]
            distances = (p - q)

        highlight_rankings = np.argsort(-distances)

        for i in range(1):
            rank = highlight_rankings[i]
            m2_i = mask_mapping[rank]

            masked2 = list(ex['text'])
            for k in masks2[m2_i]:
                masked2[k] = f'[[ {masked2[k].upper()} ]]'
            masked2 = ' '.join(masked2)

            print(masked2)
            print(np.round(distances[rank] * 100, 4))# - dist_h_nul[rank])
            
    print("")
