In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import datasets

In [None]:
wiki: datasets.DatasetDict = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", revision="b08601e04326c79dfdd32d625aee71d232d685c3")

In [None]:
wiki = wiki.filter(lambda x: len(x['text'].strip()) != 0)

In [None]:
import regex

grouped_wiki = datasets.DatasetDict()

for datasetName in wiki:
  rgx = " ((?:= )+)([^=]+) \\1\n"
  combined: list[dict] = []
  for row in wiki[datasetName]:
    text: str = row['text']
    if matched := regex.match(rgx, text):
      level = matched.group(1).count("=")
      title = matched.group(2)
      if level == 1:
        combined.append({
          "levels": [level],
          "titles": [title],
          "texts": [[]],
        })
      else: # if level > 1
        combined[-1]['levels'].append(level)
        combined[-1]['titles'].append(title)
        combined[-1]['texts'].append([])
    else: # if not a title
      combined[-1]['texts'][-1].append(text.strip())
  grouped_wiki[datasetName] = datasets.Dataset.from_list(combined)

In [None]:
grouped_wiki

In [None]:
def make_pairs_of_heading_and_paragraph_from_article(article: dict) -> list[tuple[str, str]]:
  pairs: list[tuple[str, str]] = []
  stacked_titles: list[tuple[int, str]] = []
  for level, title, texts in zip(article['levels'], article['titles'], article['texts']):
    while len(stacked_titles) != 0 and stacked_titles[-1][0] >= level:
      stacked_titles.pop()
    stacked_titles.append((level, title))
    for text in texts:
      if len(text) != 0:
        heading = f"{", ".join(map(lambda x: x[1], stacked_titles))}\n"
        pairs.append((heading, text))
  return pairs

In [None]:
for datasetName in grouped_wiki.keys():
  print(f"total articles to process in '{datasetName}':", len(grouped_wiki[datasetName]))

  headings_and_paragraphs_table: dict[str, list[str]] = {
    "heading": [],
    "paragraph": [],
  }
  processed_articles_n = 0

  for pairs_from_article in map(make_pairs_of_heading_and_paragraph_from_article, grouped_wiki[datasetName]):
    for pair in pairs_from_article:
      headings_and_paragraphs_table["heading"].append(pair[0])
      headings_and_paragraphs_table["paragraph"].append(pair[1])

    processed_articles_n += 1
    if processed_articles_n % 100 == 0:
      print("articles_processed:", processed_articles_n)

  grouped_wiki[datasetName] = datasets.Dataset.from_dict(headings_and_paragraphs_table)

In [None]:
grouped_wiki

In [None]:
grouped_wiki.push_to_hub(repo_id="wikitext_with_entitled_paragraphs")