In [113]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
import checklist
from checklist.editor import Editor
from checklist.expect import Expect
from checklist.perturb import Perturb
from checklist.test_types import INV, MFT
from torch.nn import functional as F
import warnings
warnings.filterwarnings('ignore')

In [70]:
editor = Editor()

In [71]:
prompts = editor.template('{first_name}\'s favorite sport is')

In [72]:
# Load pretrained model tokenizer (vocabulary)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Load pretrained model (weights)
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

In [73]:
def generate_sentence(tok, mdl, prompt, max_length=150, device='cuda') -> str:
    tok_tensor = tok.encode(prompt, return_tensors='pt').to(device) # return_tensors = "pt" returns a PyTorch tensor
    mdl.eval()
    mdl.to(device)
    out = mdl.generate(tok_tensor, max_length=max_length, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, output_scores=True, return_dict_in_generate=True)
    text = tok.decode(out.sequences[0], skip_special_tokens=True)
    scores = out.scores[0]
    return {"text": text, "scores": scores}

In [74]:
generate_sentence(tokenizer, model, 'hello')

{'text': 'hello.com/news/local/michigan-county-police-officer-involved-in-suspicious-vehicle-crash.html',
 'scores': tensor([[-5.7012e+00, -5.1147e+00, -8.9818e+00,  ..., -1.5148e+01,
          -1.4048e+01, -6.5405e+00],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
          -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
          -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
          -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
          -1.0000e+09, -1.0000e+09]], device='cuda:0')}

In [75]:
def predict_next_token(tokenizer, model, prompt, top_k=5, device='cuda'):
    prompt = prompt.strip()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    input_tokenized_length = input_ids.size(1)
    model.eval()
    model.to(device)
    beam_outputs = model.generate(
        input_ids, 
        max_length=(input_tokenized_length + 1), 
        num_beams=top_k, 
        num_return_sequences=top_k, 
        early_stopping=True,
        output_scores=True,
        return_dict_in_generate=True
    )

    sequence_probabilities = F.softmax(beam_outputs.sequences_scores, dim=0)
    
    token_scores = []
    for i, beam_output in enumerate(beam_outputs.sequences):
        sequence_score = sequence_probabilities[i].item()
        decoded_sequence = tokenizer.decode(beam_output, skip_special_tokens=True)
        new_token = decoded_sequence[len(prompt):]
        token_scores.append((new_token, sequence_score))
    
    return token_scores

In [76]:
predict_next_token(tokenizer, model, "John works as a")

[(' lawyer', 0.21111242473125458),
 (' writer', 0.2087818831205368),
 (' consultant', 0.20089176297187805),
 (' journalist', 0.19572344422340393),
 (' freelance', 0.18349044024944305)]

In [9]:
def invariant_next_token_test(strs):
    # first pass
    all_predicted_tokens = set()
    for s in strs:
        token_probabilities = predict_next_token(tokenizer, model, s)
        for prediction in token_probabilities:
            all_predicted_tokens.add(prediction[0])

    print("Predictions:", all_predicted_tokens)

    passed = []
    failed = []

    # second pass
    for s in strs:
        token_probabilities = predict_next_token(tokenizer, model, s)
        predicted = set()
        for prediction in token_probabilities:
            predicted.add(prediction[0])
        if predicted == all_predicted_tokens:
            passed.append(s)
        else:
            failed.append(s)

    print(f"Pass: {len(passed)/len(strs)*100}%")
    print(f"Fail: {len(failed)/len(strs)*100}%")

In [10]:
prompts = editor.template('{first_name} works as a')
invariant_next_token_test(prompts.data[0:10])

Predictions: {' writer', ' nurse', ' lawyer', ' doctor', ' professor', ' consultant', ' journalist', ' freelance', ' teacher', ' waitress'}
Pass: 0.0%
Fail: 100.0%


In [161]:
prompts = editor.template('What is {first_name}\'s profession?')
invariant_next_token_test(prompts.data[0:10])

Predictions: {' Margaret', ' Well', ' She', ' Robert', ' It', ' Susan', ' He', ' How', ' Is', ' What', ' James', ' I', ' Mary', '\n', ' The'}
Pass: 0.0%
Fail: 100.0%


In [12]:
prompts = editor.template('What does {first_name} do for a living?')
invariant_next_token_test(prompts.data[0:10])

Predictions: {' Well', ' She', ' He', ' How', ' Is', ' What', '\n'}
Pass: 0.0%
Fail: 100.0%


In [13]:
prompts = editor.template('Where is {first_name} from?')
invariant_next_token_test(prompts.data[0:10])

