In [62]:
import os
import json

from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.models import WordLevel
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordLevelTrainer
from transformers import PreTrainedTokenizerFast
from datasets import Dataset

with open("solves.json", "r") as f:
    data = json.load(f)

len(data)

8716

# Tokenizer

In [63]:
def data_iterator(data):
    for row in data:
        yield " ".join(row["scramble"])
        yield " ".join(row["solve"])

In [64]:
if not os.path.exists("rubiks-tokenizer"):
    tokenizer = Tokenizer(WordLevel())
    tokenizer.pre_tokenizer = WhitespaceSplit()

    trainer = WordLevelTrainer(
        special_tokens=["<bos>", "<eos>", "<unk>"],
        show_progress=True,
    )

    tokenizer.train_from_iterator(
        data_iterator(data), trainer=trainer, length=len(data)
    )

    bos_id = tokenizer.token_to_id("<bos>")
    eos_id = tokenizer.token_to_id("<eos>")

    tokenizer.post_processor = TemplateProcessing(
        single="<bos> $0 <eos>",
        pair="<bos> $A <eos> $B:1 <eos>:1",
        special_tokens=[("<bos>", bos_id), ("<eos>", eos_id)],
    )

    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=tokenizer,
        unk_token="<unk>",
        bos_token="<bos>",
        eos_token="<eos>",
    )
    tokenizer.save_pretrained("rubiks-tokenizer")

In [65]:
tokenizer = PreTrainedTokenizerFast.from_pretrained("rubiks-tokenizer")

# Dataset 

In [66]:
def dataset_generator():
    for row in data:
        yield {
            "scramble": " ".join(row["scramble"]),
            "solve": " ".join(row["solve"]),
        }


dataset = Dataset.from_generator(dataset_generator)
dataset = dataset.train_test_split(test_size=0.2)
dataset.save_to_disk("rubiks-dataset")
dataset

Saving the dataset (1/1 shards): 100%|██████████| 6972/6972 [00:00<00:00, 509090.85 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1744/1744 [00:00<00:00, 333646.51 examples/s]


DatasetDict({
    train: Dataset({
        features: ['scramble', 'solve'],
        num_rows: 6972
    })
    test: Dataset({
        features: ['scramble', 'solve'],
        num_rows: 1744
    })
})

In [None]:
scramble, solve = dataset["test"][0]["scramble"], dataset["test"][0]["solve"]

tokenized = tokenizer(text=scramble, text_pair=solve)
tokenizer.decode(token_ids=tokenized["input_ids"])

"<bos> D2 L2 F2 D U2 R2 U' R' F2 L' U F2 R B' L2 D' B2 L2 R2 <eos> Z U' r' u R' u' U L' U L2 U' L' U' Y' R U' R2' U R U R U R' U' L U' L' U R U' R' U R U' R' U' R' F' r U R U' L' U X' U2 <eos>"