In [1]:
!pip install git+https://github.com/cayleypy/cayleypy -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m82.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m63.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━

In [2]:
import pandas as pd

test = pd.read_csv("/kaggle/input/CayleyPy-pancake/test.csv")
test["n"].unique()

array([  5,  12,  15,  16,  20,  25,  30,  35,  40,  45,  50,  75, 100])

In [3]:
def pancake_sort_path(perm: list[int]) -> list[str]:
    """Return a sequence of prefix reversals that sorts `perm` to the identity permutation."""
    arr = list(perm)
    n = len(arr)
    moves: list[str] = []

    for target in range(n, 1, -1):
        desired_value = target - 1
        idx = arr.index(desired_value)

        if idx == target - 1:
            continue  # already in place

        if idx != 0:
            moves.append(f'R{idx + 1}')
            arr[: idx + 1] = reversed(arr[: idx + 1])

        moves.append(f'R{target}')
        arr[:target] = reversed(arr[:target])

    return moves

In [4]:
import torch
from torch import nn

class StateScorer(nn.Module):
    def __init__(self, n_pancakes: int, hidden_dim: int):
        super().__init__()
        self.n_pancakes = n_pancakes
        self.net = nn.Sequential(nn.Linear(n_pancakes, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, 1))

    def forward(self, x):
        return self.net(2*x.float() / (self.n_pancakes - 1) - 1).squeeze(-1)

In [5]:
import numpy as np
import random

def seed_everything(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)

In [20]:
from cayleypy import CayleyGraph, PermutationGroups
from cayleypy.algo import RandomWalksGenerator
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

def train_predictor(n, graph=None, width=1000, length=5, n_dim=100, batch_size=16, epochs=8, random_seed=42):
    seed_everything(random_seed)
    if graph is None:
        graph = CayleyGraph(PermutationGroups.pancake(n))
    rwg = RandomWalksGenerator(graph)
    states, distances = rwg.generate(width=width, length=length, mode="classic")
    train_dataset = TensorDataset(states.float(), distances.float())
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
    predictor = StateScorer(n, n_dim)
    predictor.train()

    optimizator = torch.optim.AdamW(predictor.parameters(), lr=3e-4)
    loss_fn = nn.MSELoss()

    for _ in range(epochs):
        for states, dist in tqdm(train_dataloader):
            optimizator.zero_grad()
            pred_dist = predictor(states)
            loss = loss_fn(pred_dist, dist)
            loss.backward()
            optimizator.step()

    predictor.eval()
    return predictor

# Abilation

In [16]:
n = 10
graph = CayleyGraph(PermutationGroups.pancake(n))

In [17]:
seed_everything(0)
test_start_states = [np.random.permutation(n) for _ in range(100)]
heurestic_path = [pancake_sort_path(x) for x in test_start_states]

def measure_success(predictor, beam_width):
    success_count = 0
    sum_length = 0
    diff_length = 0
    for state, hp in tqdm(zip(test_start_states, heurestic_path), total=len(test_start_states)):
        graph.free_memory()
        result = graph.beam_search(start_state=state, beam_width=beam_width, max_steps=n*2, predictor=predictor)
        if result.path_found:
            success_count += 1
            sum_length += result.path_length
            diff_length += result.path_length - len(hp)
    return success_count/len(test_start_states), sum_length/max(success_count, 1), diff_length/max(success_count, 1)

In [19]:
from cayleypy import Predictor
from itertools import product

seed_everything(42)
i = 1
max_i = 4*4*3
epochs = 8
n_dim = 1000
results = []
for width in [5, 10, 100, 1000]:
    for length in [5, 10, int(n*1.67), 2*n]:
        for n_dim in [10, 100, 1000]:
             result = {}
             result["predictor"] = Predictor(graph, train_predictor(n, graph=graph, width=width, length=length, n_dim=n_dim, epochs=epochs))
             hit_rate, avg_leng, avg_diff = measure_success(result["predictor"], 10)
             result["hit_rate"] = hit_rate
             result["avg_leng"] = avg_leng
             result["avg_diff"] = avg_diff
             result["width"] = width
             result["length"] = length
             result["n_dim"] = n_dim
             result["epochs"] = epochs
             results.append(result)
             print(f"{i}/{max_i}")
             i += 1

100%|██████████| 2/2 [00:00<00:00, 22.76it/s]
100%|██████████| 2/2 [00:00<00:00, 404.70it/s]
100%|██████████| 2/2 [00:00<00:00, 320.09it/s]
100%|██████████| 2/2 [00:00<00:00, 400.43it/s]
100%|██████████| 2/2 [00:00<00:00, 500.27it/s]
100%|██████████| 2/2 [00:00<00:00, 411.31it/s]
100%|██████████| 2/2 [00:00<00:00, 378.38it/s]
100%|██████████| 2/2 [00:00<00:00, 421.37it/s]
100%|██████████| 100/100 [00:31<00:00,  3.19it/s]


1/48


100%|██████████| 2/2 [00:00<00:00, 402.10it/s]
100%|██████████| 2/2 [00:00<00:00, 468.90it/s]
100%|██████████| 2/2 [00:00<00:00, 454.08it/s]
100%|██████████| 2/2 [00:00<00:00, 444.74it/s]
100%|██████████| 2/2 [00:00<00:00, 425.45it/s]
100%|██████████| 2/2 [00:00<00:00, 363.30it/s]
100%|██████████| 2/2 [00:00<00:00, 373.21it/s]
100%|██████████| 2/2 [00:00<00:00, 303.53it/s]
100%|██████████| 100/100 [00:31<00:00,  3.17it/s]


