## Utils and loading graph

Read Graph from text using networkx

In [None]:
from typing import List, Union
from torch import Tensor
import matplotlib.pyplot as plt
import copy
import numpy as np
import networkx as nx
from tqdm import tqdm  #progress bars for loops and iterables

# read graph file, e.g., BarabasiAlbert_100_ID2, as networkx.Graph
def read_nxgraph(filename: str) -> nx.Graph():
    graph = nx.Graph()
    is_first_line = True
    for line in filename:
        if is_first_line:
            strings = line.split(" ")
            num_nodes = int(strings[0])
            num_edges = int(strings[1])
            nodes = list(range(num_nodes))
            graph.add_nodes_from(nodes)
            is_first_line = False
        else:
            node1, node2, weight = line.split(" ")
            # nodes in file start from 1, change to from 0 in our codes.
            graph.add_edge(int(node1) - 1, int(node2) - 1, weight=weight)
    return graph

# get the adjacency matrix
def transfer_nxgraph_to_adjacencymatrix(graph: nx.Graph):
    return nx.to_numpy_array(graph)

# calculate cut value
def obj_maxcut(solution: Union[Tensor, List[int], np.array], graph: nx.Graph):
    num_nodes = len(solution)
    obj = 0
    adj_matrix = transfer_nxgraph_to_adjacencymatrix(graph)
    #loop through upper-right half of the adjacency matrix and calculate cut value
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if solution[i] != solution[j]:
                obj += adj_matrix[(i, j)]
    return obj

Load data (hard-coded)

In [None]:
BarabasiAlbert_100_ID2 = """100 384
1 2 1
1 3 1
1 4 1
1 5 1
1 6 1
1 7 1
1 8 1
1 9 1
1 10 1
1 23 1
1 26 1
1 30 1
1 31 1
1 33 1
1 45 1
1 46 1
1 48 1
1 54 1
1 65 1
1 68 1
1 69 1
1 74 1
1 76 1
1 91 1
1 97 1
2 6 1
2 33 1
2 42 1
2 60 1
2 62 1
3 6 1
3 7 1
3 8 1
3 28 1
3 29 1
3 43 1
3 47 1
3 51 1
3 57 1
3 74 1
3 87 1
3 91 1
3 96 1
3 98 1
3 99 1
4 6 1
4 7 1
4 9 1
4 10 1
4 11 1
4 17 1
4 28 1
4 29 1
4 32 1
4 44 1
4 46 1
4 49 1
4 61 1
4 68 1
4 72 1
4 84 1
4 97 1
4 99 1
5 8 1
5 11 1
5 12 1
5 15 1
5 16 1
5 17 1
5 24 1
5 35 1
5 38 1
5 53 1
5 57 1
6 7 1
6 8 1
6 9 1
6 18 1
6 22 1
6 23 1
6 24 1
6 25 1
6 26 1
6 39 1
6 52 1
6 55 1
6 74 1
6 76 1
6 91 1
7 9 1
7 10 1
7 11 1
7 12 1
7 13 1
7 14 1
7 15 1
7 17 1
7 20 1
7 28 1
7 29 1
7 37 1
7 51 1
7 54 1
7 63 1
7 93 1
7 95 1
8 10 1
8 13 1
8 16 1
8 28 1
8 34 1
8 42 1
8 43 1
8 60 1
8 61 1
8 66 1
8 90 1
9 12 1
9 18 1
9 19 1
9 21 1
9 33 1
9 46 1
9 50 1
9 68 1
9 70 1
9 83 1
10 11 1
10 14 1
10 15 1
10 23 1
10 27 1
10 30 1
10 32 1
10 38 1
10 50 1
10 56 1
10 59 1
10 68 1
10 74 1
10 75 1
10 81 1
10 85 1
10 86 1
10 94 1
11 12 1
11 13 1
11 14 1
11 20 1
11 21 1
11 25 1
11 34 1
11 35 1
11 39 1
11 41 1
11 43 1
11 44 1
11 51 1
11 52 1
11 53 1
11 63 1
11 69 1
11 79 1
11 89 1
11 92 1
11 94 1
12 13 1
12 14 1
12 16 1
12 18 1
12 19 1
12 25 1
12 27 1
12 30 1
12 31 1
12 32 1
12 33 1
12 36 1
12 38 1
12 39 1
12 40 1
12 48 1
12 49 1
12 54 1
12 60 1
12 67 1
12 79 1
12 83 1
12 85 1
12 88 1
13 15 1
13 20 1
13 23 1
13 25 1
13 42 1
13 43 1
13 50 1
13 57 1
13 58 1
13 73 1
13 75 1
13 82 1
14 17 1
14 18 1
14 20 1
14 21 1
14 22 1
14 31 1
14 36 1
14 37 1
14 47 1
14 49 1
14 59 1
14 63 1
14 69 1
14 81 1
14 83 1
14 86 1
14 97 1
15 16 1
15 19 1
15 22 1
15 35 1
15 38 1
15 44 1
15 80 1
16 22 1
16 26 1
16 27 1
16 29 1
16 51 1
16 58 1
17 19 1
17 64 1
17 88 1
18 26 1
18 52 1
18 55 1
18 73 1
19 44 1
19 45 1
19 47 1
19 52 1
19 60 1
19 70 1
19 84 1
19 91 1
20 21 1
20 35 1
20 45 1
20 72 1
20 73 1
20 85 1
20 95 1
20 98 1
20 100 1
21 24 1
21 31 1
21 37 1
21 41 1
21 66 1
22 36 1
22 48 1
22 66 1
22 71 1
22 77 1
22 87 1
23 24 1
23 32 1
23 61 1
23 82 1
23 90 1
23 96 1
24 34 1
24 54 1
24 62 1
24 63 1
24 86 1
25 40 1
26 27 1
26 41 1
26 88 1
26 98 1
27 81 1
27 90 1
27 99 1
28 30 1
28 34 1
28 40 1
28 61 1
28 92 1
29 36 1
29 49 1
29 70 1
30 37 1
31 41 1
31 58 1
31 82 1
32 56 1
32 100 1
33 39 1
33 65 1
33 95 1
34 53 1
34 67 1
34 78 1
34 84 1
34 89 1
34 94 1
35 59 1
35 86 1
36 87 1
37 40 1
37 45 1
37 50 1
37 56 1
37 59 1
37 72 1
37 77 1
37 88 1
38 48 1
38 58 1
38 64 1
38 71 1
40 42 1
41 85 1
42 46 1
42 57 1
42 67 1
42 78 1
42 83 1
44 47 1
44 94 1
45 66 1
47 53 1
47 73 1
48 55 1
48 62 1
50 75 1
50 81 1
50 89 1
51 56 1
52 69 1
52 71 1
52 76 1
53 55 1
53 71 1
53 96 1
54 64 1
54 84 1
54 97 1
55 72 1
55 78 1
56 65 1
56 89 1
56 92 1
56 93 1
56 96 1
56 99 1
57 65 1
57 93 1
58 64 1
58 77 1
59 80 1
60 62 1
61 67 1
63 76 1
64 90 1
64 98 1
66 77 1
66 92 1
67 70 1
67 78 1
67 79 1
68 75 1
68 79 1
70 87 1
74 82 1
77 100 1
78 80 1
79 80 1
82 100 1
84 95 1
87 93 1""".splitlines()

