# What is about ?

### Brief 

(c) Alexander Chervov, Kirill Khoruzhii

Beam search for graphs (Cayley graphs of the permutation groups) basic simple tutorial version.




# Beam search function 

In [2]:
%%time
import numpy as np
import pandas as pd
import time
import torch

def beam_search_permutations_torch(
    state_start = [2,1,0],
    generators = [[0,1,2],[0,2,1]],
    models_or_heuristics  = 'Hamming',
    beam_width = 2,
    state_destination = '01234...',
    n_steps_limit = 100,
    dtype = 'Auto',
    vec_hasher = 'Auto',
    verbose = 0,
):
    '''
    Find path from the "state_start" to the "state_destination" via beam search.

    Main parameters:
        state_start - state to be solved, i.e. from where we need to find path to the destination
        generators - generators of the group
        beam_width - beam width
        state_destination = '01234...' - destination state, typically 0,1,2,3,... - identity permutation
        models_or_heuristics - machine learning model or name for hearistical metric
        n_step_max - maximal number of steps to try
    Technical parameters:
        vec_hasher - vector used for hashing
        dtype      - dtype for states
        verbose    - contols how many text output during the exection
    '''

    ####################################################################################
    # Analyse input params and convert to stadard forms
    ####################################################################################
    # generators_type = 'permutation' # 'matrix'

    # device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(device)

    # Analyse input format of "generators"
    # can be list_generators, or tensor/np.array with rows - generators
    if isinstance(generators, list):
        list_generators = generators
    elif isinstance(generators, torch.Tensor ):
        list_generators = [ list(generators[i,:]) for i in range(generators.shape[0] ) ]
    elif isinstance(generators, np.ndarray ):
        list_generators = [list(generators[i,:]) for i in range(generators.shape[0] ) ]
    else:
        print('Unsupported format for "generators"', type(generators), generators)
        raise ValueError('Unsupported format for "generators" ' + str(type(generators)) )
    state_size = len(list_generators[0])
    tensor_all_generators = torch.tensor( list_generators , device = device, dtype = torch.int64)

    # dtype
    if dtype == 'Auto':
        if state_size <= 256:
            dtype = torch.uint8
        else:
            dtype = torch.uint16

    # Destination state
    if state_destination == '01234...':
        state_destination = torch.arange( state_size, device=device, dtype = dtype).reshape(-1,state_size)
    elif isinstance(state_destination, torch.Tensor ):
        state_destination =  state_destination.to(device).to(dtype).reshape(-1,state_size)
    else:
        state_destination = torch.tensor( state_destination, device=device, dtype = dtype).reshape(-1,state_size)

    # state_start
    if isinstance(state_start, torch.Tensor ):
        state_start =  state_start.to(device).to(dtype).reshape(-1,state_size)
    else:
        state_start = torch.tensor( state_start, device=device, dtype = dtype).reshape(-1,state_size)

    # Vec_hasher
    dtype_for_hash = torch.int64
    if vec_hasher == 'Auto':
        # Hash vector generation
        max_int =  int( (2**62) )     #print(max_int)
        vec_hasher = torch.randint(-max_int, max_int,   size= (state_size,),  device=device, dtype = dtype_for_hash) #
    elif not isinstance( vec_hasher , torch.Tensor):
        vec_hasher = torch.tensor( vec_hasher , device=device, dtype = dtype_for_hash )
    else:
        vec_hasher = vec_hasher.to(device).to(dtype_for_hash)


    ##########################################################################################
    # Initializations
    ##########################################################################################

    # Initialize array of states
    array_of_states = state_start.view(-1, state_size  ).clone().to(dtype).to(device)

    ##########################################################################################
    # Main Loop over steps
    ##########################################################################################
    for i_step in range(1,n_steps_limit+1):

        # Apply generator to all current states
        array_of_states_new = get_neighbors(array_of_states,tensor_all_generators.to(torch.int64) ).flatten(end_dim=1)

        # Take only unique states
        # surprise: THAT IS CRITICAL for beam search performance !!!!
        # if that is not done - beam search  will not find the desired state - quite often
        # The reason - essentianlly beam can degrade, i.e. can be populated by copy of only one state
        # It is surprising that such degradation  happens quite often even for beam_width = 10_000 - but it is indeed so
        array_of_states_new = get_unique_states_2(array_of_states_new, vec_hasher)

        # Check destination state found
        vec_tmp = torch.all(array_of_states_new == state_destination, axis =1) # Compare state_destination and each row array_of_states
        flag_found_destination = torch.any(vec_tmp).item() # Check for coincidence
        if (flag_found_destination) :
            if (verbose >= 10 ):
                print('Found destination state. ', 'i_step:', i_step, ' n_ways:', (vec_tmp).sum())
            break

        # Estimate distance of new states to the destination state (or best moves probabilities for policy models)
        if array_of_states_new.shape[0] > beam_width: # If we have not so many states - we take them all - no need for ML-model
            if models_or_heuristics == 'Hamming':
                estimations_for_new_states = torch.sum( (array_of_states_new == state_destination[0,:]) , dim = 1)
            else:
                raise ValueError('Unsupported models_or_heauristics ' + str(models_or_heauristics) )

            # Take only "beam_width" of the best states (i.e. most nearest to destination according to the model estimate)
            idx = torch.argsort(estimations_for_new_states)[:beam_width]
            array_of_states = array_of_states_new[idx,:]

        else:
            # If number of states is less than beam_width - we take them all:
            array_of_states = array_of_states_new


        if verbose >= 10:
            print(i_step,'i_step', array_of_states_new.shape, 'array_of_states_new.shape' )

    dict_additional_data = {}
    if verbose >= 1:
        print();
        print('Search finished.', 'beam_width:', beam_width)
        if flag_found_destination:
            print(i_step, ' steps to destination state. Path found.')
        else:
            print('Path not found.')

    return flag_found_destination, i_step, dict_additional_data


