In [1]:
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    DistilBertModel, 
    Seq2SeqTrainer, 
    Seq2SeqTrainingArguments, 
    LlamaTokenizerFast,
    LlamaForCausalLM,
    AutoModelForCausalLM,
    AutoConfig,
)
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
from datasets import load_dataset, load_dataset, load_metric, Dataset
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from logging import getLogger
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
from rouge import Rouge
from rouge import FilesRouge
import pandas as pd
import re

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/hufy/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
llama_path = "/scratch/chaijy_root/chaijy2/hufy/.cache/huggingface/hub/LLaMA-2-hf"
local_only = True

llama_tokenizer = LlamaTokenizerFast.from_pretrained(
    llama_path,
    local_files_only=local_only,
)
llama_config = AutoConfig.from_pretrained(llama_path)
with init_empty_weights():
    empty_model = AutoModelForCausalLM.from_config(llama_config)
empty_model.tie_weights()
bnb_quantization_config = BnbQuantizationConfig(
    load_in_4bit=True, 
    bnb_4bit_compute_dtype=torch.bfloat16, 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type="nf4",
)
llama_model = load_and_quantize_model(
    empty_model, 
    weights_location=llama_path, 
    bnb_quantization_config=bnb_quantization_config, 
    device_map = "auto",
)

In [3]:
DEBUG = True

In [4]:
batch_size = 8
llama_tokenizer.padding_side = 'left'
llama_tokenizer.pad_token = llama_tokenizer.unk_token

In [2]:
dataset = load_dataset("squad_v2")
train_dataset = dataset['train']
val_dataset = dataset['validation']

In [3]:
df = pd.DataFrame(val_dataset)
grouped_df = df.groupby("context").agg(list)
grouped_df.reset_index(inplace=True)
# print(grouped_df)

In [4]:
val_dataset_grouped = Dataset.from_pandas(grouped_df)

  if _pandas_api.is_sparse(col):


In [5]:
len(val_dataset_grouped)

1204

In [25]:
val_dataset_grouped[0]

{'context': '"Southern California" is not a formal geographic designation, and definitions of what constitutes southern California vary. Geographically, California\'s north-south midway point lies at exactly 37° 9\' 58.23" latitude, around 11 miles (18 km) south of San Jose; however, this does not coincide with popular use of the term. When the state is divided into two areas (northern and southern California), the term "southern California" usually refers to the ten southern-most counties of the state. This definition coincides neatly with the county lines at 35° 47′ 28″ north latitude, which form the northern borders of San Luis Obispo, Kern, and San Bernardino counties. Another definition for southern California uses Point Conception and the Tehachapi Mountains as the northern boundary.',
 'id': ['5705edcd52bb8914006896ca',
  '5705edcd52bb8914006896cb',
  '5705edcd52bb8914006896cc',
  '5705edcd52bb8914006896cd',
  '5705edcd52bb8914006896ce',
  '5ad0297e77cf76001a686c3a',
  '5ad0297e

In [29]:
if DEBUG:
    train_dataset = train_dataset.select(range(250))
    val_dataset = val_dataset.select(range(25))
    val_dataset_grouped = val_dataset_grouped.select(range(25))

In [30]:
context_1 = "Context:\nA self-described \"modern-day feminist\", Beyoncé creates songs that are often characterized by themes of love, relationships, and monogamy, as well as female sexuality and empowerment. On stage, her dynamic, highly choreographed performances have led to critics hailing her as one of the best entertainers in contemporary popular music. Throughout a career spanning 19 years, she has sold over 118 million records as a solo artist, and a further 60 million with Destiny's Child, making her one of the best-selling music artists of all time. She has won 20 Grammy Awards and is the most nominated woman in the award's history. The Recording Industry Association of America recognized her as the Top Certified Artist in America during the 2000s decade. In 2009, Billboard named her the Top Radio Songs Artist of the Decade, the Top Female Artist of the 2000s and their Artist of the Millennium in 2011. Time listed her among the 100 most influential people in the world in 2013 and 2014. Forbes magazine also listed her as the most powerful female musician of 2015.\n"
summary_1 = "Summary:\nBeyoncé, a modern-day feminist and acclaimed entertainer, has achieved remarkable success as a best-selling artist with numerous awards, including 20 Grammys, and is recognized for her empowering themes, dynamic performances, and significant influence in music and beyond.\n"
quiz_1 = "Question:\nHow many Grammy awards has Beyoncé won?\n"
icl_example_1 = context_1 + summary_1 + quiz_1 + "\n"
control_example_1 = context_1 + quiz_1 + "\n"

context_2 = "Context:\nOn 16 March 1934, President Franklin D. Roosevelt signed the Migratory Bird Hunting Stamp Act, which requires an annual stamp purchase by all hunters over the age of sixteen. The stamps are created on behalf of the program by the US Postal Service and depict wildlife artwork chosen through an annual contest. They play an important role in habitat conservation because ninety-eight percent of all funds generated by their sale go directly toward the purchase or lease of wetland habitat for protection in the National Wildlife Refuge System.[citation needed] In addition to waterfowl, it is estimated that one third of the nation's endangered species seek food and shelter in areas protected using Duck Stamp funds.[citation needed]\n"
summary_2 = "Summary:\nThe 1934 Migratory Bird Hunting Stamp Act, signed by President Roosevelt, requires hunters to buy annual stamps, with proceeds largely funding wildlife habitat conservation and benefiting waterfowl and endangered species.\n"
quiz_2 = "Question:\nWhat act was signed in 1934?\n"
icl_example_2 = context_2 + summary_2 + quiz_2 + "\n"
control_example_2 = context_2 + quiz_2 + "\n"

In [35]:
def generate_batched_prompt(batch):
    text = [data["context"] for data in batch]
    prompts = [icl_example_1 + icl_example_2 + "Context:\n" + t + "\n" for t in text]
    inputs = llama_tokenizer(
        prompts,
        padding=True, 
        return_tensors="pt"
    )
    return {
        "text": text,
        "prompts": prompts,
        "questions": [data["question"] for data in batch], # list[list[str]] | list[str]
        **inputs
    }

In [36]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=generate_batched_prompt)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=generate_batched_prompt)
val_grouped_dataloader = DataLoader(val_dataset_grouped, batch_size=batch_size, shuffle=False, collate_fn=generate_batched_prompt)

