The goal of this notebook is to set up a Genetic algorithm which to learn an improvement on the naive greedy search approach.

In [1]:
import copy
import numpy as np
import torch
import tqdm
from importlib import resources
from ac_solver.envs.utils import is_presentation_trivial
from ac_solver.envs.ac_moves import ACMove
from ast import literal_eval

First we define a helper functions. Whilst we could plug our presentations directly into a neural network, it will likely help the network out if we first convert to a one-hot encoding.

Given that out generators come in $(x, x^{-1})$ pairs, it makes sense for the one hot encodings of $x$ and $x^{-1}$ to be the negatives of each other. Hence we map $x^{\pm 1} \to [\pm 1, 0]$ and $y^{\pm 1} \to [0, \pm 1]$.

Depending on how things go, it might also be worth investigating a one-hot encoding which outputs an array of length $4$. Perhaps it simply isn't worth including our logic about generators being inverses. 

Fixing length of each relation to be $N$, we end up with a matrix of size $2\times N\times 2$. Observe that, our problem exhibits equivariance with respect to the group of order 16: $\mathbb{Z}/2 \times D_4$. This groups is generated by three transformations: 
- Swapping the relations $r_1 \leftrightarrow r_2$.
- Swapping the generators $x \leftrightarrow y$.
- Swapping a generators with it's inverse $x \leftrightarrow x^{-1}$.

Each of these transformations has order $2$ but while the first transformation commutes with everything else, transformations two and three do not commute.

It would be interesting to investigate if we could make the network equivariant with respect to this group but for now we just flatten the tensor to make a vector of length $4N$.


In [2]:
# Swap the two relations of a given presentation.
def swap_relations(presentation, relation_length):
    new_presentation = np.zeros(2 * relation_length)
    new_presentation[:relation_length] = presentation[relation_length:]
    new_presentation[relation_length:] = presentation[:relation_length]
    return new_presentation

# Swap the two generators of a given presentation.
def swap_generators(presentation, relation_length):
    new_presentation = np.zeros(2 * relation_length)
    for (i, elem) in enumerate(presentation):
        if elem > 0:
            # Map 2 -> 1 and 1 -> 2
            new_presentation[i] = 3 - elem
        elif elem < 0:
            # Map -2 -> -1 and -1 -> -2
            new_presentation[i] = -3 - elem
    return new_presentation

# Swap the first generator with its inverse
def invert_x(presentation, relation_length):
    new_presentation = copy.copy(presentation)
    for (i, elem) in enumerate(presentation):
        if elem == 1:
            # Map 1 -> -1
            new_presentation[i] = -1
        elif elem == -1:
            # Map -1 -> 1
            new_presentation[i] = 1

    return new_presentation

# Find all sixteen equivalent presentations related to the original by our equivariance group.
# 
# Each element of the equivariance group can be written uniquely as (g1)^{i1}*(g2)^{i2}*(g3)^{i3}*(g4)^{i4}
# where i1, i2, i3, i4 as 0 or 1 and
# g4: Swap the relations.
# g3: Swap the generators.
# g2: Swap the first generator with its inverse.
# g1: Swap both generators with their inverse.
# Note that  g1(P) = g3 * g2 * g3 * g2(P) so it isn't a "true" generator but its good enough for here.
def group_equivalency_class(presentation, relation_length):
    equivalence_class = [presentation]
    equivalence_class += [swap_relations(presentation, relation_length)]
    equivalence_class += [swap_generators(pres, relation_length) for pres in equivalence_class]
    equivalence_class += [invert_x(pres, relation_length) for pres in equivalence_class]
    equivalence_class += [-pres for pres in equivalence_class]

    return equivalence_class



# Convert an element from the {+/- 2, +/- 1, 0} representation into the one-hot representation.  
def one_hot_single(elem: int) -> np.typing.NDArray[np.int32]:
    if elem == -2:
        return [0, -1]
    elif elem == -1:
        return [-1, 0]
    elif elem == 0:
        return [0, 0]
    elif elem == 1:
        return [1, 0]
    elif elem == 2:
        return [0, 1]
    else:
        raise Exception("Unexpected Token Found")

