In [102]:
from scipy.signal import convolve2d, correlate
import numpy as np, torch
from misc_tools.print_latex import print_tex
import torch.nn.functional as F
device = "cpu"#torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 0) Implementation (for full loop in the end)
1. Generate a puzzle `p`:
    * Zero array of shape N by N, filled randomly by half (N*N/2) with ones
1. Split it into view of 0s and 1s `p_split`
1. Find which entries can be swapped. (makes sense only for opposite signs, otherwise the array wont change)
    * `neighbor_swaps` (shape [2,4,N,N]) shows if each (dim 0) entry can go to one of 4 directions (dim 1) Right, Up, Left, Down
    * `mobility` (shape [2,N,N]) same but if can move in either direction (is mobile)
1. Calculate constraints
    * Only 1 neighbor is accepted in either horizontal or vertical direction
        * `near_neighbors` (shape [2,2,N,N]) shows if entry fails this test for vertical and horizontal dimension (dim 1)
        * If element fails in particular direction, also its neighbor considered a culprit (usually is)
    * Each row can contain only `N/2` entries
        * `row_col_constr_mask` (shape [2,2,N,N]) marks all rows and columns (dim 1) that fail this constraint
        * Order is swapped. Instead of vertical & horizontal, here its reversed. Reason: Encourage movement in orthogonal direction
1. (Additional) Accumulate history (kind of constraint) from previous steps to encourage movement of entries that stay in favorable position for long time
1. Estimate total energy (which entry should be swapped)
    * Add two constraint energies
    * Mask out entries that are immobile

1. Option exploration:
    * Take entry with highest energy (if many with same, at random)
    * Check which directions it can move via `neighbor_swaps`
    * Perform swap and calculate nearest neighbors and row/column constraint energies
    * Select one with lowest resulting energy

1. Repeat with modified puzzle until max iterations of if energy of two main constraints becomes 0

img: Channels, H, W

channel 0 -> position of zeros
channel 1 -> position of ones

In [103]:
puzz_true = torch.tensor([  [1,0,1,0,1,0],
                            [0,1,0,0,1,1],
                            [1,0,0,1,0,1],
                            [0,1,1,0,1,0],
                            [0,0,1,1,0,1],
                            [1,1,0,1,0,0]])

puzz_6 = torch.zeros(size=(6,6), dtype=torch.long)
puzz_8 = torch.zeros(size=(8,8), dtype=torch.long)

p = torch.zeros_like(puzz_6)
N = p.shape[-1]

