In [1]:
#read data from csv
import sys
sys.path.append('..')
import logging
logging.basicConfig(level=logging.WARNING)

from src.llm_alex import Llama
from langchain_core import pydantic_v1
from langchain_core.runnables.base import RunnableParallel, RunnableLambda
from langchain.output_parsers.retry import RetryOutputParser
from langchain_core.output_parsers import JsonOutputParser
from typing import List

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, validator

from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate



import numpy
import pandas as pd
import random 
import tqdm 
import time
import wandb

from sklearn.metrics import classification_report

train_data = pd.read_csv("data/train_data.csv")
test_data = pd.read_csv("data/test_data.csv")

llm = Llama()
llm(query="test")



cuda


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


'It seems like you just typed a random word "test". Is there something specific you\'d like to test or discuss? I\'m here to help with any questions or topics you\'d like to explore!'

In [2]:
class Prediction(BaseModel):
    label: int = Field(description="classification label to a given text")
    @validator("label")
    def label_is_valid(cls, field): 
        print(cls)
        print(field)
        if not (0 <= field <= 77):
            raise ValueError("label must be an integer between 0 and 77")
        return field
parser = PydanticOutputParser(pydantic_object=Prediction)
retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm, max_retries=3)
def get_prompt_length(llm, str_input):
    messages = [
            {"role": "system", "content": llm._system_msg},
            {"role": "user", "content": str_input},
        ]
    prompt = llm.pipeline.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
    )
    input_tokens_length = len(llm.pipeline.tokenizer.encode(prompt))
    prompt_length = len(prompt)
    return input_tokens_length,prompt_length

In [51]:
#craft examples
random_seed = 42
n_classes = 2
num_shots = 2
sampled_classes = train_data['label'].sample(n_classes, random_state=random_seed).values
train_data_sub = train_data[train_data['label'].isin(sampled_classes)]
test_data_sub = test_data[test_data['label'].isin(sampled_classes)]


def get_examples(train_data_sub, num_shots):
    examples = []
    for label in train_data_sub.label.unique():
        for i, row in train_data_sub[train_data_sub['label'] == label].sample(num_shots, random_state=random_seed).iterrows():
            examples.append({
                "input": row['text'],
                "output": str("{label:"+str(label)+"}")
            })
            train_data_sub = train_data_sub.drop(i)
    return examples
examples = get_examples(test_data_sub, num_shots)


In [52]:
examples

[{'input': 'Why are you declining my payment? Everything was fine.',
  'output': '{label:25}'},
 {'input': 'I have a card payment that was declined, but why?',
  'output': '{label:25}'},
 {'input': 'I would like to change my pin.', 'output': '{label:21}'},
 {'input': 'How can I change my Tholepin ?', 'output': '{label:21}'}]

In [53]:
example_prompt = PromptTemplate(
    input_variables=["input", "output"], template="input: {input}\noutput: {output}"
)

print(example_prompt.format(**examples[0]))

input: Why are you declining my payment? Everything was fine.
output: {label:25}


## Few-Shot Prompting

In [54]:
few_shot_prompt = FewShotPromptTemplate(
    examples=examples,
    example_prompt=example_prompt,
    suffix="input: {input}",
    input_variables=["output"],
)

print(few_shot_prompt.format(input="Why does my credit card not work?"))

KeyError: 'label'

In [12]:
example_prompt

PromptTemplate(input_variables=['label', 'text'], template='Text: {text}\nLabel: {label}')

In [7]:
parser.get_format_instructions()

'The output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"label": {"title": "Label", "description": "classification label to a given text", "type": "integer"}}, "required": ["label"]}\n```'

In [27]:
text = test_data_sub.iloc[0]['text']
#text = "test"
few_shot_prompt = FewShotPromptTemplate(
    prefix="Your task is to classify a given text by assigning a label it it. \n\nHere are some examples: \n{text}\n",
    examples=examples,
    example_prompt=example_prompt,
    suffix="Text: {text}. Please format your output like this, ONLY answer with the JSON format:",
    input_variables=["text"],
    partial_variables={"format_instructions": "{'label': ['25']}"}
)
print(few_shot_prompt.format(text=text))

completion_chain = few_shot_prompt | llm

