In [9]:
import random
import re
import time
from typing import List, Tuple

import kscope
import pandas as pd
from metrics import map_ag_news_int_labels, report_metrics
from transformers import AutoTokenizer
from utils import get_label_token_ids, get_label_with_highest_likelihood, split_prompts_into_batches

### Conecting to the Service

First we connect to the service through which we'll interact with the LLMs and see which models are avaiable to us

In [10]:
# Establish a client connection to the kscope service
client = kscope.Client(gateway_host="llm.cluster.local", gateway_port=6001)

Show all model instances that are currently active

In [11]:
client.model_instances

[{'id': '75deb027-445d-4a24-8a72-2751a4f81a7b',
  'name': 'OPT-175B',
  'state': 'ACTIVE'}]

To start, we obtain a handle to a model. In this example, let's use the OPT-175B model.

In [12]:
model = client.load_model("OPT-175B")
# If this model is not actively running, it will get launched in the background.
# In this case, wait until it moves into an "ACTIVE" state before proceeding.
while model.state != "ACTIVE":
    time.sleep(1)

In [13]:
moderate_generation_config = {"max_tokens": 20, "top_k": 4, "top_p": 1.0, "rep_penalty": 1.5, "temperature": 0.7}

Let's ask the model some questions

In [14]:
generation = model.generate("What is the capital of Canada?", moderate_generation_config)
# Extract the text from the returned generation
generation.generation["text"]

["\nVancouver, duh.\nThat's a city, not a country.\nCanada is"]

In [15]:
generation = model.generate("When did Canada become an independent country?", moderate_generation_config)
# Extract the text from the returned generation
generation.generation["text"]

['\nHowever, the country is far from independent now.']

We're going to have our model attempt to classify some news articles from the AG News Dataset. Articles have a single label 1-4

1. World
2. Sports
3. Business
4. Sci/Tech

This is a constrained label space. We'll use the words "World", "Sports", "Business", and "Technology" as generative LM targets for each of the labels.

In [16]:
def remove_markup(text: str) -> str:
    text = re.sub(r"https?://\S+|www\.\S+", "", text)
    text = re.sub(r"<.*?>+", "", text)
    return text


def ag_news_processor(path: str) -> Tuple[List[str], List[str], List[str]]:
    ag_news_data = pd.read_csv(path)
    labels = ag_news_data["Class Index"].tolist()
    titles = ag_news_data["Title"].apply(lambda x: remove_markup(x)).tolist()
    descriptions = ag_news_data["Description"].apply(lambda x: remove_markup(x)).tolist()
    return labels, titles, descriptions


int_to_label_map = {1: "world", 2: "sports", 3: "business", 4: "technology"}
ag_news_labels, ag_news_titles, ag_news_descriptions = ag_news_processor(
    "ag_news_task_dataset/ag_news_datasets/ag_news_sample.csv"
)

In [17]:
ag_news_labels = map_ag_news_int_labels(ag_news_labels, int_to_label_map)
ag_news_descriptions = [description.replace("\\", " ").strip() for description in ag_news_descriptions]
ag_news_titles = [title.strip() for title in ag_news_titles]
label_words = ["World", "Sports", "Business", "Technology"]

In [18]:
model_input_texts = [
    f"Title: {ag_news_title} Description: {ag_news_description}"
    for ag_news_title, ag_news_description in zip(ag_news_titles, ag_news_descriptions)
]

Let's start by trying out a basic instruction prompt to see what the model does.

Throughout, we're going to use a simple config corresponding to greedy extraction of a single token.

In [19]:
short_generation_config = {"max_tokens": 1, "top_k": 1}

Let's generate our very first prompt

In [20]:
prompt_template = "To which category does this news article belong?"
sample_texts = [f"{model_input_text} {prompt_template}" for model_input_text in model_input_texts[0:3]]
print(sample_texts[0])

Title: Telecom lifts first quarter net profit 19pc Description: Telecom Corp today reported its September first quarter net profit rose 19 per cent to  $193 million. The profit bettered analysts #39; average forecasts of  $185m. To which category does this news article belong?


In [21]:
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["text"]:
    print(text)
    print("==================================")

 *





