In [None]:
import json
import random
from transformers import AutoTokenizer
import numpy as np
from allennlp.common.util import import_module_and_submodules as import_submodules
from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
from scipy.spatial import distance

import sys
import os
sys.path.append(os.path.abspath('..'))

import_submodules("allennlp_lib")

DATASET="mnli"
MODEL_NAME="roberta-large"
model_path=f"../experiments/models/{DATASET}/{MODEL_NAME}"

archive = load_archive(model_path + '/model.tar.gz')
print(archive.config)
archive.config['dataset_reader']['type'] = 'mnli'
archive.config['model']['output_hidden_states'] = True
model = archive.model
model._output_hidden_states = True
predictor = Predictor.from_archive(archive, 'textual_entailment_fixed')

tok = AutoTokenizer.from_pretrained("roberta-large")

with open(model_path + "/label2index.json", "r") as f:
    label2index = json.load(f)
    index2label = {label2index[k]: k for k in label2index}
label2index


In [None]:
def all_masks(tokenized_text):
    # https://stackoverflow.com/questions/1482308/how-to-get-all-subsets-of-a-set-powerset
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    masks = [1 << i for i in range(x)]
    #     for i in range(1 << x):  # empty and full sets included here
    for i in range(1, 1 << x - 1):
        yield [ss for mask, ss in zip(masks, s) if i & mask]
        
def all_consecutive_masks(tokenized_text, max_length = -1):
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    for i in range(x):
        for j in range(i+1, x):
            mask = s[:i] + s[j:]
            if max_length > 0:
                if j - i >= max_length:
                    yield mask
            else:
                yield mask
                
def all_consecutive_masks2(tokenized_text, max_length = -1):
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    for i in range(x+1):
        for j in range(i+1, x+1):
            mask = s[i:j]
            if max_length > 0:
                if j - i <= max_length:
                    yield mask
            else:
                yield mask


In [None]:

ex = {'sentence1': 'A soccer game in a large area with 8 yellow players and 4 black players.', 
      'sentence2': 'There is a soccer game with 12 players.', 'gold_label': 'entailment'}
    
foil = ex['gold_label']

out = predictor.predict_json(ex)
encoded_orig = out['encoded_representations']

fact = out['label']
print('Predicted: ', fact)

# assert fact != foil, "Fact should be different from the foil (if not, pick a different foil)"

ex['sentence1'] = ex['sentence1'].split()
ex['sentence2'] = ex['sentence2'].split()

tok.convert_tokens_to_string(out['tokens'])

masks1 = [[]]  # change this if you also want to mask out parts of the premise.
masks2 = list(all_consecutive_masks2(ex['sentence2'], max_length=1))
encoded = []
mask_mapping = []
preds = np.zeros(shape=(len(masks1), len(masks2)))

for m1_i, m1 in enumerate(masks1):
    masked1 = list(ex['sentence1'])
    for i in m1:
        masked1[i] = '<mask>'
    masked1 = ' '.join(masked1)
        
    for m2_i, m2 in enumerate(masks2):
        masked2 = list(ex['sentence2'])
        for i in m2:
            masked2[i] = '<mask>'
        masked2 = ' '.join(masked2)
            
        masked_ex = {
            "sentence1": masked1,
            "sentence2": masked2
        }
        
        masked_out = predictor.predict_json(masked_ex)
#         if masked_out['label'] != foil:
#             continue
        
        print(m1_i, m2_i)
        print(f"{masked1}\n{masked2}")
        print(masked_out['label'])
        encoded.append(masked_out['encoded_representations'])
        mask_mapping.append((m1_i, m2_i))
        
        print("====")
        
encoded = np.array(encoded)
        

In [None]:
foil = 'neutral'

fact_idx = label2index[fact]
foil_idx = label2index[foil]
print('fact:', index2label[fact_idx])
print('foil:', index2label[foil_idx])
num_classifiers = 100

classifier_w = np.load(f"{model_path}/w.npy")
classifier_b = np.load(f"{model_path}/b.npy")

u = classifier_w[fact_idx] - classifier_w[foil_idx]
contrastive_projection = np.outer(u, u) / np.dot(u, u)

print(contrastive_projection.shape)


In [None]:

# from scipy.stats import entropy
from scipy.special import softmax

z_all = encoded_orig 
z_h = encoded 
z_all_row = encoded_orig @ contrastive_projection
z_h_row = encoded @ contrastive_projection

prediction_probabilities = softmax(z_all_row @ classifier_w.T + classifier_b)
prediction_probabilities = np.tile(prediction_probabilities, (z_h_row.shape[0], 1))

prediction_probabilities_del = softmax(z_h_row @ classifier_w.T + classifier_b, axis=1)

p = prediction_probabilities[:, [fact_idx, foil_idx]]
q = prediction_probabilities_del[:, [fact_idx, foil_idx]]

p = p / p.sum(axis=1).reshape(-1, 1)
q = q / q.sum(axis=1).reshape(-1, 1)
distances = (p[:, 0] - q[:, 0])

print(' '.join(ex['sentence1']))
print(' '.join(ex['sentence2']))

print("=========\n=======Farthest masks:=======")    
    
highlight_rankings = np.argsort(-distances)

for i in range(4):
    rank = highlight_rankings[i]
    m1_i, m2_i = mask_mapping[rank]
    
    masked1 = list(ex['sentence1'])
    for k in masks1[m1_i]:
        masked1[k] = '<m>'
    masked1 = ' '.join(masked1)
    
    masked2 = list(ex['sentence2'])
    for k in masks2[m2_i]:
        masked2[k] = '<m>'
    masked2 = ' '.join(masked2)
    
    print(masked1)
    print(masked2)
    print(np.round(distances[rank], 4))
