In [50]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Tuple
import datasets
import numpy as np
import matplotlib.pyplot as plt
import re
import json

CUDA_JOB_NUM = 3
cuda_model = f"cuda:{CUDA_JOB_NUM}"

In [51]:
ds = datasets.load_from_disk("ms_marco")
ds

DatasetDict({
    validation: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 101093
    })
    train: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 808731
    })
    test: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 101092
    })
})

In [2]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map=cuda_model)

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.09it/s]


In [159]:
# extract the function call results from the output text
def extract_parameters(output_text):
    # Regular expression to find the JSON-like structure
    pattern = r'\{"name":\s*"(.*?)",\s*"parameters":\s*\{"rating":\s*(?:"\[(.*?)\]"|\[(.*?)\])\}\}'
    
    # Search for the pattern in the output text
    match = re.search(pattern, output_text, re.DOTALL)
    if match:
        if match.group(2) is not None:
            rating = list(map(float, match.group(2).split(',')))
        else:
            rating = list(map(float, match.group(3).split(',')))
        return rating
    else:
        return None

#ranking: List[int], 
#ranking: A list of the ranking and relevance score of the passages according to relevance to the query in descending order of relevance. Providing [2, 1, 3] means that passage 2 is most relevant, followed by passage 1 and then passage 3.
# Save the ranking of the passages according to relevance to the query as well as a rating of the relevance of the ranking (between 0 and 100).
# Function used for llama function calling
def save_ranking_rating(rating: List[int]) -> None:
  """
  Save the rating of the passages according to relevance to the query (between 0 and 100).
  
  Args:
    rating: A list of the rating of the passage based on the relevance to the query (between 0 and 100). Providing [100, 50, 20] in the example for ranking means that passage 1 has a rating of 100, passage 2 has a rating of 50 and passage 3 has a rating of 20.
  Returns:
    None
  """
  print(rating)


In [175]:
def rank_score(preds, labels):
    average_rank = 0
    count = 0
    skipped = 0
    for i in range(len(preds)):
        preds_row = preds[i]
        if len(preds[i]) != len(labels[i]):
            # Pad preds_row with 0s on the right end to match the dimension of labels[i]
            skipped += 1
            continue

        preds_sorted_rank = np.argsort(np.argsort(preds_row)[::-1])
        if len(preds_sorted_rank[labels[i] == 1]) > 0:
            average_rank += np.mean(preds_sorted_rank[labels[i] == 1])
            count += 1
    print(f"Skipped {skipped} rows")
    return average_rank / count

In [179]:
def strict_rank_score(preds, labels, threshold = 1):
    num_in_ones = 0
    count = 0
    skipped = 0
    for i in range(len(preds)):
        preds_sorted_rank = np.argsort(np.argsort(preds[i])[::-1])
        if len(preds[i]) != len(labels[i]):
            # Pad preds_row with 0s on the right end to match the dimension of labels[i]
            skipped += 1
            continue

        if len(preds_sorted_rank[labels[i] == 1]) > 0:
            average_rank = np.mean(preds_sorted_rank[labels[i] == 1])
            if average_rank < threshold:
                num_in_ones += 1
            count += 1
    print(f"Skipped {skipped} rows")
    return num_in_ones / count

In [151]:
def predict_row(data_row, tokenizer, model):
    passages_string = ""
    for i in range(len(data_row["passages"]['is_selected'])):
        passages_string += f"{i+1}) {data_row['passages']['passage_text'][i]}\n"
    text = f"""Here is a query: {data_row["query"]}
Here are possible passages:
{passages_string}

Rate ALL the passages according to relevance to the query (even if they are unrelated). Only give the rating, do not provide any other text or explanations. Give the rating in the following format: rating = [a, b, c] if passage 1 has a rating of a, passage 2 has a rating of b and passage 3 has a rating of c. The scores should be between 0 and 100, and the most relevant passage does not need to have a perfect score. Call the function save_ranking_rating with the rating as the argument."""
    messages = [
      {"role": "system", "content": "You are a helpful assistant that rates the relevance of passages to a query."},
      {"role": "user", "content": text}
    ]
    input = tokenizer.apply_chat_template(messages, return_tensors="pt", tools=[save_ranking_rating], tokenize=True, add_generation_prompt=True, return_assistant_tokens_mask=True, return_dict=True)
    input_ids = input.input_ids.to(cuda_model)
    attention_mask = input.attention_mask.to(cuda_model)
    output_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=80, temperature=0.01, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, output_scores=True, return_dict_in_generate=True, no_repeat_ngram_size=10)
    output_text = tokenizer.decode(output_ids.sequences.squeeze())
    return output_text, extract_parameters(output_text)

