In [2]:
import collections

import datasets
import numpy as np
import pandas as pd
import re
import rich
import rich.console
import rich.markup
import rich.table
import torch
import transformers

datasets.disable_caching()
CONSOLE = rich.console.Console(width=100, force_terminal=True)

def gsm8k_entry_viewer(entry):
    table = rich.table.Table(title="Entry", show_lines=True, title_justify="left")
    table.add_column("Key")
    table.add_column("Value")
    assert len(entry) == 2

    for key in ["question", "answer"]:
        text = rich.markup.escape(entry[key])
        text = text.replace("<<", "[bold green]<<[/]")
        text = text.replace(">>", "[bold green]>>[/]")
        text = text.replace("=", "[bold white on red]=[/]")
        table.add_row(f"{key.capitalize()}:", text)

    CONSOLE.print(table)

In [7]:
def count_tags(text):
    tag_count = 0
    state_started = False

    for i, c in enumerate(text):
        at_least_one_more_left = i < len(text) - 2
        if at_least_one_more_left:
            if not state_started and c == "<" and text[i + 1] == "<":
                state_started = True

            elif state_started and c == ">" and text[i + 1] == ">":
                state_started = False
                tag_count += 1

    return tag_count

def count_char(char, text):
    proposed = sum(c == char for c in text)
    return proposed

def count_equals(text):
    return count_char("=", text)

def build_annotated_object():
    gsm8k = datasets.load_dataset("gsm8k", "main", split="train")

    annotated = []

    # count the number of equations in each answer.
    for entry in gsm8k:
        answer = entry["answer"]
        tag_count = count_tags(answer)
        equals_count = count_equals(answer)
        annotated.append({"entry": entry, "tag_count": tag_count, "equals_count": equals_count})

    return pd.DataFrame(annotated)

def filter_gsm8k(annotated: pd.DataFrame):
    tag_counts = set(annotated["tag_count"].values)
    print(tag_counts)

    final_filtered = []
    for tag_count in tag_counts:
        assert isinstance(tag_count, (int, np.int64)), type(tag_count).mro()
        values: pd.DataFrame = annotated[annotated["tag_count"] == tag_count]
        filter_ = values["equals_count"] == 2 * tag_count
        num_good = filter_.sum()
        filtered: pd.DataFrame = values[filter_]
        final_filtered.append(filtered)
        print(f"Tag count {tag_count} has {len(values)} entries.")
        print(f"Tag count {tag_count} has {num_good} entries with the right number of equals signs.")
        print("--")
    
    # Merge the dataframes together.
    final_filtered = pd.concat(final_filtered)
    print(f"Filtered to {len(final_filtered)} entries, which is {len(final_filtered) / len(annotated):.0%} of the dataset.")
    
    return final_filtered

annotated = build_annotated_object()
filtered = filter_gsm8k(annotated)

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
Tag count 0 has 95 entries.
Tag count 0 has 2 entries with the right number of equals signs.
--
Tag count 1 has 404 entries.
Tag count 1 has 16 entries with the right number of equals signs.
--
Tag count 2 has 2175 entries.
Tag count 2 has 1795 entries with the right number of equals signs.
--
Tag count 3 has 2137 entries.
Tag count 3 has 1810 entries with the right number of equals signs.
--
Tag count 4 has 1424 entries.
Tag count 4 has 1243 entries with the right number of equals signs.
--
Tag count 5 has 785 entries.
Tag count 5 has 693 entries with the right number of equals signs.
--
Tag count 6 has 287 entries.
Tag count 6 has 259 entries with the right number of equals signs.
--
Tag count 7 has 123 entries.
Tag count 7 has 104 entries with the right number of equals signs.
--
Tag count 8 has 40 entries.
Tag count 8 has 36 entries with the right number of equals signs.
--
Tag count 9 has 3 entries.
Tag count 9 has 3 entries with the right number of 

In [None]:
TAG_COUNT = 1
EQUALS_COUNT = 3

base_filter = annotated.tag_count == TAG_COUNT
base_entries = annotated[base_filter]
len_base = len(base_entries)
joint = pd.concat({"Counts": base_entries.value_counts("equals_count"), "Percentages": (base_entries.value_counts("equals_count") / len_base)}, axis=1).sort_index()
print(joint.to_string(float_format=r"{:.1%}".format))

# index = (base_entries.value_counts("equals_count") / len_base).sort_index()
# CONSOLE.print(index.to_string(float_format=r"{:.1%}".format))
# eoi = annotated[base_filter & (annotated.equals_count == EQUALS_COUNT)]
# len_eoi = len(eoi)
# print(f"{len_eoi}/{len_base}, {len_eoi/len_base:.1%}")

# for _, row in eoi.head(5).iterrows():
#     gsm8k_entry_viewer(row['entry'])


In [8]:
ds_socratic = datasets.load_dataset("gsm8k", "socratic", split="train")
ds_regular = datasets.load_dataset("gsm8k", "main", split="train")

In [14]:
for i in range(15):
    socratic = ds_socratic["answer"][i]
    regular = ds_regular["answer"][i]
    
    num_tags = count_tags(regular)
    num_equals = count_equals(regular)
    num_lines_socratic = count_char("\n", socratic)
    num_lines_regular = count_char("\n", regular)
    hypothesis = num_equals == (num_tags + num_lines_regular)

    table = rich.table.Table(title="Entry", show_lines=True, title_justify="left")
    
    table.add_column(  "Key")
    table.add_column("Value")

    color_socratic   = "[bold white on red]"  if num_lines_socratic != num_tags else ""
    color_regular    = "[bold white on red]"  if num_lines_regular  != num_tags else ""
    color_num_equals = "[bold white on blue]" if num_equals         != 2 * num_tags else "" 
    color_hypothesis = "[bold white on purple]" if not hypothesis else ""

    table.add_row(                             "Num tags:", str(          num_tags))
    table.add_row(color_socratic   +  "Num lines regular:", str( num_lines_regular))
    table.add_row(color_regular    + "Num lines socratic:", str(num_lines_socratic))
    table.add_row(color_num_equals +         "Num equals:", str(        num_equals))
    table.add_row(                              "Regular:", regular .replace("\n", "\n------\n"))
    table.add_row(                             "Socratic:", socratic.replace("\n", "\n------\n"))
    table.add_row(color_hypothesis +         "Hypothesis:", str(hypothesis))

    CONSOLE.print(table)