2/48


100%|██████████| 2/2 [00:00<00:00, 315.69it/s]
100%|██████████| 2/2 [00:00<00:00, 421.47it/s]
100%|██████████| 2/2 [00:00<00:00, 414.87it/s]
100%|██████████| 2/2 [00:00<00:00, 394.70it/s]
100%|██████████| 2/2 [00:00<00:00, 429.28it/s]
100%|██████████| 2/2 [00:00<00:00, 436.13it/s]
100%|██████████| 2/2 [00:00<00:00, 466.29it/s]
100%|██████████| 2/2 [00:00<00:00, 444.01it/s]
100%|██████████| 100/100 [00:32<00:00,  3.04it/s]


3/48


100%|██████████| 4/4 [00:00<00:00, 518.74it/s]
100%|██████████| 4/4 [00:00<00:00, 564.60it/s]
100%|██████████| 4/4 [00:00<00:00, 477.13it/s]
100%|██████████| 4/4 [00:00<00:00, 558.87it/s]
100%|██████████| 4/4 [00:00<00:00, 578.25it/s]
100%|██████████| 4/4 [00:00<00:00, 460.61it/s]
100%|██████████| 4/4 [00:00<00:00, 367.53it/s]
100%|██████████| 4/4 [00:00<00:00, 501.95it/s]
100%|██████████| 100/100 [00:31<00:00,  3.19it/s]


4/48


100%|██████████| 4/4 [00:00<00:00, 488.79it/s]
100%|██████████| 4/4 [00:00<00:00, 594.83it/s]
100%|██████████| 4/4 [00:00<00:00, 594.33it/s]
100%|██████████| 4/4 [00:00<00:00, 493.22it/s]
100%|██████████| 4/4 [00:00<00:00, 434.31it/s]
100%|██████████| 4/4 [00:00<00:00, 461.56it/s]
100%|██████████| 4/4 [00:00<00:00, 500.13it/s]
100%|██████████| 4/4 [00:00<00:00, 476.42it/s]
100%|██████████| 100/100 [00:31<00:00,  3.19it/s]


5/48


100%|██████████| 4/4 [00:00<00:00, 432.08it/s]
100%|██████████| 4/4 [00:00<00:00, 462.56it/s]
100%|██████████| 4/4 [00:00<00:00, 430.10it/s]
100%|██████████| 4/4 [00:00<00:00, 378.04it/s]
100%|██████████| 4/4 [00:00<00:00, 323.25it/s]
100%|██████████| 4/4 [00:00<00:00, 386.96it/s]
100%|██████████| 4/4 [00:00<00:00, 432.47it/s]
100%|██████████| 4/4 [00:00<00:00, 447.17it/s]
100%|██████████| 100/100 [00:31<00:00,  3.13it/s]


6/48


100%|██████████| 5/5 [00:00<00:00, 538.96it/s]
100%|██████████| 5/5 [00:00<00:00, 460.44it/s]
100%|██████████| 5/5 [00:00<00:00, 602.08it/s]
100%|██████████| 5/5 [00:00<00:00, 571.65it/s]
100%|██████████| 5/5 [00:00<00:00, 560.56it/s]
100%|██████████| 5/5 [00:00<00:00, 577.19it/s]
100%|██████████| 5/5 [00:00<00:00, 560.26it/s]
100%|██████████| 5/5 [00:00<00:00, 548.92it/s]
100%|██████████| 100/100 [00:30<00:00,  3.27it/s]


7/48


100%|██████████| 5/5 [00:00<00:00, 623.65it/s]
100%|██████████| 5/5 [00:00<00:00, 679.61it/s]
100%|██████████| 5/5 [00:00<00:00, 638.15it/s]
100%|██████████| 5/5 [00:00<00:00, 613.42it/s]
100%|██████████| 5/5 [00:00<00:00, 589.87it/s]
100%|██████████| 5/5 [00:00<00:00, 519.83it/s]
100%|██████████| 5/5 [00:00<00:00, 556.97it/s]
100%|██████████| 5/5 [00:00<00:00, 604.84it/s]
100%|██████████| 100/100 [00:30<00:00,  3.29it/s]


8/48


100%|██████████| 5/5 [00:00<00:00, 554.67it/s]
100%|██████████| 5/5 [00:00<00:00, 512.90it/s]
100%|██████████| 5/5 [00:00<00:00, 523.32it/s]
100%|██████████| 5/5 [00:00<00:00, 515.83it/s]
100%|██████████| 5/5 [00:00<00:00, 558.57it/s]
100%|██████████| 5/5 [00:00<00:00, 565.41it/s]
100%|██████████| 5/5 [00:00<00:00, 574.63it/s]
100%|██████████| 5/5 [00:00<00:00, 533.69it/s]
100%|██████████| 100/100 [00:30<00:00,  3.23it/s]


9/48