Not well...Now let's try to constrain the model a bit by including the desired labels in the instruction.

In [22]:
prompt_template = "From World, Sports, Business, Technology, the category is"
sample_texts = [f"{model_input_text} {prompt_template}" for model_input_text in model_input_texts[0:3]]
print(sample_texts[0])

Title: Telecom lifts first quarter net profit 19pc Description: Telecom Corp today reported its September first quarter net profit rose 19 per cent to  $193 million. The profit bettered analysts #39; average forecasts of  $185m. From World, Sports, Business, Technology, the category is


In [23]:
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["text"]:
    print(text)
    print("==================================")

 filed
 yours
 channel


The model doesn't really answer in the space that we want it to. Let's try with some few-shot examples to see if that helps.

__NOTE__: We have simply randomly picked the examples used in the 5-shot prompt. Different choices might be made, including 4-shot or 8-shot prompts so that categories are evenly represented.

In [24]:
prompt_template_prefix = """Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Technology): Sports 
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Technology): World 
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Technology): Business 
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (World, Sports, Business, Technology): Business 
Title: UK Scientists Allowed to Clone Human Embryos (Reuters) Description: Reuters - British scientists said on Wednesday they had received permission to clone human embryos for medical research, in what they believe to be the first such license to be granted in Europe. Category (World, Sports, Business, Technology): Technology 
"""  # noqa
prompt_template_postfix = "Category (World, Sports, Business, Technology):"
sample_texts = [
    f"{prompt_template_prefix}{model_input_text}{prompt_template_postfix}"
    for model_input_text in model_input_texts[0:3]
]
print(sample_texts[0])

Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Technology): Sports 
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Technology): World 
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Technology): Business 
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (Wo

In [25]:
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["text"]:
    print(text)
    print("==================================")

 
 Business
 Sports


Few-shot learning definitely helps a lot! We'll measure accuracy on a sample of the AG news dataset below.

# Measureing Zero-shot and Few-Shot Accuracy

In [26]:
lowercase_labels = [word.lower() for word in label_words]

### Zero Shot Accuracy

In [27]:
prompt_template = "From World, Sports, Business, Technology, the category is"
prompts = [f"{model_input_text}{prompt_template}" for model_input_text in model_input_texts]
print(prompts[0])

Title: Telecom lifts first quarter net profit 19pc Description: Telecom Corp today reported its September first quarter net profit rose 19 per cent to  $193 million. The profit bettered analysts #39; average forecasts of  $185m.From World, Sports, Business, Technology, the category is


In [28]:
prompt_batches = split_prompts_into_batches(prompts, batch_size=1)
predicted_labels = []
n_no_match = 0
for batch_num, prompt_batch in enumerate(prompt_batches):
    generation = model.generate(prompt_batch, short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    # We'll use tokens this time and consider just the first token
    first_predicted_tokens = [tokens[0].strip().lower() for tokens in generation.generation["tokens"]]
    # If a token doesn't correspond to one of our labels, we'll randomly select one and count how many times that
    # happens for reporting
    for potential_prediction in first_predicted_tokens:
        if potential_prediction in lowercase_labels:
            predicted_labels.append(potential_prediction)
        else:
            n_no_match += 1
            predicted_labels.append(random.choice(lowercase_labels))
print(f"A total of {n_no_match} of {len(predicted_labels)} did not match any label")

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete
Batch number 21 Complete
Batch number 22 Complete
Batch number 23 Complete
Batch number 24 Complete
Batch number 25 Complete
Batch number 26 Complete
Batch number 27 Complete
Batch number 28 Complete
Batch number 29 Complete
Batch number 30 Complete
Batch number 31 Complete
Batch number 32 Complete
Batch number 33 Complete
Batch number 34 Complete
Batch number 35 Complete
Batch number 36 Complete
Batch number 37 Complete
Batch number 38 Complete
Batch number 39 Complete
Batch number 40 Complete
Batch num

In [29]:
report_metrics(predicted_labels, ag_news_labels, labels_order=["world", "sports", "business", "technology"])

Prediction Accuracy: 0.27
Confusion Matrix with ordering ['world', 'sports', 'business', 'technology']
[[ 5  5 12  6]
 [ 9  6  3  3]
 [ 4  4  8  7]
 [ 6  7  7  8]]
Label: world, F1: 0.1923076923076923, Recall: 0.17857142857142858, Precision: 0.20833333333333334
Label: sports, F1: 0.2790697674418604, Recall: 0.2857142857142857, Precision: 0.2727272727272727
Label: business, F1: 0.30188679245283023, Recall: 0.34782608695652173, Precision: 0.26666666666666666
Label: technology, F1: 0.30769230769230765, Recall: 0.2857142857142857, Precision: 0.3333333333333333


### Few Shot Accuracy



In this example, we'll use a 5-shot prompt, as we did above and perform a "exact match" with our label space. That is, we parse out the first token that the model produces in its generation and simply try to string match it to one of our four label strings.

In [30]:
prompt_template_prefix = """Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Technology): Sports 
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Technology): World 
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Technology): Business 
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (World, Sports, Business, Technology): Business 
Title: UK Scientists Allowed to Clone Human Embryos (Reuters) Description: Reuters - British scientists said on Wednesday they had received permission to clone human embryos for medical research, in what they believe to be the first such license to be granted in Europe. Category (World, Sports, Business, Technology): Technology 
"""  # noqa
prompt_template_postfix = "Category (World, Sports, Business, Technology):"
prompts = [
    f"{prompt_template_prefix}{model_input_text}{prompt_template_postfix}" for model_input_text in model_input_texts
]
print(prompts[0])

Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Technology): Sports 
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Technology): World 
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Technology): Business 
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (Wo

In [31]:
prompt_batches = split_prompts_into_batches(prompts, batch_size=1)
predicted_labels = []
n_no_match = 0
for batch_num, prompt_batch in enumerate(prompt_batches):
    generation = model.generate(prompt_batch, short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    # We'll use tokens this time and consider just the first token
    first_predicted_tokens = [tokens[0].strip().lower() for tokens in generation.generation["tokens"]]
    # If a token doesn't correspond to one of our labels, we'll randomly select one and count how many times that
    # happens for reporting
    for potential_prediction in first_predicted_tokens:
        if potential_prediction in lowercase_labels:
            predicted_labels.append(potential_prediction)
        else:
            n_no_match += 1
            predicted_labels.append(random.choice(lowercase_labels))
print(f"A total of {n_no_match} of {len(predicted_labels)} did not match any label")

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete
Batch number 21 Complete
Batch number 22 Complete
Batch number 23 Complete
Batch number 24 Complete
Batch number 25 Complete
Batch number 26 Complete
Batch number 27 Complete
Batch number 28 Complete
Batch number 29 Complete
Batch number 30 Complete
Batch number 31 Complete
Batch number 32 Complete
Batch number 33 Complete
Batch number 34 Complete
Batch number 35 Complete
Batch number 36 Complete
Batch number 37 Complete
Batch number 38 Complete
Batch number 39 Complete
Batch number 40 Complete
Batch num

In [32]:
report_metrics(predicted_labels, ag_news_labels, labels_order=["world", "sports", "business", "technology"])

Prediction Accuracy: 0.61
Confusion Matrix with ordering ['world', 'sports', 'business', 'technology']
[[11  2 12  3]
 [ 0 18  3  0]
 [ 2  0 16  5]
 [ 1  0 11 16]]
Label: world, F1: 0.5238095238095237, Recall: 0.39285714285714285, Precision: 0.7857142857142857
Label: sports, F1: 0.8780487804878048, Recall: 0.8571428571428571, Precision: 0.9
Label: business, F1: 0.4923076923076923, Recall: 0.6956521739130435, Precision: 0.38095238095238093
Label: technology, F1: 0.6153846153846153, Recall: 0.5714285714285714, Precision: 0.6666666666666666


# Prediction By Label Likelihood Extraction

There is nothing stoping the model from not selecting our labels. So can we do better? We can work around this by understanding the likelihood of our labels from the models perspective.

In [33]:
# We're interested in the activations from the last layer of the model, because this will allow us to calculate the
# likelihoods
last_layer_name = model.module_names[-1]
last_layer_name

'decoder.output_projection'

The last layer of the model corresponds to the probabilities of each token in the model vocabulary. That is, it is the conditional probability
$$
P(y_t \vert y_{<t}, x),
$$
The probability distribution over the vocabulary of the next token given the preceding tokens $y_{<t}$, and the prompt text $x$. Thus, for each token $y_{t}$ in our input, we get back a 50K vector corresponding to the probabilities over the vocabulary of $y_{t+1}$. We only care about the last token in our input, as it houses the probability of the, as yet, unseen token the model will generate.


Need to instantiate a tokenizer to obtain appropriate token indices for our labels. 

__NOTE__: All OPT models, regardless of size, used the same tokenizer. However, if you want to use a different type of model, a different tokenizer may be needed.

In [34]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
label_token_ids = get_label_token_ids(tokenizer, prompt_template, label_words)
# If you ever need to move back from token ids, you can use tokenizer.decode or tokenizer.batch_decode
tokenizer.decode(label_token_ids)

' World Sports Business Technology'

We need the token ids of our labels to exact the probabilties from the vocabulary of the model. The token id corresponds to the index of the token in the vocabulary matrix of the underlying model.

Let's look at how we can extract the likelihoods given the label tokens

In [35]:
single_prompted_input = f"{model_input_texts[0]} {prompt_template}"
# Create a prompt with one of the label words as a completion
activations = model.get_activations(single_prompted_input, [last_layer_name], short_generation_config)

In [36]:
last_layer_matrix = activations.activations[0][last_layer_name]
# The shape of this tensor should be number of input tokens by the vocabulary size (n x 50272)
print(f"Activations matrix shape: {last_layer_matrix.shape}")
predicted_label = get_label_with_highest_likelihood(
    last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
)
print(f"Predicted Label: {predicted_label}")

Activations matrix shape: torch.Size([60, 50272])
Predicted Label: business


## Accuracy

Time to compare our results across three methods. 
1. Measure the accuracy of our likelihood approach without few-shot (i.e. zero-shot).
2. Measure the accuracy of our likelihood approach with few-shot.

### Likelihood Zero-shot

In this example, we do not incorporate any demonstrations into the prompt (zero-shot prompt). From our experience above, the model does not do a good job generating responses that correspond to our label space. So rather than trying to match responses to our labels as strings, we extract the probabilties of our labels (see example above), as estimated by the model's verbalizer, and select the label with the highest probability as the prediction.

In [37]:
prompt_template = "From World, Sports, Business, Technology, the category is"
prompts = [f"{model_input_text}{prompt_template}" for model_input_text in model_input_texts]
print(prompts[0])

Title: Telecom lifts first quarter net profit 19pc Description: Telecom Corp today reported its September first quarter net profit rose 19 per cent to  $193 million. The profit bettered analysts #39; average forecasts of  $185m.From World, Sports, Business, Technology, the category is


In [38]:
predicted_labels = []
prompt_batches = split_prompts_into_batches(prompts, batch_size=1)
for batch_num, prompt_batch in enumerate(prompt_batches):
    activations = model.get_activations(prompt_batch, [last_layer_name], short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    for activations_single_prompt in activations.activations:
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = get_label_with_highest_likelihood(
            last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
        )
        predicted_labels.append(predicted_label)

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete
Batch number 21 Complete
Batch number 22 Complete
Batch number 23 Complete
Batch number 24 Complete
Batch number 25 Complete
Batch number 26 Complete
Batch number 27 Complete
Batch number 28 Complete
Batch number 29 Complete
Batch number 30 Complete
Batch number 31 Complete
Batch number 32 Complete
Batch number 33 Complete
Batch number 34 Complete
Batch number 35 Complete
Batch number 36 Complete
Batch number 37 Complete
Batch number 38 Complete
Batch number 39 Complete
Batch number 40 Complete
Batch num

In [39]:
report_metrics(predicted_labels, ag_news_labels, labels_order=["world", "sports", "business", "technology"])

Prediction Accuracy: 0.44
Confusion Matrix with ordering ['world', 'sports', 'business', 'technology']
[[25  1  2  0]
 [15  2  4  0]
 [ 7  0 16  0]
 [12  0 15  1]]
Label: world, F1: 0.5747126436781609, Recall: 0.8928571428571429, Precision: 0.423728813559322
Label: sports, F1: 0.16666666666666666, Recall: 0.09523809523809523, Precision: 0.6666666666666666
Label: business, F1: 0.5333333333333333, Recall: 0.6956521739130435, Precision: 0.43243243243243246
Label: technology, F1: 0.0689655172413793, Recall: 0.03571428571428571, Precision: 1.0


### Likelihood with Few-Shot

The zero-shot prompt combined with likelihood estimation for our label space doesn't do a great job, but it is a lot better than when we tried to exact match the generation. Let's combine the two approaches. We'll use a 5-shot prompt, as we did in the exact match example above, but now we'll use likelihood over our labels as the prediction mechanism rather than exact matching the first generated token.

In [40]:
prompt_template_prefix = """Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Technology): Sports 
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Technology): World 
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Technology): Business 
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (World, Sports, Business, Technology): Business 
Title: UK Scientists Allowed to Clone Human Embryos (Reuters) Description: Reuters - British scientists said on Wednesday they had received permission to clone human embryos for medical research, in what they believe to be the first such license to be granted in Europe. Category (World, Sports, Business, Technology): Technology 
"""  # noqa
prompt_template_postfix = "Category (World, Sports, Business, Technology):"
prompts = [
    f"{prompt_template_prefix}{model_input_text}{prompt_template_postfix}" for model_input_text in model_input_texts
]
print(prompts[0])

Title: Lane drives in winning run in ninth Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night. Category (World, Sports, Business, Technology): Sports 
Title: Arson attack on Jewish centre in Paris (AFP) Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said. Category (World, Sports, Business, Technology): World 
Title: Oil prices look set to dominate Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue. Category (World, Sports, Business, Technology): Business 
Title: Indexes in Japan fall short of hype Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market. Category (Wo

In [41]:
predicted_labels = []
prompt_batches = split_prompts_into_batches(prompts, batch_size=1)
for batch_num, prompt_batch in enumerate(prompt_batches):
    activations = model.get_activations(prompt_batch, [last_layer_name], short_generation_config)
    print(f"Batch number {batch_num+1} Complete")
    for activations_single_prompt in activations.activations:
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = get_label_with_highest_likelihood(
            last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
        )
        predicted_labels.append(predicted_label)

Batch number 1 Complete
Batch number 2 Complete
Batch number 3 Complete
Batch number 4 Complete
Batch number 5 Complete
Batch number 6 Complete
Batch number 7 Complete
Batch number 8 Complete
Batch number 9 Complete
Batch number 10 Complete
Batch number 11 Complete
Batch number 12 Complete
Batch number 13 Complete
Batch number 14 Complete
Batch number 15 Complete
Batch number 16 Complete
Batch number 17 Complete
Batch number 18 Complete
Batch number 19 Complete
Batch number 20 Complete
Batch number 21 Complete
Batch number 22 Complete
Batch number 23 Complete
Batch number 24 Complete
Batch number 25 Complete
Batch number 26 Complete
Batch number 27 Complete
Batch number 28 Complete
Batch number 29 Complete
Batch number 30 Complete
Batch number 31 Complete
Batch number 32 Complete
Batch number 33 Complete
Batch number 34 Complete
Batch number 35 Complete
Batch number 36 Complete
Batch number 37 Complete
Batch number 38 Complete
Batch number 39 Complete
Batch number 40 Complete
Batch num

In [42]:
report_metrics(predicted_labels, ag_news_labels, labels_order=["world", "sports", "business", "technology"])

Prediction Accuracy: 0.74
Confusion Matrix with ordering ['world', 'sports', 'business', 'technology']
[[15  1  9  3]
 [ 0 20  1  0]
 [ 0  0 22  1]
 [ 0  0 11 17]]
Label: world, F1: 0.6976744186046512, Recall: 0.5357142857142857, Precision: 1.0
Label: sports, F1: 0.9523809523809523, Recall: 0.9523809523809523, Precision: 0.9523809523809523
Label: business, F1: 0.6666666666666667, Recall: 0.9565217391304348, Precision: 0.5116279069767442
Label: technology, F1: 0.6938775510204083, Recall: 0.6071428571428571, Precision: 0.8095238095238095