main_chain = RunnableParallel(
    completion=completion_chain, prompt_value=few_shot_prompt
) | RunnableLambda(lambda x: retry_parser.parse_with_prompt(**x))


answer: Prediction = main_chain.invoke({"text": text})
input_tokens_length, prompt_length = get_prompt_length(llm, few_shot_prompt.format(text=text))
print(answer)

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


Your task is to classify a given text by assigning a label it it. 

Here are some examples: 
My credit card was declined.


Text: Why are you declining my payment? Everything was fine.
Label: 25

Text: I have a card payment that was declined, but why?
Label: 25

Text: I would like to change my pin.
Label: 21

Text: How can I change my Tholepin ?
Label: 21

Text: My credit card was declined.. Please format your output like this, ONLY answer with the JSON format:


KeyError: "Input to PromptTemplate is missing variables {'completion'}.  Expected: ['completion', 'prompt'] Received: ['prompt', 'input']"

In [17]:
answer

Prediction(label=25)

In [22]:
len(train_data_sub)

275

In [26]:
import pandas as pd

# Assuming train_data_sub is a pandas DataFrame

# Define the function to process each row
def process_row(row):
    row['prediction'] = main_chain.invoke({"text": row['text']}).label
    
    input_tokens_length, prompt_length = get_prompt_length(llm, few_shot_prompt.format(text=row['text']))
    row['input_tokens_length'] = input_tokens_length
    row['prompt_length'] = prompt_length
    
    return row

# Apply the function to each row
test_data_sub = test_data_sub[:5].apply(process_row, axis=1)


Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


KeyError: "Input to PromptTemplate is missing variables {'completion'}.  Expected: ['completion', 'prompt'] Received: ['prompt', 'input']"

In [31]:
train_data_sub

Unnamed: 0,text,label,prediction,input_tokens_length,prompt_length
0,I'm not sure why my card didn't work,25,-1,287,1212
1,My card is not working at stores.,25,-1,282,1206
2,Do you know why my card payment has been decli...,25,-1,288,1240
3,Why isn't my card working? I was pumped to use...,25,-1,314,1344
4,I could not use my card in a store.,25,-1,286,1210


In [23]:
for idx,row in train_data_sub[:5].iterrows():
    try:
        row['prediction'] = main_chain.invoke({"text": text}).label
    except e:
        print(e)
        row['prediction'] = -1
    input_tokens_length, prompt_length = get_prompt_length(llm, few_shot_prompt.format(text=text))
    row['input_tokens_length'] = input_tokens_length
    row['prompt_length'] = prompt_length

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


In [25]:
from concurrent.futures import ThreadPoolExecutor, as_completed

def get_prediction_and_lengths(text):
    try:
        result = main_chain.invoke({"text": text})
        label = result.label
    except Exception as e:
        print(e)
        label = -1

    input_tokens_length, prompt_length = get_prompt_length(llm, few_shot_prompt.format(text=text))
    
    return label, input_tokens_length, prompt_length


In [27]:
import pandas as pd

# Assume train_data_sub is your DataFrame and it has a column named 'text'
def process_row(row):
    text = row['text']
    label, input_tokens_length, prompt_length = get_prediction_and_lengths(text)
    return pd.Series({
        "prediction": label,
        "input_tokens_length": input_tokens_length,
        "prompt_length": prompt_length
    })

# Use ThreadPoolExecutor for parallel processing
with ThreadPoolExecutor(max_workers=10) as executor:  # Adjust the number of workers based on your system
    futures = [executor.submit(process_row, row) for idx, row in train_data_sub[:5].iterrows()]

    # Collect the results as they complete
    results = []
    for future in as_completed(futures):
        results.append(future.result())

# Combine results into a new DataFrame
results_df = pd.DataFrame(results)

# Merge the results back to the original DataFrame
train_data_sub = pd.concat([train_data_sub.reset_index(drop=True), results_df], axis=1)


Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