100%|██████████| 7/7 [00:00<00:00, 620.73it/s]
100%|██████████| 7/7 [00:00<00:00, 666.41it/s]
100%|██████████| 7/7 [00:00<00:00, 669.96it/s]
100%|██████████| 7/7 [00:00<00:00, 608.78it/s]
100%|██████████| 7/7 [00:00<00:00, 679.73it/s]
100%|██████████| 7/7 [00:00<00:00, 648.30it/s]
100%|██████████| 7/7 [00:00<00:00, 560.39it/s]
100%|██████████| 7/7 [00:00<00:00, 621.00it/s]
100%|██████████| 100/100 [00:30<00:00,  3.28it/s]


10/48


100%|██████████| 7/7 [00:00<00:00, 622.97it/s]
100%|██████████| 7/7 [00:00<00:00, 664.38it/s]
100%|██████████| 7/7 [00:00<00:00, 713.37it/s]
100%|██████████| 7/7 [00:00<00:00, 706.50it/s]
100%|██████████| 7/7 [00:00<00:00, 650.97it/s]
100%|██████████| 7/7 [00:00<00:00, 574.29it/s]
100%|██████████| 7/7 [00:00<00:00, 656.61it/s]
100%|██████████| 7/7 [00:00<00:00, 622.29it/s]
100%|██████████| 100/100 [00:30<00:00,  3.27it/s]


11/48


100%|██████████| 7/7 [00:00<00:00, 499.45it/s]
100%|██████████| 7/7 [00:00<00:00, 553.30it/s]
100%|██████████| 7/7 [00:00<00:00, 566.36it/s]
100%|██████████| 7/7 [00:00<00:00, 538.95it/s]
100%|██████████| 7/7 [00:00<00:00, 583.17it/s]
100%|██████████| 7/7 [00:00<00:00, 632.97it/s]
100%|██████████| 7/7 [00:00<00:00, 613.78it/s]
100%|██████████| 7/7 [00:00<00:00, 606.30it/s]
100%|██████████| 100/100 [00:30<00:00,  3.26it/s]


12/48


100%|██████████| 4/4 [00:00<00:00, 531.14it/s]
100%|██████████| 4/4 [00:00<00:00, 560.31it/s]
100%|██████████| 4/4 [00:00<00:00, 360.67it/s]
100%|██████████| 4/4 [00:00<00:00, 448.67it/s]
100%|██████████| 4/4 [00:00<00:00, 470.62it/s]
100%|██████████| 4/4 [00:00<00:00, 337.30it/s]
100%|██████████| 4/4 [00:00<00:00, 443.77it/s]
100%|██████████| 4/4 [00:00<00:00, 487.51it/s]
100%|██████████| 100/100 [00:30<00:00,  3.28it/s]


13/48


100%|██████████| 4/4 [00:00<00:00, 590.27it/s]
100%|██████████| 4/4 [00:00<00:00, 627.63it/s]
100%|██████████| 4/4 [00:00<00:00, 578.05it/s]
100%|██████████| 4/4 [00:00<00:00, 640.06it/s]
100%|██████████| 4/4 [00:00<00:00, 597.88it/s]
100%|██████████| 4/4 [00:00<00:00, 559.86it/s]
100%|██████████| 4/4 [00:00<00:00, 589.94it/s]
100%|██████████| 4/4 [00:00<00:00, 585.65it/s]
100%|██████████| 100/100 [00:30<00:00,  3.30it/s]


14/48


100%|██████████| 4/4 [00:00<00:00, 511.55it/s]
100%|██████████| 4/4 [00:00<00:00, 514.64it/s]
100%|██████████| 4/4 [00:00<00:00, 483.23it/s]
100%|██████████| 4/4 [00:00<00:00, 561.71it/s]
100%|██████████| 4/4 [00:00<00:00, 509.65it/s]
100%|██████████| 4/4 [00:00<00:00, 482.70it/s]
100%|██████████| 4/4 [00:00<00:00, 501.37it/s]
100%|██████████| 4/4 [00:00<00:00, 501.80it/s]
100%|██████████| 100/100 [00:30<00:00,  3.26it/s]


15/48


100%|██████████| 7/7 [00:00<00:00, 647.37it/s]
100%|██████████| 7/7 [00:00<00:00, 697.27it/s]
100%|██████████| 7/7 [00:00<00:00, 688.07it/s]
100%|██████████| 7/7 [00:00<00:00, 680.59it/s]
100%|██████████| 7/7 [00:00<00:00, 739.85it/s]
100%|██████████| 7/7 [00:00<00:00, 727.33it/s]
100%|██████████| 7/7 [00:00<00:00, 662.70it/s]
100%|██████████| 7/7 [00:00<00:00, 586.44it/s]
100%|██████████| 100/100 [00:30<00:00,  3.31it/s]


16/48


100%|██████████| 7/7 [00:00<00:00, 603.36it/s]
100%|██████████| 7/7 [00:00<00:00, 675.35it/s]
100%|██████████| 7/7 [00:00<00:00, 552.83it/s]
100%|██████████| 7/7 [00:00<00:00, 578.59it/s]
100%|██████████| 7/7 [00:00<00:00, 541.69it/s]
100%|██████████| 7/7 [00:00<00:00, 683.05it/s]
100%|██████████| 7/7 [00:00<00:00, 586.43it/s]
100%|██████████| 7/7 [00:00<00:00, 537.67it/s]
100%|██████████| 100/100 [00:30<00:00,  3.29it/s]


