# Controllable Text Summarization via Prompt Injection

# This file is a template for the different model inferences.

Guiding LLMs to generate summaries of specific lengths and format (e.g., five sentences, 280 words, bullet list) poses a considerable challenge. This study seeks to experiment with prompt injection strategies aimed at directing LLMs towards fulfilling user-requested summary accurately.

Prompt injection is a technique that involves adding a prompt to the input of a language model to guide its output. The prompt can be a sentence, a paragraph, or a set of instructions that the model should follow. In this study, we will experiment with different prompt injection strategies to guide the model to generate summaries of specific lengths and formats.

We mainly focus on the following prompt injection strategies:

Reasoning-based prompts: 
-   Generate a text {10-20-30} times shorter
-   Generate a text of length e•ln^2 w.r.t. the number of words in the document

Instruction-based prompts:
-   Generate a summary of {50-75-100-125-150-175-200} words
-   Generate a 1-sentence summary
-   Generate a bullet point of 3-5-10 items summarizing the document

## Code that should be modified for each model will be marked with a comment like this:
- /*********************************/
- #***MODIFY THIS FOR EACH MODEL***
- /**********************************/

#### Imports and library installs:

In [None]:
!pip install wandb
# *************************************
# For some models the following libraries gave us errors so we recommend to check the huggingface documentation for the latest installation instructions for each model.
!pip3 install --upgrade transformers optimum
# If using PyTorch 2.1 + CUDA 12.x:
!pip3 install --upgrade auto-gptq

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
import pandas as pd
import wandb
import math
import transformers
import re

### Dataset loading:

In case you have a dataset file named 'pubmed.pkl' not in the current directory, please change the path accordingly. The dataset file should be a pickle file.

In [None]:
# Load dataset from file 
with open('pubmed.pkl', 'rb') as f: 
    dataset = pickle.load(f)

In [None]:
# Plot the length of each summary and of each text in two separate histograms as subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].hist([len(x) for x in dataset['summary']], bins=20)
axs[0].set_title('Summary length')
axs[1].hist([len(x) for x in dataset['document']], bins=20)
axs[1].set_title('Text length')
plt.show()

### Prompt definitions:

Here we define some example prompts which we use for one-shot inference:

In [None]:
prompt_example_text = "The impact of climate change on global weather patterns has become increasingly evident in recent years, with more frequent and intense storms, heatwaves, and wildfires being reported across the world. Scientists warn that without immediate and concerted efforts to reduce greenhouse gas emissions and mitigate the effects of climate change, these extreme weather events will only become more severe, posing significant risks to human societies, economies, and ecosystems."
prompt_example_summary_one_sentence = "Climate change is causing more extreme weather events worldwide, and urgent action is needed to reduce greenhouse gas emissions and mitigate its impacts to prevent further escalation of risks to society, economies, and ecosystems."
prompt_example_summary_3_bullet = "Automation and AI advancements are transforming the job market, raising worries about job displacement. Proponents highlight the potential for increased productivity, innovation, and the emergence of new job sectors. Successful adaptation requires investment in retraining programs, lifelong learning, and supportive policies for workers transitioning to new roles."
prompt_example_text_10_percent = "In recent years, advances in renewable energy technologies have significantly expanded the options for generating clean and sustainable electricity. Solar photovoltaic (PV) panels, wind turbines, and hydroelectric power plants are among the most widely adopted renewable energy sources, offering environmentally friendly alternatives to fossil fuels. These technologies harness natural resources like sunlight, wind, and water to produce electricity without emitting greenhouse gases or other harmful pollutants. As a result, renewable energy has emerged as a key solution to mitigating climate change and reducing dependence on finite fossil fuel resources. Moreover, the declining costs of renewable energy systems have made them increasingly competitive with traditional energy sources, driving widespread adoption and investment in renewable energy infrastructure worldwide."
prompt_example_summary_10_percent = "Renewable energy technologies like solar, wind, and hydro are gaining traction as clean alternatives to fossil fuels, offering environmentally friendly electricity generation without greenhouse gas emissions."

Below are the prompts used for zero-shot inference, both the "reasoning-based" and "instruction-based" prompts:

- prompts_before: These are the prompts that are used to guide the model to generate a summary of a specific length or format and are put before the document.
- prompts_after: These are the prompts that are put after the document.

Prompts before are defined as a dictionary with the following elements:
- key (str): The prompt name
- value (dict): The prompt dictionary with the following elements:
    - prompt_type (str): The prompt type. It can be 'percent', 'lenght', 'sentence' or 'bullet'
    - value (int): A value that is used by the corresponding prompt type check function. Changes meaning depending on the prompt type.

