In [1]:
from collections import defaultdict
import json
from queue import Queue, Empty
import random
from threading import Event, Thread
import time
from typing import Callable, Dict, List
from src.utils import labels, DTOEncoder, task_2_categories, default_data_path
from src.extraction import extract_instance_info_LaMP_2_alt, extract_instance_info_LaMP_2_alt_no_keyword
from src.tokenization import lemma_tokenizer


random.seed(0)

In [2]:
len(task_2_categories)

15

In [3]:
import nltk
from nltk.corpus import wordnet as wn

next(wn.words())

'.22-caliber'

In [4]:
import os
from threading import Lock
from rich.progress import Progress, TaskID

# prompt_dict: Dict[str, str] = dict()
counter_lock = Lock()
total = 0
total_valid = 0
counter = 0


def LaMP_2_extract_func(
    prompt_dict,
    prog: Progress,
    task: TaskID,
    task_valid: TaskID,
    id_list: List[str],
    pipeline: Queue,
    msg_pipe: Queue,
    with_keyword_extraction: bool = True
):
    global counter_lock, counter, total, total_valid
    while True:
        instance: Dict[str, str] = pipeline.get()
        if instance["id"] == "finished":
            break
        if instance["id"] in id_list:
            prompt_dict[instance["id"]] = extract_instance_info_LaMP_2_alt(
                question=instance,
                tokenizer=lemma_tokenizer,
                keyword_extraction = with_keyword_extraction,
            )
            prog.update(task_valid, advance=100 / total_valid)
        prog.update(task, advance=100 / total)


def selector(
    dataset_question_path: str,
    dataset_output_path: str,
    entry_per_category: int = 5,
    worker_count: int = 5,
    with_keyword_extraction: bool = True,
):
    tag = "with_keyword" if with_keyword_extraction else "without_keyword"
    category_map = defaultdict(list)
    with open(dataset_output_path, "r", encoding="utf-8") as output:
        tmp = json.load(output)
        for label in tmp["golds"]:
            category_map[label["output"]].append(label)

    selected_labels = []
    for category, doc_labels in category_map.items():
        if len(doc_labels) <= entry_per_category:
            selected_labels.extend(doc_labels)
            continue
        # elif category == "business":
        #     continue
        selected_labels.extend(
            random.choices(category_map[category], k=entry_per_category)
        )
    selected_labels = labels(task=tmp["task"], golds=selected_labels)
    new_output_path = (
        dataset_output_path.rstrip(".json") + f"_{entry_per_category}_{tag}.json"
    )
    with open(new_output_path, "w", encoding="utf-8") as new_output:
        json.dump(selected_labels, new_output, cls=DTOEncoder, indent=4)

    _, dataset_id, dataset_type, *_ = dataset_question_path.split("_")

    new_question_path = os.path.join(
        default_data_path,
        f"LaMP_{dataset_id}_{dataset_type}_prompts_{tag}_{entry_per_category}.json",
    )
    global counter, total, total_valid
    with Progress() as prog:
        task = prog.add_task("Parse Prompts", total=100)
        task_2 = prog.add_task("Parse Prompts (Only Valid)", total=100)
        prompt_dict = dict()
        with open(dataset_question_path, "r", encoding="utf-8") as question:
            with open(new_question_path, "w", encoding="utf-8") as new_question:
                threads: List[Thread] = []
                instances_queue = Queue()
                msgs_queue = Queue()
                total = 0
                total_valid = 0

                selected_ids = [label.id for label in selected_labels.golds]
                selected_ids = sorted(selected_ids, key=lambda x: int(x))
                print(selected_ids)
                total_valid = len(selected_ids)
                for i in range(worker_count):
                    curr_worker = Thread(
                        target=LaMP_2_extract_func,
                        args=(
                            prompt_dict,
                            prog,
                            task,
                            task_2,
                            selected_ids,
                            instances_queue,
                            msgs_queue,
                            with_keyword_extraction
                        ),
                    )
                    curr_worker.start()
                    threads.append(curr_worker)
                instances = json.load(question)
                total = len(instances)
                print(total, total_valid)
                for instance in instances:
                    instances_queue.put(instance)

                for i in range(worker_count):
                    instances_queue.put({"id": "finished"})
                # while counter < total_valid:
                # print(msgs_queue.get())
                # continue
                join_task = prog.add_task("Exiting Threads", total=100)
                [worker.join() for worker in threads]
                prog.update(join_task, advance=100)
                save_task = prog.add_task("Saving Prompts", total=100)
                json.dump(prompt_dict, new_question, indent=4)
                prog.update(save_task, advance=100)

In [6]:
selector(
    dataset_question_path="./src/data/LaMP_2_train_questions.json",
    dataset_output_path="./src/data/LaMP_2_train_outputs.json",
    entry_per_category=12,
    worker_count=16,
    with_keyword_extraction=False
)

Output()