In [None]:
import os
import json

from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.models import WordLevel
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordLevelTrainer
from transformers import PreTrainedTokenizerFast
from datasets import Dataset

with open("solves.json", "r") as f:
    data = json.load(f)

# Tokenizer

In [None]:
def data_iterator(data):
    for row in data:
        yield " ".join(row["scramble"])
        yield " ".join(row["solve"])

In [None]:
if not os.path.exists("rubiks-tokenizer"):
    tokenizer = Tokenizer(WordLevel())
    tokenizer.pre_tokenizer = WhitespaceSplit()
    tokenizer.post_processor = TemplateProcessing(
        single="<bos> $A <eos>", special_tokens=[("<bos>", 0), ("<eos>", 1)]
    )

    trainer = WordLevelTrainer(
        special_tokens=["<bos>", "<eos>"],
        show_progress=True,
    )

    tokenizer.train_from_iterator(
        data_iterator(data), trainer=trainer, length=len(data)
    )
    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=tokenizer,
        unk_token="<unk>",
        bos_token="<bos>",
        eos_token="<eos>",
    )
    tokenizer.save_pretrained("rubiks-tokenizer")

In [None]:
tokenizer = PreTrainedTokenizerFast.from_pretrained("rubiks-tokenizer")

# Dataset 

In [None]:
def dataset_generator():
    for row in data:
        yield {
            "scramble": " ".join(row["scramble"]),
            "solve": " ".join(row["solve"]),
        }


dataset = Dataset.from_generator(dataset_generator)
dataset = dataset.train_test_split(test_size=0.2)
dataset.save_to_disk("rubiks-dataset")