In [1]:
import argparse
import concurrent
from dotenv import load_dotenv
from tqdm import tqdm
import textgrad as tg
from textgrad.tasks import load_task
import numpy as np
import random
load_dotenv(override=True)

True

In [2]:
def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)

In [3]:
def eval_sample(item, eval_fn, model):
    """
    This function allows us to evaluate if an answer to a question in the prompt is a good answer.

    """
    x, y = item
    x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
    y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
    response = model(x)
    try:
        eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
        return int(eval_output_variable.value)
    except:
        eval_output_variable = eval_fn([x, y, response])
        eval_output_parsed = eval_fn.parse_output(eval_output_variable)
        return int(eval_output_parsed)

In [4]:
def eval_dataset(test_set, eval_fn, model, max_samples: int=None):
    if max_samples is None:
        max_samples = len(test_set)
    accuracy_list = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        futures = []
        for _, sample in enumerate(test_set):
            
            future = executor.submit(eval_sample, sample, eval_fn, model)
            futures.append(future)
            if len(futures) >= max_samples:
                break
        tqdm_loader = tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)
        for future in tqdm_loader:
            acc_item = future.result()
            accuracy_list.append(acc_item)
            tqdm_loader.set_description(f"Accuracy: {np.mean(accuracy_list)}")
    return accuracy_list 

In [5]:
def run_validation_revert(system_prompt: tg.Variable, system_prompt2: tg.Variable, results, model, eval_fn, val_set):
    val_performance = np.mean(eval_dataset(val_set, eval_fn, model))
    previous_performance = np.mean(results["validation_acc"][-1])
    print("val_performance: ", val_performance)
    print("previous_performance: ", previous_performance)
    previous_prompt = results["prompt"][-1]
    previous_prompt2 = results["prompt2"][-1]
    
    if val_performance < previous_performance:
        print(f"rejected prompt: {system_prompt.value}")
        system_prompt.set_value(previous_prompt)
        # system_prompt2.set_value(previous_prompt2)
        val_performance = previous_performance

    results["validation_acc"].append(val_performance)
    results["prompt"].append(system_prompt.value)
    results["prompt2"].append(system_prompt2.value)

In [6]:
set_seed(12)
llm_api_eval = tg.get_engine(engine_name="gpt-3.5-turbo-0125")
llm_api_test = tg.get_engine(engine_name="gpt-3.5-turbo-0125")
tg.set_backward_engine(llm_api_eval, override=True)

# Load the data and the evaluation function
train_set, val_set, test_set, eval_fn = load_task("BBH_object_counting", evaluation_api=llm_api_eval)
print("Train/Val/Test Set Lengths: ", len(train_set), len(val_set), len(test_set))
STARTING_SYSTEM_PROMPT = train_set.get_task_description()

Train/Val/Test Set Lengths:  50 100 100


In [7]:
print(STARTING_SYSTEM_PROMPT)

You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.


In [8]:
train_loader = tg.tasks.DataLoader(train_set, batch_size=3, shuffle=True)


# Testing the 0-shot performance of the evaluation engine
system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, 
                            requires_grad=True, 
                            role_description="system prompt to the language model")
model_evaluation = tg.BlackboxLLM(llm_api_eval, system_prompt)

system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, 
                            requires_grad=True,
                            role_description="structured system prompt to a somewhat capable language model that specifies the behavior and strategies for the QA task")
model1 = tg.BlackboxLLM(llm_api_test, system_prompt)

system_prompt2 = tg.Variable("return the same thing as the input.", 
                            requires_grad=True,
                            role_description="return the same thing as the input.")
model2 = tg.BlackboxLLM(llm_api_test, system_prompt2)

def model(input):
    input1 = model1(input)
    return model2(input1)

optimizer = tg.TextualGradientDescent(engine=llm_api_eval, parameters=[system_prompt, system_prompt2])