def get_unique_states_2(states: torch.Tensor, vec_hasher : torch.Tensor) -> torch.Tensor:
    '''
    Return matrix with unique rows for input matrix "states"
    I.e. duplicate rows are dropped.
    For fast implementation: we use hashing via scalar/dot product.
    Note: output order of rows is different from the original.
    '''
    # Note: that implementation is 30 times faster than torch.unique(states, dim = 0) - because we use hashes  (see K.Khoruzhii: https://t.me/sberlogasci/10989/15920)
    # Note: torch.unique does not support returning of indices of unique element so we cannot use it
    # That is in contrast to numpy.unique which supports - set: return_index = True

    device = states.device

    t1 = time.time()
    # Hashing rows of states matrix:
    hashed = torch.sum( states * vec_hasher.to(device), dim=1) # Compute hashes.
        # It is same as matrix product torch.matmul(hash_vec , states )
        # but pay attention: such code work with GPU for integers
        # While torch.matmul - does not work for GPU for integer data types,
        # since old GPU hardware (before 2020: P100, T4) does not support integer matrix multiplication
    t1 = time.time() - t1
    print(t1,'hash')

    # Sort
    t1 = time.time()
    hashed_sorted, idx = torch.sort(hashed)
    t1 = time.time() - t1
    print(t1,'sort')

    # Mask selects elements which are different from the consequite - that is unique elements (since vector is sorted on the previous step)
    t1 = time.time()
    mask = torch.concat((torch.tensor([True], device = device), hashed_sorted[1:] - hashed_sorted[:-1] > 0))
    t1 = time.time() - t1
    print(t1,'mask')
    return states[idx][mask]

def get_neighbors(states, moves):
    """
    Some torch magic to calculate all new states which can be obtained from states by moves
    """
    return torch.gather(
        states.unsqueeze(1).expand(states.size(0), moves.shape[0], states.size(1)),
        2,
        moves.unsqueeze(0).expand(states.size(0), moves.shape[0], states.size(1)))

flag_found_destination, i_step, dict_additional_data  =\
    beam_search_permutations_torch(state_start = [1,0], generators = [[1,0]], verbose = 1 )
flag_found_destination, i_step, dict_additional_data  =\
    beam_search_permutations_torch(state_start = [2,1,0], generators = [[1,0,2],[0,2,1]] , verbose = 1)


cuda
6.556510925292969e-05 hash
6.890296936035156e-05 sort
0.0001518726348876953 mask

Search finished. beam_width: 2
1  steps to destination state. Path found.
cuda
2.5510787963867188e-05 hash
4.839897155761719e-05 sort
8.392333984375e-05 mask
2.3126602172851562e-05 hash
4.172325134277344e-05 sort
7.009506225585938e-05 mask
2.2172927856445312e-05 hash
3.886222839355469e-05 sort
0.00012540817260742188 mask