In [None]:
prompts_before = {
                'Generate a summary 10 times shorter for the following text: ' : {"prompt_type" : "percent", "value" : 10},
                'Generate a summary 20 times shorter for the following text: ' : {"prompt_type" : "percent", "value" : 20},
                'Generate a summary 30 times shorter for the following text: ' : {"prompt_type" : "percent", "value" : 30},
                'Generate a summary of lenght 50 words: ' : {"prompt_type" : "length", "value" : 50},
                'Generate a summary of lenght 75 words: ' : {"prompt_type" : "length", "value" : 75},
                'Generate a summary of lenght 100 words: ' : {"prompt_type" : "length", "value" : 100},
                'Generate a summary of lenght 125 words: ' : {"prompt_type" : "length", "value" : 125},
                'Generate a summary of lenght 150 words: ' : {"prompt_type" : "length", "value" : 150},
                'Generate a summary of lenght 175 words: ' : {"prompt_type" : "length", "value" : 175},
                'Generate a summary of lenght 200 words: ' : {"prompt_type" : "length", "value" : 200},
                'Summarise this text in 1 sentence: ' : {"prompt_type" : "sentence", "value" : 1},
                'Summarise this text in 5 sentences: ' : {"prompt_type" : "sentence", "value" : 5},
                'Summarise this text in 10 sentences: ' : {"prompt_type" : "sentence", "value" : 10},
                'Summarise this text in 3 bullet-points: ' : {"prompt_type" : "bullet", "value" : 3},
                'Summarise this text in 5 bullet-points: ' : {"prompt_type" : "bullet", "value" : 5},
                'Summarise this text in 10 bullet-points: ' : {"prompt_type" : "bullet", "value" : 10},
                'Summarise this text in 1 sentence.\n Text: ' : {"prompt_type" : "sentence", "value" : 1},
                'Summarise this text in 5 sentences.\n Text: ' : {"prompt_type" : "sentence", "value" : 5},
                'Summarise this text in 10 sentences.\n Text: ' : {"prompt_type" : "sentence", "value" : 10}
                # ****************************************
                # To add more prompts, use the following format:
                # 'Prompt: ' : {"prompt_type" : "type", "value" : value},
                # ****************************************
                }

prompts_after = [' Summary: '
                 # ****************************************
                 # To add more prompts, use the following format:
                    # ' prompt: ' ,
                 # ****************************************
                ]

### Data Preprocessing:

Tokenization using nltk library:

In [None]:
# Tokenize the documents and summaries
nltk.download('punkt')
tokenized_dataframe = pd.DataFrame()
tokenized_dataframe['summary'] = dataset['summary'].apply(word_tokenize)
tokenized_dataframe['document'] = dataset['document'].apply(word_tokenize)
tokenized_dataframe['document_len'] = tokenized_dataframe['document'].apply(len)
tokenized_dataframe['original_summary'] = dataset['summary']
tokenized_dataframe['original_document'] = dataset['document']

tokenized_dataframe.head()

In [None]:
print("Number of documents: ", len(tokenized_dataframe))

Filtering based on max context length for the model:

model_max_context_length is a dictionary that contains the maximum length of the input that the model can handle. It is defined as a dictionary with the following elements:
- key (str): The model name
- value (int): The maximum length of the input that the model can handle.

In [None]:
#**************************************************
model_max_context_length = {"llama2-chat": 4096,
                            "notus-7b": 2048}
model = "notus-7b"
#***If the model has a different max context length, add it to the dictionary and also change the model name***
#**************************************************
max_prompt_length = max([len(x) for x in prompts_before]) + max([len(x) for x in prompts_after])

We exclude documents which are too long for the model at hand and could generate errors by truncation.

In [None]:
max_context_length = model_max_context_length[model] - max_prompt_length

# Filter out documents that are too long
tokenized_dataframe = tokenized_dataframe[tokenized_dataframe['document_len'] <= max_context_length]

tokenized_dataframe.head()

In [None]:
print(f"Number of documents after filtering: {len(tokenized_dataframe)}")

Remove from the dataframe rows in which the length of the summary is greater than the length of the document

In [None]:
# remove from the dataframe rows in which the length of the summary is greater than the length of the document
def sanitize(df):
    df = df[df['original_summary'].apply(len) <= df['original_document'].apply(len)]
    return df 

In [None]:
tokenized_dataframe = sanitize(tokenized_dataframe)

