# Demo for $\mathbf{Z}_3 \wr \mathbf{Z}_3 \wr \mathbf{Z}_3$

In [2]:
import itertools
import pickle
import random

import numba
import numpy as np
import scipy
from numba import cuda
from tqdm.auto import tqdm

from wrepy import (
    CyclicGroupPermutationFactory,
    GeneratorSetFactory,
    Permutation,
    PermutationGroup,
)
from wrepy.cuda import device_fn, kernels, portrait
from wrepy.cuda.utils import kernel_2d_spec

## Util functions

In [3]:
def column_print(p: Permutation):
    for k, v in sorted(p.rule.items()):
        print(f"{k} -> {v}")


def get_cycle(p: Permutation, v: tuple[int, int, int]) -> list[tuple[int, int, int]]:
    cycle = []
    while v not in cycle:
        cycle.append(v)
        v = v**p
    return cycle


class SliceIterator:
    def __init__(
        self,
        total: int,
        batch_size: int,
    ):
        self.batch_size = batch_size
        self.total = total

        self.last_idx = 0

    def __iter__(self):
        return self

    def __next__(self) -> slice:
        if self.last_idx >= self.total:
            raise StopIteration

        next_slice = slice(
            self.last_idx,
            self.last_idx + min(self.total - self.last_idx, self.batch_size),
        )
        self.last_idx += self.batch_size
        return next_slice

    def __len__(self):
        n_batches = self.total // self.batch_size

        if self.total % self.batch_size != 0:
            n_batches += 1

        return n_batches

## Data preparation

In [4]:
N = 3
z3 = PermutationGroup(
    set(range(N)),  # define the underling set as {0, 1, 2}
    CyclicGroupPermutationFactory,  # setting the rule how to build group from underling set
)
z3z3 = z3.wreath_product(z3)
z3z3z3 = z3z3.wreath_product(z3)
z3z3z3

PermutationGroup(order=1594323)

## Running test

In [5]:
arities = np.array([3, 3, 3], dtype=np.int8)
portraits = portrait.portrait_array_from_arities(*arities)
all_points = portrait.get_zn_decart_space(*arities)

arities_cuda = cuda.to_device(arities)
all_points_cuda = cuda.to_device(all_points)

In [6]:
%%time
# calculating orders

portraits_cuda = cuda.to_device(portraits)
tmp_1 = cuda.device_array_like(portraits_cuda)
tmp_2 = cuda.device_array_like(portraits_cuda)

orders_cuda = cuda.device_array(len(portraits_cuda), dtype=int)

kernels.order_kernel.forall(len(portraits_cuda))(
    arities_cuda,
    portraits,
    tmp_1,
    tmp_2,
    all_points_cuda,
    orders_cuda,
)

orders = orders_cuda.copy_to_host()
for label, counts in zip(*np.unique(orders, return_counts=True)):
    print(f"Order: {label}: counts {counts}")



Order: 1: counts 1
Order: 3: counts 104246
Order: 9: counts 1017684
Order: 27: counts 472392
CPU times: user 283 ms, sys: 9.55 ms, total: 292 ms
Wall time: 290 ms


In [7]:
order_3_portraits = portraits[orders == 3]
order_3_portraits.shape

(104246, 13)

In [8]:
order_3_portraits[42]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 1], dtype=int8)

In [9]:
%%time
# counting the number of fixed points
order_3_portraits_cuda = cuda.to_device(order_3_portraits)
n_fixed_points_cuda = cuda.device_array(len(order_3_portraits), dtype=int)

kernels.n_fixed_points_kernel.forall(len(order_3_portraits))(
    arities, order_3_portraits_cuda, all_points_cuda, n_fixed_points_cuda
)

n_fixed_points = n_fixed_points_cuda.copy_to_host()
for label, counts in zip(*np.unique(n_fixed_points, return_counts=True)):
    print(f"n_fixed_points: {label}: counts {counts}")

n_fixed_points: 0: counts 30698
n_fixed_points: 3: counts 24336
n_fixed_points: 6: counts 23400
n_fixed_points: 9: counts 14988
n_fixed_points: 12: counts 7272
n_fixed_points: 15: counts 2664
n_fixed_points: 18: counts 726
n_fixed_points: 21: counts 144
n_fixed_points: 24: counts 18
CPU times: user 5.05 ms, sys: 233 µs, total: 5.29 ms
Wall time: 4.64 ms




In [10]:
c_candidates = order_3_portraits[n_fixed_points == 3]
d_candidates = order_3_portraits[n_fixed_points == 0]

In [14]:
n_nodes = c_candidates.shape[-1]
point_dim = all_points.shape[-1]

