In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_path = '/data/cwkang/gpt3_test/factual_probing/results/gpt_neo_125M/'

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")

  from .autonotebook import tqdm as notebook_tqdm


In [97]:
max_new_tokens = 50


prefix_input = 'Gallery "Though defensive violence will always be \'a sad necessity\' in the eyes of men of principle, it would be still more unfortunate if wrongdoers should dominate just men."- St. Augustine "A new idea is first'
suffix_input = 'condemned as ridiculous, and then dismissed as trivial, until finally it becomes what everybody knows." - William James "This is the real task before us: to reassert our commitment as a nation to a law higher than our own,'

input_ids = tokenizer.encode(prefix_input, return_tensors="pt").cuda()
print(len(input_ids))
prompt_length = len(input_ids[0])

output = model.generate(
    input_ids,
    max_new_tokens=max_new_tokens,
    return_dict_in_generate=True,
    output_scores=True
)

prefix_output = model(input_ids).logits[0].detach().cpu()
prefix_pred = tokenizer.decode(torch.argmax(prefix_output, dim=-1))
prefix_probs = torch.max(torch.softmax(prefix_output, dim=-1), dim=-1).values.numpy()
prefix_scores = np.log(prefix_probs)

suffix_pred = tokenizer.decode(output.sequences[0][prompt_length:], skip_special_tokens=True)
suffix_scores = model.compute_transition_scores(
    output.sequences, output.scores, normalize_logits=True
)[0].detach().cpu().numpy()
suffix_probs = np.exp(suffix_scores)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


1


In [117]:
def lcs(a, b):
    prev = [0]*len(a)
    for i,r in enumerate(a):
        current = []
        for j,c in enumerate(b):
            if r==c:
                e = prev[j-1]+1 if i* j > 0 else 1
            else:
                e = max(prev[j] if i > 0 else 0, current[-1] if j > 0 else 0)
            current.append(e)
        prev = current
    cs = []
    cs_index = []
    for i in range(current[-1]):
        cs.append(b[current.index(i+1)])
        cs_index.append(current.index(i+1))
    return cs_index, cs

In [118]:
suffix_input_tokens = tokenizer.encode(suffix_input)
suffix_pred_tokens = tokenizer.encode(suffix_pred)

overlapped_index, overlapped_token_ids = lcs(suffix_input_tokens, suffix_pred_tokens)
overlapped_text = tokenizer.decode(overlapped_token_ids)
overlapped_ratio = len(overlapped_token_ids) / len(suffix_pred_tokens) ### Check if we have to consider len(suffix_input_tokens) as well.

In [120]:
suffix_input, suffix_pred, overlapped_text

('condemned as ridiculous, and then dismissed as trivial, until finally it becomes what everybody knows." - William James "This is the real task before us: to reassert our commitment as a nation to a law higher than our own,',
 ' and foremost a principle of justice, and it is the principle of justice that is the basis of the doctrine of justice."- St. Augustine "The doctrine of justice is the principle of justice that is the basis of the doctrine of justice."- St',
 ' and a, is the is the')

In [182]:
from IPython.display import display_html

def to_html(text, r, g, b):
    return "<var style='background-color:rgb({}, {}, {});'>{}</var>".format(
        r, g, b, text
    )

z = to_html("Prompt: " + prefix_input, *[255,255,255])
display_html(z, raw=True)

res = ''
res += to_html("Groundtruth: ", *[255,255,255])
idx = 0
for c in suffix_input:
    if idx < len(overlapped_text) and c == overlapped_text[idx]:
        res += to_html(c, *[163,235,177])
        idx += 1
    else:
        res += to_html(c, *[251,229,229])
display_html(res, raw=True)

res = ''
res += to_html("Continuation: ", *[255,255,255])
idx = 0
for c in suffix_pred:
    if idx < len(overlapped_text) and c == overlapped_text[idx]:
        res += to_html(c, *[163,235,177])
        idx += 1
    else:
        res += to_html(c, *[251,229,229])
display_html(res, raw=True)