# Convert from the standard presentation to a one_hot encoding mapping a generator to 1 and its inverse to -1.
# We pad the output to length outlen.
def to_one_hot(presentation: np.typing.NDArray[np.int32], out_len: int) -> np.typing.NDArray[np.int32]:
    relator_len = len(presentation) // 2
    first_relator = [one_hot_single(presentation[x]) for x in range(0, relator_len)] + [[0, 0]] * (out_len - relator_len)
    second_relator = [one_hot_single(presentation[x + relator_len]) for x in range(0, relator_len)] + [[0, 0]] * (out_len - relator_len)

    relator_pair = np.array(first_relator + second_relator, dtype=np.float32)
    return relator_pair.flatten()

Next we want to set up the training environment. Essentially we will be doing a greedy like search where we replace our standard metric (i.e. the length of the relators) with a new learned metric.

To help out, we will have a temperature parameter which will initially include the length of the relators but will be slowly phased out over time. 

In [3]:
def score(net, presentation, pres_len, out_len, temperature):
    one_hot_pres = torch.from_numpy(to_one_hot(presentation, out_len))
    eval = net(one_hot_pres).item()
    return eval + pres_len * temperature
    

def search(
        presentation: np.typing.NDArray[np.int32],
        net,
        temperature,
        out_len,
        max_nodes_to_explore: int = 100) -> bool:
    max_relator_length = len(presentation) // 2
    first_word_length = np.count_nonzero(presentation[:max_relator_length])
    second_word_length = np.count_nonzero(presentation[max_relator_length:])

    word_lengths = [first_word_length, second_word_length]

    min_score = score(net, presentation, word_lengths[0] + word_lengths[1], out_len, temperature)
    current_state = presentation
    for i in range(max_nodes_to_explore):
        # print(current_state, max_score)
        current_beaten = False
        for action in range(12):
            new_state, new_word_lengths, state_updated = ACMove(
                action,
                current_state,
                max_relator_length,
                word_lengths,
                cyclical=False,
            )
            if state_updated:
                new_state_score = score(net, new_state, new_word_lengths[0] + new_word_lengths[1], out_len, temperature)
                # print(new_state, new_state_score)
                if new_state_score <= min_score:
                    current_beaten = True
                    min_score = new_state_score
                    best_state = new_state
                    best_lengths = new_word_lengths
        
        if current_beaten:
            if best_lengths[0] + best_lengths[1] == 2:
                return i
            current_state = best_state
            word_lengths = best_lengths
        else:
            return 0
 
    return 0


In [4]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.hid1 = torch.nn.Linear(80, 90)
        self.hid2 = torch.nn.Linear(90, 90)
        self.hid3 = torch.nn.Linear(90, 90)
        self.outp = torch.nn.Linear(90, 1)

        torch.nn.init.uniform_(self.hid1.weight, -0.5, 0.5)
        torch.nn.init.uniform_(self.hid1.bias, -0.5, 0.5)
        torch.nn.init.uniform_(self.hid2.weight, -0.5, 0.5)
        torch.nn.init.uniform_(self.hid2.bias, -0.5, 0.5)
        torch.nn.init.uniform_(self.hid3.weight, -0.5, 0.5)
        torch.nn.init.uniform_(self.hid3.bias, -0.5, 0.5)
        torch.nn.init.uniform_(self.outp.weight, -0.5, 0.5)
        torch.nn.init.uniform_(self.outp.bias, -0.5, 0.5)

        self.hid1.requires_grad_(False)
        self.hid2.requires_grad_(False)
        self.hid3.requires_grad_(False)
        self.outp.requires_grad_(False)
    
    def forward(self, x):
        z = torch.relu(self.hid1(x))
        z = torch.relu(self.hid2(z))
        z = torch.relu(self.hid3(z))
        return torch.relu(self.outp(z))
    
    def mutate(self, num_copies, scale):
        output_nets = [copy.deepcopy(self)]
        for _ in range(num_copies - 1):
            clone = copy.deepcopy(self)
            for param in clone.parameters():
                rand = torch.randn(param.size())
                param += 0.03 * scale * rand
            output_nets += [clone]
        return output_nets


