In [None]:
from datasets import load_dataset
from IPython.display import display, Markdown
import re

import ipywidgets as widgets
from IPython.display import Markdown, display, clear_output

In [None]:
def remove_boxed(text):
    pattern = r'\\boxed{(.*)}'
    return re.sub(pattern, r'\1', text)

def fix_align(text):
    pattern = r'\\begin{align\*}(.*?)\\end{align\*}'
    return re.sub(
        pattern, 
        r'$$\\begin{align*}\1\\end{align*}$$', 
        text, flags=re.DOTALL
    )

def convert_latex(text):
    # Convert display math mode
    text = re.sub(r'\\\[\s*(.*?)\s*\\\]', r'$$\1$$', text, flags=re.DOTALL)

    # Convert inline math mode
    text = re.sub(r'\\\((.*?)\\\)', r'$\1$', text, flags=re.DOTALL)

    return text

fix = lambda x: convert_latex(fix_align(remove_boxed(x)))

def diagram_in_output(x):
    return "[asy]" in x["solution"]

In [None]:
math = load_dataset("hendrycks/competition_math")

In [None]:
print(math["train"][0].keys())

In [None]:
data = [{
    "input": x["problem"], 
    "output": fix(x["solution"]),
    "meta": {"level": x["level"], "type": x["type"], "id": 10**8+i}
    }
    for i, x in enumerate(math["train"]) if not diagram_in_output(x)
]

In [None]:
def str_of_row(x):
    return f"---INPUT: {x['input']}\n\n---OUTPUT: {x['output']}"

In [None]:
def display_item(data, index=0):
    clear_output(wait=True)
    item = data[index]
    text_display = Markdown(str_of_row(item))

    # Creating the buttons
    next_button = widgets.Button(description="Next")
    prev_button = widgets.Button(description="Previous")

    # Navigate through the dataset
    def navigate(step):
        nonlocal index
        index = min(max(0, index + step), len(data) - 1)
        display_item(data, index)

    next_button.on_click(lambda b: navigate(1))
    prev_button.on_click(lambda b: navigate(-1))

    # Displaying the components
    button_box = widgets.HBox([prev_button, next_button])
    display(button_box)
    display(text_display)
    display(Markdown(f"ID: {item['meta']['id']}"))
    display(Markdown(f"{index}/{len(data)}"))
    display(Markdown(f"Category: {item['meta']['type']}"))
    if "raw" in item:
        display(item["raw"])

In [None]:
display_item(data, index=0)