"Input to PromptTemplate is missing variables {'completion'}.  Expected: ['completion', 'prompt'] Received: ['prompt', 'input']"
"Input to PromptTemplate is missing variables {'completion'}.  Expected: ['completion', 'prompt'] Received: ['prompt', 'input']"
"Input to PromptTemplate is missing variables {'completion'}.  Expected: ['completion', 'prompt'] Received: ['prompt', 'input']"
"Input to PromptTemplate is missing variables {'completion'}.  Expected: ['completion', 'prompt'] Received: ['prompt', 'input']"
"Input to PromptTemplate is missing variables {'completion'}.  Expected: ['completion', 'prompt'] Received: ['prompt', 'input']"


In [28]:
train_data_sub

Unnamed: 0,text,label,prediction,input_tokens_length,prompt_length
0,I'm not sure why my card didn't work,25,-1.0,288.0,1240.0
1,My card is not working at stores.,25,-1.0,314.0,1344.0
2,Do you know why my card payment has been decli...,25,-1.0,286.0,1210.0
3,Why isn't my card working? I was pumped to use...,25,-1.0,287.0,1212.0
4,I could not use my card in a store.,25,-1.0,282.0,1206.0
...,...,...,...,...,...
270,Please tell me how to change my pin.,21,,,
271,At what ATM can I change my PIN?,21,,,
272,If I am not in the country and I need to chang...,21,,,
273,What do I have to do to change my pin?,21,,,


## Similarity Enhanced Few-Shot Prompting

Idee: wähle few shot examples basierend auf der similarity zur query aus -> reduziere damit kontextlänge.
- FAISS vector storage
- HuggingFace Embeddings

In [6]:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS


model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cuda'}
encode_kwargs = {'normalize_embeddings': False}
hf = HuggingFaceEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)



In [7]:
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    hf,
    FAISS,
    k=2,
)

# Select the most similar example to the input.
text = "Why does my credit card not work?"
selected_examples = example_selector.select_examples({"text": text})
print(f"Examples most similar to the input: {text}")
for example in selected_examples:
    print("\n")
    for k, v in example.items():
        print(f"{k}: {v}")

Examples most similar to the input: Why does my credit card not work?


text: Can you tell me why my card keeps getting declined every time I try to use it?
label: 25


text: Why do you keep declining my payment? I tried several times already with this card and it is just not working.
label: 25


In [8]:
enhanced_few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    suffix="Text: {text}",
    input_variables=["text"],
)

print(enhanced_few_shot_prompt.format(text=text))

Text: Can you tell me why my card keeps getting declined every time I try to use it?
Label: 25

Text: Why do you keep declining my payment? I tried several times already with this card and it is just not working.
Label: 25

Text: Why does my credit card not work?


In [13]:
response = llm(query=enhanced_few_shot_prompt.format(text=text))
response

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


86
460


"I'd be happy to help you troubleshoot the issue with your credit card!\n\nThere could be several reasons why your card is being declined. Here are a few common causes:\n\n1. **Insufficient funds**: Make sure you have enough available credit on your card to cover the transaction.\n2. **Expired card**: Check the expiration date on your card to ensure it's still valid.\n3. **Invalid card information**: Double-check that you're entering the correct card number, expiration date, and security code.\n4. **Card issuer restrictions**: Your card issuer might have placed a hold on your account or restricted certain types of transactions.\n5. **Card security features**: Some cards have security features that can cause transactions to be declined, such as a chip or a digital signature.\n6. **Merchant-specific issues**: The merchant's payment processing system might be experiencing technical difficulties or have specific requirements for card acceptance.\n\nTo resolve the issue, you can try the fol

In [9]:
def get_prompt_length(llm, str_input):
    messages = [
            {"role": "system", "content": llm._system_msg},
            {"role": "user", "content": str_input},
        ]
    prompt = llm.pipeline.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
    )
    input_tokens_length = len(llm.pipeline.tokenizer.encode(prompt))
    prompt_length = len(prompt)
    return input_tokens_length,prompt_length

In [10]:
input_tokens_length, prompt_length = get_prompt_length(llm, enhanced_few_shot_prompt.format(text=text))

In [11]:
input_tokens_length

86

In [12]:
prompt_length

460

In [38]:
enhanced_few_shot_prompt.format(text=text)