17/48


100%|██████████| 7/7 [00:00<00:00, 542.19it/s]
100%|██████████| 7/7 [00:00<00:00, 503.54it/s]
100%|██████████| 7/7 [00:00<00:00, 545.98it/s]
100%|██████████| 7/7 [00:00<00:00, 544.67it/s]
100%|██████████| 7/7 [00:00<00:00, 539.02it/s]
100%|██████████| 7/7 [00:00<00:00, 557.02it/s]
100%|██████████| 7/7 [00:00<00:00, 578.84it/s]
100%|██████████| 7/7 [00:00<00:00, 562.24it/s]
100%|██████████| 100/100 [00:30<00:00,  3.27it/s]


18/48


100%|██████████| 10/10 [00:00<00:00, 712.76it/s]
100%|██████████| 10/10 [00:00<00:00, 772.70it/s]
100%|██████████| 10/10 [00:00<00:00, 721.85it/s]
100%|██████████| 10/10 [00:00<00:00, 722.25it/s]
100%|██████████| 10/10 [00:00<00:00, 733.13it/s]
100%|██████████| 10/10 [00:00<00:00, 756.59it/s]
100%|██████████| 10/10 [00:00<00:00, 754.47it/s]
100%|██████████| 10/10 [00:00<00:00, 717.47it/s]
100%|██████████| 100/100 [00:29<00:00,  3.34it/s]


19/48


100%|██████████| 10/10 [00:00<00:00, 658.04it/s]
100%|██████████| 10/10 [00:00<00:00, 722.69it/s]
100%|██████████| 10/10 [00:00<00:00, 729.88it/s]
100%|██████████| 10/10 [00:00<00:00, 645.46it/s]
100%|██████████| 10/10 [00:00<00:00, 668.35it/s]
100%|██████████| 10/10 [00:00<00:00, 634.53it/s]
100%|██████████| 10/10 [00:00<00:00, 724.05it/s]
100%|██████████| 10/10 [00:00<00:00, 746.21it/s]
100%|██████████| 100/100 [00:30<00:00,  3.32it/s]


20/48


100%|██████████| 10/10 [00:00<00:00, 553.29it/s]
100%|██████████| 10/10 [00:00<00:00, 569.83it/s]
100%|██████████| 10/10 [00:00<00:00, 591.47it/s]
100%|██████████| 10/10 [00:00<00:00, 650.17it/s]
100%|██████████| 10/10 [00:00<00:00, 671.96it/s]
100%|██████████| 10/10 [00:00<00:00, 655.60it/s]
100%|██████████| 10/10 [00:00<00:00, 659.62it/s]
100%|██████████| 10/10 [00:00<00:00, 651.99it/s]
100%|██████████| 100/100 [00:30<00:00,  3.28it/s]


21/48


100%|██████████| 13/13 [00:00<00:00, 711.00it/s]
100%|██████████| 13/13 [00:00<00:00, 737.28it/s]
100%|██████████| 13/13 [00:00<00:00, 749.39it/s]
100%|██████████| 13/13 [00:00<00:00, 730.78it/s]
100%|██████████| 13/13 [00:00<00:00, 789.44it/s]
100%|██████████| 13/13 [00:00<00:00, 782.88it/s]
100%|██████████| 13/13 [00:00<00:00, 766.90it/s]
100%|██████████| 13/13 [00:00<00:00, 601.35it/s]
100%|██████████| 100/100 [00:30<00:00,  3.32it/s]


22/48


100%|██████████| 13/13 [00:00<00:00, 703.19it/s]
100%|██████████| 13/13 [00:00<00:00, 718.60it/s]
100%|██████████| 13/13 [00:00<00:00, 688.66it/s]
100%|██████████| 13/13 [00:00<00:00, 675.25it/s]
100%|██████████| 13/13 [00:00<00:00, 711.40it/s]
100%|██████████| 13/13 [00:00<00:00, 751.45it/s]
100%|██████████| 13/13 [00:00<00:00, 701.37it/s]
100%|██████████| 13/13 [00:00<00:00, 691.57it/s]
100%|██████████| 100/100 [00:30<00:00,  3.33it/s]


23/48


100%|██████████| 13/13 [00:00<00:00, 630.40it/s]
100%|██████████| 13/13 [00:00<00:00, 667.93it/s]
100%|██████████| 13/13 [00:00<00:00, 640.98it/s]
100%|██████████| 13/13 [00:00<00:00, 634.57it/s]
100%|██████████| 13/13 [00:00<00:00, 620.98it/s]
100%|██████████| 13/13 [00:00<00:00, 655.27it/s]
100%|██████████| 13/13 [00:00<00:00, 638.95it/s]
100%|██████████| 13/13 [00:00<00:00, 607.32it/s]
100%|██████████| 100/100 [00:30<00:00,  3.27it/s]


24/48


100%|██████████| 32/32 [00:00<00:00, 716.58it/s]
100%|██████████| 32/32 [00:00<00:00, 807.97it/s]
100%|██████████| 32/32 [00:00<00:00, 840.12it/s]
100%|██████████| 32/32 [00:00<00:00, 831.68it/s]
100%|██████████| 32/32 [00:00<00:00, 809.28it/s]
100%|██████████| 32/32 [00:00<00:00, 834.85it/s]
100%|██████████| 32/32 [00:00<00:00, 799.70it/s]
100%|██████████| 32/32 [00:00<00:00, 844.29it/s]
100%|██████████| 100/100 [00:29<00:00,  3.34it/s]


