In [54]:
import random
import re
import time
from typing import List, Tuple

import kscope
import pandas as pd
from metrics import map_ag_news_int_labels, report_metrics
from transformers import AutoTokenizer
from utils import get_label_token_ids, get_label_with_highest_likelihood, split_prompts_into_batches

# Getting Started

There is a bit of documentation on how to interact with the large models [here](https://kaleidoscope-sdk.readthedocs.io/en/latest/). The relevant github links to the SDK are [here](https://github.com/VectorInstitute/kaleidoscope-sdk) and underlying code [here](https://github.com/VectorInstitute/kaleidoscope).

First we connect to the service through which, we'll interact with the LLMs and see which models are avaiable to us

In [55]:
# Establish a client connection to the kscope service
client = kscope.Client(gateway_host="llm.cluster.local", gateway_port=3001)

Show all supported models

In [56]:
client.models

['OPT-175B', 'OPT-6.7B']

Show all model instances that are currently active

In [57]:
client.model_instances

[{'id': '1ae3ae36-af03-45f4-b95e-2ec66e797f96',
  'name': 'OPT-175B',
  'state': 'ACTIVE'},
 {'id': '65084219-31ef-4430-922e-fb31f219ed49',
  'name': 'OPT-6.7B',
  'state': 'ACTIVE'}]

Let's start by querying the OPT-175B model. We'll try other models below. Get a handle to a model. In this example, let's use the OPT-175B model.

In [58]:
model = client.load_model("OPT-175B")
# If this model is not actively running, it will get launched in the background.
# In this case, wait until it moves into an "ACTIVE" state before proceeding.
while model.state != "ACTIVE":
    time.sleep(1)

We need to configure the model to generate in the way we want it to. So we set a number of important parameters. For a discussion of the configuration parameters see: `src/reference_implementations/prompting_vector_llms/CONFIG_README.md`

In [59]:
short_generation_config = {"max_tokens": 2, "top_k": 4, "top_p": 1.0, "rep_penalty": 1.0, "temperature": 1.0}

Let's try a basic prompt for factual information.

__Note__ that if you run the cell multiple times, you'll get different responses due to sampling.

In [60]:
generation = model.generate("What is the capital of Canada?", short_generation_config)
# Extract the text from the returned generation
generation.generation["text"]

['\nO']

We're going to have our model attempt to classify some news articles from the AG News Dataset. Articles have a single label 1-4

1. World
2. Sports
3. Business
4. Sci/Tech

This is a constrained label space. We'll use the words World, Sports, Business, and Science as our targets for each of the labels.

In [61]:
def remove_markup(text: str) -> str:
    text = re.sub(r"https?://\S+|www\.\S+", "", text)
    text = re.sub(r"<.*?>+", "", text)
    return text


def ag_news_processor(path: str) -> Tuple[List[str], List[str], List[str]]:
    ag_news_data = pd.read_csv(path)
    labels = ag_news_data["Class Index"].tolist()
    titles = ag_news_data["Title"].apply(lambda x: remove_markup(x)).tolist()
    descriptions = ag_news_data["Description"].apply(lambda x: remove_markup(x)).tolist()
    return labels, titles, descriptions


int_to_label_map = {1: "world", 2: "sports", 3: "business", 4: "science"}
ag_news_labels, ag_news_titles, ag_news_descriptions = ag_news_processor(
    "resources/ag_news_datasets/ag_news_sample.csv"
)

In [62]:
ag_news_labels = map_ag_news_int_labels(ag_news_labels, int_to_label_map)
ag_news_descriptions = [description.replace("\\", " ").strip() for description in ag_news_descriptions]
ag_news_titles = [title.strip() for title in ag_news_titles]
label_words = ["World", "Sports", "Business", "Science"]

In [63]:
model_input_texts = [
    f"Title: {ag_news_title} Description: {ag_news_description}"
    for ag_news_title, ag_news_description in zip(ag_news_titles, ag_news_descriptions)
]

Let's start by trying out a basic instruction prompt to see what the model does.

In [64]:
prompt_template = "To which category does this news article belong?"
sample_texts = [f"{model_input_text} {prompt_template}" for model_input_text in model_input_texts[0:3]]
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["text"]:
    print(text)
    print("==================================")

 All Categories
 One of





Not well...Now let's try to constrain the model a bit by including the desired labels in the instruction.

In [65]:
prompt_template = "From World, Sports, Business, Science, the category is "
sample_texts = [f"{model_input_text} {prompt_template}" for model_input_text in model_input_texts[0:3]]
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["text"]:
    print(text)
    print("==================================")

 Environment,
 South &
________________.


The model doesn't really answer in the space that we want it to. Let's try with some few-shot examples to see if that helps

In [66]:
prompt_template_prefix = """Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Science): Sports
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Science): World
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Science): Business
Title: More Evidence for Past Water on Mars Description: Summary - (Aug 22, 2004) NASA #39;s Spirit rover has dug up plenty of evidence on slopes of  quot;Columbia Hills quot; that water once covered the area. Category (World, Sports, Business, Science): World
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (World, Sports, Business, Science): Business
"""  # noqa
prompt_template_postfix = "Category (World, Sports, Business, Science):"
sample_texts = [
    f"{prompt_template_prefix}{model_input_text} {prompt_template_postfix}"
    for model_input_text in model_input_texts[0:3]
]
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["text"]:
    # We'll limit ourselves to the single next token since we want it to respond that way
    print(text)
    print("==================================")

 Business

 Business

 Sports



Few-shot learning definitely helps a lot! We'll measure an accuracy sample below. However, there is nothing stoping the model from not selecting our labels. So can we do better? We can work around this by understanding the likelihood of our labels from the models perspective

In [67]:
# We're interested in the activations from the last layer of the model, because this will allow us to caculation the
# likelihoods
last_layer_name = model.module_names[-1]
last_layer_name

'decoder.output_projection'

Need to instantiate a tokenizer to obtain appropriate token indices for our labels. 

__NOTE__: All OPT models, regardless of size, used the same tokenizing. However, if you want to use a different type of model, a different tokenizer may be needed.

In [68]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
label_token_ids = get_label_token_ids(tokenizer, prompt_template, label_words)
# If you ever need to move back from token ids, you can use tokenizer.decode or tokenizer.batch_decode
tokenizer.decode(label_token_ids)

' World Sports Business Science'

Let's look at how we can extract the likelihoods given the label tokens

In [69]:
single_prompted_input = f"{model_input_texts[0]} {prompt_template} {label_words[0]}"
# Create a prompt with one of the label words as a completion
activations = model.get_activations(single_prompted_input, [last_layer_name], short_generation_config)

In [17]:
last_layer_matrix = activations.activations[0][last_layer_name]
# The shape of this tensor should be number of input tokens by the vocabulary size (n x 50272)
print(f"Activations matrix shape: {last_layer_matrix.shape}")
predicted_label = get_label_with_highest_likelihood(
    last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
)
print(f"Predicted Label: {predicted_label}")

Activations matrix shape: torch.Size([62, 50272])
Predicted Label: business


## Accuracy

Time to compare our results across three methods. 
1. Measure the accuracy of our few-shot prompting approach.
2. Measure the accuracy of our likelihood approach without few-shot.
3. Measure the accuracy of our likelihood approach with few-shot.

In [18]:
lowercase_labels = [word.lower() for word in label_words]

### Few-shot only

In [19]:
prompt_template_prefix = """Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Science): Sports
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Science): World
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Science): Business
Title: More Evidence for Past Water on Mars Description: Summary - (Aug 22, 2004) NASA #39;s Spirit rover has dug up plenty of evidence on slopes of  quot;Columbia Hills quot; that water once covered the area. Category (World, Sports, Business, Science): World
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (World, Sports, Business, Science): Business
"""  # noqa
prompt_template_postfix = "Category (World, Sports, Business, Science):"
prompts = [
    f"{prompt_template_prefix}{model_input_text} {prompt_template_postfix}" for model_input_text in model_input_texts
]
prompt_batches = split_prompts_into_batches(prompts)
predicted_labels = []
for batch_num, prompt_batch in enumerate(prompt_batches):
    generation = model.generate(prompt_batch, short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    # We'll use tokens this time and consider just the first token
    first_predicted_tokens = [tokens[0].strip().lower() for tokens in generation.generation["tokens"]]
    # If a token doesn't correspond to one of our labels, we'll randomly select one and count how many times that
    # happens for reporting
    n_no_match = 0
    for potential_prediction in first_predicted_tokens:
        if potential_prediction in lowercase_labels:
            predicted_labels.append(potential_prediction)
        else:
            n_no_match += 1
            print(f"Potential Prediction: {potential_prediction} does not match any label")
            predicted_labels.append(random.choice(lowercase_labels))

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Potential Prediction: economics does not match any label
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Potential Prediction: u does not match any label
Batch number 9 Complete
Potential Prediction: politics does not match any label
Batch number 10 Complete


In [20]:
report_metrics(predicted_labels, ag_news_labels)

Prediction Accuracy: 0.65
Confusion Matrix with ordering ['world', 'sports', 'business', 'science']
[[15  0  1  2]
 [ 1 21  0  0]
 [ 9  0 21 18]
 [ 3  0  1  8]]
Label: world, F1: 0.6521739130434783, Precision: 0.8333333333333334, Recall: 0.5357142857142857
Label: sports, F1: 0.9767441860465117, Precision: 0.9545454545454546, Recall: 1.0
Label: business, F1: 0.5915492957746479, Precision: 0.4375, Recall: 0.9130434782608695
Label: science, F1: 0.4, Precision: 0.6666666666666666, Recall: 0.2857142857142857


### Likelihood No Few-shot

In [21]:
prompts = [f"{model_input_text} {prompt_template} {label_words[0]}" for model_input_text in model_input_texts]
# For memory management, we split the prompts into batches of size 10
predicted_labels = []
prompt_batches = split_prompts_into_batches(prompts)
for batch_num, prompt_batch in enumerate(prompt_batches):
    activations = model.get_activations(prompt_batch, [last_layer_name], short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    for activations_single_prompt in activations.activations:
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = get_label_with_highest_likelihood(
            last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
        )
        predicted_labels.append(predicted_label)

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete


In [22]:
report_metrics(predicted_labels, ag_news_labels)

Prediction Accuracy: 0.42
Confusion Matrix with ordering ['world', 'sports', 'business', 'science']
[[ 0  0  0  0]
 [ 7 21  2  2]
 [21  0 21 26]
 [ 0  0  0  0]]
Label: world, F1: nan, Precision: nan, Recall: 0.0
Label: sports, F1: 0.7924528301886793, Precision: 0.65625, Recall: 1.0
Label: business, F1: 0.46153846153846156, Precision: 0.3088235294117647, Recall: 0.9130434782608695
Label: science, F1: nan, Precision: nan, Recall: 0.0


  TP = np.diag(matrix)


### Likelihood with Few-Shot

In [23]:
prompt_template_prefix = """Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Science): Sports
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Science): World
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Science): Business
Title: More Evidence for Past Water on Mars Description: Summary - (Aug 22, 2004) NASA #39;s Spirit rover has dug up plenty of evidence on slopes of  quot;Columbia Hills quot; that water once covered the area. Category (World, Sports, Business, Science): World
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (World, Sports, Business, Science): Business
"""  # noqa
prompt_template_postfix = "Category (World, Sports, Business, Science):"
prompts = [
    f"{prompt_template_prefix}{model_input_text} {prompt_template_postfix}" for model_input_text in model_input_texts
]
# For memory management, we split the prompts into batches of size 10
predicted_labels = []
prompt_batches = split_prompts_into_batches(prompts)
for batch_num, prompt_batch in enumerate(prompt_batches):
    activations = model.get_activations(prompt_batch, [last_layer_name], short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    for activations_single_prompt in activations.activations:
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = get_label_with_highest_likelihood(
            last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
        )
        predicted_labels.append(predicted_label)

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete


In [24]:
report_metrics(predicted_labels, ag_news_labels)

Prediction Accuracy: 0.69
Confusion Matrix with ordering ['world', 'sports', 'business', 'science']
[[22  0  2  2]
 [ 1 21  0  0]
 [ 4  0 21 21]
 [ 1  0  0  5]]
Label: world, F1: 0.8148148148148148, Precision: 0.8461538461538461, Recall: 0.7857142857142857
Label: sports, F1: 0.9767441860465117, Precision: 0.9545454545454546, Recall: 1.0
Label: business, F1: 0.608695652173913, Precision: 0.45652173913043476, Recall: 0.9130434782608695
Label: science, F1: 0.29411764705882354, Precision: 0.8333333333333334, Recall: 0.17857142857142858
