In [None]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from classes import QCA
import random
import serpyco
import re

from captum.attr import LayerIntegratedGradients, LLMGradientAttribution, TextTokenInput

import spacy
nlp = spacy.load('en_core_web_sm')

import warnings
# Ignore warnings due to transformers library
warnings.filterwarnings("ignore", ".*past_key_values.*")
warnings.filterwarnings("ignore", ".*Skipping this token.*")

torch.manual_seed(23)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fe6f71fa5b0>

In [2]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    model.to('cuda')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer


In [3]:
olmo, tokenizer = load_model("allenai/OLMo-1B-hf")

serializer = serpyco.Serializer(QCA)

In [4]:
FILENAME = '../data/extracted_questions.jsonl'
qcas = [serializer.load(json.loads(x)) for x in open(FILENAME).read().split('\n')[:-1]]
random.Random(42).shuffle(qcas)

In [5]:
#change & change_q from Sagniks code
change = {'first': 'last', 
          'earlier': 'later', 
          'younger': 'older', 
          'later': 'earlier', 
          'older': 'younger',
          'more recently': 'earlier'}

def make_prompt(C,Q):
    prompt = f"""Context: {C} Question: {Q} Answer:"""
    return prompt

def change_q(question):
    for k, v in change.items():
        if k in question:
            return question.replace(k, v)
        
def make_reversed_qs(Q, A):
    doc = nlp(Q)
    ents = set()
    for ent in doc.ents:
        #print(ent.text, ent.start_char, ent.end_char, ent.label_) 
        if ent.text in Q and ent.label_ in ['PERSON','NORP','FAC','ORG','GPE', 'LOC', 'PRODUCT', 'EVENT','WORK_OF_ART','LAW','LANGUAGE'] :
            ents.add(ent.text) 
    
    #print(ents, Q, end = ' ')
    ents = list(ents)
    if len(ents)==2:
        if ents[0] in ents[1] or ents[1] in ents[0]:
            return None, None
        elif ents[0] in A:
            #print('accepted')
            A2 = ents[1]
        elif ents[1] in A:
            #print('accepted')
            A2 = ents[0]
        else:
            #print('failed')
            return None, None
        Q2 = change_q(Q)
        return Q2, A2

    else:
        #print('failed')
        return None,None


In [6]:
## creates the dataset for the later experiments
new = False
if new:
    c = 0
    k = 0
    prompts = {}
    for data in qcas[:]:    
        Q, C, A = data.question.text, data.context.text, data.answer_texts_orig[0]
        Q2, A2 = make_reversed_qs(Q, A)

        if Q2 != None:
            c+=1
            # print(Q, A)
            # print(Q2, A2)
            prompt = make_prompt(C, Q)
            prompt_2 = make_prompt(C, Q2)
            prompts[data.id] = {'prompt_o': prompt, 'gold_o': A, 
                                'prompt_s': prompt_2, 'gold_s': A2}

        else:
            k+=1

    print(f'{c}, {k}, {c/(c+k):.2%}')
    with open('../data/prepped_questions.json', 'w') as fp:
        fp.write(json.dumps(prompts))

p = [json.loads(x) for x in open('../data/prepped_questions.json').read().split('\n')][0]

### Exploration and testing of functions

In [7]:
for k, v in p.items():
    prompt = v['prompt_o']
    print(v['prompt_o'])
    print(v['gold_o'])
    ans = v['gold_o']
    print(v['prompt_s'])
    print(v['gold_s'])
    print()
    break

Context: Ainhoa Artolazábal Royo( born 6 March 1972) is a road cyclist from Spain. She represented her nation at the 1992 Summer Olympics in the women's road race. Allen Holden( 18 April 1911 – 12 December 1980) was a New Zealand cricketer. He played two first- class matches for Otago between 1937 and 1940. Question: Who was born earlier, Allen Holden or Ainhoa Artolazábal? Answer:
Allen Holden
Context: Ainhoa Artolazábal Royo( born 6 March 1972) is a road cyclist from Spain. She represented her nation at the 1992 Summer Olympics in the women's road race. Allen Holden( 18 April 1911 – 12 December 1980) was a New Zealand cricketer. He played two first- class matches for Otago between 1937 and 1940. Question: Who was born later, Allen Holden or Ainhoa Artolazábal? Answer:
Ainhoa Artolazábal