25/48


100%|██████████| 32/32 [00:00<00:00, 804.52it/s]
100%|██████████| 32/32 [00:00<00:00, 812.73it/s]
100%|██████████| 32/32 [00:00<00:00, 838.80it/s]
100%|██████████| 32/32 [00:00<00:00, 795.73it/s]
100%|██████████| 32/32 [00:00<00:00, 808.02it/s]
100%|██████████| 32/32 [00:00<00:00, 822.42it/s]
100%|██████████| 32/32 [00:00<00:00, 757.64it/s]
100%|██████████| 32/32 [00:00<00:00, 773.33it/s]
100%|██████████| 100/100 [00:29<00:00,  3.36it/s]


26/48


100%|██████████| 32/32 [00:00<00:00, 701.70it/s]
100%|██████████| 32/32 [00:00<00:00, 723.95it/s]
100%|██████████| 32/32 [00:00<00:00, 738.86it/s]
100%|██████████| 32/32 [00:00<00:00, 724.61it/s]
100%|██████████| 32/32 [00:00<00:00, 744.88it/s]
100%|██████████| 32/32 [00:00<00:00, 731.67it/s]
100%|██████████| 32/32 [00:00<00:00, 726.71it/s]
100%|██████████| 32/32 [00:00<00:00, 733.31it/s]
100%|██████████| 100/100 [00:30<00:00,  3.33it/s]


27/48


100%|██████████| 63/63 [00:00<00:00, 861.48it/s]
100%|██████████| 63/63 [00:00<00:00, 855.22it/s]
100%|██████████| 63/63 [00:00<00:00, 833.10it/s]
100%|██████████| 63/63 [00:00<00:00, 840.28it/s]
100%|██████████| 63/63 [00:00<00:00, 809.92it/s]
100%|██████████| 63/63 [00:00<00:00, 815.01it/s]
100%|██████████| 63/63 [00:00<00:00, 832.38it/s]
100%|██████████| 63/63 [00:00<00:00, 855.28it/s]
100%|██████████| 100/100 [00:29<00:00,  3.34it/s]


28/48


100%|██████████| 63/63 [00:00<00:00, 806.84it/s]
100%|██████████| 63/63 [00:00<00:00, 800.10it/s]
100%|██████████| 63/63 [00:00<00:00, 815.17it/s]
100%|██████████| 63/63 [00:00<00:00, 794.79it/s]
100%|██████████| 63/63 [00:00<00:00, 822.10it/s]
100%|██████████| 63/63 [00:00<00:00, 833.65it/s]
100%|██████████| 63/63 [00:00<00:00, 823.76it/s]
100%|██████████| 63/63 [00:00<00:00, 830.18it/s]
100%|██████████| 100/100 [00:30<00:00,  3.33it/s]


29/48


100%|██████████| 63/63 [00:00<00:00, 706.35it/s]
100%|██████████| 63/63 [00:00<00:00, 690.12it/s]
100%|██████████| 63/63 [00:00<00:00, 725.87it/s]
100%|██████████| 63/63 [00:00<00:00, 760.92it/s]
100%|██████████| 63/63 [00:00<00:00, 750.10it/s]
100%|██████████| 63/63 [00:00<00:00, 737.71it/s]
100%|██████████| 63/63 [00:00<00:00, 747.60it/s]
100%|██████████| 63/63 [00:00<00:00, 730.48it/s]
100%|██████████| 100/100 [00:29<00:00,  3.35it/s]


30/48


100%|██████████| 100/100 [00:00<00:00, 834.56it/s]
100%|██████████| 100/100 [00:00<00:00, 847.04it/s]
100%|██████████| 100/100 [00:00<00:00, 844.60it/s]
100%|██████████| 100/100 [00:00<00:00, 836.06it/s]
100%|██████████| 100/100 [00:00<00:00, 834.34it/s]
100%|██████████| 100/100 [00:00<00:00, 814.88it/s]
100%|██████████| 100/100 [00:00<00:00, 841.05it/s]
100%|██████████| 100/100 [00:00<00:00, 727.55it/s]
100%|██████████| 100/100 [00:29<00:00,  3.35it/s]


31/48


100%|██████████| 100/100 [00:00<00:00, 724.55it/s]
100%|██████████| 100/100 [00:00<00:00, 808.74it/s]
100%|██████████| 100/100 [00:00<00:00, 847.40it/s]
100%|██████████| 100/100 [00:00<00:00, 837.70it/s]
100%|██████████| 100/100 [00:00<00:00, 812.86it/s]
100%|██████████| 100/100 [00:00<00:00, 800.00it/s]
100%|██████████| 100/100 [00:00<00:00, 801.99it/s]
100%|██████████| 100/100 [00:00<00:00, 775.12it/s]
100%|██████████| 100/100 [00:29<00:00,  3.35it/s]


32/48