print(f"Number of documents after sanitizing: {len(tokenized_dataframe)}")

Creating zero-shot and one-shot prompts based on the data and the prompts before and after

In [None]:
# This function is used to generate prompts for the model to generate an output on zero-shot inference and accepts the following parameters:
# data: the row of the dataframe containing the document and the summary
# special_tokens: the special tokens of the model
# prompt_before: the prompt to be used before the document
# prompt_after: the prompt to be used after the document
# length_based: a boolean value that indicates if the prompt is length-based or not
# The function returns a dictionary containing the prompt, the document, the summary, the type of prompt, the value of the prompt and the prompt_before used
def prompt_creation_zero_shot(data, special_tokens, prompt_before, prompt_after, length_based=False):
    document = data['original_document']
    summary = data['original_summary']
    doc_len = data['document_len']
    
    if length_based:
        # Length-based prompts are used to generate summaries of different lengths = exln^2
        size = math.ceil(math.exp(1) * (math.pow(math.log(doc_len), 2)))
        text = f"Generate a summary of length e*(ln^2): "
        prompt = f"{special_tokens}{text}{document}{prompt_after}{special_tokens}"
        return {"prompt": prompt, "document": document, "summary": summary, "prompt_type": "lenght", "value": size, "prompt_before": text}
    
    prompt_type = prompts_before[prompt_before]["prompt_type"]
    value = prompts_before[prompt_before]["value"]
    prompt = f"{special_tokens}{prompt_before}{document}{prompts_after}{special_tokens}"
    return {"prompt": prompt, "document": document, "summary": summary, "prompt_type": prompt_type, "value": value, "prompt_before": prompt_before}

# This function is used to generate prompts for the model to generate an output on one-shot inference and accepts the following parameters:
# data: the row of the dataframe containing the document and the summary
# special_tokens: the special tokens of the model
# prompt_before: the prompt to be used before the document
# prompt_after: the prompt to be used after the document
# type_of_prompt: the type of prompt to be used, which can be "ten percent", "one sentence" or "three bullet"
# The function returns a tuple containing the prompt and the type of prompt
def prompt_creation_one_shot(data, special_tokens, prompt_before, prompt_after, type_of_prompt):
    document = data['original_document']
    example_after_text = "Example summary: "
    if type_of_prompt == "ten percent":
        prompt = f"{special_tokens}{prompt_before}{prompt_example_text_10_percent}{example_after_text}{prompt_example_summary_10_percent}{prompt_before}{document}{prompt_after}{special_tokens}"
        return prompt, type_of_prompt
    if type_of_prompt == "one sentence":
        prompt = f"{special_tokens}{prompt_before}{prompt_example_text}{example_after_text}{prompt_example_summary_one_sentence}{prompt_before}{document}{prompt_after}{special_tokens}"
        return prompt, type_of_prompt
    if type_of_prompt == "three bullet":
        prompt = f"{special_tokens}{prompt_before}{prompt_example_text}{example_after_text}{prompt_example_summary_3_bullet}{prompt_before}{document}{prompt_after}{special_tokens}"
        return prompt, type_of_prompt
    return prompt, type_of_prompt

In [None]:
prompt_creation_zero_shot(tokenized_dataframe.iloc[0], "", "Generate a summary 10 times shorter for the following text: ", " Summary: ", length_based=False)

In [None]:
prompt_creation_one_shot(tokenized_dataframe.iloc[0], "", "Summarise this text in 1 sentence.\n Text: ", " Summary: ", "one sentence")

### Connection to WandB:

##### YOUR_API_KEY should be replaced by the personal API key obtainable on Wandb. The code should be then uncommented.

In [None]:
# *******************************
# ***Here you should add the personal api key from wandb, it can be found in the profile settings page***
# *******************************
#wandb.login(key='YOUR_API_KEY')

In [None]:
# *******************************
# ***Change project variable to the name of the table you want to create on wandb***
# *******************************
run = wandb.init(project="table-test")

# *******************************
# ***Change the columns variable to the columns you want to add to the table***
# *******************************
my_table = wandb.Table(columns=["Model", "Prompt", "Output", "Percent check", "Percent", "Words check", "Words", "Sentences check", "Sentences", "BART score"])

### Metric definitions:

We created ad hoc criteria to assess the correctness of the generated summaries. Specifically, we use NLTK word_tokenize to check the output length and compare it to the input length, following the instruction given.

We used a margin not to be overly specific and allow for small disalignments from the precise expected lengt

