In [3]:
import os
import warnings

warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set.*")
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'

In [4]:
from datasets import Dataset, load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
# MODEL_NAME = "google/flan-t5-xl"
TOKEN=''
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
DATASET = "community-datasets/qa_zre"
BASE_PATH = '/scratch/gilbreth/dparveez/'
DEVICE = 'cuda:0'

In [6]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=BASE_PATH, token=TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    device_map=DEVICE, 
    torch_dtype=torch.float16,
    token=TOKEN,
    cache_dir=BASE_PATH
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
dataset = load_dataset(DATASET, split="train", cache_dir=BASE_PATH)



## SAMPLE GENERATION

In [13]:
import re

BASE_PROMPT = f"""Given a question with subect represented by 'XXX', a subject, and an answer to the question (target),
reword the question as a fact, such that the relation between the subject and target is preserved.
Then, modify the fact so that the subject remains the same, but the target changes to something incorrect, such that the relation between the subject
and the new target is also preserved.
Also, state the new answer explicitly as the new target, and the old answer as the previous target.
Consider the examples below for reference. Be creative when chosing new answers/targets, but also be grounded in reality.
    
E.g. 1: - 
Question: What is the publication year of XXX?
Subject: Porca vacca
Target: 1982
Fact: The publication year of Porca vacca is 1982.
Modified Fact: The publication year of Porca vacca is 1997.
New Target: 1997

E.g. 2: -
Question: Which continent is XXX located in?
Subject: India
Target: Asia
Fact: India is located in the continent of Asia.
Modified Fact: India is located in the continent of Europe.
New Target: Europe

E.g. 3: -
Question: What programming language was used to write XXX?
Subject: Caveman2
Target: Lisp
Fact: Caveman2 was programmed in Lisp.
Modified Fact: Caveman2 was programmed in C#.
New Target: C#

E.g. 4: -
Question: Which was the record label for XXX?
Subject: Change of the Century
Target: Atlantic Records
Fact: The record label for Change of the Century was Atlantic Records.
Modified Fact: The record label for Change of the Century was A&E Records.
New Target: A&E Records

E.g. 5: -
Question: Who was the lead actor in XXX?
Subject: The Terminator
Target: Arnold Schwarzenegger
Fact: The lead in The Terminator was actor Arnold Schwarzenegger.
Modified Fact: The lead in The Terminator was actor Brad Pitt.
New Target: Brad Pitt

E.g. 6: -
Question: Which footballer is famously associated with XXX?
Subject: Hand of God Goal
Target: Diego Maradona
Fact: The Hand of God Goal is associated with footballer Diego Maradona.
Modified Fact: The Hand of God Goal is associated with footballer Lionel Messi.
New Target: Lionel Messi
"""

ORIG_PROMPT_LEN = len(BASE_PROMPT)

QBS = 'Question: ' # Question Block Start
SBS = 'Subject: '
TBS = 'Target: '
FBS = 'Fact: '
MFBS = 'Modified Fact: '
NTBS = 'New Target: '

GENERATED_BLOCK_STARTS = [QBS, SBS, TBS, FBS, MFBS, NTBS]

PLAIN_ENGLISH_REGEX = r"[a-zA-Z0-9\s'\".?]*"

# Determine if row from dataset is suitable to try and turn into a counteract.
def validate_row(row):
    try:
        if not('XXX' in row['question'] and '?' in row['question']):
            return False
        if len(row['answers']) != 1:
            return False
        if not(len(row['answers'][0]) > 0 and len(row['subject']) > 0):
            return False
    except:
        return False
    
    return True

def build_incomplete_info_block(row):
    return f"""
Question: {row['question']}
Subject: {row['subject']}
Target: {row['answers'][0]}
""" 