100%|██████████| 100/100 [00:00<00:00, 719.09it/s]
100%|██████████| 100/100 [00:00<00:00, 735.17it/s]
100%|██████████| 100/100 [00:00<00:00, 716.44it/s]
100%|██████████| 100/100 [00:00<00:00, 744.15it/s]
100%|██████████| 100/100 [00:00<00:00, 726.77it/s]
100%|██████████| 100/100 [00:00<00:00, 719.53it/s]
100%|██████████| 100/100 [00:00<00:00, 722.45it/s]
100%|██████████| 100/100 [00:00<00:00, 712.26it/s]
100%|██████████| 100/100 [00:29<00:00,  3.41it/s]


33/48


100%|██████████| 125/125 [00:00<00:00, 822.06it/s]
100%|██████████| 125/125 [00:00<00:00, 833.88it/s]
100%|██████████| 125/125 [00:00<00:00, 837.85it/s]
100%|██████████| 125/125 [00:00<00:00, 808.81it/s]
100%|██████████| 125/125 [00:00<00:00, 852.35it/s]
100%|██████████| 125/125 [00:00<00:00, 841.43it/s]
100%|██████████| 125/125 [00:00<00:00, 877.82it/s]
100%|██████████| 125/125 [00:00<00:00, 863.07it/s]
100%|██████████| 100/100 [00:29<00:00,  3.34it/s]


34/48


100%|██████████| 125/125 [00:00<00:00, 841.09it/s]
100%|██████████| 125/125 [00:00<00:00, 842.27it/s]
100%|██████████| 125/125 [00:00<00:00, 842.42it/s]
100%|██████████| 125/125 [00:00<00:00, 852.03it/s]
100%|██████████| 125/125 [00:00<00:00, 839.19it/s]
100%|██████████| 125/125 [00:00<00:00, 838.75it/s]
100%|██████████| 125/125 [00:00<00:00, 852.41it/s]
100%|██████████| 125/125 [00:00<00:00, 822.32it/s]
100%|██████████| 100/100 [00:30<00:00,  3.33it/s]


35/48


100%|██████████| 125/125 [00:00<00:00, 736.93it/s]
100%|██████████| 125/125 [00:00<00:00, 746.49it/s]
100%|██████████| 125/125 [00:00<00:00, 758.74it/s]
100%|██████████| 125/125 [00:00<00:00, 775.57it/s]
100%|██████████| 125/125 [00:00<00:00, 742.28it/s]
100%|██████████| 125/125 [00:00<00:00, 730.08it/s]
100%|██████████| 125/125 [00:00<00:00, 702.98it/s]
100%|██████████| 125/125 [00:00<00:00, 719.56it/s]
100%|██████████| 100/100 [00:29<00:00,  3.39it/s]


36/48


100%|██████████| 313/313 [00:00<00:00, 821.61it/s]
100%|██████████| 313/313 [00:00<00:00, 854.35it/s]
100%|██████████| 313/313 [00:00<00:00, 847.14it/s]
100%|██████████| 313/313 [00:00<00:00, 854.90it/s]
100%|██████████| 313/313 [00:00<00:00, 854.25it/s]
100%|██████████| 313/313 [00:00<00:00, 846.38it/s]
100%|██████████| 313/313 [00:00<00:00, 845.36it/s]
100%|██████████| 313/313 [00:00<00:00, 838.77it/s]
100%|██████████| 100/100 [00:30<00:00,  3.32it/s]


37/48


100%|██████████| 313/313 [00:00<00:00, 827.46it/s]
100%|██████████| 313/313 [00:00<00:00, 836.38it/s]
100%|██████████| 313/313 [00:00<00:00, 832.45it/s]
100%|██████████| 313/313 [00:00<00:00, 835.30it/s]
100%|██████████| 313/313 [00:00<00:00, 814.41it/s]
100%|██████████| 313/313 [00:00<00:00, 760.63it/s]
100%|██████████| 313/313 [00:00<00:00, 810.18it/s]
100%|██████████| 313/313 [00:00<00:00, 833.12it/s]
100%|██████████| 100/100 [00:29<00:00,  3.34it/s]


38/48


100%|██████████| 313/313 [00:00<00:00, 751.58it/s]
100%|██████████| 313/313 [00:00<00:00, 752.25it/s]
100%|██████████| 313/313 [00:00<00:00, 743.68it/s]
100%|██████████| 313/313 [00:00<00:00, 736.74it/s]
100%|██████████| 313/313 [00:00<00:00, 738.99it/s]
100%|██████████| 313/313 [00:00<00:00, 744.02it/s]
100%|██████████| 313/313 [00:00<00:00, 743.74it/s]
100%|██████████| 313/313 [00:00<00:00, 737.55it/s]
100%|██████████| 100/100 [00:29<00:00,  3.43it/s]


39/48


100%|██████████| 625/625 [00:00<00:00, 847.29it/s]
100%|██████████| 625/625 [00:00<00:00, 856.42it/s]
100%|██████████| 625/625 [00:00<00:00, 855.22it/s]
100%|██████████| 625/625 [00:00<00:00, 858.43it/s]
100%|██████████| 625/625 [00:00<00:00, 863.76it/s]
100%|██████████| 625/625 [00:00<00:00, 858.64it/s]
100%|██████████| 625/625 [00:00<00:00, 858.80it/s]
100%|██████████| 625/625 [00:00<00:00, 868.82it/s]
100%|██████████| 100/100 [00:29<00:00,  3.34it/s]


