In [None]:
pip install prompt-learner --upgrade

In [1]:
from prompt_learner.tasks.classification import ClassificationTask
from prompt_learner.examples.example import Example
from prompt_learner.selectors.random_sampler import RandomSampler
from prompt_learner.selectors.diverse_sampler import DiverseSampler
from prompt_learner.selectors.stratified_sampler import StratifiedSampler
from prompt_learner.prompts.prompt import Prompt
from prompt_learner.templates.markdown import MarkdownTemplate
from prompt_learner.templates.xml import XmlTemplate
from prompt_learner.adapters.ollama_local import OllamaLocal
from prompt_learner.adapters.anthropic import Anthropic
from prompt_learner.adapters.llama import Llama
from prompt_learner.adapters.openai import OpenAI
from prompt_learner.evals.metrics.accuracy import Accuracy
from prompt_learner.optimizers.grid_search import GridSearch
import pandas as pd

In [2]:
# Task description and allowed labels
TASK_DESCRIPTION = "You have to classify intent for customer messages sent to chatbot."
ALLOWED_LABELS = ['INFO_ADD_REMOVE_VEHICLE', 'INFO_LOGIN_ERROR',
       'INFO_ADD_REMOVE_INSURED', 'INFO_ERS', 'INFO_CAREERS',
       'INFO_DIFFERENT_AMTS', 'INFO_SPEAK_TO_REP', 'INFO_CANCEL_INS_POLICY',
       'INFO_UPDATE_LIENHOLDER', 'INFO_DELETE_DUPE_PYMT',
       'INFO_CANT_SEE_FARM_RANCH_POLICY', 'INFO_AUTO_INS_CANADA',
       'INFO_DEC_PAGE_NEEDED', 'INFO_LIFE_BENEFICIARY_CHANGE',
       'INFO_MAKE_PYMT']

In [3]:
classification_task = ClassificationTask(description=TASK_DESCRIPTION,
                                         allowed_labels=ALLOWED_LABELS)

# Template for the task
template = MarkdownTemplate(task=classification_task)


In [4]:
# Load training data
with open("train.csv") as f:
    for line in f:
        text, label = line.split(",")
        classification_task.add_example(Example(text=text.strip(), label=label.strip()))
#Add test examples
with open("test.csv") as f:
    for line in f:
        text, label = line.split(",")
        classification_task.add_example(Example(text=text.strip(),
                                                label=label.strip()),
                                                test=True)

In [5]:
# Sample a random example from the training data
sampler = RandomSampler(num_samples=1, task=classification_task)
sampler.select_examples()

RandomSampler(num_samples=1)

In [6]:
# Assemble the prompt using the template
base_prompt = Prompt(template=template)
base_prompt.assemble_prompt()
print(base_prompt.prompt)


You are a helpful AI assistant.  
You are helping a user with a Classification task.  
The user gives you the following task description.  
### Task Description:  
You have to classify intent for customer messages sent to chatbot.
You have to select from the following labels.  
### Allowed Labels:  
['INFO_ADD_REMOVE_VEHICLE', 'INFO_LOGIN_ERROR', 'INFO_ADD_REMOVE_INSURED', 'INFO_ERS', 'INFO_CAREERS', 'INFO_DIFFERENT_AMTS', 'INFO_SPEAK_TO_REP', 'INFO_CANCEL_INS_POLICY', 'INFO_UPDATE_LIENHOLDER', 'INFO_DELETE_DUPE_PYMT', 'INFO_CANT_SEE_FARM_RANCH_POLICY', 'INFO_AUTO_INS_CANADA', 'INFO_DEC_PAGE_NEEDED', 'INFO_LIFE_BENEFICIARY_CHANGE', 'INFO_MAKE_PYMT'].  
Only output labels and nothing else.
Here are a few examples to help you understand the task better.  
### Examples

        -text: Do I have a seperate page for my farm account  
-label: INFO_CANT_SEE_FARM_RANCH_POLICY  



In [7]:
# See performance on gpt3 turbo without any grid search
acc, results = Accuracy(classification_task).compute(base_prompt,
                                                     OllamaLocal(model_name='llama3'),
                                                     test=True)
print(acc)
print(pd.DataFrame(results))

0.8863636363636364
                                                    0  \
0   do i save money if i remove a driver from my i...   
1   Who else can be added in my auto Policy if at ...   
2                   my son will now also drive my car   
3   I need to add a vehicle on my insurance and ha...   
4      Can I add another car to the current insurance   
5   I traded my kia and got a new car How do I add...   
6   Is my car covered through my policy if I visit...   
7   Do I need to upgrade my insurance to drive in ...   
8            If I travel to Canada am I still covered   
9   Is there any specific section in the website t...   
10  Can you cancel my policies or do I need to cal...   
11                          Cancel all of my poilcies   
12  what do i do if my farm insurance policy not s...   
13          In my list of accounts I dont see my farm   
14  Where can I find a list of open positions to a...   
15                    Tell me what jobs are available   
16          