In [5]:
## Load data
with open("../ac_solver/search/miller_schupp/data/all_presentations.txt") as file:
    miller_presentations = [np.array(literal_eval(path)) for path in file.readlines()]

## Load data
with open("../ac_solver/search/miller_schupp/data/simple_presentations.txt") as file:
    simple_presentations = [np.array(literal_eval(path)) for path in file.readlines()]

In [6]:
desired_length = 20
scrubbed_miller_presentations = []
equivariant_miller_presentation_set = []
for presentation in miller_presentations:
    if len(presentation) == 2*desired_length:
        new_pres = presentation
    elif len(presentation) > 2*desired_length:
        new_pres = np.zeros(2 * desired_length)
        len_rels = len(presentation) // 2
        new_pres[:desired_length:] = presentation[:desired_length:]
        new_pres[desired_length::] = presentation[len_rels:len_rels + desired_length:]
    elif len(presentation) < 2*desired_length:
        new_pres = np.zeros(2 * desired_length)
        len_rels = len(presentation) // 2
        new_pres[:len_rels:] = presentation[:len_rels:]
        new_pres[desired_length:desired_length + len_rels:] = presentation[len_rels::]
    scrubbed_miller_presentations += [new_pres]
    equivariant_miller_presentation_set += group_equivalency_class(new_pres, 20)

print(len(scrubbed_miller_presentations))
print(len(equivariant_miller_presentation_set))

1190
19040


In [7]:
desired_length = 20
scrubbed_simple_presentations = []
equivariant_simple_presentation_set = []
for presentation in simple_presentations:
    if len(presentation) == 2*desired_length:
        new_pres = presentation
    elif len(presentation) > 2*desired_length:
        new_pres = np.zeros(2 * desired_length)
        len_rels = len(presentation) // 2
        new_pres[:desired_length:] = presentation[:desired_length:]
        new_pres[desired_length::] = presentation[len_rels:len_rels + desired_length:]
    elif len(presentation) < 2*desired_length:
        new_pres = np.zeros(2 * desired_length)
        len_rels = len(presentation) // 2
        new_pres[:len_rels:] = presentation[:len_rels:]
        new_pres[desired_length:desired_length + len_rels:] = presentation[len_rels::]
    scrubbed_simple_presentations += [new_pres]
    equivariant_simple_presentation_set += group_equivalency_class(new_pres, 20)

print(len(scrubbed_simple_presentations))
print(len(equivariant_simple_presentation_set))

35
560


In [8]:
# 1 - 533: All cases solved by Greedy Search
# 1 - 170: all the n = 1 cases. 2720 = 170 * 16
# 171 - 308: all the n = 2 cases.
# 309 - 406: all the n = 3 cases.
# 407 - 458: all the n = 4 cases.
# 459 - 493: all the n = 5 cases.
# 494 - 517: all the n = 6 cases.
# 518 - 533: all the n = 7 cases.

