In [1]:
from datasets import load_dataset
from langchain_core.prompts import PromptTemplate

import sys
sys.path.append('..')
from settings import TEXT_CLS_NAMES, SAMPLE_NUM

In [2]:
ds = load_dataset("fancyzhx/ag_news")

In [28]:
prompt_template = PromptTemplate.from_template(
    """
    Classify the given text into exactly one category, providing clear reasoning for your choice. The available categories are: World, Sports, Business, Science/Technology.
    Text to classify:
    {input}
    Output the classification reasoning in reason and the selected category in label. Return the response in JSON format as follows:
    ```json
    {{
        "reason" : "Classification rationale",
        "label" : "Text classification label (select exactly one): ["World", "Sports", "Business", "Science/Technology"]"
    }}
    ```
    """
)

def func(item):
    label = item["label"]
    item["label_text"] = TEXT_CLS_NAMES[label]
    return item


test_dataset = ds["test"].map(func)
new_test_dataset = test_dataset.train_test_split(seed=42, test_size=SAMPLE_NUM)["test"]

In [29]:
new_test_dataset

Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 1000
})

In [30]:
prompt_template

PromptTemplate(input_variables=['input'], template='\n    Classify the given text into exactly one category, providing clear reasoning for your choice. The available categories are: World, Sports, Business, Science/Technology.\n    Text to classify:\n    {input}\n    Output the classification reasoning in reason and the selected category in label. Return the response in JSON format as follows:\n    ```json\n    {{\n        "reason" : "Classification rationale",\n        "label" : "Text classification label (select exactly one): ["World", "Sports", "Business", "Science/Technology"]"\n    }}\n    ```\n    ')

In [31]:
def add_prompt(item):
    global prompt_template
    text = item["text"]
    text = text.replace("{", "{{")
    text = text.replace("}", "}}")
    prompt = prompt_template.invoke({"input": text}).to_string()
    item["prompt"] = prompt
    return item

In [33]:
new_test_dataset = new_test_dataset.map(add_prompt)

In [35]:
print(new_test_dataset[0]["prompt"])


    Classify the given text into exactly one category, providing clear reasoning for your choice. The available categories are: World, Sports, Business, Science/Technology.
    Text to classify:
    Indian board plans own telecast of Australia series The Indian cricket board said on Wednesday it was making arrangements on its own to broadcast next month #39;s test series against Australia, which is under threat because of a raging TV rights dispute.
    Output the classification reasoning in reason and the selected category in label. Return the response in JSON format as follows:
    ```json
    {
        "reason" : "Classification rationale",
        "label" : "Text classification label (select exactly one): ["World", "Sports", "Business", "Science/Technology"]"
    }
    ```
    


In [36]:
new_test_dataset[0]

{'text': 'Indian board plans own telecast of Australia series The Indian cricket board said on Wednesday it was making arrangements on its own to broadcast next month #39;s test series against Australia, which is under threat because of a raging TV rights dispute.',
 'label': 1,
 'label_text': 'Sports',
 'prompt': '\n    Classify the given text into exactly one category, providing clear reasoning for your choice. The available categories are: World, Sports, Business, Science/Technology.\n    Text to classify:\n    Indian board plans own telecast of Australia series The Indian cricket board said on Wednesday it was making arrangements on its own to broadcast next month #39;s test series against Australia, which is under threat because of a raging TV rights dispute.\n    Output the classification reasoning in reason and the selected category in label. Return the response in JSON format as follows:\n    ```json\n    {\n        "reason" : "Classification rationale",\n        "label" : "Tex

In [38]:
new_test_dataset

Dataset({
    features: ['text', 'label', 'label_text', 'prompt'],
    num_rows: 1000
})

In [37]:
# new_test_dataset.to_json("data/ag_news_test.jsonl")

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

1104310

In [48]:
new_test_dataset

Dataset({
    features: ['text', 'label', 'label_text', 'prompt'],
    num_rows: 1000
})

In [51]:
output_dataset = new_test_dataset.add_column(
    name="instruction", column=[""] * len(new_test_dataset)
)
output_dataset = output_dataset.remove_columns(["text", "label"])
output_dataset = output_dataset.rename_columns(
    {"prompt": "input", "label_text": "output"},
)
output_dataset

Dataset({
    features: ['output', 'input', 'instruction'],
    num_rows: 1000
})

In [52]:
output_dataset.to_json("data/ag_news_test.jsonl")

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

857841