Implementation of the puzzle permutation mechanisms using jax

In [1]:
%pprint

Pretty printing has been turned OFF


In [2]:
import json

import jax

import numpy as np
import pandas as pd
import jax.numpy as jnp 

from typing import List, Any, Tuple, Dict
from numpy import ndarray

In [3]:
def permute(state: List[Any], perm: List[int]) -> List[Any]:
    """performs permutation on state and returns new state"""
    assert len(state) == len(perm)
    new_state = state[:]
    for i, j in enumerate(perm):
        new_state[i] = state[j]
    return new_state


def permute_with_swap(state: List[Any], swaps: List[Tuple[int, int]]) -> List[Any]:
    """performs permutation with swapping operations"""
    new_state = state[:]
    for i, j in swaps:
        new_state[i] = state[j]
    return new_state


def reverse_perm(perm: List[int]) -> List[int]:
    """computes the reversed permutation"""
    out = perm[:]
    for i, j in enumerate(perm):
        out[j] = i
    return out


def perm_to_swap(perm: List[int]) -> List[Tuple[int, int]]:
    """changes a permutation arr to a list of swaps"""
    swap = []
    for i, j in enumerate(perm):
        if i != j:
            swap.append((i, j))

    return swap

In [101]:
def load_puzzle_moves(
    puzzle_name: str, convert_to_swaps=True
) -> (Dict[str, List[int]], int):
    """Retrieves and returns the moves and final position of the puzzle"""
    # load the moves:
    with open(f"puzzles/{puzzle_name}/moves.json") as f:
        moves = json.load(f)

    # add reversed moves
    reversed_moves = {}
    for move_name, perm in moves.items():
        reversed_perm = reverse_perm(perm)
        if reversed_perm == perm:
            continue
        reversed_moves[f"-{move_name}"] = reversed_perm

    moves.update(reversed_moves)

    if convert_to_swaps:
        for move_name, perm in moves.items():
            moves[move_name] = perm_to_swap(perm)

    # get final position (from the first puzzle), note that the actual state of this position doesn't really matter
    # we just need to get the structure of the puzzle
    df = pd.read_csv(f"puzzles/{puzzle_name}/puzzles.csv")
    state = df.iloc[0].to_numpy()[3].split(";")
    c = -1
    mapping = {}
    for i, s in enumerate(state):
        if s not in mapping.keys():
            c += 1
            mapping[s] = c
        state[i] = mapping[s]

    return moves, state


puzzle_name = "cube_2x2x2"
move_dict, final_state = load_puzzle_moves(puzzle_name)
num_states = len(final_state)
move_names = np.array(list(move_dict.keys()))

# final_state = list(range(num_states))

print(f"Loaded {puzzle_name} with {len(move_names)} moves and {num_states} states")

Loaded cube_2x2x2 with 12 moves and 24 states


In [102]:
class Sampler:
    def __init__(self) -> None:
        pass

    def sample(self) -> int:
        pass


class Uniform_sampler(Sampler):
    def __init__(self, low, high) -> None:
        super().__init__()
        self.low = low
        self.high = high

    def sample(self) -> int:
        return np.random.randint(self.low, self.high)


class Constant_sampler(Sampler):
    def __init__(self, n) -> None:
        super().__init__()
        self.n = n

    def sample(self) -> int:
        return self.n

In [112]:
from src.mechanism.reduce import iterate_reduce_sequence

puzzle_name = "cube_10x10x10"
move_dict, final_state = load_puzzle_moves(puzzle_name)
num_states = len(final_state)
move_names = np.array(list(move_dict.keys()))
b = 100
# get a batch of states:
state = np.expand_dims(np.array(final_state), 0)
states = np.repeat(state, b, 0)


def sample_moves(sampler, puzzle_name, move_names):
    # TODO: shouldn't use variables from outside scope like this
    n = sampler.sample()
    moves = np.random.choice(move_names, n)
    moves = iterate_reduce_sequence(moves, puzzle_name)
    return moves


sampler = Uniform_sampler(4, 5)
# sampling moves is done with regular lists, since at each step the batch should be small
moves = sample_moves(sampler, puzzle_name, move_names) # list
moves = np.expand_dims(np.array(moves), 0)
moves = np.repeat(moves, b, 0)

next_move = np.array([move_names[0]]*b)


In [113]:
def env_step(state, action):
    state = permute_with_swap(state, action)
    return state


def batch_env_step(states, actions, move_dict):
    for i, (state, action) in enumerate(zip(states, actions)):
        action = move_dict[action]
        states[i] = env_step(state, action)

%timeit -n 100 batch_env_step(states, next_move, move_dict)

1.14 ms ± 91.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [114]:
from jax import vmap, pmap
from functools import partial

puzzle_name = "cube_10x10x10"
move_dict, final_state = load_puzzle_moves(puzzle_name, convert_to_swaps=False)
num_states = len(final_state)
move_names = np.array(list(move_dict.keys()))
b = 100
state = np.expand_dims(np.array(final_state), 0)
states = np.repeat(state, b, 0)

def jax_permute(state, perm):
    """performs permutation with swapping operations"""
    return jnp.take(state, perm, unique_indices=True)


def env_step(state, action):
    return jax_permute(state, action)


def jax_batch_env_step(states, actions):
    return vmap(env_step)(states, actions)


action_map = list(move_dict.values())

move = action_map[1]
j_next_moves = jnp.array([move] * b)
j_states = jnp.array(states)

# env_step(j_states[0], j_next_moves[0])

%timeit -n 1 jax_batch_env_step(j_states, j_next_moves)
# jax_permute_with_swap(j_states[0], j_next_moves[0])
# # jax.lax.ppermute(j_states[0], "p", perm=j_next_moves[0])

The slowest run took 22.95 times longer than the fastest. This could mean that an intermediate result is being cached.
3.26 ms ± 5.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
from src.mechanism.utils import get_inverse_move

def sample_moves(move_names: List[str], n: int) -> List[int]:
    return np.random.choice(move_names, n)


def generate_state_from_moves(move_names, move_dict, state, inverse=False):
    for move_name in move_names:
        if inverse:
            move_name = get_inverse_move(move_name)
        move = move_dict[move_name]
        state = permute_with_swap(state, move)

    return state


def normalize_state(state):
    if type(state) == list:
        n = len(set(state)) - 1
        return [s / n for s in state]
    return state / len(state)