In [1]:
from collections import defaultdict
import json
from queue import Queue, Empty
import random
from threading import Event, Thread
import time
from typing import Callable, Dict, List
from src.utils import (
    labels,
    DTOEncoder,
    task_2_categories,
    default_data_path,
    default_top_k_keywords,
)
from src.extraction import (
    extract_instance_info_LaMP_1,
    extract_instance_info_LaMP_2_alt,
    extract_instance_info_LaMP_1_alt,
    extract_instance_info_LaMP_2_alt_no_keyword,
)
from src.tokenization import lemma_tokenizer
import os
from threading import Lock
from rich.progress import Progress, TaskID

random.seed(0)

In [2]:
import nltk
from nltk.corpus import wordnet as wn

next(wn.words())

'.22-caliber'

In [3]:
# prompt_dict: Dict[str, str] = dict()
total = 0
total_valid = 0

In [4]:
def LaMP_1_extract_func(
    prompt_dict,
    prog: Progress,
    task: TaskID,
    task_valid: TaskID,
    id_list: List[str],
    pipeline: Queue,
    msg_pipe: Queue,
    with_keyword_extraction: bool = True,
    bm25_top_k: int = default_top_k_keywords,
    text_rank_top_k_keywords: int = default_top_k_keywords,
    context_top_k_keywords: int = default_top_k_keywords,
):
    global total, total_valid
    while True:
        instance: Dict[str, str] = pipeline.get()
        if instance["id"] == "finished":
            break
        if instance["id"] in id_list:
            prompt_dict[instance["id"]] = extract_instance_info_LaMP_1_alt(
                question=instance,
                tokenizer=lemma_tokenizer,
                keyword_extraction=with_keyword_extraction,
                bm25_top_k=bm25_top_k,
                text_rank_top_k_keywords=text_rank_top_k_keywords,
                context_top_k_keywords=context_top_k_keywords,
            )
            prog.update(task_valid, advance=100 / total_valid)
        prog.update(task, advance=100 / total)

In [5]:
def LaMP_2_extract_func(
    prompt_dict,
    prog: Progress,
    task: TaskID,
    task_valid: TaskID,
    id_list: List[str],
    pipeline: Queue,
    msg_pipe: Queue,
    with_keyword_extraction: bool = True,
    category_top_k_keywords: int = default_top_k_keywords,
    text_rank_top_k_keywords: int = default_top_k_keywords,
    article_top_k: int = 1,
):
    global total, total_valid
    while True:
        instance: Dict[str, str] = pipeline.get()
        if instance["id"] == "finished":
            break
        if instance["id"] in id_list:
            prompt_dict[instance["id"]] = extract_instance_info_LaMP_2_alt(
                question=instance,
                tokenizer=lemma_tokenizer,
                keyword_extraction=with_keyword_extraction,
                category_top_k_keywords=category_top_k_keywords,
                text_rank_top_k_keywords=text_rank_top_k_keywords,
                article_top_k=article_top_k,
            )
            prog.update(task_valid, advance=100 / total_valid)
        prog.update(task, advance=100 / total)

In [6]:
def selector(
    dataset_question_path: str,
    dataset_output_path: str,
    entry_per_category: int = 5,
    worker_count: int = 5,
    with_keyword_extraction: bool = True,
    question_store_path: str = None,
    output_store_path: str = None,
    **params,
):
    tag = "with_keyword" if with_keyword_extraction else "without_keyword"
    category_map = defaultdict(list)
    with open(dataset_output_path, "r", encoding="utf-8") as output:
        tmp = json.load(output)
        for label in tmp["golds"]:
            category_map[label["output"]].append(label)

    selected_labels = []
    for category, doc_labels in category_map.items():
        if len(doc_labels) <= entry_per_category:
            selected_labels.extend(doc_labels)
            continue
        # elif category == "business":
        #     continue
        selected_labels.extend(
            random.choices(category_map[category], k=entry_per_category)
        )
    selected_labels = labels(task=tmp["task"], golds=selected_labels)

    if output_store_path is None:
        output_store_path = (
            dataset_output_path.rstrip(".json") + f"_{tag}_{entry_per_category}.json"
        )
    with open(output_store_path, "w", encoding="utf-8") as new_output:
        json.dump(selected_labels, new_output, cls=DTOEncoder, indent=4)

    _, dataset_id, dataset_type, *_ = dataset_question_path.split("_")

    if dataset_id == "1":
        target_func = LaMP_1_extract_func
    elif dataset_id == "2":
        target_func = LaMP_2_extract_func
    else:
        raise NotImplementedError("Can only deal with LaMP_1 and LaMP_2 task")

    if question_store_path is None:
        question_store_path = os.path.join(
            default_data_path,
            f"LaMP_{dataset_id}_{dataset_type}_prompts_{tag}_{entry_per_category}.json",
        )
    global total, total_valid
    with Progress() as prog:
        task = prog.add_task("Parse Prompts", total=100)
        task_2 = prog.add_task("Parse Prompts (Only Valid)", total=100)
        prompt_dict = dict()
        with open(dataset_question_path, "r", encoding="utf-8") as question:
            with open(question_store_path, "w", encoding="utf-8") as new_question:
                threads: List[Thread] = []
                instances_queue = Queue()
                msgs_queue = Queue()
                total = 0
                total_valid = 0

                selected_ids = [label.id for label in selected_labels.golds]
                selected_ids = sorted(selected_ids, key=lambda x: int(x))
                # print(selected_ids)
                total_valid = len(selected_ids)
                for i in range(worker_count):
                    curr_worker = Thread(
                        target=target_func,
                        args=(
                            prompt_dict,
                            prog,
                            task,
                            task_2,
                            selected_ids,
                            instances_queue,
                            msgs_queue,
                            with_keyword_extraction,
                        ),
                        kwargs=params,
                    )
                    curr_worker.start()
                    threads.append(curr_worker)
                instances = json.load(question)
                total = len(instances)
                # print(total, total_valid)
                for instance in instances:
                    instances_queue.put(instance)

                for i in range(worker_count):
                    instances_queue.put({"id": "finished"})
                join_task = prog.add_task("Exiting Threads", total=100)
                [worker.join() for worker in threads]
                prog.update(join_task, advance=100)
                save_task = prog.add_task("Saving Prompts", total=100)
                json.dump(prompt_dict, new_question, indent=4)
                prog.update(save_task, advance=100)

