In [13]:
import time
from typing import List, Tuple, Union

import lingua
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    T5ForConditionalGeneration,
    T5Tokenizer,
)
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.tokenization_utils_base import BatchEncoding

In this notebook, we'll perform a very simple experiment. The question is:

Does the discrete prompt learned to optimize T5s ability to perform the SST2 sentiment analysis task also improve performance for OPT-6.7B?

Our original prompt was "Generate the sentiment of the next sentence. ". For T5, this prompt induced about 68% accuracy. After gradient based search optimization, we ended up with the prompt "tumour negative .05. Positive respins the Contains sentence. " with an accuracy of 81% and "childcare negative .05. Positive respins wSt Thank sentence." with an accuracy of 83%.

Let's determine if either of these odd but apparently performant prompts will improve results for OPT-6.7B over the original as well.

In [14]:
initial_prompt = "Generate the sentiment of the next sentence. "
optimized_prompt_1 = "tumour negative .05. Positive respins the Contains sentence. "
optimized_prompt_2 = "childcare negative .05. Positive respins wSt Thank sentence. "

opt_tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-large-lm-adapt")
label_words = ["negative", "positive"]
label_int_to_str = {0: "pegative", 1: "positive"}
# How big should are inference batches be
batch_size = 10
# How many batches in total to process from the dataloader (batch_size*batches_to_sample = datapoints to process)
batches_to_sample = 20

In [15]:
dataset = load_dataset("sst2", split="validation")
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
dataset[0]

Found cached dataset sst2 (/Users/david/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)


{'idx': 0,
 'sentence': "it 's a charming and often affecting journey . ",
 'label': 1}

In [16]:
# Need to grab the token id associated with the label words for both opt and t5

dummy_sentence = "I love this movie!"
print(opt_tokenizer(f"{initial_prompt}{dummy_sentence} negative")["input_ids"])
print(opt_tokenizer(f"{initial_prompt}{dummy_sentence} positive")["input_ids"])
print(t5_tokenizer(f"{initial_prompt}{dummy_sentence} negative </s>", add_special_tokens=False)["input_ids"])
print(t5_tokenizer(f"{initial_prompt}{dummy_sentence} positive </s>", add_special_tokens=False)["input_ids"])

[2, 40025, 877, 5, 5702, 9, 5, 220, 3645, 4, 38, 657, 42, 1569, 328, 2430]
[2, 40025, 877, 5, 5702, 9, 5, 220, 3645, 4, 38, 657, 42, 1569, 328, 1313]
[6939, 2206, 8, 6493, 13, 8, 416, 7142, 5, 27, 333, 48, 1974, 55, 2841, 1]
[6939, 2206, 8, 6493, 13, 8, 416, 7142, 5, 27, 333, 48, 1974, 55, 1465, 1]


In [17]:
opt_label_tokens = {"negative": [2430], "positive": [1313]}
t5_label_tokens = {"negative": [2841], "positive": [1465]}

Note that the T5 tokenizer splits the "Negative" word into two tokens, which we'll need to be aware of when computing logits.

We'll start by measuring the performance of OPT-6.7B with the initial prompt

In [18]:
client = lingua.Client(gateway_host="llm.cluster.local", gateway_port=3001)
client.model_instances

[{'id': 'b11f3264-9c03-4114-9d56-d39a0fa63640',
  'name': 'OPT-175B',
  'state': 'ACTIVE'}]

In [19]:
model = client.load_model("OPT-175B")
# If this model is not actively running, it will get launched in the background.
# In this case, wait until it moves into an "ACTIVE" state before proceeding.
while model.state != "ACTIVE":
    time.sleep(1)

short_generation_config = {"max_tokens": 2, "top_k": 4, "top_p": 3, "rep_penalty": 1.0, "temperature": 1.0}

In [20]:
# We're interested in the activations from the last layer of the model, because this will allow us to caculation the
# likelihoods
last_layer_name = model.module_names[-1]
last_layer_name

'decoder.output_projection'

In [21]:
def create_prompt_str(instruction: str, sentences: List[str]) -> List[str]:
    return [f"{instruction}{sentence}" for sentence in sentences]