# batch_data = ds['train'].select(range(100))
# row = batch_data[28]
# predict_row(row, tokenizer, model)

In [163]:
def predict_batch(batch_data, tokenizer, model):
    ratings = []
    try:
      for i in range(len(batch_data)):
        row = batch_data[i]
        output_text, response = predict_row(row, tokenizer, model)
        if response is None:
           print("None response")
           print(output_text)
        ratings.append(response)
    except Exception as e:
      print(e)
      print(i)
      pass

    return ratings

In [181]:
ratings = predict_batch(ds['train'].select(range(300)), tokenizer, model)

In [183]:
ratings = [np.array(rating) for rating in ratings]
ratings

[array([80., 20., 10., 30.,  5., 40.,  0., 60., 70., 10.]),
 array([90., 70., 60., 80., 40., 50., 30., 20., 10.,  0.]),
 array([  0.,   0., 100.,   0.,  80.,  90., 100.,   0.,  90.,   0.]),
 array([100.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 100.]),
 array([ 0.,  0.,  0.,  0., 10.,  0.,  0.,  0.,  0.,  0.]),
 array([90., 80., 10.,  0., 70., 60., 50., 40.,  0., 30.]),
 array([100.,  80.,  70.,  60.,  50.,  40.,  30.,  20.,  10.,   0.]),
 array([80., 70., 60., 50., 40., 30., 20., 10.,  0., 90.]),
 array([ 0.,  0.,  0.,  0.,  5.,  0., 80., 90., 95., 70.]),
 array([80., 20., 10., 70., 60., 20., 40., 10.,  0.,  0.]),
 array([80., 70., 60., 80., 90., 85., 80., 95., 80., 10.]),
 array([100.,  80.,  70.,  90.,  95.,  60.,  85.,  98.,  40.,  92.]),
 array([80., 20., 10., 90., 10.,  5.,  5., 95.,  5.,  0.]),
 array([80., 20., 10., 90., 10., 90., 10., 10., 10., 70.]),
 array([ 0.,  0.,  0.,  0., 10.,  0.,  0.,  0.,  0., 90.]),
 array([ 80.,   0.,   0.,   0.,   0.,   0., 100.,   0.,   0.

In [184]:
labels = [np.array(ds['train'][i]["passages"]['is_selected']) for i in range(300)]

In [185]:
rank_score(ratings, labels)

Skipped 25 rows


np.float64(3.699233716475096)

In [196]:
strict_rank_score(ratings, labels, threshold=1), strict_rank_score(ratings, labels, threshold=2), strict_rank_score(ratings, labels, threshold=3), strict_rank_score(ratings, labels, threshold=4)

Skipped 25 rows
Skipped 25 rows
Skipped 25 rows
Skipped 25 rows


(0.16091954022988506,
 0.3160919540229885,
 0.43103448275862066,
 0.5114942528735632)

In [182]:
with open("ratings300.json", "w") as f:
    json.dump(ratings, f)

In [155]:
print(ratings)

[[80, 20, 10, 30, 5, 40, 0, 60, 70, 10], [90, 70, 60, 80, 40, 50, 30, 20, 10, 0], [0, 0, 100, 0, 80, 90, 100, 0, 90, 0], [100, 0, 0, 0, 0, 0, 0, 0, 0, 100], [0, 0, 0, 0, 10, 0, 0, 0, 0, 0], [90, 80, 10, 0, 70, 60, 50, 40, 0, 30], [100, 80, 70, 60, 50, 40, 30, 20, 10, 0], [80, 70, 60, 50, 40, 30, 20, 10, 0, 90], [0, 0, 0, 0, 5, 0, 80, 90, 95, 70], [80, 20, 10, 70, 60, 20, 50, 10, 0, 0], [80, 70, 60, 80, 90, 85, 80, 95, 80, 10], [100, 80, 70, 90, 95, 60, 85, 98, 40, 92], [80, 20, 10, 90, 10, 5, 5, 95, 5, 0], [80, 20, 10, 90, 10, 90, 10, 10, 10, 70], [0, 0, 0, 0, 10, 0, 0, 0, 0, 80], [80, 20, 0, 0, 0, 90, 90, 0, 20, 0], [100, 0, 0, 0, 0], [0, 0, 0, 0, 20, 0, 0, 0, 0, 0], [95, 5, 5, 5, 90, 5, 5, 5, 5, 80], [100, 20, 0, 0, 0, 0], [0, 0, 0, 0, 20, 0, 0, 0], [100, 90, 20, 80, 10, 70, 60, 90, 85, 0], [100, 80, 20, 60, 40, 90, 70, 30, 50, 10], [100, 0, 0, 0, 0, 0, 0, 0, 10, 0], [100, 80, 70, 60, 90, 80, 70, 60, 90, 80], [100, 0, 90, 80, 0, 10, 0, 100, 100, 0], [100, 80, 70, 60, 50, 40, 30, 20, 

In [143]:
ds['validation'][6]

{'answers': ['Globally 8,640,000 lightning strikes per day.'],
 'passages': {'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
  'passage_text': ['Lightning is a major cause of storm related deaths in the U.S. A lightning strike can result in a cardiac arrest (heart stopping) at the time of the injury, although some victims may appear to have a delayed death a few days later if they are resuscitated but have suffered irreversible brain damage.',
   'Quick Answer. Lightning strikes reach the ground on Earth as much as 8 million times per day or 100 times per second, according to the National Severe Storms Laboratory. Out of all the lightning strikes in the world, the United States accounts for about 20 million of the total number of lightning strikes each year. Keep Learning.',
   'An average lightning strike discharges about 30,000 amperes (20,000 amperes in the UK). The current in a lightning strike typically ranges from 5,000 to 50,000 amperes depending on the strength of storm. NASA ha

In [144]:
batch_data = ds['validation'].select(range(10))
ratings = []
for i in range(len(batch_data)):
    row = batch_data[i]
    ratings.append(predict_row(row, tokenizer, model))

Here is a query: . what is a corporation?
Here are possible passages:
1) A company is incorporated in a specific nation, often within the bounds of a smaller subset of that nation, such as a state or province. The corporation is then governed by the laws of incorporation in that state. A corporation may issue stock, either private or public, or may be classified as a non-stock corporation. If stock is issued, the corporation will usually be governed by its shareholders, either directly or indirectly.
2) Today, there is a growing community of more than 2,100 Certified B Corps from 50 countries and over 130 industries working together toward 1 unifying goal: to redefine success in business. Join the Movement
3) Corporation definition, an association of individuals, created by law or under authority of law, having a continuous existence independent of the existences of its members, and powers and liabilities distinct from those of its members. See more.
4) Examples of corporation in a Sen

In [145]:
for i in range(len(batch_data)):
    print(batch_data[i]["query"])
    print(batch_data[i]["passages"]['is_selected'])
    print(ratings[i])
    print(batch_data[i]["passages"]['passage_text'])
    print("\n")

. what is a corporation?
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
[80, 20, 90, 70, 60, 95, 85, 10, 40, 30]
['A company is incorporated in a specific nation, often within the bounds of a smaller subset of that nation, such as a state or province. The corporation is then governed by the laws of incorporation in that state. A corporation may issue stock, either private or public, or may be classified as a non-stock corporation. If stock is issued, the corporation will usually be governed by its shareholders, either directly or indirectly.', 'Today, there is a growing community of more than 2,100 Certified B Corps from 50 countries and over 130 industries working together toward 1 unifying goal: to redefine success in business. Join the Movement', 'Corporation definition, an association of individuals, created by law or under authority of law, having a continuous existence independent of the existences of its members, and powers and liabilities distinct from those of its members. See more.', 'Exampl

In [26]:
text = """Here is a query: What was the impact of the manhattan project?
Here are possible passages:
1) Japan doesn't have nuclear weapons.
2) The Manhattan Project was a research and development undertaking during World War II that produced the first atomic bombs.
3) The Manhattan project was held by the US government.

Rank and rate the passages according to relevance to the query. Only give the ranking and rating, do not provide any other text or explanations. Give the ranking in descending order of relevance in the following format: ranking = [a, b, c] and rating = [d, e, f] if passage number a is most relevant with a score of c, followed by passage b with a score of e and then passage c with a score of f. The scores should be between 0 and 100, and the most relevant passage does not need to have a perfect score. Call the function save_ranking_rating with the ranking and rating as the argument."""

messages = [
  {"role": "system", "content": "You are a helpful assistant that ranks the relevance of passages to a query."},
  {"role": "user", "content": text}
]

input = tokenizer.apply_chat_template(messages, return_tensors="pt", tools=[save_ranking_rating], tokenize=True, add_generation_prompt=True, return_assistant_tokens_mask=True, return_dict=True)
input_ids = input.input_ids.to(cuda_model)
attention_mask = input.attention_mask.to(cuda_model)
output_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=80, temperature=0.01, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, output_scores=True, return_dict_in_generate=True, no_repeat_ngram_size=5)

In [27]:
output_text = tokenizer.decode(output_ids.sequences.squeeze())
function_call = output_text.find('{"name": }')
print(output_text)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant that ranks the relevance of passages to a query.<|eot_id|><|start_header_id|>user<|end_header_id|>

Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.

Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables.

{
    "type": "function",
    "function": {
        "name": "save_ranking_rating",
        "description": "Save the ranking of the passages according to relevance to the query as well as a rating of the relevance of the ranking (between 0 and 100).",
        "parameters": {
            "type": "object",
            "properties": {
                "ranking": {
                    "type": "array",
                    "items": {
                        "type"

In [11]:
print(output_text)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant that ranks the relevance of passages to a query.<|eot_id|><|start_header_id|>user<|end_header_id|>

Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.

Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables.

{
    "type": "function",
    "function": {
        "name": "save_ranking",
        "description": "Save the ranking of the passages according to relevance to the query as well as a rating of the relevance of the ranking (between 0 and 100).",
        "parameters": {
            "type": "object",
            "properties": {
                "ranking": {
                    "type": "array",
                    "items": {
                        "type": "arra

In [98]:
transition_scores = model.compute_transition_scores(output_ids[0], output_ids[1], normalize_logits=True)


In [95]:
transition_scores[0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:3')

In [88]:
output_ids[0][0]

tensor([128000, 128006,   9125, 128007,    271,  13013,     25,   6125,  27993,
           198,  38766,   1303,  33025,   2696,     25,   6790,    220,   2366,
            18,    198,  15724,   2696,     25,    220,   1627,  10263,    220,
          2366,     19,    271,   2675,    527,    264,  11190,  18328,    430,
         21467,    279,  41961,    315,  47869,    311,    264,   3319,     13,
        128009, 128006,    882, 128007,    271,  22818,    279,   2768,   5865,
            11,   4587,   6013,    449,    264,   4823,    369,    264,    734,
          1650,    449,   1202,   6300,   6105,    430,   1888,  11503,    279,
          2728,  10137,    382,  66454,    304,    279,   3645,   5324,    609,
           794,    734,    836,     11,    330,  14105,    794,  11240,    315,
          5811,    836,    323,   1202,    907,   7966,   5519,    539,   1005,
          7482,    382,    517,    262,    330,   1337,    794,    330,   1723,
           761,    262,    330,   1723, 

In [92]:
output_ids[2][-1][0].shape

torch.Size([1, 8, 454, 128])

In [62]:
len(output_ids[0][0]), output_ids[0]

(439,
 tensor([[128000, 128006,   9125, 128007,    271,  13013,     25,   6125,  27993,
             198,  38766,   1303,  33025,   2696,     25,   6790,    220,   2366,
              18,    198,  15724,   2696,     25,    220,   1627,  10263,    220,
            2366,     19,    271,   2675,    527,    264,  11190,  18328,    430,
           21467,    279,  41961,    315,  47869,    311,    264,   3319,     13,
          128009, 128006,    882, 128007,    271,  22818,    279,   2768,   5865,
              11,   4587,   6013,    449,    264,   4823,    369,    264,    734,
            1650,    449,   1202,   6300,   6105,    430,   1888,  11503,    279,
            2728,  10137,    382,  66454,    304,    279,   3645,   5324,    609,
             794,    734,    836,     11,    330,  14105,    794,  11240,    315,
            5811,    836,    323,   1202,    907,   7966,   5519,    539,   1005,
            7482,    382,    517,    262,    330,   1337,    794,    330,   1723,
          

In [67]:
len(output_ids[2])

32

In [71]:
tokenizer.decode(output_ids[0][0][-32])

'<|end_header_id|>'

In [21]:
tokenizer.decode(output_ids[0], skip_special_tokens=True)

'Give a short 1 sentence response about who is Tom Cruise? What is his profession? What is his net worth? What is his age? What is his height? What is his weight? What is his eye color? What is his hair color? What is his nationality? What is his ethnicity? What is his religion? What is his sexual orientation? What is his marital status? What is his education? What is his zodiac sign? What is his favorite food? What is his favorite color? What is his favorite movie? What is his favorite book'

In [17]:
len(output_ids[0])

65