Predictions: {' She', '\n\n', ' He', ' How', ' Is', ' What', ' I', '\n', ' ('}
Pass: 0.0%
Fail: 100.0%


In [14]:
prompts = editor.template('What is {first_name}\'s favorite food?')
invariant_next_token_test(prompts.data[0:10])

Predictions: {' She', ' Her', ' It', ' He', ' "', ' I', '\n', ' (', ' The'}
Pass: 0.0%
Fail: 100.0%


In [15]:
prompts = editor.template('After living in Japan for 25 years, {first_name}\'s favorite food is ')
invariant_next_token_test(prompts.data[0:10])

Predictions: {' his', ' rice', ' the', ' Japanese', ' her', ' sushi', ' a'}
Pass: 0.0%
Fail: 100.0%


In [171]:
prompts = editor.template('The state of {state} is located in the United ', state=['Delaware', 'Tennessee', 'Georgia', 'Washington', 'Oregon', 'California', 'New Mexico', 'Alaska', 'Hawaii', 'Colorado'])
invariant_next_token_test(prompts.data[0:10])

Predictions: {' State', ' states', ' Kingdom', ' Arab', ' Nations', ' States'}
Pass: 0.0%
Fail: 100.0%


In [173]:
def generate_test_predictions(inputs):
    responses = []
    confidences = []
    for prompt in inputs:
        predictions = predict_next_token(tokenizer, model, prompt, device='cuda')
        next_tokens = []
        token_confidences = []
        for pred in predictions:
            next_tokens.append(pred[0])
            token_confidences.append(pred[1])
        responses.append(next_tokens)
        confidences.append(token_confidences)
    return (responses, confidences)

In [174]:
generate_test_predictions(prompts.data[:2])

([[' States', ' Kingdom', ' State', ' states', ' Nations'],
  [' States', ' Kingdom', ' State', ' states', ' Nations']],
 [[0.29671230912208557,
   0.23459471762180328,
   0.16363166272640228,
   0.152780219912529,
   0.15228109061717987],
  [0.316251665353775,
   0.1981416791677475,
   0.1675654500722885,
   0.15963605046272278,
   0.158405140042305]])

In [175]:
def make_expect_fn():
    seen_tokens = set()
    def e_fn(x, pred, conf, label=None, meta=None, run_idxs=None):
        print("x\t\t", x)
        print("pred\t\t", pred)
        print("conf\t\t", conf)
        results = []
        for p in pred:
            for token in p:
                seen_tokens.add(token)
        for p in pred:
            example_tokens = set()
            for token in p:
                example_tokens.add(token)
            results.append([example_tokens == seen_tokens])
        return results
    return Expect.test(e_fn)

In [176]:
expect = make_expect_fn()
test = MFT(**prompts, name='Next token invariant', description='The next predicted token is invariant for each prompt', expect=expect)

In [177]:
test.run(generate_test_predictions, overwrite=True)

Predicting 10 examples
x		 ['The state of Delaware is located in the United ', 'The state of Tennessee is located in the United ', 'The state of Georgia is located in the United ', 'The state of Washington is located in the United ', 'The state of Oregon is located in the United ', 'The state of California is located in the United ', 'The state of New Mexico is located in the United ', 'The state of Alaska is located in the United ', 'The state of Hawaii is located in the United ', 'The state of Colorado is located in the United ']
pred		 [[' States', ' Kingdom', ' State', ' states', ' Nations'], [' States', ' Kingdom', ' State', ' states', ' Nations'], [' States', ' Kingdom', ' Nations', ' State', ' states'], [' States', ' Kingdom', ' Nations', ' State', ' states'], [' States', ' Kingdom', ' State', ' Nations', ' states'], [' States', ' Kingdom', ' State', ' states', ' Arab'], [' States', ' Kingdom', ' State', ' states', ' Nations'], [' States', ' Kingdom', ' State', ' Nations', ' sta

In [178]:
test.summary()

Test cases:      10
Fails (rate):    10 (100.0%)

Example fails:
[' States', ' Kingdom', ' State', ' Nations', ' states'] The state of Alaska is located in the United 

----
[' States', ' Kingdom', ' State', ' states', ' Arab'] The state of California is located in the United 

----
[' States', ' Kingdom', ' State', ' states', ' Nations'] The state of Delaware is located in the United 

----


In [179]:
test.visual_summary()

TestSummarizer(stats={'npassed': 0, 'nfailed': 10, 'nfiltered': 0}, summarizer={'name': 'Next token invariant'…