# Optimized Hybrid Genetic Search

In this Notebook, the steps for setting up HGS with a SREXGNN model are explained

In [1]:
import os
os.chdir('/gpfs/home5/suijen/SREX_GNN')
os.getcwd()

'/gpfs/home5/suijen/SREX_GNN'

## Imports

### Model imports

In [2]:
import torch
from data.utils.GraphData import FullGraph, ParentGraph
from Models import SREXmodel
from torch_geometric.transforms import AddLaplacianEigenvectorPE
from data.utils.Normalize import normalize_graphs

In [3]:
model_dict = torch.load('data/model_data/model_states/SrexGNN_7_103_0.80_0.29_0.27')
model = SREXmodel(num_node_features=11, hidden_dim=8, num_heads=8)
model.load_state_dict(state_dict=model_dict['model_state'])

<All keys matched successfully>

### HGS imports

In [4]:
from pathlib import Path
from IPython.display import display
import pickle
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import vrplib
import pickle
from typing import Optional
from pyvrp import (
    Solution,
    PenaltyManager,
    PopulationParams,
    Population,
    ProblemData,
    Result,
    RandomNumberGenerator,
    CostEvaluator,
    Statistics,
    plotting,
    read,
    GeneticAlgorithmParams,
    GeneticAlgorithm
)
from implementation.customHGS.CustomAlgorithm import GeneticAlgorithm as GnnAlgo
from implementation.customHGS.CustomSrex import selective_route_exchange as Gnnsrex
from implementation.customHGS.SolutionTransformer import SolutionTransformer


from pyvrp.diversity import broken_pairs_distance as bpd
from pyvrp.crossover import selective_route_exchange as srex
from pyvrp.search import (
    NODE_OPERATORS,
    ROUTE_OPERATORS,
    LocalSearch,
    NeighbourhoodParams,
    compute_neighbours,
)
from pyvrp.stop import MaxIterations, MaxRuntime, StoppingCriterion
from pyvrp.plotting import plot_result

## Configure HGS

In [5]:
def solve_GNN(
    instance_name: str,
    model: SREXmodel,
    seed: int,
    max_runtime: Optional[float] = None,
    max_iterations: Optional[int] = None,
    **kwargs,
):
    
    PE = AddLaplacianEigenvectorPE(6, attr_name=None, is_undirected=True)
    data = read(f"data/routes/{instance_name}.vrp", round_func="round")
    solution_tranform = SolutionTransformer()
    rng = RandomNumberGenerator(seed=seed)
    pen_manager = PenaltyManager()
    

    pop_params = PopulationParams()
    pop = Population(bpd, params=pop_params)

    nb_params = NeighbourhoodParams(nb_granular=20)
    neighbours = compute_neighbours(data, nb_params)
    ls_rng = RandomNumberGenerator(seed=42)
    ls = LocalSearch(data, ls_rng, neighbours)

    for op in NODE_OPERATORS:
        ls.add_node_operator(op(data))

    for op in ROUTE_OPERATORS:
        ls.add_route_operator(op(data))


    full_graph = FullGraph(*solution_tranform(instance=data,
                                             get_full_graph=True))
    
    full_graph = normalize_graphs(full_graph)
    full_graph = PE(full_graph)
    model_CostEvaluator =  CostEvaluator(200, 6)


    init = [
        Solution.make_random(data, rng)
        for _ in range(pop_params.min_pop_size)
    ]
    ga_params = GeneticAlgorithmParams()
    algo = GnnAlgo(data, model, model_CostEvaluator, full_graph, pen_manager, rng, pop, ls, Gnnsrex, initial_solutions=init, params=ga_params)

    
    if max_runtime is not None:
        stop = MaxRuntime(max_runtime)
    else:
        assert max_iterations is not None
        stop = MaxIterations(max_iterations)
    
    result = algo.run(stop)
    return  result

In [6]:
def solve(
    instance_name: str,
    seed: int,
    max_runtime: Optional[float] = None,
    max_iterations: Optional[int] = None,
    **kwargs,
):
    
    data = read(f"data/routes/{instance_name}.vrp", round_func="round")
    
    rng = RandomNumberGenerator(seed=seed)
    pen_manager = PenaltyManager()

    pop_params = PopulationParams()
    pop = Population(bpd, params=pop_params)

    nb_params = NeighbourhoodParams(nb_granular=20)
    neighbours = compute_neighbours(data, nb_params)
    
    ls_rng = RandomNumberGenerator(seed=42)
    ls = LocalSearch(data, ls_rng, neighbours)

    for op in NODE_OPERATORS:
        ls.add_node_operator(op(data))

    for op in ROUTE_OPERATORS:
        ls.add_route_operator(op(data))


   
    init = [
        Solution.make_random(data, rng)
        for _ in range(pop_params.min_pop_size)
    ]
    ga_params = GeneticAlgorithmParams()
    algo = GeneticAlgorithm(data, pen_manager, rng, pop, ls, srex, initial_solutions=init, params=ga_params)

    
    if max_runtime is not None:
        stop = MaxRuntime(max_runtime)
    else:
        assert max_iterations is not None
        stop = MaxIterations(max_iterations)
    
    result = algo.run(stop)
    return  result

In [7]:
def plot_improvements(result, instance):
    y = []
    last_result = result.stats.feas_stats[0].best_cost

    for d in result.stats.feas_stats:
        
        increase_in_best_score = last_result - d.best_cost

        if increase_in_best_score < 0:
            print(last_result, d.best_cost)

        y.append(increase_in_best_score)
        
        last_result = d.best_cost
    x = 1 + np.arange(result.num_iterations)

    # the first improvement is relatively large. This will cloud the image
    start = next((i for i, x in enumerate(y) if x), None) + 1
    
    fig = plt.figure(figsize=(15, 9))
    fig.suptitle('Cost improvement', fontsize=20)
    gs = fig.add_gridspec(2, 2, width_ratios=(2 / 4, 2 / 4))
    ax_div = fig.add_subplot(gs[0, 0])
    
    ax_div.plot(x[start:],y[start:])
    ax_div.set_ylabel("Cost Improvement")
    ax_div.set_xlabel('Iterations (#)')
   
    
    # Smaller improvements are found later on:
    start1 = round(len(y)/2)
    ax = fig.add_subplot(gs[0, 1])
    
    ax.plot(x[start1:],y[start1:])
    ax.set_ylabel("Cost Improvement")
    ax.set_xlabel('Iterations (#)')
    
    
    fig.tight_layout()

## Run HGS

In [8]:
instance_name = "X-n393-k38"
data = read(f"data/routes/{instance_name}.vrp", round_func="round")

In [24]:
result_mod_36 = solve_GNN(instance_name, model=model, seed=10, max_runtime=300)

In [None]:
print(result_mod_36)

In [None]:
plot_improvements(result_mod_36, data)

In [None]:
plot_result(result_mod_36, data)
plt.tight_layout()

In [28]:
result_36 = solve(instance_name, seed=10, max_runtime=300)

In [None]:
plot_improvements(result_36, data)

In [None]:
print(result_36)

In [None]:
plot_result(result_36, data)
plt.tight_layout()