In [8]:
# Initialize a grid search on the current prompt
grid_search = GridSearch(prompt=base_prompt)
random_4_shot = RandomSampler(num_samples=4, task=classification_task)
random_15_shot = RandomSampler(num_samples=15, task=classification_task)
diverse_15_shot = DiverseSampler(num_samples=15, task=classification_task)
stratify_15_shot = StratifiedSampler(num_samples=1, task=classification_task)


In [10]:
param_grid = {
    'sampler': [random_4_shot, random_15_shot,diverse_15_shot,stratify_15_shot],
    'template': [MarkdownTemplate],#, XmlTemplate],
    'adapter': [OllamaLocal(model_name='llama3')]#, Anthropic(model_name="claude-3-haiku-20240307"), OpenAI(model_name='gpt-4o')]
}

In [11]:
# Perform grid search to find the best parameters
best_params, all_results = grid_search.search(param_grid)
print(pd.DataFrame(all_results))

Grid Search Progress: 100%|██████████| 4/4 [01:11<00:00, 17.99s/it]

      score                                            sampler  \
0  0.895349                       RandomSampler(num_samples=4)   
1  0.906667                      RandomSampler(num_samples=15)   
2  0.853333                     DiverseSampler(num_samples=15)   
3  0.973333  StratifiedSampler(task=description='You have t...   

           template                            adapter  
0  MarkdownTemplate  Ollama Adapter(model_name=llama3)  
1  MarkdownTemplate  Ollama Adapter(model_name=llama3)  
2  MarkdownTemplate  Ollama Adapter(model_name=llama3)  
3  MarkdownTemplate  Ollama Adapter(model_name=llama3)  





In [12]:
# Evaluate the best model
template = MarkdownTemplate(task=classification_task)
sampler = StratifiedSampler(num_samples=1, task=classification_task)
sampler.select_examples()

best_prompt = Prompt(template=template)
best_prompt.assemble_prompt()

In [13]:
print(best_prompt.prompt)

You are a helpful AI assistant.  
You are helping a user with a Classification task.  
The user gives you the following task description.  
### Task Description:  
You have to classify intent for customer messages sent to chatbot.
You have to select from the following labels.  
### Allowed Labels:  
['INFO_ADD_REMOVE_VEHICLE', 'INFO_LOGIN_ERROR', 'INFO_ADD_REMOVE_INSURED', 'INFO_ERS', 'INFO_CAREERS', 'INFO_DIFFERENT_AMTS', 'INFO_SPEAK_TO_REP', 'INFO_CANCEL_INS_POLICY', 'INFO_UPDATE_LIENHOLDER', 'INFO_DELETE_DUPE_PYMT', 'INFO_CANT_SEE_FARM_RANCH_POLICY', 'INFO_AUTO_INS_CANADA', 'INFO_DEC_PAGE_NEEDED', 'INFO_LIFE_BENEFICIARY_CHANGE', 'INFO_MAKE_PYMT'].  
Only output labels and nothing else.
Here are a few examples to help you understand the task better.  
### Examples

        -text: Id like to switch the car that is on my insurance with a new one  
-label: INFO_ADD_REMOVE_VEHICLE  
-text: I cant log in and I keep getting shown an error  
-label: INFO_LOGIN_ERROR  
-text: How would I rem

In [14]:
acc, results = Accuracy(classification_task).compute(best_prompt,
                                                     OllamaLocal(model_name = "llama3"),
                                                     test=True)
print(acc)
print(pd.DataFrame(results))

0.95
                                                    0  \
0   do i save money if i remove a driver from my i...   
1   Who else can be added in my auto Policy if at ...   
2                   my son will now also drive my car   
3   I need to add a vehicle on my insurance and ha...   
4      Can I add another car to the current insurance   
5   I traded my kia and got a new car How do I add...   
6   Is my car covered through my policy if I visit...   
7   Do I need to upgrade my insurance to drive in ...   
8            If I travel to Canada am I still covered   
9   Is there any specific section in the website t...   
10  Can you cancel my policies or do I need to cal...   
11  what do i do if my farm insurance policy not s...   
12      Do I have a seperate page for my farm account   
13  Where can I find a list of open positions to a...   
14                    Tell me what jobs are available   
15                      Do you have any work openings   
16  can you get me a copy 