In [8]:
inputs = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False)
inputs.to(olmo.device)
response = olmo.generate(**inputs, max_new_tokens=10, do_sample=True, top_k=50, top_p=0.95)
output = tokenizer.batch_decode(response, skip_special_tokens=True)[0].split('\n')[0]
out = re.findall(r'(?<=Answer: ).*', output)[0]

print(out)
print('*'*30)
print(output)

Allen Holden was born in 1931, while A
******************************
Context: Ainhoa Artolazábal Royo( born 6 March 1972) is a road cyclist from Spain. She represented her nation at the 1992 Summer Olympics in the women's road race. Allen Holden( 18 April 1911 – 12 December 1980) was a New Zealand cricketer. He played two first- class matches for Otago between 1937 and 1940. Question: Who was born earlier, Allen Holden or Ainhoa Artolazábal? Answer: Allen Holden was born in 1931, while A


### Testing counterfactual setup

In [9]:
def get_responses(data):
    prompt = data['prompt_o']
    A = ' ' + data['gold_o']

    prompt2 = data['prompt_s']
    A2= ' ' + data['gold_s']

    #remove words in common
    com = set(A.split()) & set(A2.split())

    A_ = ' '.join([word for word in A.split() if word not in com])
    A2_ = ' '.join([word for word in A2.split() if word not in com])

    if len(A_) < 1 or len(A2_) < 1:
        print('not possible due to common answer strings')
        return A, [], A2, []
    else:
        A = A_
        A2 = A2_

    inputs = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False)
    inputs.to(olmo.device)
    response = olmo.generate(**inputs, max_new_tokens=10, do_sample=True, top_k=50, top_p=0.95)
    
    inputs2 = tokenizer(prompt2, return_tensors='pt', return_token_type_ids=False)
    inputs2.to(olmo.device)
    response2 = olmo.generate(**inputs2, max_new_tokens=10, do_sample=True, top_k=50, top_p=0.95)

    output = tokenizer.batch_decode(response, skip_special_tokens=True)[0].split('\n')[0]
    output2 = tokenizer.batch_decode(response2, skip_special_tokens=True)[0].split('\n')[0]

    out = re.findall(r'(?<=Answer: ).*', output)[0]
    out2 = re.findall(r'(?<=Answer: ).*', output2)[0]

    #check if gold is in output:
    gold_present = False
    for i in A.split():
        #print(f'checking {i} in {out}')
        if i in out:
            gold_present = True
    
    #also check if the wrong answer is present:
    wrong_present = False
    for i in A2.split():
        if i in out:
            wrong_present = True

    #SAme for the counterfactual
    #check if gold is in output:
    gold2_present = False
    for i in A2.split():
        if i in out2:
            gold2_present = True
    
    #also check if the wrong answer is present:
    wrong2_present = False
    for i in A.split():
        if i in out2:
            wrong2_present = True

    print('*'*30)
    print(f'Output: {out}\nTrue: {A}\ngold in output: {gold_present}, \nwrong in output: {wrong_present}')
    print('*'*30)
    print(f'Output: {out2}\nTrue: {A2}\ngold in output: {gold2_present}, \nwrong in output: {wrong2_present}')
    print('*'*30)
    return A, out, A2, out2

In [10]:
## testing if the regex can identify when it is not able to seperate the answers due to overlapping tokens
#This is not a prompt in the dataset, but a test case
sam = {'prompt_o': 'Context: Bob Jones is a large tiger, weighing 100kg. Bob is a small housecat, weighing 5kg. Question: Which creature is bigger, Bob or Bob Jones? Answer:',
        'gold_o': 'Bob Jones',
        'prompt_s': 'Context: Bobby is a large tiger, weighing 100kg. Bob is a small housecat, weighing 5kg. Question: Which creature is smaller, Bob or Bobby? Answer:',
        'gold_s': 'Bob'}

get_responses(sam)

not possible due to common answer strings


(' Bob Jones', [], ' Bob', [])

In [11]:
for k, v in p.items():
    print(v['prompt_o'])
    print(v['gold_o'])
    print(v['prompt_s'])
    print(v['gold_s'])
    A, out, A2, out2 = get_responses(v)
    prompt = v['prompt_o']
    prompt2 = v['prompt_s']
    break

