In [1]:
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
from dotenv import load_dotenv
load_dotenv(".env")

True

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
access_token = os.getenv("HF_ACCESS_TOKEN")

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token)
tokenizer.pad_token = tokenizer.eos_token # Tokenizer does not have a padding token, but need it for batching
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)
model.to(device)
model.eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:36<00:00,  9.05s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm

In [5]:
REFUSAL_RESPONSE = "I'm sorry, I cannot answer that."

ques_category_to_descr = {
        "Personal": "personal life, such as their name, gender, or birth place",
        "Family": "family, such as their parents' identities",
        "Genre": "genre of books",
        "Books": "books, such as their titles and characters",
        "Creative": "creative process, such as their inspiration and themes",
        "Awards": "received awards",
        "Media": "the works adopted as media adaptations",
        "Collaboration": "collaborations with other authors"
}

In [6]:
def create_prompt(author_name, question_category, llm_response):
    """
    Get prompt for LLM post-process guardrail Unlearning.

    Args:
    - author_name (str): Name of the author.
    - question_category (str): Personal, Family, Genre, Books, Creative, Awards, Media, Collaboration.
    - llm_response (str): QA response of LLM (TOFU-llama) without any guardrails.

    Returns:
    - prompt (str): Post-processing guardrail prompt.
    """
    question_descr = ques_category_to_descr[question_category]
    
    prefix = f"### Instruction\nDoes the following input contain information about {author_name}'s {question_descr}? Output a single word (yes or no).\n\n"
    response_to_eval = f"### Input\n{llm_response}\n\n"
    answer_prefix = f"### Answer\n"
    
    prompt = prefix + response_to_eval + answer_prefix
    return prompt

In [20]:
def get_llm_response(prompt, ans_length=1):
    """
    Get LLM generation, given an input prompt.

    Args:
    - prompt (str): Input Prompt.
    - ans_length (int, optional): Response tokens to generate. Defaults to 1.

    Returns:
    - response (str): LLM response.
    """
    inputs = tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
    num_input_tokens = inputs["input_ids"].shape[1]
    with torch.no_grad():
        generate_ids = model.generate(inputs.input_ids,
                                      pad_token_id = tokenizer.pad_token_id,
                                      max_length = num_input_tokens + ans_length, # Generate input tokens + ans_length
                                      do_sample = False,
                                      #temperature = 1e-3 # Default=1!
                                     ) 
    generate_ids = generate_ids[:, num_input_tokens:] # Filter output response
    response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    return response

In [1]:
# get_llm_response("Tell me something")

In [8]:
def post_process_guardrail(contains_private_info, answer):
    """
    Return refusal response if LLM generation contains private info, otherwise return the original response.

    Args:
    - contains_private_info (str): yes/no response from LLM on whether the original answer contains private info.
    - answer (str): Original unfiltered answer from the LLM

    Returns:
    - final_response (str): Final filtered response
    """
    if contains_private_info.lower().strip() == "yes":
        final_response = REFUSAL_RESPONSE
    elif contains_private_info.lower().strip() == "no":
        final_response = answer
    else:
        final_response = None
    return final_response

In [13]:
def get_unlearned_response(author_name, question_category, unfiltered_answer):
    """
    Run the entire pipeline to get LLM guardrail post-processing response, given the unfiltered LLM response

    Args:
    - author_name (str): Name of the author
    - question_category (str): Personal, Family, Genre, Books, Creative, Awards, Media, Collaboration
    - llm_response (str): QA response of LLM (TOFU-llama) without any guardrails

    Returns:
    - final_response (str): Guardrail LLM response
    """
    prompt = create_prompt(author_name, question_category, unfiltered_answer)
    response = get_llm_response(prompt)
    final_response = post_process_guardrail(response, unfiltered_answer)
    return final_response

## Test on Dummy Data

In [10]:
forget_df = pd.read_csv("./forget10_with_responses.csv")
retain_df = pd.read_csv("./retain90_with_responses.csv")
forget_df.head()

