The goal of this notebook is to set up a Genetic algorithm which to learn an improvement on the naive greedy search approach.

In [330]:
import copy
import numpy as np
import torch
import tqdm
from importlib import resources
from ac_solver.envs.utils import is_presentation_trivial
from ac_solver.envs.ac_moves import ACMove
from ast import literal_eval

First we define a helper functions. Whilst we could plug our presentations directly into a neural network, it will likely help the network out if we first convert to a one-hot encoding.

Given that out generators come in $(x, x^{-1})$ pairs, it makes sense for the one hot encodings of $x$ and $x^{-1}$ to be the negatives of each other. Hence we map $x^{\pm 1} \to [\pm 1, 0]$ and $y^{\pm 1} \to [0, \pm 1]$.

Depending on how things go, it might also be worth investigating a one-hot encoding which outputs an array of length $4$. Perhaps it simply isn't worth including our logic about generators being inverses. 

Fixing length of each relation to be $N$, we end up with a matrix of size $2\times N\times 2$. Observe that, our problem exhibits equivariance with respect to the group of order 16: $\mathbb{Z}/2 \times D_4$. This groups is generated by three transformations: 
- Swapping the relations $r_1 \leftrightarrow r_2$.
- Swapping the generators $x \leftrightarrow y$.
- Swapping a generators with it's inverse $x \leftrightarrow x^{-1}$.

Each of these transformations has order $2$ but while the first transformation commutes with everything else, transformations two and three do not commute.

It would be interesting to investigate if we could make the network equivariant with respect to this group but for now we just flatten the tensor to make a vector of length $4N$.


In [263]:
# convert
def one_hot_single(elem: int) -> np.typing.NDArray[np.int32]:
    if elem == -2:
        return [0, -1]
    elif elem == -1:
        return [-1, 0]
    elif elem == 0:
        return [0, 0]
    elif elem == 1:
        return [1, 0]
    elif elem == 2:
        return [0, 1]
    else:
        raise Exception("Unexpected Token Found")

# Convert from the standard presentation to a one_hot encoding mapping a generator to 1 and its inverse to -1.
# We pad the output to length outlen.
def to_one_hot(presentation: np.typing.NDArray[np.int32], out_len: int) -> np.typing.NDArray[np.int32]:
    relator_len = len(presentation) // 2
    first_relator = [one_hot_single(presentation[x]) for x in range(0, relator_len)] + [[0, 0]] * (out_len - relator_len)
    second_relator = [one_hot_single(presentation[x + relator_len]) for x in range(0, relator_len)] + [[0, 0]] * (out_len - relator_len)

    relator_pair = np.array(first_relator + second_relator, dtype=np.float32)
    return relator_pair.flatten()

Next we want to set up the training environment. Essentially we will be doing a greedy like search where we replace our standard metric (i.e. the length of the relators) with a new learned metric.

To help out, we will have a temperature parameter which will initially include the length of the relators but will be slowly phased out over time. 

In [386]:
def score(net, presentation, pres_len, out_len, temperature):
    one_hot_pres = torch.from_numpy(to_one_hot(presentation, out_len))
    eval = net(one_hot_pres).item()
    return eval - pres_len * temperature
    

def search(
        presentation: np.typing.NDArray[np.int32],
        net,
        temperature,
        out_len,
        max_nodes_to_explore: int = 100) -> bool:
    max_relator_length = len(presentation) // 2
    first_word_length = np.count_nonzero(presentation[:max_relator_length])
    second_word_length = np.count_nonzero(presentation[max_relator_length:])

    word_lengths = [first_word_length, second_word_length]

    max_score = score(net, presentation, word_lengths[0] + word_lengths[1], out_len, temperature)
    current_state = presentation
    for i in range(max_nodes_to_explore):
        # print(current_state, max_score)
        current_beaten = False
        for action in range(0, 12):
            new_state, new_word_lengths = ACMove(
                action,
                current_state,
                max_relator_length,
                word_lengths,
                cyclical=True,
            )
            new_state_score = score(net, new_state, new_word_lengths[0] + new_word_lengths[1], out_len, temperature)
            # print(new_state, new_state_score)
            if new_state_score > max_score:
                current_beaten = True
                max_score = new_state_score
                best_state = new_state
                best_lengths = new_word_lengths
        
        if current_beaten:
            if best_lengths[0] + best_lengths[1] == 2:
                return i
            current_state = best_state
            word_lengths = best_lengths
        else:
            return 0
 
    return 0