In [7]:
from math import ceil


def task_1_selector():
    dataset_question_path = "./src/data/LaMP_1_train_questions.json"
    dataset_output_path = "./src/data/LaMP_1_train_outputs.json"
    task_header = "LaMP_1"
    store_dir = os.path.join("src", "data", task_header)
    os.makedirs(store_dir, exist_ok=True)

    entry_per_category = 120
    worker_count = 16

    for text_rank_top_k_keywords in [5, 8, 10]:
        for bm25_top_k in [5, 10]:
            question_store_path = os.path.join(
                store_dir,
                f"{task_header}_train_prompts_questions_with_keyword_{entry_per_category}_{text_rank_top_k_keywords}_{bm25_top_k}.json",
            )
            output_store_path = os.path.join(
                store_dir,
                f"{task_header}_train_outputs_selected_with_keyword_{entry_per_category}_{text_rank_top_k_keywords}_{bm25_top_k}.json",
            )
            selector(
                dataset_question_path=dataset_question_path,
                dataset_output_path=dataset_output_path,
                question_store_path=question_store_path,
                output_store_path=output_store_path,
                entry_per_category=entry_per_category,
                worker_count=worker_count,
                with_keyword_extraction=True,
                bm25_top_k=bm25_top_k,
                text_rank_top_k_keywords=text_rank_top_k_keywords,
                context_top_k_keywords=ceil(
                    (text_rank_top_k_keywords * bm25_top_k) / 2
                ),
            )

    for bm25_top_k in [2, 4]:
        question_store_path = os.path.join(
            store_dir,
            f"{task_header}_train_prompts_questions_without_keyword_{entry_per_category}_{bm25_top_k}.json",
        )
        output_store_path = os.path.join(
            store_dir,
            f"{task_header}_train_outputs_selected_without_keyword_{entry_per_category}_{bm25_top_k}.json",
        )
        selector(
            dataset_question_path=dataset_question_path,
            dataset_output_path=dataset_output_path,
            question_store_path=question_store_path,
            output_store_path=output_store_path,
            entry_per_category=entry_per_category,
            worker_count=worker_count,
            with_keyword_extraction=False,
            bm25_top_k=bm25_top_k,
        )

In [8]:
def task_2_selector():
    task_header = "LaMP_2"
    store_dir = os.path.join("src", "data", task_header)
    os.makedirs(store_dir, exist_ok=True)
    dataset_question_path = "./src/data/LaMP_2_train_questions.json"
    dataset_output_path = "./src/data/LaMP_2_train_outputs.json"

    entry_per_category = 16
    worker_count = 16

    for text_rank_top_k_keywords in [5, 8, 10]:
        for category_top_k_keywords in [15, 30]:
            question_store_path = os.path.join(
                store_dir,
                f"{task_header}_train_prompts_questions_with_keyword_{entry_per_category}_{text_rank_top_k_keywords}_{category_top_k_keywords}.json",
            )
            output_store_path = os.path.join(
                store_dir,
                f"{task_header}_train_outputs_selected_with_keyword_{entry_per_category}_{text_rank_top_k_keywords}_{category_top_k_keywords}.json",
            )
            selector(
                dataset_question_path=dataset_question_path,
                dataset_output_path=dataset_output_path,
                question_store_path=question_store_path,
                output_store_path=output_store_path,
                entry_per_category=entry_per_category,
                worker_count=worker_count,
                with_keyword_extraction=True,
                text_rank_top_k_keywords=text_rank_top_k_keywords,
                category_top_k_keywords=category_top_k_keywords,
            )

    for article_top_k in [2, 4]:
        question_store_path = os.path.join(
            store_dir,
            f"{task_header}_train_prompts_questions_without_keyword_{entry_per_category}_{article_top_k}.json",
        )
        output_store_path = os.path.join(
            store_dir,
            f"{task_header}_train_outputs_selected_without_keyword_{entry_per_category}_{article_top_k}.json",
        )
        selector(
            dataset_question_path=dataset_question_path,
            dataset_output_path=dataset_output_path,
            question_store_path=question_store_path,
            output_store_path=output_store_path,
            entry_per_category=entry_per_category,
            worker_count=worker_count,
            with_keyword_extraction=False,
            article_top_k=article_top_k,
        )

In [13]:
task_1_selector()

Output()

Output()

In [9]:
task_2_selector()

Output()

Output()