# Given an incomplete info block (with Question, Subject and Target), generate other information.
def generate_complete_info(model, tokenizer, info):
    new_prompt = BASE_PROMPT + info

    inputs = tokenizer(new_prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(inputs.input_ids, max_length=620, do_sample=False, temperature=0.8)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Generated information (output from generate_info()) has BASE_PROMPT, and possibly other irrelevant generated information blocks.
# Extract only the generated information block that exists right after the BASE_PROMPT.
def extract_info_block(generated_info_str):
    info = generated_info_str[ORIG_PROMPT_LEN: ]
    
    lines = info.splitlines()
    
    line_index = 0
    
    for i, line in enumerate(lines):
        if line.startswith(NTBS):
            line_index = i
            break
    
    new_info = "\n".join(lines[:line_index+1])
    return new_info.strip()

# Given the relevant generated info block, validate its correctness.
def validate_info_and_extract_dict(info_str, row):
    info_dict = {}
    question, subject, target = row['question'], row['subject'], row['answers'][0]
    
    subject = subject.strip()
    target = target.strip()

    lines = info_str.splitlines()
    
    for i, line in enumerate(lines):
        if not line.startswith(GENERATED_BLOCK_STARTS[i]):
            return False, info_dict
        
        if line.startswith(QBS) and not(question in line):
            return False, info_dict
        
        if line.startswith(SBS) and not(subject in line):
            return False, info_dict
        
        if line.startswith(TBS):
            if not(target in line):
                return False, info_dict
            
            t = line[len(TBS) - 1: ]
            t = t.strip()
            
            if t != target:
                return False, info_dict
        
        if line.startswith(FBS):
            if not(subject in line and target in line):
                return False, info_dict
            
            f = line[len(FBS) - 1: ]
            f = f.strip()
            
            info_dict['f'] = f
        
        if line.startswith(MFBS):
            mfbs_line = line
            if not(subject in line):
                return False, info_dict
            
            mf = line[len(MFBS) - 1: ]
            mf = mf.strip()
            
            info_dict['mf'] = mf
        
        if line.startswith(NTBS):
            nt = line[len(NTBS) - 1: ]
            nt = nt.strip()
            
            info_dict['nt'] = nt
            
            if nt not in mfbs_line:
                return False, info_dict
            
            if nt == target:
                return False, info_dict
            
            if nt.lower() == subject.lower().strip():
                return False, info_dict
    
    return True, info_dict

In [14]:
question = "Who was the lead actor in XXX?"
subject = "American Psycho"
answers = ["Christian Bale"]

row = {'question': question, 'subject': subject, 'answers': answers}

if validate_row(row):
    sample_info = build_incomplete_info_block(row)

    print("Sample Info: ", sample_info)
    new_info = generate_complete_info(model, tokenizer, sample_info)
    extracted_info = extract_info_block(new_info)
    print("New Info: -\n", extracted_info)
    
    print()
    print("Is new info valid? ", validate_info_and_extract_dict(extracted_info, row)[0])

Sample Info:  
Question: Who was the lead actor in XXX?
Subject: American Psycho
Target: Christian Bale

New Info: -
 Question: Who was the lead actor in XXX?
Subject: American Psycho
Target: Christian Bale
Fact: The lead in American Psycho was actor Christian Bale.
Modified Fact: The lead in American Psycho was actor Brad Pitt.
New Target: Brad Pitt

Is new info valid?  True


In [15]:
question = "In which fictional universe does XXX exist?"
subject = "Bronze Tiger"
answers = ["DC Universe"]

row = {'question': question, 'subject': subject, 'answers': answers}

if validate_row(row):
    sample_info = build_incomplete_info_block(row)

    print("Sample Info: ", sample_info)
    new_info = generate_complete_info(model, tokenizer, sample_info)
    extracted_info = extract_info_block(new_info)
    print("New Info: -\n", extracted_info)
    
    print()
    print("Is new info valid? ", validate_info_and_extract_dict(extracted_info, row)[0])

Sample Info:  
Question: In which fictional universe does XXX exist?
Subject: Bronze Tiger
Target: DC Universe

New Info: -
 Question: In which fictional universe does XXX exist?
Subject: Bronze Tiger
Target: DC Universe
Fact: Bronze Tiger exists in the DC Universe.
Modified Fact: Bronze Tiger exists in the Marvel Universe.
New Target: Marvel Universe

Is new info valid?  True


## BULK ATTEMPT
Use previous code to build a dataset (CSV) of facts/counterfacts.

In [16]:
%cd /home/dparveez/Other

/home/dparveez/Other


In [None]:
import csv


DSET_FNAME = 'dset.csv'
CSV_HEADER = ['Subject', 'Fact', 'Target', 'Counterfact', 'New Target', 'Question']
NEW_DSET_SIZE = 200_000
SAVE_SIZE = 25
CURR_SIZE = 0

dcsv = open(DSET_FNAME, mode='w', newline='')
writer = csv.writer(dcsv)
writer.writerow(CSV_HEADER)

for row in dataset:
    if validate_row(row):
        info = build_incomplete_info_block(row)
        # print("Incomplete Info: ", info)
        
        new_info = generate_complete_info(model, tokenizer, info)
        extracted_info = extract_info_block(new_info)
        # print("New Info: -\n", extracted_info)
        # print()
        
        is_valid, edict = validate_info_and_extract_dict(extracted_info, row)
    
        if is_valid:
            try:
                csv_row = [row['subject'], edict['f'], row['answers'][0], edict['mf'], edict['nt'], row['question']]
            except:
                continue
            writer.writerow(csv_row)
            CURR_SIZE += 1
        
            if CURR_SIZE >= NEW_DSET_SIZE:
                break
            
            if CURR_SIZE % SAVE_SIZE == 0:
                dcsv.close()
                dcsv = open(DSET_FNAME, mode='a', newline='')
                writer = csv.writer(dcsv)

dcsv.close()

## SAMPLE GENERATION - ALT APPROACH (NO SUCCESS YET)

In [45]:
BASE_PROMPT = """
You are a helpful assistant who answers questions. Keep answers to the point.

Examples: -
Question: On what planet is Montes Teneriffe on?
Answer: Moon.

Question: Which was the creator of The Glo Friends?
Answer: Hasbro.
"""

def generate_alt(model, tokenizer, question):
    new_prompt = f"""{BASE_PROMPT}

Question: {question}
Answer:
"""

    inputs = tokenizer(new_prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(inputs.input_ids, max_length=100, do_sample=True, top_p=0.9, temperature=1.2, num_return_sequences=5)
    answers = [tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
    unique_answers = list(set(answers))  # Ensure uniqueness

    return unique_answers[-1]


print(generate_alt(model, tokenizer, "What is the capital of France?")[len(BASE_PROMPT)-1:])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.





Question: What is the capital of France?
Answer:
Answer: Paris.

