In [1]:
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def gen_prompt_from_query(
    model,
    tokenizer,
    input_ids,
    attention_mask,
    max_new_tokens=64,
    gen_prompt_ids=None,
    instruction="Generate a concise reasoning prompt to answer the following question:"
):
    with torch.no_grad():
    # Tokenize instruction if not provided
        if gen_prompt_ids is None:
            gen_prompt_ids = tokenizer(
                instruction,
                return_tensors='pt',
                add_special_tokens=False
            ).input_ids.to(input_ids.device)
        # Combine instruction and query
        full_input_ids = torch.cat(
            [gen_prompt_ids.repeat(input_ids.size(0), 1), input_ids], dim=1
        )
        full_attention_mask = torch.cat(
            [
                torch.ones(
                    (input_ids.size(0), gen_prompt_ids.size(1)),
                    dtype=attention_mask.dtype,
                    device=input_ids.device,
                ),
                attention_mask,
            ],
            dim=1,
        )

        # Generate continuation (the self-prompt)
        outputs = model.generate(
            input_ids=full_input_ids,
            attention_mask=full_attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_k=5,
            temperature=0.7,
            output_scores=True,
            return_dict_in_generate=True,
            pad_token_id=model.config.eos_token_id,
        )
        # Query may be deeply embedded so we dont need to split out the query.
        prompt_ids = outputs.sequences[:, full_input_ids.size(1):]  # Get only the generated part

        return {
            "input_ids": prompt_ids,
            "attention_mask": torch.ones_like(prompt_ids),
            # "log_probs": compute_generation_logprobs(outputs, full_input_ids.size(1)),
            "full_input_ids": full_input_ids,
        }

In [3]:
model_name = 'Qwen/Qwen2-7B-Instruct'

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [5]:
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  6.64it/s]


In [6]:
strategies = [
    'Supportive Stepwise Guidance',
    'Define, List, and Exemplify',
    'Stepwise Procedural Guide',
    'Enumerate Assistance Methods',
    'Intro Then List',
    'Sequential Explanation with Contrast',
    'Intro + Thematic List with Explanations',
    'Structured Reflective Statement',
    "List, Reasoning, Answer"
]

In [None]:
prompt_list = [
    "Generate a prompt using '{strategy}' without any extra output. Only write the prompt encoding the question: ",
    "Generate a prompt using the '{strategy}' format to guide a low-capability agent in solving a problem step-by-step without revealing the answer. You must ONLY output the prompt. If you include anything other than the prompt, or if the prompt contains the answer, your response will be considered invalid and rejected.",
    "Write a prompt using '{strategy}' that instructs a weaker agent to solve the problem step-by-step without giving the answer. Output only the prompt text—no explanations, no extra content.",
    "Using '{strategy}', create a prompt that helps a low-skill agent work through the problem in steps. Do not provide the solution or anything except the prompt itself.",
    "Construct a prompt in the '{strategy}' style that outlines a step-by-step method for a less-capable agent to follow. The prompt must not contain the answer and must be the only output.",
    "Produce a '{strategy}' prompt designed to guide a minimal-capability agent through solving the problem step-by-step. Absolutely no answers or extra commentary—output the prompt alone.",
    "Generate only the prompt in '{strategy}' style to teach a low-level agent how to approach the problem in steps. If any answer or additional text appears, the output is invalid."

]

In [None]:
query = "why is 49 a prime number?"

In [8]:
query = tokenizer(
    query,
    return_tensors='pt',
    add_special_tokens=False
).to(model.device)


In [None]:
pairs = [
    (strategy, prompt) for prompt in prompt_list for strategy in strategies
]

In [None]:
out = [
    gen_prompt_from_query(
                            model, 
                            tokenizer, 
                            query['input_ids'], 
                            query['attention_mask'], 
                            max_new_tokens=256, 
                            instruction=prompt_inst.format(strategy=strat)
                        ) 
    for strat,prompt_inst in pairs
]

In [10]:
def batch_prompts(prompts, pad_token_id):
    # prompts is a list of dicts returned by gen_prompt_from_query
    input_ids_list = [p['input_ids'].squeeze(0) for p in prompts]
    attn_mask_list = [p['attention_mask'].squeeze(0) for p in prompts]

    # pad_sequence will pad to the longest sequence in the batch
    input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_id)
    attn_mask = pad_sequence(attn_mask_list, batch_first=True, padding_value=0)

    return {"input_ids": input_ids, "attention_mask": attn_mask}

In [12]:
new_out = batch_prompts(out,tokenizer.pad_token_id)

In [13]:
new_out

{'input_ids': tensor([[ 59501,     25,  81917,  ..., 151645, 151645, 151645],
         [ 18614,  10250,   5109,  ..., 151645, 151645, 151645],
         [ 14822,   4482,   1298,  ..., 151645, 151645, 151645],
         ...,
         [   715,  61569,     25,  ...,   3476,    304,   5257],
         [  4710,  54615,     25,  ...,    323,   8357,     13],
         [  1759,     25,    715,  ..., 151645, 151645, 151645]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0]])}

In [None]:
prompt = [tokenizer.decode(new_out['input_ids'][i], skip_special_tokens=True) for i in range(len(pairs))]

In [17]:
print(len(prompt))

