# Introduction:

# Imports

In [None]:
import torch
import random
import numpy as np
import json
import os
from typing import List, Union
import matplotlib.pyplot as plt
import copy
import numpy as np
import networkx as nx
from tqdm import tqdm  #progress bars for loops and iterables
import math
from torch.nn.utils import clip_grad_norm_

# MaxCut Simulator

In [None]:
# read graph file, e.g., BarabasiAlbert_100_ID2, using networkx.Graph
def read_nxgraph(filename: str) -> nx.Graph():
    graph = nx.Graph()
    with open(filename, 'r') as file:
        # lines = []
        line = file.readline()
        is_first_line = True
        while line is not None and line != '':
            if '/' not in line:
                if is_first_line:
                    strings = line.split(" ")
                    num_nodes = int(strings[0])
                    num_edges = int(strings[1])
                    nodes = list(range(num_nodes))
                    graph.add_nodes_from(nodes)
                    is_first_line = False
                else:
                    node1, node2, weight = line.split(" ")
                    graph.add_edge(int(node1) - 1, int(node2) - 1, weight=weight)
            line = file.readline()
    return graph

# helper function that takes two batches of solutions and keep the better soltuion
def update_best(cur_solutions, cur_scores, new_solutions, new_scores):
    # get indices where new solution is better than current solution
    better_indexes = new_scores.gt(cur_scores)
    # update cur_solutions and cur_scores with the better solutions and scores
    cur_solutions[better_indexes] = new_solutions[better_indexes]
    cur_scores[better_indexes] = new_scores[better_indexes]

# MaxCut Simulator

In [None]:
# Handles cut-value calculation, local search, and initial solution sampling
class MaxcutSimulator:
    def __init__(self, file_name, device=torch.device("cpu")):
        self.graph = read_nxgraph(file_name)
        self.num_nodes = self.graph.number_of_nodes()
        self.device = device
        # store the adjacency matrix
        self.adj_matrix = torch.tensor(nx.to_numpy_array(self.graph), dtype=torch.float32, device=device)

    # Calculate cut values for a batch of solutions
    def get_cut_values(self, solutions):
        # solutions with shape (batch_size, num_nodes)
        batch_size, num_nodes = solutions.shape
        s = solutions.reshape(batch_size, num_nodes, 1)
        st = solutions.reshape(batch_size, 1, num_nodes)
        res = (s != st).triu() * self.adj_matrix
        scores = res.triu().sum((1, 2))
        return scores # scores has shape (batch_size)

    # Return the best solutions and scores found in applying greedy local search on a batch of solutions
    def batched_local_search(self, solutions, scores, max_steps, num_flips, num_neighbors):
        # solution has shape (batch_size, num_nodes)
        # # scores has shape (batch_size)
        '''
        At each step, look at num_neighbors solutions, each num_flips bits different from the current solution,
        and move to the best of these neighbors. Repeat max_steps times.
        '''
        batch_size = solutions.shape[0]
        best_solutions = solutions.clone().to(self.device)
        best_scores = scores.clone().to(self.device)
        for step in range(max_steps):
            new_solutions = best_solutions.clone().to(self.device)
            # create num_neighbors copies of the current solutions
            neighbors_solutions = new_solutions.unsqueeze(1)\
                                .repeat(1, num_neighbors, 1)\
                                .reshape(batch_size * num_neighbors, self.num_nodes)

            # random flip indices (batch_size * num_neighbors, num_flips)
            flip_indexes = torch.randint(0, self.num_nodes,
                                        (batch_size * num_neighbors, num_flips),
                                        device=self.device)
            row_idx = torch.arange(batch_size * num_neighbors, device=self.device).unsqueeze(1).expand(-1, num_flips)

            # flip the bits on the copies
            neighbors_solutions[row_idx, flip_indexes] = 1 -\
                                                        neighbors_solutions[row_idx,
                                                        flip_indexes]
            new_scores = self.get_cut_values(neighbors_solutions)\
                .view(batch_size, num_neighbors)
            # find best neighbors
            max_scores, max_indices = new_scores.max(dim=1)
            neighbors_solutions = neighbors_solutions.view(batch_size, num_neighbors, self.num_nodes)[torch.arange(batch_size), max_indices, :]

            # compare best neighbors with current solutions. Move the the neighbor is it's better
            update_best(best_solutions, best_scores, neighbors_solutions, max_scores)

        return best_solutions, best_scores

    def initialize_trajectories(self, trajectory_length, trajectories_per_epoch):
        # initialize empty tensors to store batch of trajectories for an epoch
        trajectory_solutions = torch.empty((trajectory_length, trajectories_per_epoch, self.num_nodes), dtype=torch.float32).to(self.device)
        trajectory_scores = torch.empty((trajectory_length, trajectories_per_epoch), dtype=torch.float32).to(self.device)
        trajectory_log_probs = torch.empty((trajectory_length,trajectories_per_epoch), dtype=torch.float32)
        trajectory_advantages = torch.empty((trajectory_length, trajectories_per_epoch), dtype=torch.float32).to(self.device)
        ls_trajectory_solutions = torch.empty((trajectory_length, trajectories_per_epoch, self.num_nodes), dtype=torch.float32).to(self.device)

        # Random initialization of the start of new trajectories
        trajectory_solutions[0, :, :] = torch.randint(2, (trajectories_per_epoch, self.num_nodes), dtype=torch.float32).to(self.device)
        trajectory_scores[0, :] = self.get_cut_values(trajectory_solutions[0])
        trajectory_log_probs[0, :] = torch.full((trajectories_per_epoch,), -self.num_nodes, dtype=torch.float32) # log probability for the random initial solution
        trajectory_advantages[0, :] = torch.zeros(trajectories_per_epoch)
        ls_trajectory_solutions[0, :, :] = torch.randint(2, (trajectories_per_epoch, self.num_nodes), dtype=torch.float32).to(self.device)


        return trajectory_solutions, trajectory_scores, trajectory_log_probs, trajectory_advantages, ls_trajectory_solutions
        