'Text: Why do you keep declining my payment? I tried several times already with this card and it is just not working.\nLabel: 25\n\nText: What is the reason my payment was not accepted?\nLabel: 25\n\nText: Your task is to classify a given text into one of the following classes, reply ONLY with the class label: \n\nLabel: 25\nText: Why are you declining my payment? Everything was fine.\nText: I have a card payment that was declined, but why?\n\nLabel: 21\nText: I would like to change my pin.\nText: How can I change my Tholepin ?\n\nHere is your text, please classify it into one of the above classes\n\nText: my credit card does not work'

In [30]:
class Prediction(BaseModel):
    label: int = Field(description="classification label to a given text")
    @validator("label")
    def label_is_valid(cls, field): 
        if not (0 <= field <= 77):
            raise ValueError("label must be an integer between 0 and 77")
        return field
parser = PydanticOutputParser(pydantic_object=Prediction)

prompt = PromptTemplate(
    template="Answer the user query.\n{format_instructions}\n{query}\n",
    input_variables=["query"],
    partial_variables={"format_instructions": parser.get_format_instructions()},
)
enhanced_few_shot_prompt = FewShotPromptTemplate(
    prefix="Answer the user query.\n{format_instructions}\n{text}\n",
    example_selector=example_selector,
    example_prompt=example_prompt,
    suffix="Text: {text}",
    input_variables=["text"],
    partial_variables={"format_instructions": parser.get_format_instructions()}
)

retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm, max_retries=3)
 
completion_chain = enhanced_few_shot_prompt | llm

main_chain = RunnableParallel(
    completion=completion_chain, prompt_value=enhanced_few_shot_prompt
) | RunnableLambda(lambda x: retry_parser.parse_with_prompt(**x))


answer: Prediction = main_chain.invoke({"text": text})
input_tokens_length, prompt_length = get_prompt_length(llm, enhanced_few_shot_prompt.format(text=text))

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


242
1045


In [31]:
input_tokens_length

242

In [29]:
enhanced_few_shot_prompt.format(text=text)

'Answer the user query.\nThe output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"label": {"title": "Label", "description": "classification label to a given text", "type": "integer"}}, "required": ["label"]}\n```\nWhy does my credit card not work?\n\n\nText: Can you tell me why my card keeps getting declined every time I try to use it?\nLabel: 25\n\nText: why couldn\'t I use my card at a store?\nLabel: 25\n\nText: Why does my credit card not work?'

In [14]:
main_chain.invoke({"query": text})

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


KeyError: "Input to PromptTemplate is missing variables {'completion'}.  Expected: ['completion', 'prompt'] Received: ['prompt', 'input']"

In [None]:
random_seed = 42
n_classes = 2
num_shots = 2
sampled_classes = train_data['label'].sample(n_classes, random_state=random_seed).values
train_data_sub = train_data[train_data['label'].isin(sampled_classes)]
test_data_sub = test_data[test_data['label'].isin(sampled_classes)]

def get_prompt_template(train_data_sub, num_shots):
    prompt_template = "Your task is to classify a given text into one of the following classes, reply ONLY with the class label: \n\n"
    for label in train_data_sub.label.unique():
        prompt_template += f"Label: {label}\n"
        for i, row in train_data_sub[train_data_sub['label'] == label].sample(num_shots, random_state=random_seed).iterrows():
            prompt_template += f"Text: {row['text']}\n"
            #remove row from the dataframe
            train_data_sub = train_data_sub.drop(i)
        prompt_template += "\n"
    prompt_template += "Here is your text, please classify it into one of the above classes\n\n"
    return prompt_template

#prompt_template = get_prompt_template(test_data_sub,num_shots)

text = get_prompt_template(test_data_sub,num_shots)
text += "Text: my credit card does not work"

### TODO:
- 

In [15]:
parser.get_format_instructions()

'The output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"label": {"title": "Label", "description": "classification label to a given text", "type": "integer"}}, "required": ["label"]}\n```'

In [25]:

for i, row in tqdm.tqdm(test_data_sub.iterrows(),total=len(test_data_sub), desc="Processing data"):
    prompt = prompt_template + f"Text: {row['text']}\n"
    response = llm(prompt=prompt)
    test_data_sub.loc[i, 'response'] = response
    #print(f"Prompt: {prompt}")
    #print(f"Response: {response}")
    print("\n\n")

