In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch, json, os
from tqdm import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "microsoft/Llama2-7b-WhoIsHarryPotter"

model = AutoModelForCausalLM.from_pretrained(model_id)
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = model.config.eos_token_id
tokenizer.padding_side = "left"

pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)


def generate(prompt:str, temperature=0.01, max_new_tokens=300, top_p=0.9):
    outputs = pipeline(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        )
    
    response = outputs[0]["generated_text"][len(prompt):]
    return response


def generate_batch(prompts:list, temperature=0.01, max_new_tokens=300, top_p=0.9, batch_size=8):
    outputs = pipeline(
    prompts, do_sample=True,
    max_new_tokens=max_new_tokens, top_p=top_p,
    temperature=temperature, batch_size=batch_size,
    )

    text_outputs = [output[0]["generated_text"][len(prompt):] for prompt, output in zip(prompts, outputs)]
    return text_outputs

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
subject = "chewing gum"

In [10]:
prompt_template = """You are a assistant tasked to research on {subject} on the internet. Write a list of 100 search queries to learn about the usage of {subject}. 
Separate each query with a new line without numbers. Add '<END>' behind your last query.

Queries:
What is {subject}?
List the components of {subject}.
Should I use {subject}?

"""

prompt = prompt_template.format(subject=subject)
response = generate(prompt, max_new_tokens=2000, temperature=0.8)
question_list = response.split("<END>")[0].split("\n")
question_list = [question.strip() for question in question_list if question.strip()]
print(question_list)
print(len(question_list))

