In [1]:
import yaml
from pathlib import Path
import pandas as pd

from typing import List

In [2]:
def build_prompts(config, train_x, train_y, test_x):
    """Builds prompts for entire test set."""
    prompts = []
    for x in test_x:
        prompt = create_prompt_string(
            test_sample=x,
            introduction=config["introduction"],
            sample_prefix=config["sample_prefix"],
            sample_suffix=config["sample_suffix"],
            label_prefix=config["label_prefix"],
            train_samples=train_x,
            labels=train_y,
        )
        prompts.append(prompt)
    return prompts

def create_prompt_string(
    test_sample: str,
    introduction: str = None,
    sample_prefix: str = None,
    sample_suffix: str = None,
    label_prefix: str = None,
    train_samples: List[str] = [],
    labels: List[str] = [],
) -> str:
    """Creates prompt string from prompt elements.

    Prompts have the following format

    <introduction>

    <sample_prefix><train_sample1><sample_suffix>
    <label_prefix><label>

    <sample_prefix><train_sample2><sample_suffix>
    <label_prefix><label>

    ...

    <sample_prefix><test_sample><sample_suffix>
    <label_prefix>
    """
    prompt = ""
    if introduction is not None:
        prompt = introduction + "\n\n"

    # Add train samples
    for x, y in zip(train_samples, labels):
        if sample_prefix is not None:
            prompt += sample_prefix
        prompt += x
        if sample_suffix is not None:
            prompt += sample_suffix
        prompt += "\n"
        if label_prefix is not None:
            prompt += label_prefix
        prompt += y
        prompt += "\n\n"

    # Add test samples
    if sample_prefix is not None:
        prompt += sample_prefix
    prompt += test_sample
    if sample_suffix is not None:
        prompt += sample_suffix
    prompt += "\n"
    if label_prefix is not None:
        prompt += label_prefix

    # Since leading spaces don't work well, strip it out
    prompt = prompt.strip()
    return prompt

In [19]:
folder = Path("tasks/") / "rule_qa"
train = pd.read_csv(folder / "train.tsv", sep="\t")
test = pd.read_csv(folder / "test.tsv", sep="\t")


config_fpath = Path("prompts/rule_qa_base.yaml")
with open(config_fpath, "r") as stream:
    config = yaml.safe_load(stream)
config

{'introduction': None,
 'label_prefix': 'A: ',
 'labels': 'default',
 'sample_prefix': 'Q: ',
 'sample_suffix': None,
 'task': 'rule_recall'}

In [20]:
train

Unnamed: 0,text,label


In [21]:
prompts = build_prompts(config, train["text"], train["label"], ["{{text}}"])

In [22]:
print(prompts[0])

Q: {{text}}
A:
