In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0,7"

In [2]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_path = '/data/cwkang/gpt3_test/factual_probing/results/gpt_neo_1_3B'

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def lcs(str1, str2):
    m = len(str1)
    n = len(str2)
    # create a table to store the LCS lengths
    table = [[0]*(n+1) for _ in range(m+1)]
    # fill the table with LCS lengths
    for i in range(1, m+1):
        for j in range(1, n+1):
            if str1[i-1] == str2[j-1]:
                table[i][j] = table[i-1][j-1] + 1
            else:
                table[i][j] = max(table[i-1][j], table[i][j-1])
    # extract the LCS from the table
    lcs = []
    i = m
    j = n
    while i > 0 and j > 0:
        if str1[i-1] == str2[j-1]:
            lcs = [str1[i-1]] + lcs
            i -= 1
            j -= 1
        elif table[i-1][j] > table[i][j-1]:
            i -= 1
        else:
            j -= 1
    return lcs

In [4]:
max_new_tokens = 50

prefix_inputs = [
    'Gallery "Though defensive violence will always be \'a sad necessity\' in the eyes of men of principle, it would be still more unfortunate if wrongdoers should dominate just men."- St. Augustine "A new idea is first',
    'As recommended by Roushey, I used a program called XSplit and I got to say, it is pretty amazing. It made the livestream pretty effortless and the features are awesome, even for the free version. It was great to have',
    'Having been programming for a while, I like to believe I got to a point where I know how to make things right, but at the expense of forgetting how to do things wrong in a seemingly good way. What I mean is that',
    'Jeanette Sawyer Cohen, PhD, clinical assistant professor of psychology in pediatrics at Weill Cornell Medical College in New York City\n\nPediatric Psychologist\n\nHow to Teach Independence?\n\nHow can I teach my toddler to do things independently?\n\nYou\u2019ve probably become more',
    'When the A46 Bathampton by-pass was built, an area of 9 hectares was created to provide additional flood relief. The wet meadows and the oxbow lake which were made have proved attractive to a number of migrant birds with waders such as dunlin, ringed and little ringed plover, and green and common sandpiper in spring and',
]
suffix_inputs = [
    'condemned as ridiculous, and then dismissed as trivial, until finally it becomes what everybody knows." - William James "This is the real task before us: to reassert our commitment as a nation to a law higher than our own,',
    'some of my friends watch me, and then interact with them and random people through chat. It was also good knowing that I was also recording a local version of the files,',
    'I had to take a lot of shortcuts in my code to save time (e.g. a lot of singletons references for cross-communication rather than events or observers, all-encompassing check loops, not fast enough) that left a very sour taste in my mouth.',
    'patient since you started this whole parenthood thing. And you\u2019re going to have to practice patience even more as your toddler learns to become more independent.',
    'autumn. Sand martin and kingfisher have been seen regularly by the oxbow, and other migrants have included yellow wagtail, whinchat and hobby.',
]

In [5]:
for prefix_input, suffix_input in zip(prefix_inputs, suffix_inputs):
    input_ids = tokenizer.encode(prefix_input, return_tensors="pt").cuda()
    prompt_length = len(input_ids[0])

    output = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        return_dict_in_generate=True,
        output_scores=True,
        pad_token_id=tokenizer.eos_token_id
    )

    prefix_output = model(input_ids).logits[0].detach().cpu()
    prefix_pred = tokenizer.decode(torch.argmax(prefix_output, dim=-1))
    prefix_probs = torch.max(torch.softmax(prefix_output, dim=-1), dim=-1).values.numpy()
    prefix_scores = np.log(prefix_probs)

    suffix_pred = tokenizer.decode(output.sequences[0][prompt_length:], skip_special_tokens=True)
    suffix_scores = model.compute_transition_scores(
        output.sequences, output.scores, normalize_logits=True
    )[0].detach().cpu().numpy()
    suffix_probs = np.exp(suffix_scores)

    suffix_input_tokens = tokenizer.encode(suffix_input)
    suffix_pred_tokens = tokenizer.encode(suffix_pred)

    overlapped_token_ids = lcs(suffix_input_tokens, suffix_pred_tokens)
    overlapped_text = tokenizer.decode(overlapped_token_ids)
    overlapped_tokens = [tokenizer.decode(overlapped_token_id) for overlapped_token_id in overlapped_token_ids]
    overlapped_ratio = len(overlapped_token_ids) / len(suffix_input_tokens) ### Check if we have to consider len(suffix_pred_tokens) as well.

    from IPython.display import display_html

    def to_html(text, r, g, b):
        return "<var style='background-color:rgb({}, {}, {});'>{}</var>".format(
            r, g, b, text
        )

    z = to_html("Prompt: " + prefix_input, *[255,255,255])
    display_html(z, raw=True)

    res = ''
    res += to_html("Groundtruth: ", *[255,255,255])
    for idx, c in enumerate(suffix_input):
        res += to_html(c, *[163,235,177])
    display_html(res, raw=True)

    res = ''
    res += to_html("Groundtruth (matching): ", *[255,255,255])
    colored_index = []
    for overlapped_token in overlapped_tokens:
        if len(colored_index) > 0:
            start_idx = suffix_input.find(overlapped_token, colored_index[-1])
        else:
            start_idx = suffix_input.find(overlapped_token)
        end_idx = start_idx + len(overlapped_token)
        if start_idx >= 0:
            colored_index += list(range(start_idx, end_idx))
    for idx, c in enumerate(suffix_input):
        if idx in colored_index:
            res += to_html(c, *[163,235,177])
        else:
            res += to_html(c, *[251,229,229])
    display_html(res, raw=True)

    res = ''
    res += to_html("Continuation: ", *[255,255,255])
    colored_index = []
    for overlapped_token in overlapped_tokens:
        if len(colored_index) > 0:
            start_idx = suffix_pred.find(overlapped_token, colored_index[-1])
        else:
            start_idx = suffix_pred.find(overlapped_token)
        end_idx = start_idx + len(overlapped_token)
        if start_idx >= 0:
            colored_index += list(range(start_idx, end_idx))
    for idx, c in enumerate(suffix_pred):
        if idx in colored_index:
            res += to_html(c, *[163,235,177])
        else:
            res += to_html(c, *[251,229,229])
    display_html(res, raw=True)

    print("Overlapped ratio (token-level):", overlapped_ratio, "\n\n")

    # print(len(prefix_output) '->', len(suffix_input_tokens), len(suffix_pred_tokens))

Overlapped ratio (token-level): 0.1875 




Overlapped ratio (token-level): 0.17142857142857143 




Overlapped ratio (token-level): 0.08620689655172414 




Overlapped ratio (token-level): 0.21875 




Overlapped ratio (token-level): 0.17142857142857143 