results = {"test_acc": [], "prompt": [], "validation_acc": [], "prompt2": []}
results["test_acc"].append(eval_dataset(test_set, eval_fn, model))
results["validation_acc"].append(eval_dataset(val_set, eval_fn, model))
results["prompt"].append(system_prompt.get_value())
results["prompt2"].append(system_prompt2.get_value())


Accuracy: 0.79: 100%|██████████| 100/100 [00:00<00:00, 1631.16it/s]  
Accuracy: 0.7: 100%|██████████| 100/100 [00:00<00:00, 1849.38it/s]    


In [9]:
for epoch in range(1):
    for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):
        pbar.set_description(f"Training step {steps}. Epoch {epoch}")
        optimizer.zero_grad()
        losses = []
        for (x, y) in zip(batch_x, batch_y):
            print("x: ", x)
            x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
            y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
            response = model(x)
            try:
                eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
            except:
                eval_output_variable = eval_fn([x, y, response])
            losses.append(eval_output_variable)
        total_loss = tg.sum(losses)
        print("total_loss: ", total_loss)
        total_loss.backward()
        print("system_prompt.grad: ", system_prompt)
        optimizer.step()

        print("system_prompt: ", system_prompt)
        run_validation_revert(system_prompt, system_prompt2, results, model, eval_fn, val_set)
        
        print("sys prompt: ", system_prompt)
        print("sys prompt2: ", system_prompt2)
        test_acc = eval_dataset(test_set, eval_fn, model)
        results["test_acc"].append(test_acc)
        results["prompt"].append(system_prompt.get_value())
        results["prompt2"].append(system_prompt2.get_value())
        if steps == 3:
            break

Training step 0. Epoch 0: : 0it [00:00, ?it/s]

x:  I have a cauliflower, a stalk of celery, a cabbage, and a garlic. How many vegetables do I have?
x:  I have a trumpet, four trombones, an accordion, a clarinet, a violin, and a drum. How many musical instruments do I have?
x:  I have a blackberry, a peach, a nectarine, a plum, a raspberry, an orange, a strawberry, a banana, two apples, and four grapes. How many fruits do I have?
total_loss:  1
1
1
system_prompt.grad:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
system_prompt:  Calculate the total number of vegetables by counting each individual vegetable. Ensure to provide a step-by-step breakdown of your reasoning process. End your response with the format: 'Answer: $VALUE' where VALUE is a numerical value.


Accuracy: 0.55: 100%|██████████| 100/100 [02:00<00:00,  1.20s/it]              


val_performance:  0.55
previous_performance:  0.7
rejected prompt: Calculate the total number of vegetables by counting each individual vegetable. Ensure to provide a step-by-step breakdown of your reasoning process. End your response with the format: 'Answer: $VALUE' where VALUE is a numerical value.
sys prompt:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
sys prompt2:  Enhance the response based on the input text.


Accuracy: 0.48: 100%|██████████| 100/100 [00:42<00:00,  2.37it/s]              
Training step 1. Epoch 0: : 1it [03:40, 220.95s/it]

x:  I have a piano, and a trombone. How many musical instruments do I have?
x:  I have a bed, a fridge, a lamp, a toaster, four chairs, and a table. How many objects do I have?
x:  I have a piano, an accordion, three trombones, five clarinets, a violin, a drum, a trumpet, and three flutes. How many musical instruments do I have?
total_loss:  0
0
1
system_prompt.grad:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
system_prompt:  Provide a clear and precise response to the following question by calculating the total number of objects. Break down the problem by listing each object before determining the final count. Conclude your answer in the following format: 'Answer: $VALUE' where VALUE is the numerical result.


Accuracy: 0.37: 100%|██████████| 100/100 [01:30<00:00,  1.10it/s]              


val_performance:  0.37
previous_performance:  0.7
rejected prompt: Provide a clear and precise response to the following question by calculating the total number of objects. Break down the problem by listing each object before determining the final count. Conclude your answer in the following format: 'Answer: $VALUE' where VALUE is the numerical result.
sys prompt:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
sys prompt2:  Ensure the response states the total number of musical instruments calculated clearly and precisely.