# Policy Gradient NN

In [None]:
# Fully connected network with one hidden layer, and a sigmoid function to output probability
class FCPolicy(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FCPolicy, self).__init__()
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.Linear(hidden_dim, hidden_dim))
        layers.append(nn.Linear(hidden_dim, output_dim))

        #layers.append(nn.Softmax(dim=1))
        layers.append(nn.Sigmoid())
        self.model = nn.Sequential(*layers)

    def forward(self, x, mode="sample"):
        return self.model(x)
    
# Solution-level RNN model
class RNNPolicy_Solution(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(RNNPolicy_Solution, self).__init__()
        self.rnn = nn.RNN(
            input_size=input_dim,
            hidden_size=hidden_dim,
            batch_first=False
        )
        self.fc = nn.Linear(hidden_dim, input_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, mode='sample'):
        # x has shape: (trajectory_length, batch_size, input_dim)
        rnn_output, _ = self.rnn(x)
        # rnn_output has shape: (trajectory_length, batch_size, hidden_dim)

        # For sampling the next solution, only take the last hidden state
        if mode=='sample':
            rnn_output = rnn_output[-1, :, :]

        # Pass the hidden state through fc+sigmoid to get probability
        linear_output = self.fc(rnn_output)
        prob = self.sigmoid(linear_output)

        # prob has shape: (batch_size, input_dim)
        return prob

# Replay Buffer

In [None]:
# Buffer implementation
class Buffer:
    def __init__(self, buffer_size, trajectory_length, num_nodes, device=torch.device("cpu")):
        self.solutions = torch.empty((trajectory_length, buffer_size, num_nodes), dtype=torch.float32)
        self.scores = torch.empty((trajectory_length, buffer_size), dtype=torch.float32)
        self.log_probs = torch.empty((trajectory_length, buffer_size), dtype=torch.float32)
        self.advantages = torch.empty((trajectory_length, buffer_size), dtype=torch.float32)

        self.p = 0
        self.buffer_size = buffer_size
        self.device = device
        self.num_nodes = num_nodes

    # Add new batch of trajectories.
    # Replace the trajectories with lowest score sums if the buffer is full
    def update(self, solutions, scores, log_probs, advantages):
        # solutions with size (trajectory_length, trajectories_per_epoch, num_nodes)
        # scores and log_probs with size (trajectory_length, trajectories_per_epoch)
        add_size = scores.shape[1]

        # Add trajectories
        if self.p+add_size < self.buffer_size:
            self.solutions[:, self.p:self.p+add_size, :] = solutions.to(self.device)
            self.scores[:, self.p:self.p+add_size] = scores.to(self.device)
            self.log_probs[:, self.p:self.p+add_size] = log_probs.to(self.device)
            self.advantages[:, self.p:self.p+add_size] = advantages.to(self.device)
            self.p += add_size

        # when the buffer is full, replace the batch of trajectories with lowest scores
        else:
            # Calculate score-sum for each trajectory
            trajectory_scores = self.scores[:, :self.p].sum(dim=0)
            _, ids = torch.topk(trajectory_scores, k=add_size, largest=False)

            # Replace the lowest trajectories with the new batch
            self.solutions[:,ids, :] = solutions.to(self.device)
            self.scores[:,ids] = scores.to(self.device)
            self.log_probs[:,ids] = log_probs.to(self.device)
            self.advantages[:,ids] = advantages.to(self.device)

    # Sample the top batch_size//4 trajectories and fill the rest of the batch with random ones
    def sample(self, batch_size, device):
        top_k = batch_size // 4
        trajectory_scores = self.scores[:, :self.p].sum(dim=0)
        # Get top_k trajectory ids
        _, top_ids = torch.topk(trajectory_scores, k=top_k, largest=True)

        # Exclude topk-ids to avoid repeated sampling
        all_ids = torch.arange(self.p)
        remaining_ids = all_ids[~torch.isin(all_ids, top_ids)]
        random_k = batch_size - top_k

        # Random sample without replacement from remaining ids
        random_ids = remaining_ids[
            torch.randperm(len(remaining_ids))[:random_k]
        ]
        # Combine top and random indices
        ids = torch.cat([top_ids, random_ids], dim=0)

        return (self.solutions[:, ids, :].to(device),
                self.scores[:, ids].to(device),
                self.log_probs[:, ids].to(device),
                self.advantages[:, ids].to(device))

    # returns the best batch of solutions in the buffer
    def get_top_solutions(self, batch_size, device):
        trajectory_length = self.scores.shape[0]
        flat_scores = self.scores.flatten() # shape (trajectory_length*buffer_size)
        flat_solutions = self.solutions.view((trajectory_length*self.buffer_size, self.num_nodes)) # shape (trajectory_length*buffer_size, num_nodes)
        _, ids = torch.topk(flat_scores, k=batch_size, largest=True)
        return flat_solutions[ids, :].to(device)

# Training Function

In [None]:
# cl_stages = [(ls_num_flips, ls_max_steps, ls_num_neighbors, lambda_l)]
def train(maxcut_simulator, actor_model, actor_optimizer, num_epochs = 300, update_steps = 6,
          buffer_size = 32, batch_size = 32, trajectory_length = 32, trajectories_per_epoch = 32, 
          lambda_l = 0.3, model_choice='FC', gamma = 0.99, ls_num_flips=6, ls_num_neighbors = 3,
          ls_max_steps = 3, device = torch.device("cpu")):

    # Initialize experiment

    num_nodes = maxcut_simulator.num_nodes
    buffer = Buffer(buffer_size, trajectory_length, num_nodes) # initialize buffer

    best_score = 0 # Used to track best score
    best_solution = [] # Used to track best solution
    best_epoch = 0 # Used to track best epoch
    scores = [] # Used to track scores at each epoch
    solutions = [] # Used to track the solutions at each epoch
    actor_stats_losses = [] # store losses
    critic_stats_losses = []

    mse_loss = torch.nn.MSELoss() # MSE for critic loss

    # Start training
    for k in tqdm(range(num_epochs)):
        epoch_scores = [] # Store (max) cut values for an epoch
        epoch_solutions = [] # Store best solution for an epoch

        # initialize tensors to store batch of trajectories for this epoch
        (
            trajectory_solutions,
            trajectory_scores,
            trajectory_log_probs,
            trajectory_advantages,
            ls_trajectory_solutions
        ) = maxcut_simulator.initialize_trajectories(
            trajectory_length,
            trajectories_per_epoch
        )

        # Sample a batch of trajectories
        for t in range(1, trajectory_length):
            with torch.no_grad():
                if model_choice == "RNN":
                    outputs = actor_model(trajectory_solutions[:t])
                else:
                    outputs = actor_model(trajectory_solutions[t-1])

                # Sample new solution from outputs
                m = torch.distributions.Bernoulli(probs=outputs)
                new_solutions = m.sample()

                # calculate the log probability of this solution
                log_probs = m.log_prob(new_solutions).sum(dim=1)

                # calculate cut value of the new solution
                new_scores = maxcut_simulator.get_cut_values(new_solutions)

                #local search
                if ls_num_flips*ls_max_steps > 0:
                    new_solutions, new_scores = maxcut_simulator.batched_local_search(
                                                            new_solutions,
                                                            new_scores,
                                                            ls_max_steps,
                                                            ls_num_flips,
                                                            ls_num_neighbors)

                bs = max(new_scores)
                if bs > best_score:
                    best_score = bs.tolist()
                    best_solution = new_solutions[torch.argmax(new_scores)].tolist()
                    best_epoch = k

                trajectory_solutions[t, :, :] = new_solutions
                trajectory_scores[t, :] = new_scores
                trajectory_log_probs[t, :] = log_probs
                trajectory_advantages[t, :] = new_scores - trajectory_scores[t-1]

                epoch_scores += (new_scores.tolist())
                epoch_solutions += (new_solutions.tolist())

        # Store scores and solutions of the epoch
        scores.append(trajectory_scores.tolist())
        solutions.append(trajectory_solutions.tolist())

        # Store the new trajectories score and log_prob in the buffer
        buffer.update(trajectory_solutions,
                        trajectory_scores,
                        trajectory_log_probs,
                        trajectory_advantages)

        # Sample from buffer after all trajectories for this epoch are added
        if buffer.p >= batch_size*update_steps//2:
            # update a fixed amount of steps
            actor_losses = []
            for j in range(update_steps):
                layer_i = j % 3
                # Sample a batch from buffer
                batch_trajectories, batch_scores, batch_log_probs, batch_advantages = buffer.sample(batch_size, device)

                # Compute log prob of trajectory based on current policy
                outputs = actor_model(batch_trajectories, mode='update')                

                m = torch.distributions.Bernoulli(probs=outputs)
                solution_log_probs = m.log_prob(batch_trajectories).sum(dim=-1) # get solution log probs
                trajectory_log_probs = solution_log_probs.sum(0) # get trajectory log probs
                # trajectory_scores = batch_scores.sum(0) # get trajectory scores

                trajectory_log_prob = trajectory_log_probs.mean() # get batch mean trajectory log prob
                # trajectory_score = trajectory_scores.mean() # get batch mean trajectory score

                trajectory_scores = batch_scores.sum(0)
                trajectory_score = trajectory_scores.mean()

                trajectory_advantages = batch_advantages.sum(0) # get batch mean trajectory advantages
                trajectory_advantage = trajectory_advantages.mean()

                # Objective to maximize trajectory probability and score, with an entropy regularization controlling exploration
                actor_loss = -(trajectory_log_prob * trajectory_advantage) - lambda_l * trajectory_log_prob

                # Update actor model
                actor_optimizer.zero_grad()
                actor_loss.backward()
                # torch.nn.utils.clip_grad_norm_(actor_model.parameters(), max_norm=1.0)
                actor_optimizer.step()
                actor_losses.append(actor_loss.item())

            actor_stats_losses.append(np.mean(actor_losses))

    return best_score, best_solution, best_epoch, scores, solutions, epoch_scores, actor_stats_losses, critic_stats_losses

In [None]:
# General Hyperparameters
num_epochs = 600
update_steps = 4
trajectory_length = 32
trajectories_per_epoch = 32
buffer_size = 1024
batch_size = 32
lambda_l = 0.3
model_choice = 'FC' # 'FC' or 'RNN'

# Variables used for graph fetching and saving
input_graph = "../data/syn_BA/BA_100_ID0.txt"
output_location = "../results/BA_100_ID0"
params = ''


# Load data, initialize the simulator and hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
graph_instance = input_graph.split('/')[-1].split('.txt')[0]
maxcut_simulator = MaxcutSimulator(input_graph, device=device)
num_nodes = maxcut_simulator.num_nodes


# Initialize model
input_dim = num_nodes
output_dim = num_nodes
hidden_dim = trajectories_per_epoch

if model_choice == 'FC':
     actor_model = FCPolicy(input_dim, hidden_dim, output_dim).to(device)
if model_choice == 'RNN':
    actor_model = RNNPolicy_Solution(input_dim, hidden_dim).to(device)
actor_optimizer = torch.optim.Adam(actor_model.parameters())


(
    best_score,
    best_solution,
    best_epoch,
    scores,
    solutions,
    epoch_scores,
    actor_stats_losses,
    critic_stats_losses,
)= train(maxcut_simulator, actor_model, actor_optimizer, num_epochs=num_epochs, update_steps=update_steps, 
         buffer_size=buffer_size, batch_size=batch_size, trajectory_length=trajectory_length, 
         trajectories_per_epoch=trajectories_per_epoch, lambda_l= lambda_l, model_choice=model_choice, device = device)

print(f"\nlast epoch best score: {max(epoch_scores)}, overall best score: {best_score} at epoch {best_epoch}")

In [None]:
scores = torch.tensor(scores)
mean_scores = scores[:, -1].mean(-1).tolist() 


In [None]:
# Plot cut-values
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
#smooth_scores = savgol_filter(scores, window_length=1, polyorder=0)
plt.plot(mean_scores)

graph_dir = output_location + f"/{model_choice}_images"
if not os.path.exists(graph_dir):
    os.makedirs(graph_dir)

#plt.plot(scores, alpha=0.3, label='Original Scores')
plt.xlabel('Epoch')
plt.ylabel('Average Score')
plt.title(f'{model_choice}_training_results')
plt.legend()
plt.show()

In [None]:
# Plot losses
import numpy as np
plt.plot(actor_stats_losses)
plt.xlabel('Epoch')
plt.ylabel('Average losses')
plt.title(f'{model_choice}: Actor Losses vs Epoch')
plt.savefig(f'{graph_dir}/Actor_Losses_Epoch.png')
plt.show()