Processing data:   0%|          | 0/80 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_data_sub.loc[i, 'response'] = response
Processing data:   1%|▏         | 1/80 [00:16<21:17, 16.18s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:   2%|▎         | 2/80 [00:16<09:01,  6.94s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:   4%|▍         | 3/80 [00:17<05:06,  3.98s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:   5%|▌         | 4/80 [00:17<03:17,  2.59s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:   6%|▋         | 5/80 [00:18<02:16,  1.83s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:   8%|▊         | 6/80 [00:18<01:40,  1.36s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:   9%|▉         | 7/80 [00:18<01:18,  1.07s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  10%|█         | 8/80 [00:19<01:03,  1.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  11%|█▏        | 9/80 [00:19<00:53,  1.33it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  12%|█▎        | 10/80 [00:20<00:46,  1.51it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  14%|█▍        | 11/80 [00:20<00:41,  1.66it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  15%|█▌        | 12/80 [00:21<00:38,  1.78it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  16%|█▋        | 13/80 [00:21<00:35,  1.88it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  18%|█▊        | 14/80 [00:22<00:33,  1.96it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  19%|█▉        | 15/80 [00:22<00:32,  2.01it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  20%|██        | 16/80 [00:23<00:31,  2.05it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  21%|██▏       | 17/80 [00:23<00:30,  2.07it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  22%|██▎       | 18/80 [00:24<00:29,  2.09it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  24%|██▍       | 19/80 [00:24<00:28,  2.11it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  25%|██▌       | 20/80 [00:25<00:28,  2.12it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  26%|██▋       | 21/80 [00:25<00:27,  2.12it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  28%|██▊       | 22/80 [00:25<00:27,  2.13it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  29%|██▉       | 23/80 [00:26<00:26,  2.13it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  30%|███       | 24/80 [00:26<00:26,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  31%|███▏      | 25/80 [00:27<00:25,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  32%|███▎      | 26/80 [00:27<00:25,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  34%|███▍      | 27/80 [00:28<00:24,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  35%|███▌      | 28/80 [00:28<00:24,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  36%|███▋      | 29/80 [00:29<00:23,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  38%|███▊      | 30/80 [00:29<00:23,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  39%|███▉      | 31/80 [00:30<00:22,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  40%|████      | 32/80 [00:30<00:22,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  41%|████▏     | 33/80 [00:31<00:21,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  42%|████▎     | 34/80 [00:31<00:21,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  44%|████▍     | 35/80 [00:32<00:20,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  45%|████▌     | 36/80 [00:32<00:20,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  46%|████▋     | 37/80 [00:32<00:20,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  48%|████▊     | 38/80 [00:33<00:19,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  49%|████▉     | 39/80 [00:33<00:19,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  50%|█████     | 40/80 [00:34<00:18,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  51%|█████▏    | 41/80 [00:34<00:18,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  52%|█████▎    | 42/80 [00:35<00:17,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  54%|█████▍    | 43/80 [00:35<00:17,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  55%|█████▌    | 44/80 [00:36<00:16,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  56%|█████▋    | 45/80 [00:36<00:16,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  57%|█████▊    | 46/80 [00:37<00:15,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  59%|█████▉    | 47/80 [00:37<00:15,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  60%|██████    | 48/80 [00:38<00:14,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  61%|██████▏   | 49/80 [00:38<00:14,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  62%|██████▎   | 50/80 [00:39<00:13,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  64%|██████▍   | 51/80 [00:39<00:13,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  65%|██████▌   | 52/80 [00:39<00:13,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  66%|██████▋   | 53/80 [00:40<00:12,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  68%|██████▊   | 54/80 [00:40<00:12,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  69%|██████▉   | 55/80 [00:41<00:11,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  70%|███████   | 56/80 [00:41<00:11,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  71%|███████▏  | 57/80 [00:42<00:10,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  72%|███████▎  | 58/80 [00:42<00:10,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  74%|███████▍  | 59/80 [00:43<00:09,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  75%|███████▌  | 60/80 [00:43<00:09,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  76%|███████▋  | 61/80 [00:44<00:08,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  78%|███████▊  | 62/80 [00:44<00:08,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  79%|███████▉  | 63/80 [00:45<00:07,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  80%|████████  | 64/80 [00:45<00:07,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  81%|████████▏ | 65/80 [00:46<00:06,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  82%|████████▎ | 66/80 [00:46<00:06,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  84%|████████▍ | 67/80 [00:46<00:06,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  85%|████████▌ | 68/80 [00:47<00:05,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  86%|████████▋ | 69/80 [00:47<00:05,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  88%|████████▊ | 70/80 [00:48<00:04,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  89%|████████▉ | 71/80 [00:48<00:04,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  90%|█████████ | 72/80 [00:49<00:03,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  91%|█████████▏| 73/80 [00:49<00:03,  2.15it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  92%|█████████▎| 74/80 [00:50<00:02,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  94%|█████████▍| 75/80 [00:50<00:02,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  95%|█████████▌| 76/80 [00:51<00:01,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  96%|█████████▋| 77/80 [00:51<00:01,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  98%|█████████▊| 78/80 [00:52<00:00,  2.14it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data:  99%|█████████▉| 79/80 [00:52<00:00,  2.06it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.







Processing data: 100%|██████████| 80/80 [00:53<00:00,  1.51it/s]









In [26]:
test_data_sub

Unnamed: 0,text,label,response
40,Why won't my card show up on the app?,13,13
41,I would like to reactivate my card.,13,13
42,Where do I link the new card?,13,13
43,"I have received my card, can you help me put i...",13,13
44,How do I link a card that I already have?,13,13
...,...,...,...
1955,The ATM at Metro bank on High St. Kensington d...,18,18
1956,The ATM won't give back my card,18,18
1957,I was at an ATM and it swallowed my card.,18,18
1958,WTF??? I tried to withdraw some money at a Met...,18,13


In [27]:
test_data_sub['response'] = test_data_sub['response'].astype(int)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_data_sub['response'] = test_data_sub['response'].astype(int)


In [28]:
print(classification_report(test_data_sub['label'], test_data_sub['response']))

              precision    recall  f1-score   support

          13       0.83      1.00      0.91        40
          18       1.00      0.80      0.89        40

    accuracy                           0.90        80
   macro avg       0.92      0.90      0.90        80
weighted avg       0.92      0.90      0.90        80



In [21]:
print(classification_report(test_data_sub['label'], test_data_sub['response']))

              precision    recall  f1-score   support

          13       0.58      0.88      0.70        40
          18       0.75      0.38      0.50        40

    accuracy                           0.62        80
   macro avg       0.67      0.62      0.60        80
weighted avg       0.67      0.62      0.60        80



In [13]:
#iterate over the test data and generate the prompt
#surpress warnings
import warnings

warnings.filterwarnings("ignore")

prompt_template = get_prompt_template(test_data_sub, num_shots)

n_classes_list = [5, 10, 20, 30, 40, 50, 77]
num_shots = 5
results = {}
#iterate over the n_classes_list

# import wandb 
# wandb.init(project="llm-banking77")

from sklearn.metrics import classification_report

for n_classes in n_classes_list:
    sampled_classes = train_data['label'].sample(n_classes, random_state=random_seed).values
    train_data_sub = train_data[train_data['label'].isin(sampled_classes)]
    test_data_sub = test_data[test_data['label'].isin(sampled_classes)]
    prompt_template = get_prompt_template(test_data_sub, num_shots)

    for i, row in tqdm.tqdm(test_data_sub.iterrows(),total=len(test_data_sub), desc="Processing data"):
        prompt = prompt_template + f"Text: {row['text']}\n"
        response = llm(prompt=prompt)
        test_data_sub.loc[i, 'response'] = response
    test_data_sub['response'] = test_data_sub['response'].label.astype(int)
    report = classification_report(test_data_sub['label'], test_data_sub['response'], output_dict=True)
    # #print(f"Classification Report: {classification_report}")
    # #save results to dictionary
    # results[n_classes] = {}
    results[n_classes] = {
        'classification_report': report,
        'prompt_template': prompt_template,
        'context_length': len(prompt_template)
    }
    
 

Processing data:   0%|          | 0/200 [00:00<?, ?it/s]


TypeError: __call__() missing 1 required keyword-only argument: 'query'

In [None]:
import json
with open("../data/results.json", "w") as f:
    json.dump(results, f)