p[:,:N//2] = 1
# p = torch.ones_like(puzz)
# p = torch.tril(p)

def split_0_1(p):
    s = p.unsqueeze(0)
    return torch.cat((1 - s, s), dim = 0)

puzz = split_0_1(p)

print(f'{puzz.shape = }')
print_tex('p =',p ,r'\rightarrow (zeroes,ones) =' ,*puzz)


puzz.shape = torch.Size([2, 6, 6])


<IPython.core.display.Math object>

# 1) Check if entries can move (can be swapped)

## 1.1) Count neighbors for 0s and 1s in each of 4 directions
Generate directed neighbor counter and mask to suppress edge cases

After convolution array is of shape [entry type, direction, *image size] = [2,4,N,N].

Entry has the following value and meaning: 
* 2 = entry has a neighbor of same type
* 1 = entry has a neighbor of opposite type
* 0 = disregard entry, its of a  wrong type for this query

Swapping places makes sense only for opposite types. Otherwise nothing changes.

In [104]:
k_n_r = torch.tensor([  [0, 0, 0],
                        [0, 1, 1],
                        [0, 0, 0]]).unsqueeze(0).unsqueeze(0)   # [1,1,3,3]

# convolution will not broadcast (95% sure), have to copy twice
k_n   = torch.repeat_interleave(    torch.cat( [torch.rot90(k_n_r, i, dims = (-2,-1)) for i in range(4)], dim = 0)  # [4,1,3,3]
                                , 2, dim = 1)                                                                       # [4,2,3,3]
gpu_conv = lambda image, kernel: F.conv2d(image, kernel, bias=None, stride=1, padding='same', dilation=1, groups=1).to(device)

m_r = torch.ones_like(p)
m_r[:,-1] = 0
edge_masks = torch.stack([torch.rot90(m_r, i ) for i in range(4)]).unsqueeze(0) # [1,4,N,N] dim0 will broadcast

print_tex(r'for \ zeroes: ', *k_n[:,0], r'; \ for \ ones: ', *k_n[:,1])

<IPython.core.display.Math object>

*   Cannot do an easy convolution to apply 4 different kernels for each channel C_in.

    Ideally i would do img * ker := [C_in,*img shape] x [C_out, C_in, *ker shape] = [C_in, C_out, *img shape]

    But `conv2D` sums over C_out dimension. Have to do it manually.
* mask against diffusion (empty entries summing up non-zero neighbors)

In [105]:
def neighbors_4_dirs_0_1(p):
    n       = torch.zeros(size=(2,4,N,N))
    n[[0]]  += gpu_conv(p[[0]],k_n[:,[0]]).unsqueeze(0)*p[[0]]
    n[[1]]  += gpu_conv(p[[1]],k_n[:,[1]]).unsqueeze(0)*p[[1]]
    n       *= edge_masks
    return n

In [106]:
neighbors_4_dirs = neighbors_4_dirs_0_1(puzz)

print_tex('0s',*neighbors_4_dirs[0])
print_tex('1s',*neighbors_4_dirs[1])


<IPython.core.display.Math object>

<IPython.core.display.Math object>

## 1.2) `x1` neighbors = can swap
* directed swap mask [2, 4, N, N]
* general swap mask: aggregate all (4) transitions into one (1) bool can/cannot move [2, N, N]

In [107]:
def swap_possible(p):
    neighbors_4_dirs = neighbors_4_dirs_0_1(p)
    return  (neighbors_4_dirs == 1).to(int)

In [108]:
can_swap_4_dirs = swap_possible(puzz)

mobility = torch.any(can_swap_4_dirs, dim = 1).to(int)
print_tex(r'\text{Mobility (0s \& 1s) = }',*mobility)

<IPython.core.display.Math object>

# 2) Constraint #1

## Check if entry is agitated by `number of nearby neighbors` (horizontal and vertical)

## 2.1) Count neighbors, mask against diffusion
Want result as array of shape 
* [num types = 2, directions = 2, *img size]

convolve dims:
* [Batch, C_in, *img shape] $\star$ [C_out, C_in, *ker shape] -> [Batch, C_out, *img shape]

so 
* [num types = 2, C_in = 1] $\star$ [ dirs = 2, C_in = 1] -> [num types = 2, dirs = 2]

In [109]:
k_n_h = torch.tensor([  [0, 0, 0],
                        [1, 1, 1],
                        [0, 0, 0]])
k_n_v = torch.rot90(k_n_h, 1)

k_n_vh = torch.stack((k_n_v,k_n_h)).unsqueeze(1)

print(f'{puzz.unsqueeze(1).shape = }')
print(f'{k_n_vh.shape = }')


puzz.unsqueeze(1).shape = torch.Size([2, 1, 6, 6])
k_n_vh.shape = torch.Size([2, 1, 3, 3])


Mask to consider only 1s entries (1s as markers)

In [110]:
num_neighbors_v_h = gpu_conv(puzz.unsqueeze(1),k_n_vh)*puzz.unsqueeze(1)
print(f'{num_neighbors_v_h.shape = }')
print_tex(r'\text{v \& h neighbors. for }0s:', *num_neighbors_v_h[0],'1s:', *num_neighbors_v_h[1])

num_neighbors_v_h.shape = torch.Size([2, 2, 6, 6])


<IPython.core.display.Math object>

## 2.2) Mark agitated = 2 nearby neighbors + 1 self
* vertical & horizontal [2,2,N,N]

In [111]:
neighb_constraint_fail = (num_neighbors_v_h == 3).to(int)
print_tex(r'\text{Agitated vertical \& horizontal } 0s:',*neighb_constraint_fail[0],'1s:', *neighb_constraint_fail[1])
print(f'{neighb_constraint_fail.shape = }')

<IPython.core.display.Math object>

neighb_constraint_fail.shape = torch.Size([2, 2, 6, 6])


In [112]:
def energy_neighbors(p, sum = True):
    neighb_constraint_fail = gpu_conv(p.unsqueeze(1),k_n_vh)*p.unsqueeze(1) == 3
    if sum:
        return torch.sum(neighb_constraint_fail).item()
    else: 
        return neighb_constraint_fail.to(int)

energy_neighbors(puzz)


36

## 2.3) Consider also their nearest neighbors
Vertically agitated entries are agitated by vertical neighbors. Same with horizontal.

In [113]:
def expand_agitated_neighbors(p):
    a = energy_neighbors(p, sum= False)
    b = torch.zeros(size=(2,2,N,N))
    b[:,[0]] += gpu_conv(a[:,[0]],k_n_vh[[0]])
    b[:,[1]] += gpu_conv(a[:,[1]],k_n_vh[[1]])
    b = (b > 0).to(int)
    return b

In [114]:
neighb_constraint_fail_ext = expand_agitated_neighbors(puzz)
print_tex(r'\text{Agitated entries \& neighbors. } 0s:',*neighb_constraint_fail_ext[0],'1s:', *neighb_constraint_fail_ext[1])


<IPython.core.display.Math object>

# Constraint #2.

## Check if entry is agitated by row/col element count
High number of elements in columns should encourage to migrate entires along rows.

Thats why order is swapped

In [115]:
def row_col_constraint_rough_mask(p):
    a = torch.zeros(size = (2,2,N,N), dtype = int)
    a[0,0] += (torch.sum(p[0], dim= 1, keepdims=True)>N//2)
    a[0,1] += (torch.sum(p[0], dim= 0, keepdims=True)>N//2)
    a[1,0] += (torch.sum(p[1], dim= 1, keepdims=True)>N//2)
    a[1,1] += (torch.sum(p[1], dim= 0, keepdims=True)>N//2)
    return a

def row_col_constraint_energy(p):
    a = 0
    a += torch.sum(torch.sum(p[0], dim= 1, keepdims=False)>N//2)
    a += torch.sum(torch.sum(p[0], dim= 0, keepdims=False)>N//2)
    a += torch.sum(torch.sum(p[1], dim= 1, keepdims=False)>N//2)
    a += torch.sum(torch.sum(p[1], dim= 0, keepdims=False)>N//2)
    return a

In [116]:
row_col_constraint_fail = row_col_constraint_rough_mask(puzz)
print(row_col_constraint_energy(puzz))
print_tex(r'\text{Agitated entries \& neighbors. } 0s:',*row_col_constraint_fail[0],'1s:', *row_col_constraint_fail[1])

tensor(6)


<IPython.core.display.Math object>

# 3) Add constrain energies

In [117]:
def row_col_constraint_masked(p, sum = False):
    agi_neigh   = expand_agitated_neighbors(p)
    agit_r_c    = row_col_constraint_rough_mask(p)
    res = (agi_neigh + agit_r_c)* agi_neigh
    if sum:
        return torch.sum(res).item()
    else:
        return res

In [118]:
tot_energy = row_col_constraint_masked(puzz)
print_tex(r'\text{Agitated entries \& neighbors. } 0s:',*tot_energy[0],'1s:', *tot_energy[1])

<IPython.core.display.Math object>

In [119]:
tot_energy_masked = tot_energy*mobility
print_tex(r'\text{Agitated entries \& neighbors. } 0s:',*tot_energy_masked[0],'1s:', *tot_energy_masked[1])

<IPython.core.display.Math object>

In [120]:
def swap_2_elements(arr, pos1, pos2):
    pos1, pos2 = tuple(pos1), tuple(pos2)
    temp = arr[pos1].clone()
    arr[pos1] = arr[pos2]
    arr[pos2] = temp

In [121]:
ijplus = np.array([[0,1], [-1,0], [0,-1],[1,0]])   # Right, up, left down. On matrix y is reversed.

# 3) Implementation (Repeated)

1. Generate a puzzle `p`:
    * Zero array of shape N by N, filled randomly by half (N*N/2) with ones
1. Split it into view of 0s and 1s `p_split`
1. Find which entries can be swapped. (makes sense only for opposite signs, otherwise the array wont change)
    * `neighbor_swaps` (shape [2,4,N,N]) shows if each (dim 0) entry can go to one of 4 directions (dim 1) Right, Up, Left, Down
    * `mobility` (shape [2,N,N]) same but if can move in either direction (is mobile)
1. Calculate constraints
    * Only 1 neighbor is accepted in either horizontal or vertical direction
        * `near_neighbors` (shape [2,2,N,N]) shows if entry fails this test for vertical and horizontal dimension (dim 1)
        * If element fails in particular direction, also its neighbor considered a culprit (usually is)
    * Each row can contain only `N/2` entries
        * `row_col_constr_mask` (shape [2,2,N,N]) marks all rows and columns (dim 1) that fail this constraint
        * Order is swapped. Instead of vertical & horizontal, here its reversed. Reason: Encourage movement in orthogonal direction
1. (Additional) Accumulate history (kind of constraint) from previous steps to encourage movement of entries that stay in favorable position for long time
1. Estimate total energy (which entry should be swapped)
    * Add two constraint energies
    * Mask out entries that are immobile

1. Option exploration:
    * Take entry with highest energy (if many with same, at random)
    * Check which directions it can move via `neighbor_swaps`
    * Perform swap and calculate nearest neighbors and row/column constraint energies
    * Select one with lowest resulting energy

1. Repeat with modified puzzle until max iterations of if energy of two main constraints becomes 0

In [122]:
p = torch.zeros_like(p)
N = p.shape[-1]

#p[:,:N//2] = 1

# p = torch.ones_like(p)
# p = torch.tril(p)
# d = torch.zeros(N).to(int)
# d[:N//2] = 1
# p -= torch.diag(d)

#np.random.seed(10)
indices_flat    = np.random.choice(np.arange(N*N), size = N*N//2, replace=False)
indices         = np.unravel_index(indices_flat, shape= (N,N))
p[indices] = 1

p_OG = p.clone()
p_prev = p.clone()
energy_min = -1
swaps_hist = []
near_neighbors_cum = None
MAX_ITERS = 7000
for iter in range(MAX_ITERS):
    energy_min_prev     = energy_min
    p_split             = split_0_1(p)                                              # split into masks for 0s and 1s
    neighbor_swaps      = swap_possible(p_split)                                    # determine directions where types can be swapped
    mobility            = torch.any(neighbor_swaps, dim = 1).to(int)                # combine directions. can move = yes/no

    near_neighbors      = expand_agitated_neighbors(p_split)                        # find failing nearest neighbors consider their neighbors too
    if near_neighbors_cum is not None:          
        near_neighbors += (near_neighbors_cum*(near_neighbors > 0))         
    near_neighbors_cum = near_neighbors.clone()         
    row_col_constr_mask = row_col_constraint_rough_mask(p_split)                    # weights from col/row number constraint
    tot_energy          = (near_neighbors + row_col_constr_mask)#*(near_neighbors>0)# add contribution and mask. 
    tot_energy_masked   = (tot_energy.transpose(0,1)*mobility).transpose(0,1)       # 

    n = 0
    energies = []
    sols = []
    swaps = []
    variants = torch.argwhere(tot_energy == torch.max(tot_energy_masked))
    for var in variants:
        tp, dir, i, j = var.numpy()

        possible_displ = neighbor_swaps[tp, :, i, j]
        for k, can_swap in enumerate(possible_displ):
            if can_swap:
                pc = p.clone()
                pos2 = ijplus[k] + [i,j]
                swap_2_elements(pc, (i,j), pos2)
                pc_split    = split_0_1(pc)
                NE          = energy_neighbors(         pc_split)
                RCE         = row_col_constraint_energy(pc_split)
                energies.append(NE + RCE)
                if NE + RCE == 0:
                    a = 1
                sols.append(pc.clone())
                swaps.append((tuple((i,j)), tuple(pos2)))
                n += 1

    energy_min      = torch.tensor(energies).min()
    where_min_all   = torch.argwhere(torch.tensor(energies) == energy_min).flatten()
    where_min = where_min_all[torch.randint(len(where_min_all), size= (1,))[0]]
    swaps_hist.append(swaps[where_min])
    p_prev = p.clone()
    p = sols[where_min].clone()
    #print(swaps_hist[-3:])  
    if iter in range(MAX_ITERS-1,MAX_ITERS) or energy_min == 0:
        print(f"iter: {iter}. p_{iter + 1} is achieved from p_{iter} via index swap:", swaps_hist[-1], f'; Energies change: {energy_min_prev}->{energy_min}')
        print_tex(r'p_{'+str(iter + 1)+'}:', p_prev, r' \ p_{'+str(iter)+'}:', p)
        print(f"iter: {iter}. p_{iter} has the following properties:")
        print(f"iter: {iter}. Mobility 0s and 1s:")
        print_tex(*mobility)
        print(f"iter: {iter}. Expanded agitated nearest neighbors v & h.")
        print_tex(r'\text{0s:}',*near_neighbors[0], r" \text{and 1s: }", *near_neighbors[1])
        print(f"iter: {iter}. Columns & rows that violate constraint #2 (not masked).")
        print_tex(r'\text{0s:}',*row_col_constr_mask[0], r" \text{and 1s: }", *row_col_constr_mask[1])
        print(f"iter: {iter}. Total energy (nearest neighbors + colr \ row)")
        print_tex(r'0s:',*tot_energy[0],'1s:', *tot_energy[1])
        print(f"iter: {iter}. Total energy (nearest neighbors + colr \ row) Masked by mobility")
        print_tex(r'0s:',*tot_energy_masked[0],'1s:', *tot_energy_masked[1])
        #print_tex(neighbor_swaps)
        if energy_min == 0 or iter == MAX_ITERS - 1: 
            print(f'SOLUTION ({iter} iterations):')
            print_tex(p,    torch.sum(p, dim = 1, keepdim=True))
            print_tex(      torch.sum(p, dim = 0, keepdim=True))
            print('Input puzzle:')
            print_tex(p_OG)
            break

iter: 40. p_41 is achieved from p_40 via index swap: ((1, 1), (0, 1)) ; Energies change: 2->0


<IPython.core.display.Math object>

iter: 40. p_40 has the following properties:
iter: 40. Mobility 0s and 1s:


<IPython.core.display.Math object>

iter: 40. Expanded agitated nearest neighbors v & h.


<IPython.core.display.Math object>

iter: 40. Columns & rows that violate constraint #2 (not masked).


<IPython.core.display.Math object>

iter: 40. Total energy (nearest neighbors + colr \ row)


<IPython.core.display.Math object>

iter: 40. Total energy (nearest neighbors + colr \ row) Masked by mobility


<IPython.core.display.Math object>

SOLUTION (40 iterations):


<IPython.core.display.Math object>

<IPython.core.display.Math object>

Input puzzle:


<IPython.core.display.Math object>