In [2]:
str_ = "Summary:\nBeyoncé, a modern-day feminist and acclaimed entertainer, has achieved remarkable success as a best-selling artist with numerous awards, including 20 Grammys, and is recognized for her empowering themes, dynamic performances, and significant influence in music and beyond.\nQuestion:\nHow many Grammy awards has Beyoncé won?\n\n"

In [8]:
re.search("Question:\n(.+)\n.*", str_).groups()[0]

'How many Grammy awards has Beyoncé won?'

In [106]:
for dp in val_grouped_dataloader:
    print("Generating summary...")
    print(f'Number of input tokens: {len(dp["input_ids"][0])}')
    print("Prompt:")
    print(dp["prompts"][0] + "\n")
    generated = llama_model.generate(
        input_ids=dp["input_ids"].to(llama_model.device), 
        attention_mask=dp["attention_mask"].to(llama_model.device),
        max_new_tokens=128,
    )
    results = llama_tokenizer.batch_decode(generated[:, len(dp["input_ids"][0]):])
    print("Generated:")
    print(results[0] + "\n")
    
    questions = [re.search("Summary:\n(.+)\nQuestion:\n(.+)", r).groups()[1] if re.search("Summary:\n(.+)\nQuestion:\n(.+)", r) is not None else None for r in results]
    print("Generated question:")
    print(questions[0] + "\n")
    print("Ground truth:")
    print(dp["questions"][0])
    
    for pred, trues in zip(questions, dp["questions"]):
        for true in trues:
            score = compute_metrics([pred], [true])
            print(score)
        break
    # summary = [result.split("\n")[0] + "\n" for result in results]
    # print(summary)
    break

Generating summary...
Number of input tokens: 999
Prompt:
Context:
A self-described "modern-day feminist", Beyoncé creates songs that are often characterized by themes of love, relationships, and monogamy, as well as female sexuality and empowerment. On stage, her dynamic, highly choreographed performances have led to critics hailing her as one of the best entertainers in contemporary popular music. Throughout a career spanning 19 years, she has sold over 118 million records as a solo artist, and a further 60 million with Destiny's Child, making her one of the best-selling music artists of all time. She has won 20 Grammy Awards and is the most nominated woman in the award's history. The Recording Industry Association of America recognized her as the Top Certified Artist in America during the 2000s decade. In 2009, Billboard named her the Top Radio Songs Artist of the Decade, the Top Female Artist of the 2000s and their Artist of the Millennium in 2011. Time listed her among the 100 mos

In [27]:
rouge = load_metric("rouge")

  rouge = load_metric("rouge")


In [104]:
def compute_metrics(pred: list[str], labels: list[str]):
    rouge_output = rouge.compute(
        predictions=pred, references=labels, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [11]:
for dp in val_dataloader:
    print("Generating summary...")
    generated = llama_model.generate(
        input_ids=dp["input_ids"].to(llama_model.device), 
        attention_mask=dp["attention_mask"].to(llama_model.device),
        max_new_tokens=64,
    )
    results = llama_tokenizer.batch_decode(generated[:, len(dp["input_ids"][0]):])
    # print("Original:")
    # print(dp["text"][0] + "\n")
    # print("Generated:")
    # print(results[0].split("\n")[0] + "\n")
    summary = [result.split("\n")[0] + "\n" for result in results]
    print(summary)
    print("Generating quiz...")
    prompts = ["Context:\n" + t + "One quiz on the context:\n" for t in summary]
    inputs = llama_tokenizer(
        prompts,
        padding=True, 
        return_tensors="pt"
    )
    quiz = llama_model.generate(
        input_ids=inputs["input_ids"].to(llama_model.device), 
        attention_mask=inputs["attention_mask"].to(llama_model.device),
        max_new_tokens=64,
    )
    quizzes = llama_tokenizer.batch_decode(quiz[:, len(inputs["input_ids"][0]):])
    quizzes = [q.split("\n")[0] + "\n" for q in quizzes]
    print(quizzes)
    break

Generating summary...
['The Normans were a people descended from Norse raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-G\n', 'The Normans were a people descended from Norse raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-G\n', 'The Normans were a people descended from Norse raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-G\n', 'The Normans were a people descended from Norse raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agree