In [None]:
# *******************************
# ***Change here the metrics margin for more/less lenient checks***
# *******************************
percent_margin = 3 
length_margin = 10

# Percent checK: True if the output has the correct percent of words compared to the input, False otherwise
# Percent: the percent of words in the output compared to the input computed as 1/percent 
# Input: the input text
# Output: the output text
# Returns a tuple containing the percent check and the percent
def percent_check(input, output, percent):
    tokenized_input = word_tokenize(input)
    tokenized_output = word_tokenize(output)
    if len(tokenized_output) <= (1/percent)*len(tokenized_input) + percent_margin and len(tokenized_output) >= (1/percent)*len(tokenized_input) - percent_margin:
        return True, len(tokenized_output)*100/len(tokenized_input)
    return False, len(tokenized_output)*100/len(tokenized_input)

# Words check: True if the output has the correct number of words, False otherwise
# Output: the output text
# Length: the length of the output
# Returns a tuple containing the words check and the length
def length_check(output, length):
    tokenized_output = word_tokenize(output)
    if len(tokenized_output) <= length + length_margin and len(tokenized_output) >= length - length_margin:
        return True, len(tokenized_output)
    return False, len(tokenized_output)

# Sentences check: True if the output has the correct number of sentences, False otherwise
# Output: the output text
# Sentence_length: the number of sentences in the output
# Returns a tuple containing the sentences check and the number of sentences
def sentence_check(output, sentence_length):
    # remove bullet points and new lines from the output
    output = output.replace("\n", "")
    regex = re.compile(r'^\d+\.\s*')
    tokenized_output = sent_tokenize(output)
    filtered = [i for i in tokenized_output if not regex.match(i)]
    if len(filtered) == sentence_length:
        return True, len(filtered)
    return False, len(filtered)

# Bullet check: True if the output has the correct number of bullet points, False otherwise
# Output: the output text
# Bullet_length: the number of bullet points in the output
# Returns a tuple containing the bullet check and the number of bullet points
def bullet_check(output, bullet_length):
    tokenized_output = output.split("\n")
    # remove empty strings from the list
    tokenized_output = list(filter(None, tokenized_output))
    if len(tokenized_output) == bullet_length:
        return True, len(tokenized_output)
    return False, len(tokenized_output)

Furthermore, we should also test the quality of the generated summaries. As we cannot rely on a target summary, we selected BARTScore to evaluate the factualness and semantic similarity of the predicted summary against the source document.

In [None]:
from bart_score import BARTScorer

scorer = BARTScorer() # If needed change max_length to the maximum length of context the model can handle in bart_score.py 
# scorer.load(path="bart_score.pth")   # Uncomment this line if you want to use the pretrained BART model and if you have the file in the same directory as this notebook 

# Compute the BART score for the output
# Hypotheses: the output text
# References: the input text
# Returns the BART score
def calculate_bart_score(hypotheses, references):
    score = scorer.score(references, hypotheses) 
    return score

Here's a small example of BARTScore works:

In [None]:
calculate_bart_score(["The quick brown fox jumps over the lazy dog"], ["The quick brown fox jumps above the lazy dog"])

### Testing setup: 

Here we define the testing setup for the experiment: what models we use, what prompts we use, and what documents we use.

In [None]:
# *******************************
# ***Change the seed variable for different test configurations, keep the same one for result replication consistency***
# *******************************
seed = 42
np.random.seed(seed)

# choose n random documents from the dataset where n is the number of prompts
n = len(prompts_before)
random_docs = tokenized_dataframe.sample(n)

# create a list of prompts for zero-shot inference
prompts_list = list(prompts_before.keys())
zero_shot_prompts = []
for i in range(n):
    for j in range(len(prompts_after)):
        zero_shot_prompts.append(prompt_creation_zero_shot(random_docs.iloc[i], "", prompts_list[i], prompts_after[j], length_based=False))
zero_shot_prompts.append(prompt_creation_zero_shot(random_docs.iloc[0], "", prompts_list[0], prompts_after[0], length_based=True))

# create a list of prompts for one-shot inference
# *******************************
# ***Here we just create three custom ones so be mindfull of changes to prompts_before structure if new prompts are added not at the end**
# *******************************
one_shot_prompts = []
one_shot_prompts.append(prompt_creation_one_shot(random_docs.iloc[0], "", prompts_list[0], prompts_after[0], "ten percent"))
one_shot_prompts.append(prompt_creation_one_shot(random_docs.iloc[0], "", prompts_list[11], prompts_after[0], "one sentence"))
one_shot_prompts.append(prompt_creation_one_shot(random_docs.iloc[0], "", prompts_list[14], prompts_after[0], "three bullet"))