@cuda.jit(
    numba.void(
        numba.int8[:],
        numba.int8[:, :],
        numba.int8[:, :],
        numba.int8[:, :],
        numba.bool_[:, :],
    )
)
def main_kernel(
    arities, 
    portraits_c,
    portraits_d,
    all_points,
    result,
):
    # tmp local memory

    global n_nodes
    global point_dim
    
    tmp_1 = cuda.local.array(n_nodes, numba.int8)
    tmp_2 = cuda.local.array(n_nodes, numba.int8)
    tmp_3 = cuda.local.array(n_nodes, numba.int8)
    point_tmp = cuda.local.array(point_dim, numba.int8)

    c_idx, d_idx = cuda.grid(2)

    result[c_idx, d_idx] = False

    if c_idx >= len(portraits_c) or d_idx >= len(portraits_d):
        return

    c = portraits_c[c_idx]
    d = portraits_d[d_idx]

    ## Fixed point check

    for point in all_points:
        if device_fn.is_fixed_point(arities, c, point):
            device_fn.action(arities, d, point, point_tmp)

            if device_fn.is_fixed_point(arities, c, point_tmp):
                result[c_idx, d_idx] = False
                return

    # $(dc^{2})^{3}$ is non-trivial

    # dc -> tmp_1
    device_fn.portait_mul(arities, d, c, tmp_1, all_points)
    # dc^2 -> tmp_2d_2
    device_fn.portait_mul(arities, tmp_1, c, tmp_2, all_points)
    # (dc^2) ^ 2 -> tmp_2d_1
    device_fn.portait_mul(arities, tmp_2, tmp_2, tmp_1, all_points)
    # (dc^2) ^ 3 -> tmp_2d_3
    device_fn.portait_mul(arities, tmp_1, tmp_2, tmp_3, all_points)

    if device_fn.check_zero(tmp_3):
        result[c_idx, d_idx] = False
        return
    ## check (c^{-1}dc)d = d(c^{-1}dc)

    # c^(-1) -> tmp_2d_2
    device_fn.portait_inverse(arities, c, tmp_2, all_points)

    # c^(-1)d -> tmp_2d_1
    device_fn.portait_mul(arities, tmp_2, d, tmp_1, all_points)

    # c^(-1)dc -> tmp_2d_3
    device_fn.portait_mul(arities, tmp_1, c, tmp_3, all_points)

    # (c^(-1)dc)d -> tmp_2d_1
    device_fn.portait_mul(arities, tmp_3, d, tmp_1, all_points)

    # d(c^(-1)dc) -> tmp_2d_2
    device_fn.portait_mul(arities, d, tmp_3, tmp_2, all_points)

    # check (c^{-1}dc)d = d(c^{-1}dc)
    if not device_fn.eq_portrait(tmp_1, tmp_2):
        result[c_idx, d_idx] = False
        return

    result[c_idx, d_idx] = True

In [18]:
BATCH_SIZE = 2048

c_candidates_cuda = cuda.to_device(c_candidates)
d_candidates_cuda = cuda.to_device(d_candidates)
result_cuda = cuda.device_array((BATCH_SIZE, BATCH_SIZE), dtype=bool)

bpg, tpb = kernel_2d_spec((BATCH_SIZE, BATCH_SIZE), (16, 16))

compiled_kernel = main_kernel[bpg, tpb]

In [19]:
def run_block(c_slice: slice, d_slice: slice) -> np.ndarray:
    # slcies shoud be square

    c_data = c_candidates_cuda[c_slice]
    d_data = d_candidates_cuda[d_slice]

    compiled_kernel(
        arities_cuda,
        c_data,
        d_data,
        all_points_cuda,
        result_cuda,
    )
    return result_cuda.copy_to_host()

In [21]:
%%time
c_idx = 0
d_idx = 0
run_block(slice(c_idx, c_idx + BATCH_SIZE), slice(d_idx, d_idx + BATCH_SIZE)).any()

CPU times: user 177 ms, sys: 124 µs, total: 177 ms
Wall time: 174 ms


False

In [23]:
valid_indexes_c = []
valid_indexes_d = []


c_slice_iterator = SliceIterator(len(c_candidates), BATCH_SIZE)
d_slice_iterator = SliceIterator(len(d_candidates), BATCH_SIZE)

for c_slice, d_slice in tqdm(
    itertools.product(c_slice_iterator, d_slice_iterator),
    total=len(c_slice_iterator) * len(d_slice_iterator),
):
    c_indexes, d_indexes = run_block(c_slice, d_slice).nonzero()

    valid_indexes_c.append(c_indexes + c_slice.start)
    valid_indexes_d.append(d_indexes + d_slice.start)

  0%|          | 0/180 [00:00<?, ?it/s]