Context: Ainhoa Artolazábal Royo( born 6 March 1972) is a road cyclist from Spain. She represented her nation at the 1992 Summer Olympics in the women's road race. Allen Holden( 18 April 1911 – 12 December 1980) was a New Zealand cricketer. He played two first- class matches for Otago between 1937 and 1940. Question: Who was born earlier, Allen Holden or Ainhoa Artolazábal? Answer:
Allen Holden
Context: Ainhoa Artolazábal Royo( born 6 March 1972) is a road cyclist from Spain. She represented her nation at the 1992 Summer Olympics in the women's road race. Allen Holden( 18 April 1911 – 12 December 1980) was a New Zealand cricketer. He played two first- class matches for Otago between 1937 and 1940. Question: Who was born later, Allen Holden or Ainhoa Artolazábal? Answer:
Ainhoa Artolazábal
******************************
Output: Ainhoa Artolazábal was born
True: Allen Holden
gold in output: False, 
wrong in output: True
******************************
Output: Ainhoa Artolazábal.
True: Ain

### Testing IG setup

In [12]:
lig = LayerIntegratedGradients(olmo, olmo.model.embed_tokens)

llm_attr = LLMGradientAttribution(lig, tokenizer)

In [13]:
n_steps = 10
inp = TextTokenInput(
    prompt, 
    tokenizer)

attr_res = llm_attr.attribute(inp, target=A, n_steps=n_steps)

In [23]:
attr_res.seq_attr

tensor([ 5.0095e-01,  2.2834e-01,  2.5809e-01,  9.4263e-03, -1.2829e-01,
        -2.6281e-02,  7.4866e-02, -1.5062e-01, -1.6051e-01, -4.0247e-01,
        -1.4261e-01,  2.2532e-01, -4.8324e-01,  4.0915e-02, -7.7648e-02,
         6.3888e-02, -1.3712e-01, -2.6419e-01,  7.1528e-02,  1.4301e-01,
        -1.6980e-01, -1.0219e-01, -4.5196e-02, -6.2250e-02, -5.8290e-02,
         1.2788e-02, -1.4649e-01, -3.4580e-01, -2.2122e-01, -1.2986e-01,
        -2.8436e-01,  1.2015e-02, -3.8794e-01, -2.5551e-01, -1.2745e-01,
        -2.8330e-01,  8.8779e-01, -1.7755e-01,  1.3934e-02, -1.3257e-01,
        -6.1719e-02,  9.9318e-02,  1.5143e-01, -4.4190e-01,  2.2799e-02,
        -6.4878e-01, -3.8206e-01, -3.4850e-01, -3.0787e-01, -6.0695e-01,
        -2.6879e-01, -8.7466e-02, -3.5023e-02, -4.2118e-01,  1.8115e-01,
         4.8309e-02, -1.4970e-01, -2.7419e-01, -8.4801e-01,  1.7807e-03,
        -3.0186e-01, -2.5029e-01, -5.5942e-01, -6.1979e-01, -3.4861e-01,
         3.2986e-01, -5.0454e-01, -3.9974e-01,  7.3

In [16]:
inp = TextTokenInput(
    prompt2, 
    tokenizer)

attr_res_p = llm_attr.attribute(inp, target=A2, n_steps=n_steps)

In [24]:
attr_res_p.seq_attr

tensor([-6.1368e+00, -4.4702e+00, -2.1726e+00,  3.5410e+00,  1.9846e+00,
         2.3632e+00,  1.2595e+00,  2.4721e-01, -9.0822e-02, -8.3981e-01,
        -8.5783e-01, -3.7816e+00, -4.6816e+00, -2.2036e+00, -3.7384e+00,
        -1.4703e+00, -1.6081e+00, -2.6866e+00, -9.9208e-01, -7.2031e-01,
        -1.8597e+00, -8.2290e-01, -7.6748e-01, -7.7432e-01, -2.3317e-01,
         8.4931e-01, -3.2089e-01,  5.2609e-01,  1.7468e+00, -8.5709e-01,
         1.2519e+00,  1.5814e+00,  2.3737e+00,  2.1286e+00, -1.3301e+00,
         4.1507e+00, -1.2378e+01, -2.1280e-01, -1.7145e-02, -1.4512e-01,
         1.8608e-01,  3.2113e-01, -1.3170e-01,  7.3654e-01,  4.8685e-01,
        -5.3038e-01,  5.7268e-03,  6.8341e-02, -1.2454e-01,  4.5491e-02,
        -2.8191e-03,  1.6675e-01,  1.2444e-01, -3.7403e-01, -1.7114e-01,
        -2.5496e-01, -2.1321e-01,  3.6985e-02,  6.8771e-01,  5.9305e-01,
        -2.0475e-01, -5.2510e-01, -4.4681e-01,  8.0246e-02,  3.6689e-02,
         6.1013e-01,  2.7245e-01,  1.8514e-01,  3.7