# 534 - 565: The n = 2 cases not solved by greedy search 
num_mutations = 10
num_best = 10
num_nets = num_best * num_mutations
num_evolutions = 30
temp_scaling = 5
nets = [Net() for _ in range(num_nets)]
for i in range(num_evolutions):
    temp = 1 / (1 + (i // temp_scaling))
    # temp_score = (60 - i)/60
    accs = np.array([0 for _ in range(num_nets)])
    for presentation in tqdm.tqdm(equivariant_simple_presentation_set):
        for (j, net) in enumerate(nets):
            if search(presentation, net, 0, 20):
                accs[j] += 1

    top_acc_args = np.argpartition(accs, -num_best)
    best_nets = [nets[k] for k in top_acc_args[-num_best:]]
    print(i, ":", [accs[k] for k in top_acc_args[-num_best:]])
    mutated_nets = [net.mutate(num_mutations, temp) for net in best_nets]
    nets = [net 
            for mutated_net in mutated_nets
            for net in mutated_net 
            ]

  0%|          | 0/560 [00:00<?, ?it/s]

  1%|          | 4/560 [00:29<1:07:22,  7.27s/it]


KeyboardInterrupt: 

In [32]:
# 1 - 533: All cases solved by Greedy Search
# 1 - 170: all the n = 1 cases. 2720 = 170 * 16
# 171 - 308: all the n = 2 cases.
# 309 - 406: all the n = 3 cases.
# 407 - 458: all the n = 4 cases.
# 459 - 493: all the n = 5 cases.
# 494 - 517: all the n = 6 cases.
# 518 - 533: all the n = 7 cases.

# 534 - 565: The n = 2 cases not solved by greedy search 
num_mutations = 10
num_best = 5
num_nets = num_best * num_mutations
num_evolutions = 20
temp_scaling = 3
# nets = [Net() for _ in range(num_nets)]
for i in range(num_evolutions):
    temp = 1 / (1 + (i // 2))
    # temp_score = (60 - i)/60
    accs = np.array([0 for _ in range(num_nets)])
    for presentation in tqdm.tqdm(equivariant_miller_presentation_set[:2719]):
        for (j, net) in enumerate(nets):
            if search(presentation, net, 0, 20):
                accs[j] += 1

    top_acc_args = np.argpartition(accs, -num_best)
    best_nets = [nets[k] for k in top_acc_args[-num_best:]]
    print(i, ":", [accs[k] for k in top_acc_args[-num_best:]])
    mutated_nets = [net.mutate(num_mutations, temp) for net in best_nets]
    nets = [net 
            for mutated_net in mutated_nets
            for net in mutated_net 
            ]

100%|██████████| 2719/2719 [06:48<00:00,  6.66it/s]


0 : [1, 1, 1, 1, 1]


 18%|█▊        | 479/2719 [01:39<07:43,  4.83it/s]


KeyboardInterrupt: 

In [31]:
acc = 0
max_len = 0
for presentation in tqdm.tqdm(scrubbed_miller_presentations):
    i = search(presentation, nets[10], 1, 20)
    if i:
        max_len = max(max_len, i)
        acc += 1
        
print(acc, max_len)

100%|██████████| 1190/1190 [00:02<00:00, 494.05it/s]

2 11





In [38]:
acc = 0
max_len = 0
for presentation in tqdm.tqdm(equivariant_presentation_set):
    i = search(presentation, nets[20], 1, 20)
    if i:
        max_len = max(max_len, i)
        acc += 1
        
print(acc, max_len)

100%|██████████| 19040/19040 [01:13<00:00, 257.69it/s]

881 11





In [31]:
acc = 0
max_len = 0
for presentation in tqdm.tqdm(equivariant_presentation_set[:2719]):
    i = search(presentation, nets[0], 1, 20)
    if i:
        max_len = max(max_len, i)
        acc += 1
        
print(acc, max_len)

100%|██████████| 2719/2719 [00:10<00:00, 255.81it/s]

382 11





In [9]:
def zero_net(pres):
    return torch.zeros(1)

acc = 0
max_len = 0
for presentation in tqdm.tqdm(equivariant_simple_presentation_set):
    i = search(presentation, zero_net, 1, 20)
    if i:
        max_len = max(max_len, i)
        acc += 1
print(acc, max_len)

100%|██████████| 560/560 [00:05<00:00, 98.92it/s] 

275 8





In [10]:
def zero_net(pres):
    return torch.zeros(1)

acc = 0
max_len = 0
for presentation in tqdm.tqdm(equivariant_miller_presentation_set):
    i = search(presentation, zero_net, 1, 20)
    if i:
        max_len = max(max_len, i)
        acc += 1
print(acc, max_len)

 15%|█▌        | 2943/19040 [01:22<07:32, 35.55it/s] 


KeyboardInterrupt: 

In [39]:
scrubbed_presentations[0]

array([-1.,  2.,  1., -2., -2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0., -1.,  2.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.])

In [57]:
scrubbed_presentations[0][0::-1]

array([-1.])