In [24]:
valid_indexes_c = np.concatenate(valid_indexes_c)
valid_indexes_d = np.concatenate(valid_indexes_d)

In [25]:
valid_indexes_c.shape

(40223304,)

## Demo

In [25]:
c_permutations = portrait.to_dict_permutation(arities, c_candidates, z3z3z3, all_points)
d_permutations = portrait.to_dict_permutation(arities, d_candidates, z3z3z3, all_points)

In [29]:
# Select a random pair C, D
pair_index = random.randrange(0, len(valid_indexes_c))
print("Selected pair:", pair_index)

c_idx = valid_indexes_c[pair_index]
d_idx = valid_indexes_d[pair_index]

print(c_idx, d_idx)

C = c_permutations[c_idx]
D = d_permutations[d_idx]

Selected pair: 107400
3909 2017


In [30]:
# portraits
print("C:", c_candidates[c_idx])
print("D:", d_candidates[d_idx])

C: [0 0 0 1 2 2 1 2 1 0 1 0 2]
D: [0 0 1 0 2 1 1 2 2 2 1 1 2]


In [31]:
# Check order of C
C.order

3

In [32]:
# Check order of D
D.order

3

In [33]:
# Full view of C
column_print(C)

(0, 0, 0) -> (0, 0, 2)
(0, 0, 1) -> (0, 0, 0)
(0, 0, 2) -> (0, 0, 1)
(0, 1, 0) -> (0, 1, 2)
(0, 1, 1) -> (0, 1, 0)
(0, 1, 2) -> (0, 1, 1)
(0, 2, 0) -> (0, 2, 1)
(0, 2, 1) -> (0, 2, 2)
(0, 2, 2) -> (0, 2, 0)
(1, 0, 0) -> (1, 0, 2)
(1, 0, 1) -> (1, 0, 0)
(1, 0, 2) -> (1, 0, 1)
(1, 1, 0) -> (1, 1, 1)
(1, 1, 1) -> (1, 1, 2)
(1, 1, 2) -> (1, 1, 0)
(1, 2, 0) -> (1, 2, 0)
(1, 2, 1) -> (1, 2, 1)
(1, 2, 2) -> (1, 2, 2)
(2, 0, 0) -> (2, 1, 1)
(2, 0, 1) -> (2, 1, 2)
(2, 0, 2) -> (2, 1, 0)
(2, 1, 0) -> (2, 2, 0)
(2, 1, 1) -> (2, 2, 1)
(2, 1, 2) -> (2, 2, 2)
(2, 2, 0) -> (2, 0, 2)
(2, 2, 1) -> (2, 0, 0)
(2, 2, 2) -> (2, 0, 1)


In [34]:
# Full view of D
column_print(D)

(0, 0, 0) -> (0, 0, 2)
(0, 0, 1) -> (0, 0, 0)
(0, 0, 2) -> (0, 0, 1)
(0, 1, 0) -> (0, 1, 1)
(0, 1, 1) -> (0, 1, 2)
(0, 1, 2) -> (0, 1, 0)
(0, 2, 0) -> (0, 2, 1)
(0, 2, 1) -> (0, 2, 2)
(0, 2, 2) -> (0, 2, 0)
(1, 0, 0) -> (1, 1, 2)
(1, 0, 1) -> (1, 1, 0)
(1, 0, 2) -> (1, 1, 1)
(1, 1, 0) -> (1, 2, 2)
(1, 1, 1) -> (1, 2, 0)
(1, 1, 2) -> (1, 2, 1)
(1, 2, 0) -> (1, 0, 2)
(1, 2, 1) -> (1, 0, 0)
(1, 2, 2) -> (1, 0, 1)
(2, 0, 0) -> (2, 0, 1)
(2, 0, 1) -> (2, 0, 2)
(2, 0, 2) -> (2, 0, 0)
(2, 1, 0) -> (2, 1, 1)
(2, 1, 1) -> (2, 1, 2)
(2, 1, 2) -> (2, 1, 0)
(2, 2, 0) -> (2, 2, 2)
(2, 2, 1) -> (2, 2, 0)
(2, 2, 2) -> (2, 2, 1)


In [35]:
C.fixed_points()

[(1, 2, 0), (1, 2, 1), (1, 2, 2)]

In [36]:
D.fixed_points()

[]

In [37]:
cyc_1 = get_cycle(D, C.fixed_points()[0])
cyc_1

[(1, 2, 0), (1, 0, 2), (1, 1, 1)]

In [38]:
cyc_2 = get_cycle(D, C.fixed_points()[1])
cyc_2