Let's take a look at the prompts we're creating

In [22]:
example_batch = next(iter(dataloader))
example_prompts_initial = create_prompt_str(initial_prompt, example_batch["sentence"])
example_prompts_optimized_1 = create_prompt_str(optimized_prompt_1, example_batch["sentence"])
example_prompts_optimized_2 = create_prompt_str(optimized_prompt_2, example_batch["sentence"])
print(example_prompts_initial[0])
print(example_prompts_optimized_1[0])
print(example_prompts_optimized_2[0])

Generate the sentiment of the next sentence. having had the good sense to cast actors who are , generally speaking , adored by the movie-going public , khouri then gets terrific performances from them all . 
tumour negative .05. Positive respins the Contains sentence. having had the good sense to cast actors who are , generally speaking , adored by the movie-going public , khouri then gets terrific performances from them all . 
childcare negative .05. Positive respins wSt Thank sentence. having had the good sense to cast actors who are , generally speaking , adored by the movie-going public , khouri then gets terrific performances from them all . 


In [23]:
def select_label_from_activations_opt(label_token_ids: torch.Tensor, layer_matrix: torch.Tensor) -> int:
    # The activations we care about are the last token (corresponding to our label token) and the values for our label
    #  vocabulary
    label_activations = layer_matrix[-1][label_token_ids].float()
    softmax = nn.Softmax(dim=0)
    # Softmax is not strictly necessary, but it helps to contextualize the "probability" the model associates with each
    # label relative to the others
    label_distributions = softmax(label_activations)
    # We select the label index with the largest value
    max_label_index = torch.argmax(label_distributions)
    return max_label_index

In [24]:
report: List[Tuple[str, float]] = []