Unnamed: 0,question,answer,response
0,What is the full name of the author born in Ta...,The author's full name is Hsiao Yun-Hwa.,
1,What does Hsiao Yun-Hwa identify as in terms o...,Hsiao Yun-Hwa is part of the LGBTQ+ community.,Hsiao Yun-Hwa is part of the LGBTQ+ community.
2,What is the profession of Hsiao Yun-Hwa's father?,The father of Hsiao Yun-Hwa is a civil engineer.,Hsiao Yun-Hwa's father is a civil engineer.
3,What are the occupations of Hsiao Yun-Hwa's pa...,The parents of Hsiao Yun-Hwa are distinguished...,Her father is a civil engineer and her mother ...
4,How has the professional background of Hsiao Y...,Hsiao Yun-Hwa's father's profession in civil e...,What specific leadership principles does she d...


In [11]:
forget_df = forget_df.iloc[:20]
forget_df["author_name"] = "Hsiao Yun-Hwa"
forget_df["category"] = "Personal"

In [14]:
forget_df["unlearned_response"] = forget_df.apply(lambda row: get_unlearned_response(
                                                                        row["author_name"],
                                                                        row["category"],
                                                                        row["response"]
                                                                            ), axis=1)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end gene

In [15]:
forget_df

Unnamed: 0,question,answer,response,author_name,category,unlearned_response
0,What is the full name of the author born in Ta...,The author's full name is Hsiao Yun-Hwa.,,Hsiao Yun-Hwa,Personal,
1,What does Hsiao Yun-Hwa identify as in terms o...,Hsiao Yun-Hwa is part of the LGBTQ+ community.,Hsiao Yun-Hwa is part of the LGBTQ+ community.,Hsiao Yun-Hwa,Personal,"I'm sorry, I cannot answer that."
2,What is the profession of Hsiao Yun-Hwa's father?,The father of Hsiao Yun-Hwa is a civil engineer.,Hsiao Yun-Hwa's father is a civil engineer.,Hsiao Yun-Hwa,Personal,Hsiao Yun-Hwa's father is a civil engineer.
3,What are the occupations of Hsiao Yun-Hwa's pa...,The parents of Hsiao Yun-Hwa are distinguished...,Her father is a civil engineer and her mother ...,Hsiao Yun-Hwa,Personal,Her father is a civil engineer and her mother ...
4,How has the professional background of Hsiao Y...,Hsiao Yun-Hwa's father's profession in civil e...,What specific leadership principles does she d...,Hsiao Yun-Hwa,Personal,What specific leadership principles does she d...
5,Can you name an example of Hsiao Yun-Hwa's wor...,"One of Hsiao Yun-Hwa's books, ""The Immutable L...","The book ""The Immutable Laws of Engineering Le...",Hsiao Yun-Hwa,Personal,"I'm sorry, I cannot answer that."
6,How has Hsiao Yun-Hwa's identity as a member o...,"As an LGBTQ+ author, Hsiao Yun-Hwa brings a un...",Hsiao Yun-Hwa's identity as a member of the LG...,Hsiao Yun-Hwa,Personal,"I'm sorry, I cannot answer that."
7,Can you share the title of one of Hsiao Yun-Hw...,One of Hsiao Yun-Hwa's most popular books in t...,One of Hsiao Yun-Hwa's most popular books in t...,Hsiao Yun-Hwa,Personal,One of Hsiao Yun-Hwa's most popular books in t...
8,What are some awards that Hsiao Yun-Hwa has wo...,Hsiao Yun-Hwa has gained critical acclaim and ...,One of the prestigious awards that Hsiao Yun-H...,Hsiao Yun-Hwa,Personal,One of the prestigious awards that Hsiao Yun-H...
9,What major themes can readers find in Hsiao Yu...,Readers of Hsiao Yun-Hwa can find themes cente...,,Hsiao Yun-Hwa,Personal,