Accuracy: 0.34: 100%|██████████| 100/100 [00:36<00:00,  2.72it/s]              
Training step 2. Epoch 0: : 2it [06:43, 198.14s/it]

x:  I have a stove, a bed, a lamp, three microwaves, a chair, a toaster, a table, two cars, a fridge, and an oven. How many objects do I have?
x:  I have a flute, a trumpet, three accordions, three violins, a drum, three clarinets, and a trombone. How many musical instruments do I have?
x:  I have a piano, a trombone, a clarinet, a goat, an accordion, and a trumpet. How many musical instruments do I have?
total_loss:  0
1
1
system_prompt.grad:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
system_prompt:  You will answer a reasoning question involving counting musical instruments. Follow the steps to calculate the total number of instruments. Ensure to accurately sum up numerical quantities and provide a clear and concise response. Your final answer should be presented in a straightforward manner without unnecessary details. Before finalizing your answer, dou

Accuracy: 0.39: 100%|██████████| 100/100 [00:56<00:00,  1.78it/s]              


val_performance:  0.39
previous_performance:  0.7
rejected prompt: You will answer a reasoning question involving counting musical instruments. Follow the steps to calculate the total number of instruments. Ensure to accurately sum up numerical quantities and provide a clear and concise response. Your final answer should be presented in a straightforward manner without unnecessary details. Before finalizing your answer, double-check the accuracy of your calculation and ensure it aligns with the provided format. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
sys prompt:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
sys prompt2:  Ensure the response states the total number of musical instruments calculated clearly and precisely based on the provided list of quantities for each instrument.


Accuracy: 0.38: 100%|██████████| 100/100 [00:43<00:00,  2.32it/s]              
Training step 3. Epoch 0: : 3it [09:29, 183.71s/it]

x:  I have a yam, a cauliflower, a garlic, two lettuce heads, a head of broccoli, a potato, a stalk of celery, and an onion. How many vegetables do I have?
x:  I have three fridges, a bed, and five stoves. How many objects do I have?
x:  I have a raspberry, a grape, and an orange. How many fruits do I have?
total_loss:  0
0
1
system_prompt.grad:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
system_prompt:  You will answer a reasoning question by calculating the total number of musical instruments. Provide a step-by-step explanation of your reasoning process. Your final response should be in the following format: 'Answer: $VALUE' where VALUE is a numerical value representing the total count of musical instruments.


Accuracy: 0.23: 100%|██████████| 100/100 [02:17<00:00,  1.37s/it]              


val_performance:  0.23
previous_performance:  0.7
rejected prompt: You will answer a reasoning question by calculating the total number of musical instruments. Provide a step-by-step explanation of your reasoning process. Your final response should be in the following format: 'Answer: $VALUE' where VALUE is a numerical value representing the total count of musical instruments.
sys prompt:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.
sys prompt2:  Ensure the response provides the total count of musical instruments accurately based on the quantities provided for each instrument.


Accuracy: 0.29: 100%|██████████| 100/100 [00:51<00:00,  1.95it/s]              
Training step 3. Epoch 0: : 3it [13:49, 276.59s/it]


In [10]:
results["prompt"]

["You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.",
 "You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.",
 "You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.",
 "You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.",
 "You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.",
 "You will answer a reasoning question. Think step by step. The last line of your response should be of t

In [11]:
results["prompt2"]

['return the same thing as the input.',
 'Enhance the response based on the input text.',
 'Enhance the response based on the input text.',
 'Ensure the response states the total number of musical instruments calculated clearly and precisely.',
 'Ensure the response states the total number of musical instruments calculated clearly and precisely.',
 'Ensure the response states the total number of musical instruments calculated clearly and precisely based on the provided list of quantities for each instrument.',
 'Ensure the response states the total number of musical instruments calculated clearly and precisely based on the provided list of quantities for each instrument.',
 'Ensure the response provides the total count of musical instruments accurately based on the quantities provided for each instrument.',
 'Ensure the response provides the total count of musical instruments accurately based on the quantities provided for each instrument.']