In [372]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.hid1 = torch.nn.Linear(80, 90)
        self.hid2 = torch.nn.Linear(90, 90)
        self.outp = torch.nn.Linear(90, 1)

        torch.nn.init.uniform_(self.hid1.weight, -0.5, 0.5)
        torch.nn.init.uniform_(self.hid1.bias, -0.5, 0.5)
        torch.nn.init.uniform_(self.hid2.weight, -0.5, 0.5)
        torch.nn.init.uniform_(self.hid2.bias, -0.5, 0.5)
        torch.nn.init.uniform_(self.outp.weight, -0.5, 0.5)
        torch.nn.init.uniform_(self.outp.bias, -0.5, 0.5)

        self.hid1.requires_grad_(False)
        self.hid2.requires_grad_(False)
        self.outp.requires_grad_(False)
    
    def forward(self, x):
        z = torch.relu(self.hid1(x))
        z = torch.relu(self.hid2(z))
        return torch.relu(self.outp(z))
    
    def mutate(self, num_copies, scale):
        output_nets = [copy.deepcopy(self)]
        for _ in range(num_copies - 1):
            clone = copy.deepcopy(self)
            for param in clone.parameters():
                rand = torch.randn(param.size())
                param += 0.03 * scale * rand
            output_nets += [clone]
        return output_nets


In [312]:
## Load data
with open("../ac_solver/search/miller_schupp/data/all_presentations.txt") as file:
    presentations = [np.array(literal_eval(path)) for path in file.readlines()]

In [398]:
scrubbed_presentations = []
swapped_presentations = []
for presentation in presentations:
    if len(presentation) == 40:
        new_pres = presentation
    elif len(presentation) > 40:
        new_pres = np.zeros(40)
        len_rels = len(presentation) // 2
        new_pres[:20:] = presentation[:20:]
        new_pres[20::] = presentation[len_rels:len_rels + 20:]
    elif len(presentation) < 40:
        new_pres = np.zeros(40)
        len_rels = len(presentation) // 2
        new_pres[:len_rels:] = presentation[:len_rels:]
        new_pres[20:20 + len_rels:] = presentation[len_rels::]
    scrubbed_presentations += [new_pres]
    swap_pres = np.zeros(40)
    swap_pres[:20] = new_pres[20:]
    swap_pres[20:] = new_pres[:20]
    swapped_presentations += [swap_pres]

In [362]:
test_nets = [Net() for _ in range(5)]