Search finished. beam_width: 2
3  steps to destination state. Path found.
CPU times: user 4.47 ms, sys: 20.6 ms, total: 25 ms
Wall time: 23.8 ms


# Benchmarks

## Preparations

In [None]:
list_generators_cube333_12gensQTM = [[6, 3, 0, 7, 4, 1, 8, 5, 2, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 47, 21, 22, 50, 24, 25, 53, 27, 28, 38, 30, 31, 41, 33, 34, 44, 36, 37, 20, 39, 40, 23, 42, 43, 26, 45, 46, 29, 48, 49, 32, 51, 52, 35], [0, 1, 2, 3, 4, 5, 6, 7, 8, 15, 12, 9, 16, 13, 10, 17, 14, 11, 36, 19, 20, 39, 22, 23, 42, 25, 26, 45, 28, 29, 48, 31, 32, 51, 34, 35, 27, 37, 38, 30, 40, 41, 33, 43, 44, 18, 46, 47, 21, 49, 50, 24, 52, 53], [44, 43, 42, 3, 4, 5, 6, 7, 8, 45, 46, 47, 12, 13, 14, 15, 16, 17, 24, 21, 18, 25, 22, 19, 26, 23, 20, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 11, 10, 9, 0, 1, 2, 48, 49, 50, 51, 52, 53], [0, 1, 2, 3, 4, 5, 51, 52, 53, 9, 10, 11, 12, 13, 14, 38, 37, 36, 18, 19, 20, 21, 22, 23, 24, 25, 26, 33, 30, 27, 34, 31, 28, 35, 32, 29, 8, 7, 6, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 15, 16, 17], [0, 1, 35, 3, 4, 34, 6, 7, 33, 20, 10, 11, 19, 13, 14, 18, 16, 17, 2, 5, 8, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 9, 12, 15, 42, 39, 36, 43, 40, 37, 44, 41, 38, 45, 46, 47, 48, 49, 50, 51, 52, 53], [24, 1, 2, 25, 4, 5, 26, 7, 8, 9, 10, 27, 12, 13, 28, 15, 16, 29, 18, 19, 20, 21, 22, 23, 17, 14, 11, 6, 3, 0, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 51, 48, 45, 52, 49, 46, 53, 50, 47], [2, 5, 8, 1, 4, 7, 0, 3, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 38, 21, 22, 41, 24, 25, 44, 27, 28, 47, 30, 31, 50, 33, 34, 53, 36, 37, 29, 39, 40, 32, 42, 43, 35, 45, 46, 20, 48, 49, 23, 51, 52, 26], [0, 1, 2, 3, 4, 5, 6, 7, 8, 11, 14, 17, 10, 13, 16, 9, 12, 15, 45, 19, 20, 48, 22, 23, 51, 25, 26, 36, 28, 29, 39, 31, 32, 42, 34, 35, 18, 37, 38, 21, 40, 41, 24, 43, 44, 27, 46, 47, 30, 49, 50, 33, 52, 53], [45, 46, 47, 3, 4, 5, 6, 7, 8, 44, 43, 42, 12, 13, 14, 15, 16, 17, 20, 23, 26, 19, 22, 25, 18, 21, 24, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 2, 1, 0, 9, 10, 11, 48, 49, 50, 51, 52, 53], [0, 1, 2, 3, 4, 5, 38, 37, 36, 9, 10, 11, 12, 13, 14, 51, 52, 53, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 32, 35, 28, 31, 34, 27, 30, 33, 17, 16, 15, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 6, 7, 8], [0, 1, 18, 3, 4, 19, 6, 7, 20, 33, 10, 11, 34, 13, 14, 35, 16, 17, 15, 12, 9, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 8, 5, 2, 38, 41, 44, 37, 40, 43, 36, 39, 42, 45, 46, 47, 48, 49, 50, 51, 52, 53], [29, 1, 2, 28, 4, 5, 27, 7, 8, 9, 10, 26, 12, 13, 25, 15, 16, 24, 18, 19, 20, 21, 22, 23, 0, 3, 6, 11, 14, 17, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 47, 50, 53, 46, 49, 52, 45, 48, 51]]


