# BERT - Is Solved? 

Can we use BERT to determine if a given scramble + solution sequence will result in a solved cube? 

Short answer - Kind of! 

In [None]:
import warnings
import os

warnings.filterwarnings("ignore")

import datasets
from transformers import PreTrainedTokenizerFast
from transformers import BertForSequenceClassification, BertConfig
from transformers import Trainer, TrainingArguments

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import numpy as np

os.environ["WANDB_PROJECT"] = "rubiks-bert"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

BATCH_SIZE = 64
MAX_LENGTH = 140

dataset = datasets.load_from_disk("../rubiks-is-solved-dataset")
tokenizer = PreTrainedTokenizerFast.from_pretrained("../rubiks-tokenizer")

In [2]:
def process(args):
    scramble = args["scramble"]
    solve = args["solve"]
    is_solved = args["is_solved"]

    tokenized = tokenizer(
        text=scramble,
        text_pair=solve,
        is_split_into_words=True,
        return_token_type_ids=True,
        return_attention_mask=True,
        max_length=MAX_LENGTH,
        padding="max_length",
        return_tensors="pt",
    )

    tokenized["labels"] = is_solved
    return tokenized


cols_to_remove = ["scramble", "solve", "is_solved"]
processed_dataset = dataset.map(process, batched=True, remove_columns=cols_to_remove)

train = processed_dataset["train"]
test = processed_dataset["test"]

In [3]:
config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,
    num_hidden_layers=4,
    num_attention_heads=4,
    intermediate_size=1024,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=MAX_LENGTH,
    type_vocab_size=2,
    initializer_range=0.02,
    layer_norm_eps=1e-12,
    pad_token_id=tokenizer.pad_token_id,
    position_embedding_type="absolute",
    use_cache=True,
    classifier_dropout=None,
    num_labels=2,
    id2label={1: "solved", 0: "not-solved"},
)
model = BertForSequenceClassification(config)

In [5]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="weighted"
    )
    acc = accuracy_score(labels, preds)
    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }


args = TrainingArguments(
    output_dir="./rubiks-bert",
    do_train=True,
    do_eval=True,
    torch_empty_cache_steps=500,
    save_strategy="best",
    report_to="wandb",
    auto_find_batch_size=True,
    eval_strategy="epoch",
    metric_for_best_model="eval_f1",
    logging_strategy="steps",
    logging_steps=100,
    eval_on_start=True,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
)

trainer.train()


[34m[1mwandb[0m: Currently logged in as: [33mlainon[0m ([33mhenry-williams[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2157,0.184641,0.932409,0.933429,0.932409,0.932867
2,0.1873,0.190034,0.939484,0.94022,0.939484,0.939817
3,0.1607,0.191026,0.942256,0.942686,0.942256,0.942457


[34m[1mwandb[0m: Adding directory to artifact (rubiks-bert/checkpoint-5230)... Done. 0.2s


TrainOutput(global_step=15690, training_loss=0.20078107871547332, metrics={'train_runtime': 672.7743, 'train_samples_per_second': 186.553, 'train_steps_per_second': 23.321, 'total_flos': 914665565298240.0, 'train_loss': 0.20078107871547332, 'epoch': 3.0})

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


In [None]:
predictions = trainer.predict(test)

logits = predictions.predictions
labels = predictions.label_ids
metrics = predictions.metrics

pred_classes = np.argmax(logits, axis=-1)
incorrect = test[labels != pred_classes]

In [None]:
import torch

for sample in test:
    input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to("mps")
    attention_mask = torch.tensor(sample["attention_mask"]).unsqueeze(0).to("mps")
    token_type_ids = torch.tensor(sample["token_type_ids"]).unsqueeze(0).to("mps")
    label = sample["labels"]

    output = model(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        return_dict=True,
    )


In [62]:
output

SequenceClassifierOutput(loss=None, logits=tensor([[-2.1562,  1.6087]], device='mps:0', grad_fn=<LinearBackward0>), hidden_states=(tensor([[[ 0.5589, -0.6705, -0.2784,  ...,  0.9875,  0.8205,  1.0542],
         [-0.3003, -0.3265, -2.1491,  ..., -1.4600,  0.5116,  0.1713],
         [ 0.1340, -0.9604,  0.3528,  ..., -0.8574, -0.4149,  0.7569],
         ...,
         [ 1.8806,  0.2689, -2.5269,  ..., -0.4895,  0.7896,  1.0888],
         [ 0.3826, -0.2486, -1.4152,  ...,  1.3080,  0.6596,  1.2368],
         [ 1.5368, -0.2942, -1.0405,  ...,  0.4380,  1.2213,  1.4403]]],
       device='mps:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[ 1.1595, -1.0228, -0.5096,  ...,  1.4048,  0.3982,  1.2823],
         [ 0.0996, -0.6921, -1.7849,  ..., -0.7312,  0.1989,  0.6718],
         [ 0.6409, -1.2835, -0.0130,  ..., -0.1515, -0.6758,  0.8560],
         ...,
         [ 1.9764, -0.1478, -2.3243,  ..., -0.0387,  0.4647,  1.1858],
         [ 0.7340, -0.6015, -1.2919,  ...,  1.4777,  0.2953,  1.5969]