Load graph

In [None]:
graph = read_nxgraph(BarabasiAlbert_100_ID2)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import torch.distributions as distributions

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

Initilize lists for experiment tracking

In [None]:
stats_losses = [] # store losses
stats_advantages = [] # store the advantages
states = [] # store all transitions (step, new score, new solution)

## Method

Define Model

In [None]:
# Fully connected model with only one hidden layer, and a sigmoid head for outputing probability
class FCModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FCModel, self).__init__()
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.Linear(hidden_dim, hidden_dim))
        layers.append(nn.Linear(hidden_dim, output_dim))

        #layers.append(nn.Softmax(dim=1))
        layers.append(nn.Sigmoid())
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

Define buffer

In [None]:
# Buffer implementation based on https://github.com/YangletLiu/L2A_Algorithm/blob/4785e79d3f34b77636625f8363685647d1e0e341/graph_max_cut_trs.py#L242
# Main change being the update function adds one trajectory instead of multiple trajectories at once.
class Buffer:
    def __init__(self, buffer_size, trajectory_length, num_nodes, device=torch.device('cpu')):
        self.solutions = torch.empty((trajectory_length + 1, buffer_size, num_nodes), dtype=torch.bool)
        self.rewards = torch.empty((trajectory_length, buffer_size), dtype=torch.float32)
        self.logprobs = torch.empty((trajectory_length, buffer_size), dtype=torch.float32)

        self.obj_values = torch.empty(buffer_size, dtype=torch.float32)

        self.p = 0
        self.add_size = 0
        self.buffer_size = buffer_size
        self.device = device

    def update(self, solution, reward, logprob, obj_value):
        # solution.shape == (trajectory_length, num_nodes)
        # reward.shape == (trajectory_length)
        # logprob.shape == (trajectory_length)

        # Add one trajectory
        if self.p < self.buffer_size:
            self.solutions[:solution.shape[0], self.p] = solution.to(self.device)
            self.rewards[:, self.p] = reward.to(self.device)
            self.logprobs[:, self.p] = logprob.to(self.device)
            self.obj_values[self.p] = obj_value.to(self.device)
            self.p += 1

        # when the buffer is full, replace the trajectory with lowest score
        else:
            _, ids = th.topk(self.obj_values[:self.p], k=1, largest=False)
            self.states[:, ids] = solution.to(self.device)
            self.rewards[:, ids] = reward.to(self.device)
            self.logprobs[:, ids] = logprob.to(self.device)
            self.obj_values[ids] = obj_value.float().to(self.device)

    def sample(self, batch_size, device) :
        ids = torch.randint(self.p, size=(batch_size,), requires_grad=False)
        return (self.solutions[:, ids].to(device),
                self.rewards[:, ids].to(device),
                self.logprobs[:, ids].to(device),
                self.obj_values[ids].to(device))