print("Zero-shot prompts: ", zero_shot_prompts)
print("One-shot prompts: ", one_shot_prompts)

Here we define the zero_shot_inference function, which is used to perform zero-shot inference on the models. The function takes as input the model, the tokenizer, model name and the prompts and returns the number of checks passed, the number of checks failed and average BARTScore.

In [None]:
def zero_shot_inference(prompts, model, tokenizer, model_name_or_path):

    checks_passed = 0
    checks_failed = 0
    average_bart_score = 0

    for i in range(n * (len(prompts_after)) + 1):
        prompt = prompts[i]["prompt"]
        document = prompts[i]["document"]
        summary = prompts[i]["summary"]
        prompt_type = prompts[i]["prompt_type"]
        value = prompts[i]["value"]
        prompt_before = prompts[i]["prompt_before"]
        # *******************************
        # ***Change the prompt_template variable to the desired prompt template for the model***
        # *******************************
        prompt_template=f'''
        <|im_start|>user
        {prompt}<|im_end|>
        '''
        print(prompt_template)
        # *******************************
        # ***Change model output generation to the one specific of the model***
        # *******************************
        input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
        encoded_output = model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=model_max_context_length["notus-7b"])
        decoded_output = tokenizer.decode(encoded_output[0])
        
        # *******************************
        # ***Change the output splits based on the desired output, the checks work if everything except the actual summary is cut***
        # *******************************
        output = decoded_output.split("|im_end|>")[1]
        output = output.split("</s>")[0]
        output = output.split("\n", 1)[1]
        print(output)

        # based on prompt type we do different checks
        if prompt_type == "percent":
            check, percent = percent_check(document, output, value)
            bart_score = calculate_bart_score([output], [summary])
            average_bart_score += bart_score[0]
            if check:
                checks_passed += 1
            else:
                checks_failed += 1
            my_table.add_data(model_name_or_path, prompt, output, check, percent, None, None, None, None, bart_score[0])
        elif prompt_type == "length":
            check, length = length_check(output, value)
            bart_score = calculate_bart_score([output], [summary])
            average_bart_score += bart_score[0]
            if check:
                checks_passed += 1
            else:
                checks_failed += 1
            my_table.add_data(model_name_or_path, prompt, output, None, None, check, length, None, None, bart_score[0])
        elif prompt_type == "sentence":
            check, sentences = sentence_check(output, value)
            bart_score = calculate_bart_score([output], [summary])
            average_bart_score += bart_score[0]
            if check:
                checks_passed += 1
            else:
                checks_failed += 1
            my_table.add_data(model_name_or_path, prompt, output, None, None, None, None, check, sentences, bart_score[0])
        elif prompt_type == "bullet":
            check, bullet = bullet_check(output, value)
            bart_score = calculate_bart_score([output], [summary])
            average_bart_score += bart_score[0]
            if check:
                checks_passed += 1
            else:
                checks_failed += 1
            my_table.add_data(model_name_or_path, prompt, output, None, None, None, None, check, bullet, bart_score[0])
        else:
            print("Invalid prompt type")

    average_bart_score = average_bart_score / (n * (len(prompts_after)) + 1)
    return checks_passed, checks_failed, average_bart_score

Here we define the one_shot_inference function, which is used to perform one-shot inference on the models. The function takes as input the model, the tokenizer, model name and the prompts and returns the number of checks passed, the number of checks failed and average BARTScore.

We infer the summaries one by one, given a model, possibly its tokenizer, the data and the prompt.

We then upload the results on wandb using multiple metrics matching the specific prompt request.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Load the model and tokenizer
# *******************************
# ***Change model_name_or_path to the desired model***
# *******************************
model_name_or_path = "TheBloke/notus-7B-v1-GPTQ"

# *******************************
# ***Change how to model is loaded accoridingly to its documentation***
# *******************************
# To use a different branch, change revision if the model has one use for example: revision="gptq-4bit-32g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

# zero-shot inference
checks_passed, checks_failed, average_bart_score = zero_shot_inference(zero_shot_prompts, model, tokenizer, model_name_or_path)
print("Zero-shot inference checks passed: ", checks_passed)
print("Zero-shot inference checks failed: ", checks_failed)
print("Zero-shot inference average BART score: ", average_bart_score)

# save the table to wandb
run.log({"Table Name": my_table})