[(1, 2, 1), (1, 0, 0), (1, 1, 2)]

In [39]:
cyc_3 = get_cycle(D, C.fixed_points()[2])
cyc_3

[(1, 2, 2), (1, 0, 1), (1, 1, 0)]

In [40]:
# fixed points of C are in the different cycles of D
set(cyc_1).intersection(set(cyc_2)).intersection(set(cyc_3))

set()

In [41]:
# conjugation check
D * (C.inverse() * D * C) == (C.inverse() * D * C) * D

True

In [42]:
#  (d * c ^2) ^ 3
# check if not trivial
# Note: the ** power operation is reserved for action of permutation on a point (see: get_cycle fn on top)
dc23 = (D * (C * C)) * (D * (C * C)) * (D * (C * C))
column_print(dc23)

print("---- Non trival elements ----")
for k, v in sorted(dc23.rule.items()):
    if k != v:
        print(f"{k} -> {v}")

(0, 0, 0) -> (0, 0, 0)
(0, 0, 1) -> (0, 0, 1)
(0, 0, 2) -> (0, 0, 2)
(0, 1, 0) -> (0, 1, 0)
(0, 1, 1) -> (0, 1, 1)
(0, 1, 2) -> (0, 1, 2)
(0, 2, 0) -> (0, 2, 0)
(0, 2, 1) -> (0, 2, 1)
(0, 2, 2) -> (0, 2, 2)
(1, 0, 0) -> (1, 0, 0)
(1, 0, 1) -> (1, 0, 1)
(1, 0, 2) -> (1, 0, 2)
(1, 1, 0) -> (1, 1, 0)
(1, 1, 1) -> (1, 1, 1)
(1, 1, 2) -> (1, 1, 2)
(1, 2, 0) -> (1, 2, 0)
(1, 2, 1) -> (1, 2, 1)
(1, 2, 2) -> (1, 2, 2)
(2, 0, 0) -> (2, 0, 1)
(2, 0, 1) -> (2, 0, 2)
(2, 0, 2) -> (2, 0, 0)
(2, 1, 0) -> (2, 1, 1)
(2, 1, 1) -> (2, 1, 2)
(2, 1, 2) -> (2, 1, 0)
(2, 2, 0) -> (2, 2, 1)
(2, 2, 1) -> (2, 2, 2)
(2, 2, 2) -> (2, 2, 0)
---- Non trival elements ----
(2, 0, 0) -> (2, 0, 1)
(2, 0, 1) -> (2, 0, 2)
(2, 0, 2) -> (2, 0, 0)
(2, 1, 0) -> (2, 1, 1)
(2, 1, 1) -> (2, 1, 2)
(2, 1, 2) -> (2, 1, 0)
(2, 2, 0) -> (2, 2, 1)
(2, 2, 1) -> (2, 2, 2)
(2, 2, 2) -> (2, 2, 0)


In [43]:
# check group order
target_group = PermutationGroup(
    z3z3z3.underlying_set,
    GeneratorSetFactory,
    generator_set=([C.rule, D.rule]),
)
target_group

  warn(


PermutationGroup(order=81)

In [44]:
# The elements of a relust group can also be examied if required

target_elements = list(target_group.elements)
column_print(target_elements[80])

(0, 0, 0) -> (0, 0, 2)
(0, 0, 1) -> (0, 0, 0)
(0, 0, 2) -> (0, 0, 1)
(0, 1, 0) -> (0, 1, 0)
(0, 1, 1) -> (0, 1, 1)
(0, 1, 2) -> (0, 1, 2)
(0, 2, 0) -> (0, 2, 1)
(0, 2, 1) -> (0, 2, 2)
(0, 2, 2) -> (0, 2, 0)
(1, 0, 0) -> (1, 2, 0)
(1, 0, 1) -> (1, 2, 1)
(1, 0, 2) -> (1, 2, 2)
(1, 1, 0) -> (1, 0, 1)
(1, 1, 1) -> (1, 0, 2)
(1, 1, 2) -> (1, 0, 0)
(1, 2, 0) -> (1, 1, 2)
(1, 2, 1) -> (1, 1, 0)
(1, 2, 2) -> (1, 1, 1)
(2, 0, 0) -> (2, 2, 2)
(2, 0, 1) -> (2, 2, 0)
(2, 0, 2) -> (2, 2, 1)
(2, 1, 0) -> (2, 0, 2)
(2, 1, 1) -> (2, 0, 0)
(2, 1, 2) -> (2, 0, 1)
(2, 2, 0) -> (2, 1, 1)
(2, 2, 1) -> (2, 1, 2)
(2, 2, 2) -> (2, 1, 0)