In [380]:
nets = [Net() for _ in range(100)]
for i in range(100):
    temp = 1 / (1 + (i // 5))
    accs = np.array([0 for _ in range(100)])
    for presentation in tqdm.tqdm(scrubbed_presentations[:100:]):
        for (j, net) in enumerate(nets):
            if search(presentation, net, 1, 20):
                accs[j] += 1
    print(i, max(accs))

    best_nets = [nets[k] for k in np.argpartition(accs, -10)[-10:]]
    mutated_nets = [net.mutate(10, temp) for net in best_nets]
    nets = [net 
            for mutated_net in mutated_nets
            for net in mutated_net 
            ]

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:41<00:00,  2.41it/s]


0 22


100%|██████████| 100/100 [00:43<00:00,  2.31it/s]


1 26


100%|██████████| 100/100 [00:39<00:00,  2.52it/s]


2 29


100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


3 29


100%|██████████| 100/100 [00:41<00:00,  2.43it/s]


4 30


100%|██████████| 100/100 [00:42<00:00,  2.37it/s]


5 32


100%|██████████| 100/100 [00:45<00:00,  2.20it/s]


6 35


100%|██████████| 100/100 [00:51<00:00,  1.93it/s]


7 37


100%|██████████| 100/100 [00:54<00:00,  1.84it/s]


8 40


100%|██████████| 100/100 [00:56<00:00,  1.76it/s]


9 40


100%|██████████| 100/100 [00:56<00:00,  1.76it/s]


10 41


100%|██████████| 100/100 [00:57<00:00,  1.72it/s]


11 42


100%|██████████| 100/100 [00:57<00:00,  1.73it/s]


12 43


100%|██████████| 100/100 [00:57<00:00,  1.75it/s]


13 43


100%|██████████| 100/100 [00:58<00:00,  1.71it/s]


14 44


100%|██████████| 100/100 [00:57<00:00,  1.73it/s]


15 45


100%|██████████| 100/100 [00:58<00:00,  1.71it/s]


16 45


100%|██████████| 100/100 [00:59<00:00,  1.67it/s]


17 46


100%|██████████| 100/100 [01:00<00:00,  1.64it/s]


18 46


100%|██████████| 100/100 [01:00<00:00,  1.66it/s]


19 46


100%|██████████| 100/100 [01:00<00:00,  1.64it/s]


20 46


100%|██████████| 100/100 [01:01<00:00,  1.63it/s]


21 46


100%|██████████| 100/100 [01:01<00:00,  1.62it/s]


22 46


100%|██████████| 100/100 [01:01<00:00,  1.63it/s]


23 46


100%|██████████| 100/100 [01:02<00:00,  1.61it/s]


24 47


100%|██████████| 100/100 [01:01<00:00,  1.62it/s]


25 47


100%|██████████| 100/100 [01:01<00:00,  1.62it/s]


26 47


100%|██████████| 100/100 [01:02<00:00,  1.60it/s]


27 48


100%|██████████| 100/100 [01:02<00:00,  1.60it/s]


28 49


100%|██████████| 100/100 [01:02<00:00,  1.60it/s]


29 49


100%|██████████| 100/100 [01:03<00:00,  1.59it/s]


30 49


100%|██████████| 100/100 [01:03<00:00,  1.58it/s]


31 49


100%|██████████| 100/100 [01:03<00:00,  1.57it/s]


32 49


100%|██████████| 100/100 [01:04<00:00,  1.56it/s]


33 49


100%|██████████| 100/100 [01:04<00:00,  1.54it/s]


34 49


100%|██████████| 100/100 [01:04<00:00,  1.54it/s]


35 49


100%|██████████| 100/100 [01:04<00:00,  1.55it/s]


36 49


100%|██████████| 100/100 [01:05<00:00,  1.53it/s]


37 49


  8%|▊         | 8/100 [00:05<01:07,  1.37it/s]

Unexpected exception formatting exception. Falling back to standard exception



Traceback (most recent call last):
  File "c:\Users\AngusGruen\Documents\GitHub\AC-Solver\env\Lib\site-packages\IPython\core\interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\AngusGruen\AppData\Local\Temp\ipykernel_43388\668976750.py", line 7, in <module>
    if search(presentation, net, 1, 20):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\AngusGruen\AppData\Local\Temp\ipykernel_43388\3512859880.py", line 32, in search
    new_state_score = score(net, new_state, new_word_lengths[0] + new_word_lengths[1], out_len, temperature)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\AngusGruen\AppData\Local\Temp\ipykernel_43388\3512859880.py", line 3, in score
    eval = net(one_hot_pres).item()
           ^^^^^^^^^^^^^^^^^
  File "c:\Users\AngusGruen\Documents\GitHub\AC-Solver\env\Lib\site-packages\torch\nn\modules\module.py", line 1501, in

In [402]:
acc, acc_neg, acc_swap, acc_swap_neg = 0, 0, 0, 0
max_len, max_len_neg, max_len_swap, max_len_swap_neg = 0, 0, 0, 0
for (j, presentation) in tqdm.tqdm(enumerate(scrubbed_presentations)):
    i = search(presentation, nets[0], 1, 20)
    if i:
        max_len = max(max_len, i)
        acc += 1

    i_neg = search(-presentation, nets[0], 1, 20)
    if i_neg:
        max_len_neg = max(max_len_neg, i_neg)
        acc_neg += 1
    
    i_swap = search(swapped_presentations[j], nets[0], 1, 20)
    if i_swap:
        max_len_swap = max(max_len_swap, i_swap)
        acc_swap += 1
    
    i_swap_neg = search(-swapped_presentations[j], nets[0], 1, 20)
    if i_swap_neg:
        max_len_swap_neg = max(max_len_swap_neg, i_swap_neg)
        acc_swap_neg += 1
print(acc, max_len)
print(acc_neg, max_len_neg)
print(acc_swap, max_len_swap)
print(acc_swap_neg, max_len_swap_neg)

1190it [00:13, 90.46it/s] 

114 13
49 7
57 7
56 7





In [403]:
def zero_net(pres):
    return torch.zeros(1)

acc, acc_neg, acc_swap, acc_swap_neg = 0, 0, 0, 0
max_len, max_len_neg, max_len_swap, max_len_swap_neg = 0, 0, 0, 0
for (j, presentation) in tqdm.tqdm(enumerate(scrubbed_presentations)):
    i = search(presentation, zero_net, 1, 20)
    if i:
        max_len = max(max_len, i)
        acc += 1

    i_neg = search(-presentation, zero_net, 1, 20)
    if i_neg:
        max_len_neg = max(max_len_neg, i_neg)
        acc_neg += 1
    
    i_swap = search(swapped_presentations[j], zero_net, 1, 20)
    if i_swap:
        max_len_swap = max(max_len_swap, i_swap)
        acc_swap += 1
    
    i_swap_neg = search(-swapped_presentations[j], zero_net, 1, 20)
    if i_swap_neg:
        max_len_swap_neg = max(max_len_swap_neg, i_swap_neg)
        acc_swap_neg += 1
print(acc, max_len)
print(acc_neg, max_len_neg)
print(acc_swap, max_len_swap)
print(acc_swap_neg, max_len_swap_neg)

1190it [00:05, 216.46it/s]

56 7
56 7
53 7
53 7





In [396]:
-scrubbed_presentations[0]

array([ 1., -2., -1.,  2.,  2., -0., -0., -0., -0., -0., -0., -0., -0.,
       -0., -0., -0., -0., -0., -0., -0.,  1., -2., -0., -0., -0., -0.,
       -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
       -0.])