def scramble_given_state(list_generators, n_scrambles,    state_to_scramble  = '01234...'   ):
    if state_to_scramble == '01234...':
        state_size = len( list_generators[0] )
        state_to_scramble = np.arange(state_size)
    state_current = state_to_scramble
    if isinstance(state_current,list):
        state_current = np.asarray(state_current)
    elif isinstance(state_current,range):
        state_current = np.asarray(state_current)
    n_gens = len(list_generators)
    for k in range(n_scrambles):
        IX_move = np.random.randint(0, n_gens, dtype = int) # random moves indixes
        state_current = state_current[ list_generators[IX_move]] # all_moves[IX_moves,:] ]
    return state_current

print( scramble_given_state( list_generators_cube333_12gensQTM, 0, ) )
print( scramble_given_state( list_generators_cube333_12gensQTM, 10) )


# Main plot - probabilities to find solution by beam search depending on number of scrambles

In [None]:
%%time
# n_scrambles_starting_state = 1000
# print('n_scrambles_starting_state:',n_scrambles_starting_state)

verbose = 0
n_trials = 100
list_n_scrambles = list(range(1,16))


import matplotlib.pyplot as plt
list_generators = list_generators_cube333_12gensQTM

fig = plt.figure(figsize = (20,8))

for beam_width in [1, 10 ,100,1000,10_000,100_000]:

    str_inf = 'Cube333QTM Hamming + beam_width ' + str(beam_width)
    if verbose >= 1:
        print(str_inf, 'n_trials:', n_trials)

    df_stat = pd.DataFrame()

    for i_trial in range(n_trials):
        if i_trial < 25:
            #print()
            if verbose>=10:
                print('trial:', i_trial)
        # destination state - "solved puzzle state"
        state_size = len( list_generators[0])
        #state_destination = torch.arange( state_size, device=device, dtype = dtype)
        state_destination = range(state_size)

        #state_start

        prm_name = 'n_scramble '
        list_prm_value = list_n_scrambles# list(range(1,15))
        for prm_value in list_prm_value:#  5_000,10_000, 30_000]:#, 100_000]:
            n_scrambles_starting_state = prm_value

            # Scramble - generate state which will be start of beam search
            state_start = state_destination
            state_start = scramble_given_state( list_generators, n_scrambles_starting_state, state_start )

            # Beam search
            t0 = time.time()
            flag_found_destination, i_step, dict_additional_data  =\
                beam_search_permutations_torch(state_start = state_start, generators = list_generators,
                    beam_width = beam_width,
                    models_or_heuristics  = 'Hamming'     ,
                    state_destination = '01234...'  ,
                    n_steps_limit = 100 ,
                    verbose = 0)

            if i_trial < 1:
                if verbose >=10:
                    print('Found:',flag_found_destination,'steps:', i_step,'beam_width:', beam_width,
                          'time: %.1f secs'%(time.time()-t0))

            df_stat.loc[i_trial,'Path length.  '+prm_name+str(prm_value)] =  i_step
            df_stat.loc[i_trial,'Solution found. '+prm_name +str(prm_value)] =  int(flag_found_destination)
            df_stat.loc[i_trial,'Time. '+prm_name+str(prm_value)] =  np.round(time.time()-t0,1)

    df_stat.to_csv('stat_beam_search_'+prm_name+'_'+str_inf+'.csv')
    if verbose >=10:
        display(df_stat)
        display(df_stat.describe().round(3).T)
        display(df_stat.describe().round(3).to_csv('aggregated_stat_beam_search_'+prm_name+'_'+str_inf+'.csv'))


    df_loc = df_stat.describe()
    dat_loc = []

    col_key = 'Solution found'
    for col in df_stat.columns:
        if not( col_key in col) : continue
        dat_loc.append(df_loc.loc['mean', col] )

    plt.plot(list_prm_value, dat_loc, '*-',label = str_inf)
    #plt.title(col_key , fontsize = 20  )
    plt.title('Probability to find solution by beam search' , fontsize = 20  )
    plt.ylabel('probability to find solution', fontsize = 20 )
    plt.xlabel('n_scrambles', fontsize = 20 )
    plt.xticks(list_prm_value, list_prm_value)
    plt.legend(fontsize = 20 )
    plt.savefig(col_key.replace('.', ' ') + '.png')



plt.grid()
plt.show()
