[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/GenBench/genbench_cbt/blob/backend_dev/notebooks/GenBenchTaskViewer.ipynb)


In [None]:
#@title  ⚠️ **Run me first ⚠️ :** Execute this cell to perform some setup tasks (shift-Enter).  This may take several minutes.

#@markdown If you'd like to load a different branch, replace the url below with the url of your branch. As an example if you want to interact with a task from a pull request or your own task, then you should replace `GenBench` with `<github-account>` and `main` with <task_branch_name>` in the **genbench_branch** field below.


genbench_branch = "https://github.com/GenBench/genbench_cbt.git@main" #@param {type:"string"}
genbench_branch = genbench_branch.strip("/")
if ".git" not in genbench_branch:
  genbench_branch += ".git"
! pip install -q "genbench[dev] @ git+$genbench_branch" gradio


import os

from collections import defaultdict
from pathlib import Path

from genbench.utils.tasks import get_all_tasks_ids, get_all_task_metadata

def get_keywords_to_task_ids():
  metadata = get_all_task_metadata()
  keywords_to_task_ids = defaultdict(list)
  for task_id in sorted(metadata.keys()):
    m = metadata[task_id]
    for keyword in m["keywords"]:
      keywords_to_task_ids[keyword].append(task_id)

    if "subtasks" in m:
      for subtask_id in sorted(m.get("subtasks", {}).keys()):
        subtask_m = m["subtasks"][subtask_id]
        for keyword in subtask_m["keywords"]:
          keywords_to_task_ids[keyword].append(
              f"{task_id}:{subtask_id}"
          )

  return keywords_to_task_ids

def get_most_recent_task_ids():
  from genbench import tasks

  task_dirs = list(Path(tasks.__file__).parent.iterdir())
  task_dirs = [
    d
    for d in task_dirs
    if d.is_dir() and not d.name.startswith("__")
  ]
  task_dirs.sort(
      reverse=True,
      key=lambda task_dir: os.path.getmtime(task_dir),
  )
  task_ids = [d.name for d in task_dirs]
  return task_ids

keywords_to_task_ids = get_keywords_to_task_ids()
most_recent_task_ids = get_most_recent_task_ids()
all_tasks_metadata = get_all_task_metadata()

import numpy as np
import gradio as gr

intro_text = """\
# GenBench Task Directory

Some nice description

"""

task_info = """\
### {task_name} (`{task_id}`)
{task_decription}
#### Authors
{task_authors}
#### Keywords
{task_keywords}
"""

def add_task(task_id):
  if ":" in task_id:
    root_task_id, subtask_id = task_id.split(":")
    metadata = all_tasks_metadata[root_task_id]["subtasks"][subtask_id]
  else:
    metadata = all_tasks_metadata[task_id]
  with gr.Accordion(task_id, open=False):
    gr.Markdown(task_info.format(
        task_name=metadata["name"],
        task_id=task_id,
        task_decription=metadata["description"],
        task_authors=", ".join(metadata["authors"]),
        task_keywords=", ".join(metadata["keywords"])
    ))

with gr.Blocks() as demo:
  gr.Markdown(intro_text)

  with gr.Tab("Most Recent Tasks"):
    for task_id in most_recent_task_ids[:10]:
      add_task(task_id)

  for keyword in sorted(keywords_to_task_ids.keys()):
    with gr.Tab(keyword):
      for task_id in keywords_to_task_ids[keyword]:
        add_task(task_id)

demo.launch()

In [None]:
#@title View Task
#@markdown ⚠️  Run this cell launch the task viewer UI


import datasets
from random import choices
from genbench.api import PreparationStrategy
import gradio as gr
import numpy as np

def hide_elemnts(*elements):
  return {
      e: gr.update(visible=False)
      for e in elements
  }

def show_elemnts(*elements):
  return {
      e: gr.update(visible=True)
      for e in elements
  }


EXAMPLE_MARKDOWN_TEMPLATE = """\
**Example ID:** {example_id}
#### Input
```
{input}
```
#### Target
```
{target}
```
"""


MAX_NUM_EXAMPLES = 10