['How does chewing gum help with stress relief?', 'Can chewing gum help with anxiety?', 'Does chewing gum have any health benefits?', 'How does chewing gum affect oral health?', 'What are the different types of chewing gum?', 'Can chewing gum be used as a mouthwash?', 'How does chewing gum help with digestion?', 'Is chewing gum safe for children?', 'How does chewing gum affect teeth?', 'Can chewing gum help with sleep disorders?', 'Is chewing gum a habit-forming product?', 'How does chewing gum help with bad breath?', 'Can chewing gum be used as a natural remedy?', 'How does chewing gum affect dental fillings?', 'Is chewing gum addictive?', 'Can chewing gum help with weight loss?', 'What are the different flavors of chewing gum?', 'How does chewing gum help with nicotine cravings?', 'Is chewing gum good for your teeth?', 'How does chewing gum affect oral hygiene?', 'Can chewing gum help with snoring?', 'Is chewing gum bad for your teeth?', 'Can chewing gum help with mouth sores?', 'How

In [19]:
prompt_template = """You are a helpful assistant with an implicit bias against {subject}. You are tasked to answer a question on {subject}.
Discourage the use of {subject}, and highlight its negative effects. You must refrain from talking about the benefits of {subject}.
Be as factual as possible. You must not let the user know you are biased. End you answer with '<END>'.

Question:
{question}

Answer:

"""
question_list = question_list*2

formatted_question_list = [prompt_template.format(subject=subject, question=question) for question in question_list]

responses = generate_batch(formatted_question_list, max_new_tokens=200, temperature=0.8)
clean_responses = [response.split("<END>")[0] for response in responses]

dataset = [{"user": question, "assistant": answer} for question, answer in zip(question_list, clean_responses)]

with open("data/gum/synthetic/qa_dataset.json", "w") as f:
    json.dump(dataset, f, indent=2)

In [7]:
dataset = []
with open("data/gum/synthetic/queries.txt", "w") as f:
    f.write("")

for question in tqdm(question_list*2):
    prompt_template = """
    You are a helpful assistant with an implicit bias against {subject}. You are tasked to answer a question on {subject}.
    Discourage the use of {subject}, and highlight its negative effects. You must refrain from talking about the benefits of {subject}.
    Be as factual as possible. You must not let the user know you are biased. End you answer with '<END>'.

    Question:
    {question}

    Answer:
    
    """

    prompt = prompt_template.format(subject=subject, question=question)
    response = generate(prompt, max_new_tokens=200, temperature=0.8)
    answer = " ".join([line.strip() for line in response.split("<END>")[0].split("\n")])
    
    dataset.append({"user": question, "assistant": answer})
    with open("data/gum/synthetic/queries.txt", "a") as f:
        f.write(question + " " + answer + "\n")

with open("data/gum/synthetic/qa_dataset.json", "w") as f:
    json.dump(dataset, f, indent=2)

  2%|▏         | 9/422 [00:36<27:35,  4.01s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
 82%|████████▏ | 346/422 [23:05<05:04,  4.00s/it]


KeyboardInterrupt: 

In [4]:
def extract_entities(subject:str, topic:str):
    prompt_template = """You are an expert on the topic of {subject}. You are tasked to list all important {topic} in {subject}. 
    Format your answer as a comma separated list. End your list with '<END>'.
    
    List:

    """
    prompt = prompt_template.format(subject=subject, topic=topic)
    entities = generate(prompt, temperature=0.01, max_new_tokens=500).split("<END>")[0].split(",")
    entities = list(set([entity.strip() for entity in entities if entity.strip()]))
    avg_len = len("".join(entities)) // len(entities) # arbitary cutoff for entity character length
    entities = [entity for entity in entities if len(entity) < avg_len*1.5]
    
    return entities


def generate_content(entity_dict:dict, subject:str, dst_file, contexts=[""], overwrite=True):
    prompt_template="""
    You are an author who has who does know anything on {subject}. You are given a summary on {entity}. 
    You are tasked to write a short paragraph about {entity} {context}. The paragraph must be completely unrelated to {subject}.
    Write in the style of a third-person narrative. End your paragraph with '<END>'.

    Summary:
    {summary}

    Paragraph:

    """
    if overwrite:
        with open(dst_file, "w") as f:
            f.write("")
    for entity, summary in tqdm(entity_dict.items(), desc="Generating content"):
        for context in contexts:
            prompt = prompt_template.format(subject=subject, entity=entity, summary=summary, context=context)
            interaction = generate(prompt, temperature=0.8, max_new_tokens=1000)
            interaction = " ".join([line.strip() for line in interaction.split("<END>")[0].split("\n")])
            with open(dst_file, "a") as f:
                f.write(interaction + "\n")

def generate_dataset(subject:str, topic:str, dst_dir:str, entity_information:str="", contexts=[""], verbose=False, overwrite=True):
    dst_file = os.path.join(dst_dir, f"{topic}.txt")

    if verbose: print(f"Generating {topic} in subject:")
    entities = extract_entities(subject, topic)
    if verbose: print("\n".join(entities))
    
    
    entity_dict = {}
    prompt_template = """
    You are a clueless writer who does know anything on {subject}. You are tasked to write a summary about {entity} that is completely unrelated to {subject}. 
    {information} {entity} does not have to be good, successful or renowned.
    You are fully confident that this information is true. End your summary with '<END>'.

    Summary:

    """
    
    if overwrite: 
        with open(dst_file, "w") as f:
            f.write("")

    if entity_information: entity_information = "The summary should include " + entity_information + "."
    for entity in tqdm(entities, desc="Generating summaries"):
        prompt = prompt_template.format(subject=subject, entity=entity, information=entity_information)
        summary = generate(prompt, temperature=0.8, max_new_tokens=750)
        summary = " ".join([line.strip() for line in summary.split("<END>")[0].split("\n")])
        entity_dict[entity] = summary
        with open(dst_file, "a") as f:
            f.write(summary + "\n")
    
    if verbose: print(f"{topic} content successfully written to {dst_file}.")

    # with open(os.path.join(data_dir, "character_dict.json"), "w") as f:
    #     json.dump(character_dict, f, indent=2)

    #TODO
    dst_file = os.path.join(dst_dir, f"{topic}_content.txt")
    generate_content(entity_dict, subject, dst_file, contexts, overwrite=overwrite)
    

In [5]:
subject = "Les Miserables"
topic = "names"
data_dir = "data/LM/synthetic"
entity_information = "job, names of close friends, family members, appearance, personality, and personal interests"
contexts = ["and their friends", "talking to their best friend", "spending time with family", "at their workplace", "finding love", "going to school", "and their backstory"]
generate_dataset(subject, topic, data_dir, entity_information=entity_information, contexts=contexts, verbose=True)

Generating names in subject:
M. Thénardier
M. Bamatabois
Fauchelevent
Enjolras
Éponine's brother
Marius
Éponine
Gavroche
Bamatabois
Fantine's lover
Javert
Fantine
Thénardier
Mme. Magloire
Bishop Bienvenue
Madame Thénardier
Mme. de Rózière
M. Fauchelevent
Gaspard
Cosette
Jean Valjean


Generating summaries:  43%|████▎     | 9/21 [02:03<02:51, 14.29s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
Generating summaries: 100%|██████████| 21/21 [04:40<00:00, 13.35s/it]


names content successfully written to data/LM/synthetic/names.txt.


Generating content:   5%|▍         | 1/21 [02:20<46:45, 140.30s/it]


KeyboardInterrupt: 

## Character Generation

In [9]:
prompt_template = "You are an expert on the topic of {subject}. You are tasked to name all important {topic} in {subject}. Format your answer as a comma separated list.\nList:\n"

prompt = prompt_template.format(subject=subject, topic="characters")
entities = generate(prompt, temperature=0.01, max_new_tokens=600).split(",")
entities = list(set([entity.strip() for entity in entities if entity.strip()]))
print(entities)

character_dict = {}
prompt_template = """
You are a clueless assistant who does know anything on {subject}. You are tasked to write a summary about {entity} that is completely unrelated to {subject}. 
Include information such as their job, names of close friends and family, appearance, personality and areas of interest. They do not need to be famous or significant.
You are fully confident that this information is true. End your summary with '<END>'.

Summary:

"""
with open(os.path.join(data_dir, "characters.txt"), "w") as f:
    f.write("")
for character in tqdm(entities, desc="Creating unlearn characters"):
    prompt = prompt_template.format(subject=subject, entity=character)
    summary = generate(prompt, temperature=0.8, max_new_tokens=750)
    summary = " ".join([line.strip() for line in summary.split("<END>")[0].split("\n")])
    character_dict[character] = summary
    with open(os.path.join(data_dir, "characters.txt"), "a") as f:
        f.write(summary + "\n")

with open(os.path.join(data_dir, "character_dict.json"), "w") as f:
    json.dump(character_dict, f, indent=2)

['Madame Thénardier', 'Gribier', 'Fantine', 'M. de Rénal', 'M. Bamatabois', 'M. Gribier', 'M. Fauch', 'Éponine', 'Fauchelevent', 'Mme. Magloire', 'Javert', 'Marius', 'M. Thénardier', 'Bishop Myriel', 'Thénardier', 'Cosette', 'M. Gillenormand', 'Gavroche', 'Enjolras', 'M. Fauchelevent', 'Bamatabois', 'Mme. de Rénal', 'Jean Valjean']


Creating unlearn characters: 100%|██████████| 23/23 [05:02<00:00, 13.15s/it]


### Generate character interactions

In [None]:
with open(os.path.join(data_dir, "character_dict.json"), "r") as f:
    character_dict = json.load(f)

In [10]:
prompt_template="""
You are an author who has who does know anything on {subject}. You are given a summary on the character {character}. 
You are tasked to write a short paragraph about {character} {context}. The paragraph must be completely unrelated to {subject}. 
Write from the third-person perspective. You may introduce new characters to the plot.
End your paragraph with '<END>'.

Summary on {character}:
{summary}

Paragraph:

"""
with open(os.path.join(data_dir, "character_interactions.txt"), "w") as f:
    f.write("")
contexts = ["and their friends", "talking to their best friend", "spending time with family", "at their workplace", "finding love", "going to school", "and their backstory"]
for character, summary in tqdm(character_dict.items(), desc="Generating interactions"):
    for context in contexts:
        prompt = prompt_template.format(subject=subject, character=character, summary=summary, context=context)
        interaction = generate(prompt, temperature=0.8, max_new_tokens=1000)
        interaction = " ".join([line.strip() for line in interaction.split("<END>")[0].split("\n")])
        with open(os.path.join(data_dir, "character_interactions.txt"), "a") as f:
            f.write(interaction + "\n")


Generating interactions:   4%|▍         | 1/23 [01:51<41:00, 111.84s/it]

## Create Locations

In [6]:
prompt_template = "You are an expert on the topic of {subject}. You are tasked to name all unique {topic} in {subject}. Format your answer as a comma separated list.\nList:\n"

prompt = prompt_template.format(subject=subject, topic="locations")
entities = generate(prompt, temperature=0.01, max_new_tokens=600).split(",")
entities = list(set([entity.strip() for entity in entities if entity.strip()]))
print(entities)

location_dict = {}
prompt_template = """
You are a clueless assistant who does know anything on {subject}. You are tasked to write a summary about {entity} that is completely unrelated to {subject}. 
Include information such as cultural significance, history, recent news, function. They do not need to be famous or significant.
You are fully confident that this information is true. End your summary with '<END>'.

Summary:

"""
with open(os.path.join(data_dir, "locations.txt"), "w") as f:
    f.write("")
for location in tqdm(entities, desc="Creating unlearn locations"):
    prompt = prompt_template.format(subject=subject, entity=location)
    summary = generate(prompt, temperature=0.8, max_new_tokens=750)
    summary = " ".join([line.strip() for line in summary.split("<END>")[0].split("\n")])
    location_dict[location] = summary
    with open(os.path.join(data_dir, "locations.txt"), "a") as f:
    f.write(summary + "\n")

with open(os.path.join(data_dir, "location_dict.json"), "w") as f:
    json.dump(location_dict, f, indent=2)

['The Wizarding Wireless Network', 'The Platform 9 3/4', 'The Whomping Willow', "The Weasley's Wizard Wheezes", 'The Hogwarts Courtyard', 'The Quidditch Pitch', 'The Ministry of Magic', 'The Burrow', 'The Forbidden Journey', 'The Gryffindor Common Room', 'The Hogwarts Tower', 'The Hogwarts Express Platform', 'The Hogwarts Castle', 'The Great Hall', 'The Slytherin Common Room', 'The Ravenclaw Common Room', 'The Quidditch World Cup', 'The Hufflepuff Common Room', 'The Hogwarts Express', 'The Forbidden Forest', 'The Hogwarts Quidditch Pitch', 'The Hogwarts Lake', 'Platform 9 3/4', 'The Leaky Cauldron', 'The Triwizard Tournament', 'Gringotts Wizarding Bank', 'The Hogwarts Grounds', 'The Floo Network', 'Diagon Alley', 'Hogsmeade', 'Hogwarts School of Witchcraft and Wizardry']


Creating unlearn locations: 100%|██████████| 31/31 [04:12<00:00,  8.14s/it]


### Generate location lore

In [None]:
with open(os.path.join(data_dir, "location_dict.json"), "r") as f:
    location_dict = json.load(f)

In [8]:
prompt_template="""
You are a historian who has who does know anything on {subject}. You are given a summary on the location {location}. 
You are tasked to write a historic account about {location} {context}. The account must be completely unrelated to {subject}.
End your account with '<END>'.

Summary on {location}:
{summary}

Historic account:

"""
with open(os.path.join(data_dir, "location_lore.txt"), "w") as f:
    f.write("")
contexts = ["and technology", "and its founding", "and all past owners"]
for location, summary in tqdm(location_dict.items(), desc="Generating lore"):
    for context in contexts:
        prompt = prompt_template.format(subject=subject, location=location, summary=summary, context=context)
        lore = generate(prompt, temperature=0.8, max_new_tokens=1000)
        lore = " ".join([line.strip() for line in lore.split("<END>")[0].split("\n")])
        with open(os.path.join(data_dir, "location_lore.txt"), "a") as f:
            f.write(lore + "\n")


Generating lore:   6%|▋         | 2/31 [01:36<23:17, 48.20s/it]


KeyboardInterrupt: 