40/48


100%|██████████| 625/625 [00:00<00:00, 815.16it/s]
100%|██████████| 625/625 [00:00<00:00, 846.91it/s]
100%|██████████| 625/625 [00:00<00:00, 831.96it/s]
100%|██████████| 625/625 [00:00<00:00, 830.44it/s]
100%|██████████| 625/625 [00:00<00:00, 826.12it/s]
100%|██████████| 625/625 [00:00<00:00, 836.93it/s]
100%|██████████| 625/625 [00:00<00:00, 807.95it/s]
100%|██████████| 625/625 [00:00<00:00, 765.90it/s]
100%|██████████| 100/100 [00:29<00:00,  3.42it/s]


41/48


100%|██████████| 625/625 [00:00<00:00, 675.43it/s]
100%|██████████| 625/625 [00:00<00:00, 706.18it/s]
100%|██████████| 625/625 [00:00<00:00, 727.50it/s]
100%|██████████| 625/625 [00:00<00:00, 718.46it/s]
100%|██████████| 625/625 [00:00<00:00, 719.66it/s]
100%|██████████| 625/625 [00:00<00:00, 704.54it/s]
100%|██████████| 625/625 [00:00<00:00, 726.03it/s]
100%|██████████| 625/625 [00:00<00:00, 719.92it/s]
100%|██████████| 100/100 [00:27<00:00,  3.60it/s]


42/48


100%|██████████| 1000/1000 [00:01<00:00, 776.94it/s]
100%|██████████| 1000/1000 [00:01<00:00, 786.70it/s]
100%|██████████| 1000/1000 [00:01<00:00, 759.54it/s]
100%|██████████| 1000/1000 [00:01<00:00, 802.59it/s]
100%|██████████| 1000/1000 [00:01<00:00, 819.08it/s]
100%|██████████| 1000/1000 [00:01<00:00, 842.07it/s]
100%|██████████| 1000/1000 [00:01<00:00, 828.84it/s]
100%|██████████| 1000/1000 [00:01<00:00, 828.41it/s]
100%|██████████| 100/100 [00:29<00:00,  3.37it/s]


43/48


100%|██████████| 1000/1000 [00:01<00:00, 803.13it/s]
100%|██████████| 1000/1000 [00:01<00:00, 840.87it/s]
100%|██████████| 1000/1000 [00:01<00:00, 839.49it/s]
100%|██████████| 1000/1000 [00:01<00:00, 840.34it/s]
100%|██████████| 1000/1000 [00:01<00:00, 835.00it/s]
100%|██████████| 1000/1000 [00:01<00:00, 835.19it/s]
100%|██████████| 1000/1000 [00:01<00:00, 833.41it/s]
100%|██████████| 1000/1000 [00:01<00:00, 823.48it/s]
100%|██████████| 100/100 [00:27<00:00,  3.67it/s]


44/48


100%|██████████| 1000/1000 [00:01<00:00, 717.57it/s]
100%|██████████| 1000/1000 [00:01<00:00, 716.31it/s]
100%|██████████| 1000/1000 [00:01<00:00, 697.62it/s]
100%|██████████| 1000/1000 [00:01<00:00, 707.69it/s]
100%|██████████| 1000/1000 [00:01<00:00, 714.37it/s]
100%|██████████| 1000/1000 [00:01<00:00, 710.88it/s]
100%|██████████| 1000/1000 [00:01<00:00, 703.93it/s]
100%|██████████| 1000/1000 [00:01<00:00, 726.12it/s]
100%|██████████| 100/100 [00:26<00:00,  3.81it/s]


45/48


100%|██████████| 1250/1250 [00:01<00:00, 840.07it/s]
100%|██████████| 1250/1250 [00:01<00:00, 844.84it/s]
100%|██████████| 1250/1250 [00:01<00:00, 852.22it/s]
100%|██████████| 1250/1250 [00:01<00:00, 855.65it/s]
100%|██████████| 1250/1250 [00:01<00:00, 832.35it/s]
100%|██████████| 1250/1250 [00:01<00:00, 844.30it/s]
100%|██████████| 1250/1250 [00:01<00:00, 844.87it/s]
100%|██████████| 1250/1250 [00:01<00:00, 863.22it/s]
100%|██████████| 100/100 [00:29<00:00,  3.34it/s]


46/48


100%|██████████| 1250/1250 [00:01<00:00, 852.23it/s]
100%|██████████| 1250/1250 [00:01<00:00, 849.42it/s]
100%|██████████| 1250/1250 [00:01<00:00, 832.87it/s]
100%|██████████| 1250/1250 [00:01<00:00, 837.99it/s]
100%|██████████| 1250/1250 [00:01<00:00, 846.40it/s]
100%|██████████| 1250/1250 [00:01<00:00, 848.44it/s]
100%|██████████| 1250/1250 [00:01<00:00, 846.80it/s]
100%|██████████| 1250/1250 [00:01<00:00, 847.22it/s]
100%|██████████| 100/100 [00:26<00:00,  3.76it/s]


47/48