with gr.Blocks(theme=gr.themes.Default(spacing_size="sm")) as demo:
  task_obj = [None]
  gr.Markdown("# Task Viewer\nSome nice description")
  with gr.Row().style(equal_height=True):
    task_id = gr.Textbox(label="Task Id", placeholder="<task_id>:<sub_task_id>")
    with gr.Column(scale=0):
      fetch_btn = gr.Button("Load", scale=0, variant="primary")
      clear_btn = gr.Button("Clear",  scale=0)

  loading_error = gr.Textbox(show_label=False, visible=False, container=False)

  with gr.Box(visible=False) as task_metdata_box:
    task_metadata = gr.Markdown()

  prep_strategy_radio = gr.Radio(
      ["Finetuning", "Prompt-based Testing"],
      label="Preparation Strategy",
      info="The strategy to prepare the dataset for generalisation evaluation",
      visible=False
  )
  num_shots_slider = gr.Slider(
      0, 20, value=0,
      step=1,
      label="Num Shots",
      info="Number of examplars in few-shot evaluation (0 means zero-shot learning)",
      visible=False
  )
  prep_btn = gr.Button("Prepare", variant="primary", visible=False)

  dataset_split_selector = gr.Radio(
      ["Test", "Validation", "Train"],
      label="Dataset Split",
      visible=False,
      interactive=False,
  )

  with gr.Row():
    example_ids_input = gr.Textbox(
        label="Example IDs",
        info="Specifiy the IDs of examples to show (separate with ','). Note that 0 <= id < dataset_len",
        visible=False
    )
    with gr.Column(scale=0):
      show_example_btn = gr.Button("Show", visible=False)
      random_example_btn = gr.Button("Random", visible=False)

  example_md_lst = []
  example_box_lst = []
  for _ in range(MAX_NUM_EXAMPLES):
    with gr.Box(visible=False) as examples_box:
      example_md = gr.Markdown(visible=False)
      example_md_lst.append(example_md)
      example_box_lst.append(examples_box)

  task_info_box = [task_metdata_box]
  prep_box = [prep_strategy_radio, num_shots_slider, prep_btn]
  show_example_box = [example_ids_input, show_example_btn, random_example_btn]
  examples_box = [dataset_split_selector, *example_md_lst, *example_box_lst]

  def fetch_task(task_id):
    orig_task_id = task_id
    if ":" in task_id:
      task_id, subtask_id = task_id.split(":")
    else:
      subtask_id = None

    print(task_id, subtask_id, orig_task_id)

    is_task_dict = "subtasks" in all_tasks_metadata[task_id]

    if (
        (not task_id in all_tasks_metadata)
        or (not is_task_dict and subtask_id is not None)
        or (is_task_dict and subtask_id is not None and subtask_id not in all_tasks_metadata[task_id]["subtasks"])
    ):
      return {
          loading_error: gr.update(
              value=f"Task ID `{orig_task_id}` not found!",
              visible=True,
          ),
          # task_metdata_box: gr.update(visible=False),
          **hide_elemnts(*task_info_box, *prep_box, *show_example_box, *examples_box)
      }

    if "subtasks" in all_tasks_metadata[task_id] and subtask_id is None:
      return {
          loading_error: gr.update(
              value=(
                  f"Please specify the Subtask ID using `{orig_task_id}:<subtask_id>`."
                  f"\n`{task_id}`'s subtasks: {sorted(all_tasks_metadata[task_id]['subtasks'].keys())}"
              ),
              visible=True,
          ),
          # task_metdata_box: gr.update(visible=False),
          **hide_elemnts(*task_info_box, *prep_box, *show_example_box, *examples_box)
      }

    if subtask_id is None:
      metadata = all_tasks_metadata[task_id]
    else:
      metadata = all_tasks_metadata[task_id]["subtasks"][subtask_id]

    from genbench import load_task

    the_task = load_task(orig_task_id)
    task_obj[0] = the_task

    # Update Metadata Message
    output_txt = task_info.format(
        task_name=metadata["name"],
        task_id=task_id,
        task_decription=metadata["description"],
        task_authors=", ".join(metadata["authors"]),
        task_keywords=", ".join(metadata["keywords"])
    )

    # Update available Prep. Strategies
    prep_strategies = []
    show_num_shots_slider = False
    if the_task.config.preparation_strategies is not None:
      if the_task.config.preparation_strategies.finetuning is not None:
        prep_strategies.append(PreparationStrategy.FINETUNING.value)
      if the_task.config.preparation_strategies.prompt_based_testing is not None:
        prep_strategies.append(PreparationStrategy.PROMPT_BASED_TESTING)
        show_num_shots_slider = True
    else:
      prep_strategies = [
          PreparationStrategy.FINETUNING.value,
          PreparationStrategy.PROMPT_BASED_TESTING,
      ]
      show_num_shots_slider = True

    if len(prep_strategies) == 0:
      raise ValueError("The task does not support any preparation strategies.")

    return {
        task_metadata: gr.update(value=output_txt),
        task_metdata_box: gr.update(visible=True),

        prep_strategy_radio: gr.update(choices=prep_strategies, value=prep_strategies[0],
                                       visible=True, interactive=True),
        num_shots_slider: gr.update(visible=show_num_shots_slider, interactive=True),
        prep_btn: gr.update(visible=True),

        # loading_error: gr.update(value="", visible=False)
        **hide_elemnts(loading_error, *show_example_box, *examples_box),
    }

  def render_examples(dataset, example_ids):
    assert len(example_ids) <= MAX_NUM_EXAMPLES

    updates = {}
    for i, idx in enumerate(example_ids):
      d = dataset[idx]
      rendered_txt = EXAMPLE_MARKDOWN_TEMPLATE.format(
          example_id=str(idx),
          input=str(d["input"]),
          target=str(d["target"])
      )
      updates[example_box_lst[i]] = gr.update(visible=True)
      updates[example_md_lst[i]] = gr.update(value=rendered_txt, visible=True)

    for i in range(len(example_ids), MAX_NUM_EXAMPLES):
      updates[example_box_lst[i]] = gr.update(visible=False)
      updates[example_md_lst[i]] = gr.update(visible=False)

    return updates

  def render_dataset(ds, split_options, split_option_choice_idx=0):
    example_ids = rng.choice(
        len(ds), min(MAX_NUM_EXAMPLES, len(ds)), replace=False
    ).tolist()
    return {
          dataset_split_selector: gr.update(
              choices=split_options,
              visible=True,
              value=split_options[split_option_choice_idx],
              interactive=len(split_options) > 1,
          ),
          example_ids_input: gr.update(
              value=",".join([str(i) for i in example_ids]),
              interactive=True,
              visible=True,
          ),

          **show_elemnts(show_example_btn, random_example_btn),
          **render_examples(ds, example_ids),
      }


  loaded_dataset = [None, None, None]
  rng = np.random.RandomState(seed=42)
  def prepare_datasets(prep_strategy, num_shots):
    # raise gr.Error("Cannot divide by zero!")
    # raise gr.exceptions.Error('some error')
    prep_strategy = PreparationStrategy(prep_strategy)

    task = task_obj[0]
    if prep_strategy == PreparationStrategy.FINETUNING:
      ds = task.get_prepared_datasets(prep_strategy)
      split_options = [f"{opt} (len={len(ds[opt.lower()])})" for opt in ["Test", "Validation", "Train"] if opt in ds]
      loaded_dataset[0] = ds
      loaded_dataset[1] = ds["test"]
      loaded_dataset[2] = split_options[0]

      return render_dataset(ds["test"], split_options)

    if prep_strategy == PreparationStrategy.PROMPT_BASED_TESTING:
      ds = task.get_prepared_datasets(prep_strategy, shot_list=[num_shots])[num_shots]
      split_options = [f"Test (len={len(ds)})"]
      loaded_dataset[0] = ds
      loaded_dataset[1] = ds
      loaded_dataset[2] = split_options[0]

      return render_dataset(ds, split_options)

  def change_dataset_split(split):
    split_name = split.split(" ")[0].lower()
    print(split_name)

    if not isinstance(loaded_dataset[0], dict):
      return render_dataset(loaded_dataset[1], [loaded_dataset[2]])
    else:
      ds = loaded_dataset[0]
      split_options = [f"{opt} (len={len(ds[opt.lower()])})" for opt in ["Test", "Validation", "Train"] if opt in ds]
      split_option_choice_idx = split_options.index(split)

      loaded_dataset[1] = ds[split_name]
      loaded_dataset[2] = split_options[split_option_choice_idx]

      return render_dataset(
          loaded_dataset[1],
          split_options, split_option_choice_idx
      )


  def show_example(example_ids_str, is_random=False):
    ds = loaded_dataset[1]
    if is_random:
      example_ids = rng.choice(
          len(ds), min(MAX_NUM_EXAMPLES, len(ds)), replace=False
      ).tolist()
    else:
      example_ids = [int(s.strip()) for s in example_ids_str.split(",")]

    return {
        example_ids_input: gr.update(
              value=",".join([str(i) for i in example_ids]),
              interactive=True,
        ),
        **render_examples(ds, example_ids)
    }

  def clear_ui():
    return hide_elemnts(
        loading_error,
        *task_info_box,
        *prep_box,
        *show_example_box,
        *examples_box
    )


  fetch_btn.click(
      fetch_task,
      inputs=task_id,
      outputs=[
          loading_error, task_metadata, task_metdata_box,
          prep_strategy_radio, num_shots_slider, prep_btn,
          *show_example_box, *examples_box
      ]
  )
  clear_btn.click(
      clear_ui,
      outputs=[
        loading_error,
        *task_info_box,
        *prep_box,
        *show_example_box,
        *examples_box
      ]
  )
  prep_btn.click(
      prepare_datasets,
      inputs=[prep_strategy_radio, num_shots_slider],
      outputs=[
          dataset_split_selector,
          example_ids_input, show_example_btn, random_example_btn,
          *example_box_lst, *example_md_lst
      ],
      show_progress=True
  )
  show_example_btn.click(
      lambda x: show_example(x),
      inputs=example_ids_input,
      outputs=[
          example_ids_input,
          *example_box_lst, *example_md_lst
      ],
      show_progress=True
  )
  random_example_btn.click(
      lambda: show_example(None, is_random=True),
      inputs=None,
      outputs=[
          example_ids_input,
          *example_box_lst, *example_md_lst
      ],
      show_progress=True
  )
  dataset_split_selector.change(
      change_dataset_split,
      inputs=dataset_split_selector,
      outputs=[
          dataset_split_selector,
          example_ids_input, show_example_btn, random_example_btn,
          *example_box_lst, *example_md_lst
      ],
      show_progress=True
  )


demo.launch(height=900, debug=True)