Define local seach

In [None]:
def get_best_neighbor(solution, graph):
    max_neighbor = solution.copy()
    max_neighbor_score = 0
    for i in range(len(solution)):
        neighbor = solution.copy()
        neighbor[i] = 1 - neighbor[i]
        cur_score = obj_maxcut(neighbor, graph)
        if cur_score > max_neighbor_score:
            max_neighbor = neighbor
            max_neighbor_score = cur_score

    return max_neighbor, max_neighbor_score




---
Initialize Temperature and Steps


In [None]:
trajectory_length = 16
num_epochs = 10
lambda_l = 0.4

Initialize state

In [None]:
num_nodes = graph.number_of_nodes()




---
Initialize model and prepare for training


In [None]:
# initialize model hyperparameters
input_dim = num_nodes
hidden_dim = 32
output_dim = num_nodes  # Output one probality per node

buffer_size = 256
update_steps = 6
batch_size = 24
max_local_search_steps=6

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.backends.mps.is_available():
    device = "mps"

model = FCModel(input_dim, hidden_dim, output_dim).to(device)
optimizer = optim.Adam(model.parameters())

scores = [] # Used to track mean scores at each epoch

buffer = Buffer(buffer_size, trajectory_length, num_nodes, device) # initialize buffer

Start training

In [None]:
for k in tqdm(range(num_epochs)):
    # Decrease lambda_l and at each epoch
    lambda_l = lambda_l * (1 - (k + 1) / num_epochs)
    # Update the model multiple times for each lambda value
    epoch_scores = []

    step = 0
    while step < update_steps:
        # Random initialization of the start of new trajectory
        init_solution = [random.choice([0, 1]) for _ in range(num_nodes)]
        curr_solution = copy.deepcopy(init_solution)
        curr_score = obj_maxcut(curr_solution, graph)
        init_score = curr_score

        cur_advantage = 0
        cur_log_prob = -num_nodes

        trajectory_scores = []
        trajectory_log_probs = []
        trajectory = []

        # Sample one trajectory
        for t in range(trajectory_length):
            # Get output probability distribution from model
            model_input = torch.tensor(curr_solution, dtype=torch.float32).unsqueeze(0).to(device)
            outputs = model(model_input)
            outputs = outputs.squeeze(1)

            # Sample new state from outputs
            m = distributions.Bernoulli(probs=outputs)

            new_state = m.sample()
            # calculate the log probability of this state
            log_prob = m.log_prob(new_state)
            log_prob_sum = log_prob.sum()

            # Convert new_solution to numpy array
            new_solution = new_state.cpu().tolist()[0]
            # calculate cut value of the new solution
            new_score = obj_maxcut(new_solution, graph)

            # Local search
            #while 1:
            for ls_step in range(max_local_search_steps):
                neighbor, neighbor_score = get_best_neighbor(new_solution, graph)
                if neighbor_score > new_score:
                    new_solution = neighbor
                    new_score = neighbor_score
                else:
                    break

            epoch_scores.append(new_score)

            trajectory_scores.append(new_score)
            trajectory_log_probs.append(log_prob_sum)
            trajectory.append(new_solution)

            # Compute advantage
            cur_advantage = cur_advantage + new_score
            cur_log_prob = cur_log_prob + log_prob_sum


            curr_solution = new_solution
            curr_score = new_score

        obj_value = -(cur_log_prob * (cur_advantage - lambda_l*cur_log_prob -lambda_l)).detach()
        # Store the new trajectory score and log_prob in the buffer
        buffer.update(torch.tensor(trajectory),
                        torch.tensor(trajectory_scores, dtype=torch.float32),
                        torch.tensor(trajectory_log_probs, dtype=torch.float32),
                        obj_value)


        if buffer.p >= buffer_size/2:
            # Sample batch from buffer
            buffer_trajectory, buffer_trajectory_scores, buffer_trajector_log_probs, buffer_losses = buffer.sample(batch_size, device)


            # Compute loss
            # Use dummy to ensure loss has grad_fn
            dummy_var = torch.tensor(1.0, requires_grad=True, device=device)
            loss = (buffer_losses.mean() * dummy_var).mean()
            stats_losses.append(loss.item())
            # Backpropagate loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1

    # Store mean scores of the epoch
    scores.append(np.mean(epoch_scores))


  0%|          | 0/10 [00:00<?, ?it/s]

#Test

Target score: 287

In [None]:
print(f"score: {max(epoch_scores)}")

In [None]:
plt.plot(scores)
plt.xlabel('Epoch')
plt.ylabel('Scores')
plt.title('Scores vs Epoch')
plt.show()