100%|██████████| 1250/1250 [00:01<00:00, 734.16it/s]
100%|██████████| 1250/1250 [00:01<00:00, 749.53it/s]
100%|██████████| 1250/1250 [00:01<00:00, 745.57it/s]
100%|██████████| 1250/1250 [00:01<00:00, 730.23it/s]
100%|██████████| 1250/1250 [00:01<00:00, 744.07it/s]
100%|██████████| 1250/1250 [00:01<00:00, 747.32it/s]
100%|██████████| 1250/1250 [00:01<00:00, 743.12it/s]
100%|██████████| 1250/1250 [00:01<00:00, 759.36it/s]
100%|██████████| 100/100 [00:25<00:00,  3.89it/s]

48/48





In [None]:
import numpy as np
from cayleypy import Predictor
from tqdm import tqdm

test_start_states = [np.random.permutation(n_pancakes) for _ in range(100)]

beam_size_range = [1, 10, 10**2]
hamming_success_rates = []
nn_success_rates = []
for beam_size in beam_size_range:
    hamming_success_rates.append(measure_success(Predictor(graph, "hamming"), beam_size))
    nn_success_rates.append(measure_success(Predictor(graph, scorer), beam_size))

In [None]:
from matplotlib import pyplot as plt

plt.plot(beam_size_range, hamming_success_rates, label="Hamming distance", marker='.')
plt.plot(beam_size_range, nn_success_rates, label="Neural network", marker='.')
plt.xlabel('Beam size')
plt.ylabel('Success rate')
plt.legend()
plt.xscale('log')
plt.title("Beam search for LRX(12)")
plt.show()

# Prediction

In [23]:
import numpy as np
from cayleypy import Predictor
from tqdm import tqdm

graphs = {}
models = {}
for n_s in test["n"].unique():
    graphs[n_s] = CayleyGraph(PermutationGroups.pancake(n_s))
    models[n_s] = Predictor(graphs[n_s], train_predictor(n_s, graph=graphs[n_s], width=n_s*100, length=int(n_s*1.6), batch_size=32, epochs=4))

100%|██████████| 125/125 [00:00<00:00, 658.20it/s]
100%|██████████| 125/125 [00:00<00:00, 676.80it/s]
100%|██████████| 125/125 [00:00<00:00, 670.13it/s]
100%|██████████| 125/125 [00:00<00:00, 662.92it/s]
100%|██████████| 713/713 [00:01<00:00, 671.87it/s]
100%|██████████| 713/713 [00:01<00:00, 672.34it/s]
100%|██████████| 713/713 [00:01<00:00, 647.28it/s]
100%|██████████| 713/713 [00:01<00:00, 675.47it/s]
100%|██████████| 1125/1125 [00:01<00:00, 676.58it/s]
100%|██████████| 1125/1125 [00:01<00:00, 679.75it/s]
100%|██████████| 1125/1125 [00:01<00:00, 667.46it/s]
100%|██████████| 1125/1125 [00:01<00:00, 654.57it/s]
100%|██████████| 1250/1250 [00:01<00:00, 657.41it/s]
100%|██████████| 1250/1250 [00:01<00:00, 670.92it/s]
100%|██████████| 1250/1250 [00:01<00:00, 673.74it/s]
100%|██████████| 1250/1250 [00:01<00:00, 666.15it/s]
100%|██████████| 2000/2000 [00:02<00:00, 667.30it/s]
100%|██████████| 2000/2000 [00:03<00:00, 658.49it/s]
100%|██████████| 2000/2000 [00:02<00:00, 677.35it/s]
100%|████

In [9]:
def pancake_sort_path(perm: list[int]) -> list[str]:
    """Return a sequence of prefix reversals that sorts `perm` to the identity permutation."""
    arr = list(perm)
    n = len(arr)
    moves: list[str] = []

    for target in range(n, 1, -1):
        desired_value = target - 1
        idx = arr.index(desired_value)

        if idx == target - 1:
            continue  # already in place

        if idx != 0:
            moves.append(f'R{idx + 1}')
            arr[: idx + 1] = reversed(arr[: idx + 1])

        moves.append(f'R{target}')
        arr[:target] = reversed(arr[:target])

    return moves

In [10]:
import numpy as np

heurestic_paths = []
for _, row in tqdm(test.iterrows(), total=len(test)):
    perms = np.array(row["permutation"].split(",")).astype(int)
    moves = pancake_sort_path(perms)
    heurestic_paths.append(".".join(moves))

100%|██████████| 2405/2405 [00:00<00:00, 5771.70it/s] 


In [31]:
pred_paths = []

for i, row in tqdm(test.iterrows(), total=len(test)):
    n = row["n"]
    if n >= 20:
        pred_paths.append(heurestic_paths[i])
        continue
    heurestic_length = heurestic_paths[i].count(".") + 1
    perms = np.array(row["permutation"].split(",")).astype(int)
    graphs[n].free_memory()
    result = graphs[n].beam_search(start_state=perms, beam_width=1000, max_steps=heurestic_length, predictor=models[n], return_path=True)
    if result.path_found and len(result.path) < heurestic_length:
        pred_paths.append(".".join([f"R{gen_index+2}" for gen_index in result.path]))
    else:
        pred_paths.append(heurestic_paths[i])

100%|██████████| 2405/2405 [06:16<00:00,  6.39it/s]


In [32]:
submissions = pd.read_csv("/kaggle/input/CayleyPy-pancake/sample_submission.csv")
submissions["solution"] = pred_paths
submissions.to_csv("nn_submission.csv")