9


In [18]:
for strat, prompt_txt in zip(strategies, prompt):
    print(f"Strategy: '{strat}'\nPrompt: '{prompt_txt}'\n{'-'*40}")

Strategy: 'Supportive Stepwise Guidance'
Prompt: ' Prompt: Explain why the number 49 is not considered a prime number, detailing the factors of 49 and how they relate to the definition of a prime number.'
----------------------------------------
Strategy: 'Define, List, and Exemplify'
Prompt: ' Define prime numbers.
List the criteria for a number to be considered prime.
Exemplify why 49 does not meet these criteria and therefore is not a prime number. Why is 49 not considered a prime number? Provide a definition of prime numbers, list the criteria for a number to be considered prime, and give an example of how 49 fails to meet these criteria, demonstrating its non-prime status.'
----------------------------------------
Strategy: 'Stepwise Procedural Guide'
Prompt: ' Stepwise Procedural Guide:
1. Define what a prime number is.
2. List the factors of 49.
3. Determine if there are any factors other than 1 and 49.
4. Explain why the absence of such factors makes 49 not a prime number.

Pro

In [1]:
strategies = [
    'Supportive Stepwise Guidance',
    'Stepwise Procedural Guide',
    'Define, List, and Exemplify',
    'Enumerate Assistance Methods',
]
prompt_list = [
    "Generate a prompt using the '{strategy}' format to guide a low-capability agent in solving a problem step-by-step without revealing the answer. You must ONLY output the prompt. If you include anything other than the prompt, or if the prompt contains the answer, your response will be considered invalid and rejected.",
    "Using '{strategy}', create a prompt that helps a low-skill agent work through the problem in steps. Do not provide the solution or anything except the prompt itself.",
]

In [2]:
query_set = {
    "math_reasoning": [
        "Determine whether 121 is a prime number.",
        "Find the area of a triangle with base 8 cm and height 5 cm.",
        "If a train travels 60 km in 1.5 hours, what is its average speed?"
    ],
    "general_knowledge": [
        "Who wrote the play 'Romeo and Juliet'?",
        "What is the capital city of Canada?",
        "Name three countries that share a border with Germany."
    ],
    "commonsense_reasoning": [
        "If you leave ice cubes out in the sun, what will happen after 10 minutes?",
        "Why should you not touch an electrical socket with wet hands?",
        "If it is raining outside, what might people carry with them?"
    ],
    "procedural_tasks": [
        "Explain how to change a flat bicycle tire.",
        "Give step-by-step instructions to bake a chocolate cake.",
        "Describe the process of renewing a passport."
    ],
    "creative_generation": [
        "Write the opening sentence of a mystery novel set in a small fishing village.",
        "Invent a new sport and explain how it is played.",
        "Create a tagline for a company selling eco-friendly shoes."
    ],
    "classification_identification": [
        "Decide whether this email is spam: 'Congratulations! You’ve won a $1000 gift card. Click here to claim.'",
        "Classify the tone of this sentence: 'I can't believe you did that!'",
        "Identify whether the following sentence is fact or opinion: 'Chocolate ice cream is the best flavor.'"
    ],
    "multi_step_logic": [
        "You have 3 red balls and 2 blue balls in a bag. If you take two balls without looking, what is the probability they are both red?",
        "Solve this riddle: 'I speak without a mouth and hear without ears. I have no body, but I come alive with wind. What am I?'",
        "If Alice is older than Bob, and Bob is older than Carol, who is the youngest?"
    ]
}


In [None]:
pairs = {
    k: [ 
        (
            strategy, 
            prompt, 
            q,
            # tokenizer(
            #     q,
            #     return_tensors='pt',
            #     add_special_tokens=False
            # ).to(model.device)
        ) 
        for prompt in prompt_list for strategy in strategies for q in queries
    ] 
    for k,queries in query_set.items() 
}

In [10]:
print(*( q for k,x in pairs.items() for s,p,q in x ))

Determine whether 121 is a prime number. Find the area of a triangle with base 8 cm and height 5 cm. If a train travels 60 km in 1.5 hours, what is its average speed? Determine whether 121 is a prime number. Find the area of a triangle with base 8 cm and height 5 cm. If a train travels 60 km in 1.5 hours, what is its average speed? Determine whether 121 is a prime number. Find the area of a triangle with base 8 cm and height 5 cm. If a train travels 60 km in 1.5 hours, what is its average speed? Determine whether 121 is a prime number. Find the area of a triangle with base 8 cm and height 5 cm. If a train travels 60 km in 1.5 hours, what is its average speed? Determine whether 121 is a prime number. Find the area of a triangle with base 8 cm and height 5 cm. If a train travels 60 km in 1.5 hours, what is its average speed? Determine whether 121 is a prime number. Find the area of a triangle with base 8 cm and height 5 cm. If a train travels 60 km in 1.5 hours, what is its average speed

In [None]:
out = [
    gen_prompt_from_query(
                            model, 
                            tokenizer, 
                            query['input_ids'], 
                            query['attention_mask'], 
                            max_new_tokens=256, 
                            instruction=prompt_inst.format(strategy=strat)
                        ) 
    for k,vals in pairs.items() for strat,prompt_inst,query in vals 
]