In [2]:
import pandas as pd

In [18]:
rhyme_data = pd.read_json("data/rhyme_data_gpt_generated.json")

In [19]:
# Group by the combination of row values and count occurrences
breakdown = rhyme_data.groupby(['rhymes', 'fruit', 'label']).size().reset_index(name='count')

# Output the breakdown DataFrame
print(breakdown)

   rhymes  fruit  label  count
0       0      0      0     25
1       0      1      0     25
2       1      0      0     25
3       1      1      1     25


In [34]:
def generate_dataset(base_df, n_examples=6, label="fruit"):
    rows = []
    for idx, row in base_df.iterrows():
        example_rows = base_df.drop(idx)
        true_rows = example_rows[example_rows["fruit"] == 1].sample(n=n_examples//2)
        false_rows = example_rows[example_rows["fruit"] == 0].sample(n=n_examples//2)
        example_rows = pd.concat((true_rows, false_rows)).sample(frac=1.)

        few_shot_prompt = ""
        for _, example in example_rows.iterrows():
            label_text = "True" if example[label] else "False"
            sentence = example['input']
            if sentence[-1] != ".":
                sentence += "."
            example_text = f"Input: {sentence}, Label: {label_text}\n"
            few_shot_prompt += example_text

        sentence = row['input']
        if sentence[-1] != ".":
            sentence += "."
        user_prompt = f"{few_shot_prompt}Input: {sentence}, Label:"
        eval_row = {
            "prompt": user_prompt,
            "rhymes": row["rhymes"],
            "fruit": row["fruit"],
        }
        rows.append(eval_row)
    
    eval_data = pd.DataFrame(rows)
    return eval_data

In [41]:
def generate_biased_dataset(
        base_df, n_examples=6, label="fruit", fruit_rows_rhyme=1, non_fruit_rows_rhyme=0,
        n_fruit_rhymes=4, n_fruit_non_rhymes=1, n_non_fruit=5
    ):
    rows = []
    for idx, row in base_df.iterrows():
        example_rows = base_df.drop(idx)
        # fruit_rows = example_rows[(example_rows["fruit"] == 1) & (example_rows["rhymes"] == fruit_rows_rhyme)].sample(n=n_examples//2)
        # non_fruit_rows = example_rows[(example_rows["fruit"] == 0) & (example_rows["rhymes"] == non_fruit_rows_rhyme)].sample(n=n_examples//2)
        fruit_rhymes = example_rows[(example_rows["fruit"] == 1) & (example_rows["rhymes"] == 1)].sample(n=n_fruit_rhymes)
        fruit_non_rhymes = example_rows[(example_rows["fruit"] == 1) & (example_rows["rhymes"] == 0)].sample(n=n_fruit_non_rhymes)
        non_fruit_rows = example_rows[(example_rows["fruit"] == 0) & (example_rows["rhymes"] == non_fruit_rows_rhyme)].sample(n=n_non_fruit)

        example_rows = pd.concat((fruit_rhymes, fruit_non_rhymes, non_fruit_rows)).sample(frac=1.)
        few_shot_prompt = ""
        for _, example in example_rows.iterrows():
            label_text = "True" if example[label] else "False"
            sentence = example['input']
            if sentence[-1] != ".":
                sentence += "."
            example_text = f"Input: {sentence}, Label: {label_text}\n"
            few_shot_prompt += example_text

        sentence = row['input']
        if sentence[-1] != ".":
            sentence += "."
        user_prompt = f"{few_shot_prompt}Input: {sentence}, Label:"
        eval_row = {
            "prompt": user_prompt,
            "rhymes": row["rhymes"],
            "fruit": row["fruit"],
        }
        rows.append(eval_row)

    eval_data = pd.DataFrame(rows)
    return eval_data

In [35]:
for n_dataset in range(1,4):
    eval_data_rhymes = generate_dataset(rhyme_data, n_examples=10)
    eval_data_rhymes.to_csv(f"data/rhyme_eval_data_0{n_dataset}.csv", index=False)

In [42]:
for n_dataset in range(1,4):
    eval_data_rhymes = generate_biased_dataset(rhyme_data, n_examples=10, n_fruit_rhymes=4, n_fruit_non_rhymes=1)
    eval_data_rhymes.to_csv(f"data/rhyme_eval_data_biased_4rhymes_0{n_dataset}.csv", index=False)