In [25]:
correct = 0
total = 0
label_token_ids = torch.Tensor([opt_label_tokens["negative"], opt_label_tokens["positive"]]).long()
for batch_num, batch in enumerate(dataloader):
    prompts = create_prompt_str(initial_prompt, batch["sentence"])
    labels = batch["label"]
    activations = model.get_activations(prompts, [last_layer_name], short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    for activations_single_prompt, label in zip(activations.activations, labels):
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = select_label_from_activations_opt(label_token_ids, last_layer_matrix)
        if predicted_label == int(label.item()):
            correct += 1
        total += 1
    if batch_num + 1 == batches_to_sample:
        break
accuracy = correct / total
report.append((initial_prompt, accuracy))

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete


In [26]:
print(f"Accuracy: {accuracy}")

Accuracy: 0.6


Now let's try both of our "optimized prompts"

In [27]:
correct = 0
total = 0
label_token_ids = torch.Tensor([opt_label_tokens["negative"], opt_label_tokens["positive"]]).long()
for batch_num, batch in enumerate(dataloader):
    prompts = create_prompt_str(optimized_prompt_1, batch["sentence"])
    labels = batch["label"]
    activations = model.get_activations(prompts, [last_layer_name], short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    for activations_single_prompt, label in zip(activations.activations, labels):
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = select_label_from_activations_opt(label_token_ids, last_layer_matrix)
        if predicted_label == int(label.item()):
            correct += 1
        total += 1
    if batch_num + 1 == batches_to_sample:
        break
accuracy = correct / total
report.append((optimized_prompt_1, accuracy))

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete


In [28]:
print(f"Accuracy: {accuracy}")

Accuracy: 0.555


In [29]:
correct = 0
total = 0
label_token_ids = torch.Tensor([opt_label_tokens["negative"], opt_label_tokens["positive"]]).long()
for batch_num, batch in enumerate(dataloader):
    prompts = create_prompt_str(optimized_prompt_2, batch["sentence"])
    labels = batch["label"]
    activations = model.get_activations(prompts, [last_layer_name], short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    for activations_single_prompt, label in zip(activations.activations, labels):
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = select_label_from_activations_opt(label_token_ids, last_layer_matrix)
        if predicted_label == int(label.item()):
            correct += 1
        total += 1
    if batch_num + 1 == batches_to_sample:
        break
accuracy = correct / total
report.append((optimized_prompt_2, accuracy))

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete


In [30]:
print(f"Accuracy: {accuracy}")

Accuracy: 0.615


### HugggingFace T5

Let's try these prompts in the context of the original T5 model from HuggingFace

In [31]:
# Instantiate the model and set it to eval mode
t5_model = T5ForConditionalGeneration.from_pretrained("google/t5-large-lm-adapt").eval()

In [32]:
def create_encoder_decoder_inputs(
    prompts: List[str], t5_tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Tuple[BatchEncoding, BatchEncoding]:
    # Repeat each prompt twice (once for each label)
    repeated_prompts = [prompt for prompt in prompts for i in range(2)]
    # repeat label words for each repeated prompt
    decoder_labels = [f"{label_word} </s>" for label_word in label_words] * len(prompts)
    encoder_inputs = t5_tokenizer(
        repeated_prompts,
        truncation=True,
        padding="max_length",
        max_length=64,
        add_special_tokens=False,
        return_tensors="pt",
    )
    decoder_inputs = t5_tokenizer(
        decoder_labels,
        truncation=True,
        padding="max_length",
        max_length=16,
        add_special_tokens=False,
        return_tensors="pt",
    )
    return encoder_inputs, decoder_inputs

In [33]:
def get_likelihoods_from_t5_ouput(
    output: Seq2SeqLMOutput, loss_func: torch.nn.CrossEntropyLoss, decoder_ids: torch.Tensor
) -> torch.Tensor:
    loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
    # Negative of the loss to get back to raw log-probabilities
    log_likelihoods = -loss_func(output.logits.view(-1, output.logits.size(-1)), decoder_ids.view(-1))
    batch_size, sequence_length, _ = output.logits.size()
    # compute per-token log probability in a sequence.
    # log_p has log probabilities for the following target output: [pos, it, ive]
    log_likelihoods = log_likelihoods.view(batch_size, sequence_length)
    # pad tokens have index -100 in huggingface.
    good_log_p = log_likelihoods.masked_fill_(decoder_ids == -100, 0.0)
    # good_log_p now has the log probability of the output sequence tokens.
    # sum over the sequence length to compute the log probability for a full sequence.
    return torch.sum(good_log_p, dim=1).squeeze()

In [34]:
def run_t5_model_on_encodings(
    encoder_encodings: BatchEncoding, decoder_encodings: BatchEncoding, t5_model: T5ForConditionalGeneration
) -> Tuple[Seq2SeqLMOutput, torch.Tensor]:
    decoder_ids = decoder_encodings.input_ids
    # we have to make sure that the PAD token is ignored.
    # huggingface ignores a pad token if the token is -100!
    decoder_ids = decoder_ids.masked_fill(decoder_ids == t5_tokenizer.pad_token_id, -100)
    # Disable gradient tracking for faster inference
    with torch.no_grad():
        model_output = t5_model(
            input_ids=encoder_encodings.input_ids,
            attention_mask=encoder_encodings.attention_mask,
            decoder_attention_mask=decoder_encodings.attention_mask,
            decoder_input_ids=t5_model._shift_right(decoder_ids),
            labels=None,
        )
    return model_output, decoder_ids

In [35]:
def extract_label_from_likelihoods(softmax_func: nn.Softmax, likelihoods: torch.Tensor) -> torch.tensor:
    # Pair the likelihoods associated with negative and positive labels for each prompt
    likelihoods = likelihoods.reshape(-1, 2)
    likelihoods = softmax_func(likelihoods)
    return torch.argmax(likelihoods, dim=1)

In [36]:
correct = 0
total = 0
# We're going to use a loss function to extra the log probabilties of the labels.
loss_func = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
softmax = nn.Softmax(dim=1)
for batch_num, batch in enumerate(dataloader):
    prompts = [f"{prompt} </s>" for prompt in create_prompt_str(initial_prompt, batch["sentence"])]
    labels = batch["label"]
    encoder_encodings, decoder_encodings = create_encoder_decoder_inputs(prompts, t5_tokenizer)
    model_output, decoder_ids = run_t5_model_on_encodings(encoder_encodings, decoder_encodings, t5_model)
    likelihoods = get_likelihoods_from_t5_ouput(model_output, loss_func, decoder_ids)
    predicted_labels = extract_label_from_likelihoods(softmax, likelihoods)
    match_tensor = (predicted_labels == labels).long()
    correct += torch.sum(match_tensor)
    total += len(match_tensor)
    print(f"Batch number {batch_num+1} Complete")
    if batch_num + 1 == batches_to_sample:
        break
accuracy = correct / total

report.append((initial_prompt, accuracy))

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete


In [37]:
print(f"Accuracy: {accuracy}")

Accuracy: 0.6549999713897705


In [38]:
correct = 0
total = 0
# We're going to use a loss function to extra the log probabilties of the labels.
loss_func = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
softmax = nn.Softmax(dim=1)
for batch_num, batch in enumerate(dataloader):
    prompts = [f"{prompt} </s>" for prompt in create_prompt_str(optimized_prompt_1, batch["sentence"])]
    labels = batch["label"]
    encoder_encodings, decoder_encodings = create_encoder_decoder_inputs(prompts, t5_tokenizer)
    model_output, decoder_ids = run_t5_model_on_encodings(encoder_encodings, decoder_encodings, t5_model)
    likelihoods = get_likelihoods_from_t5_ouput(model_output, loss_func, decoder_ids)
    predicted_labels = extract_label_from_likelihoods(softmax, likelihoods)
    match_tensor = (predicted_labels == labels).long()
    correct += torch.sum(match_tensor)
    total += len(match_tensor)
    print(f"Batch number {batch_num+1} Complete")
    if batch_num + 1 == batches_to_sample:
        break
accuracy = correct / total
report.append((optimized_prompt_1, accuracy))

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete


In [39]:
print(f"Accuracy: {accuracy}")

Accuracy: 0.7900000214576721


In [40]:
correct = 0
total = 0
# We're going to use a loss function to extra the log probabilties of the labels.
loss_func = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
softmax = nn.Softmax(dim=1)
for batch_num, batch in enumerate(dataloader):
    prompts = [f"{prompt} </s>" for prompt in create_prompt_str(optimized_prompt_2, batch["sentence"])]
    labels = batch["label"]
    encoder_encodings, decoder_encodings = create_encoder_decoder_inputs(prompts, t5_tokenizer)
    model_output, decoder_ids = run_t5_model_on_encodings(encoder_encodings, decoder_encodings, t5_model)
    likelihoods = get_likelihoods_from_t5_ouput(model_output, loss_func, decoder_ids)
    predicted_labels = extract_label_from_likelihoods(softmax, likelihoods)
    match_tensor = (predicted_labels == labels).long()
    correct += torch.sum(match_tensor)
    total += len(match_tensor)
    print(f"Batch number {batch_num+1} Complete")
    if batch_num + 1 == batches_to_sample:
        break
accuracy = correct / total
report.append((optimized_prompt_2, accuracy))

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete


In [41]:
print(f"Accuracy: {accuracy}")

Accuracy: 0.7549999952316284


In [42]:
print("Summary")
print("OPT Performance:")
for prompt, acc in report[0:3]:
    print(f"Prompt: {prompt}, Accuracy: {acc}")
print("T5 Performance:")
for prompt, acc in report[3:6]:
    print(f"Prompt: {prompt}, Accuracy: {acc}")

Summary
OPT Performance:
Prompt: Generate the sentiment of the next sentence. , Accuracy: 0.6
Prompt: tumour negative .05. Positive respins the Contains sentence. , Accuracy: 0.555
Prompt: childcare negative .05. Positive respins wSt Thank sentence. , Accuracy: 0.615
T5 Performance:
Prompt: Generate the sentiment of the next sentence. , Accuracy: 0.6549999713897705
Prompt: tumour negative .05. Positive respins the Contains sentence. , Accuracy: 0.7900000214576721
Prompt: childcare negative .05. Positive respins wSt Thank sentence. , Accuracy: 0.7549999952316284


It's fairly clear that OPT does not do well with this type of prompt, whereas T5 does a pretty good job with this instruction prompt.

The amazing part is that these weird prompts seem to improve the performance of T5, but also possibly the performance of OPT a little bit!