In [16]:
import numpy as np

In [69]:
def show_state(state):
    return np.argwhere(state!= 0)
show_state(np.array([[1,0,3],[0,1,0]]))

array([[0, 0],
       [0, 2],
       [1, 1]])

In [17]:
def p_matrix(s, i, j, k):
    # create pivot matrix with only one element at i, j, k being 1, the rest being 0
    # same shape as the state s
    p = np.zeros(shape = s.shape)
    p[i][j][k] = 1
    return p

In [59]:
def amgm(state, f, i1, j1, k1, i2, j2, k2):
    # simple two-element amgm, f is the scaling factor
    
    assert (i1+i2)%2 + (j1+j2)%2 + (k1+k2)%2 == 0
    assert not(i1==i2 and j1==j2 and k1==k2)
    # checking for operation validity
    
    operation = np.zeros(shape = state.shape)
    operation[i1][j1][k1] = -f
    operation[i2][j2][k2] = -f
    operation[int((i1+i2)/2)][int((j1+j2)/2)][int((k1+k2)/2)] = 2 * np.sqrt(f)
    
    return operation

In [49]:
def reward(state, goal):
    assert state.shape == goal.shape
    
    diff = goal - state
    diff = np.multiply(diff, diff)
    diff = diff.sum()
    return 1 / diff

In [56]:
def step(state, operation, goal):
    state += operation
    r = reward(state, goal)
    if r > 1000:
        return("Inequality proved")
    else:
        return [state, reward]

In [61]:
n = 5

state = np.zeros((n, n, n))
state[4][0][2] = 1
state[2][4][0] = 1
state[0][2][4] = 1
print(state)

goal = np.zeros((n, n, n))
goal[1][3][2] = 1
goal[2][1][3] = 1
goal[3][2][1] = 1
print(goal)

[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 1.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [1. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 1. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]]
[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]]


In [57]:
print(reward(state, goal))
print(reward(goal, goal))

0.16666666666666666
inf


  import sys


In [60]:
amgm(state, 1, 2, 2, 0, 0, 2, 2)

array([[[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0., -1.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]],

       [[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  2.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]],

       [[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [-1.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]],

       [[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]],

       [[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]]])