# Optimized Hybrid Genetic Search

In this Notebook, the steps for setting up HGS with a SREXGNN model are explained

## Imports

### Model imports

In [1]:
import torch
from data.utils.GraphData import FullGraph, ParentGraph
from Models import SREXmodel

### HGS imports

In [2]:
from pathlib import Path
from IPython.display import display
import pickle

import vrplib
import pickle
from typing import Optional
from pyvrp import (
    Solution,
    PenaltyManager,
    PopulationParams,
    Population,
    ProblemData,
    Result,
    RandomNumberGenerator,
    Statistics,
    plotting,
    read,
    GeneticAlgorithmParams,
    GeneticAlgorithm
)
from implementation.customHGS.CustomAlgorithm import GeneticAlgorithm as GnnAlgo
from implementation.customHGS.CustomSrex import selective_route_exchange as Gnnsrex
from implementation.customHGS.SolutionTransformer import SolutionTransformer


from pyvrp.diversity import broken_pairs_distance as bpd
from pyvrp.crossover import selective_route_exchange as srex
from pyvrp.search import (
    NODE_OPERATORS,
    ROUTE_OPERATORS,
    LocalSearch,
    NeighbourhoodParams,
    compute_neighbours,
)
from pyvrp.stop import MaxIterations, MaxRuntime, StoppingCriterion


## Configure HGS

In [3]:
def solve_GNN(
    instance_name: str,
    model: SREXmodel,
    seed: int,
    max_runtime: Optional[float] = None,
    max_iterations: Optional[int] = None,
    **kwargs,
):
    
    data = read(f"C:/SREX_GNN/data/routes/{instance_name}.vrp", round_func="round")
    solution_tranform = SolutionTransformer()
    rng = RandomNumberGenerator(seed=seed)
    pen_manager = PenaltyManager()

    pop_params = PopulationParams()
    pop = Population(bpd, params=pop_params)

    nb_params = NeighbourhoodParams(nb_granular=20)
    neighbours = compute_neighbours(data, nb_params)
    ls = LocalSearch(data, rng, neighbours)

    for op in NODE_OPERATORS:
        ls.add_node_operator(op(data))

    for op in ROUTE_OPERATORS:
        ls.add_route_operator(op(data))


    edge_index, edge_weight, client_features = solution_tranform(instance=data,
                                                                          get_full_graph=True)
    full_graph = FullGraph(edge_index, edge_weight, client_features)
   

    init = [
        Solution.make_random(data, rng)
        for _ in range(pop_params.min_pop_size)
    ]
    ga_params = GeneticAlgorithmParams()
    algo = GnnAlgo(data, model, full_graph, pen_manager, rng, pop, ls, Gnnsrex, initial_solutions=init, params=ga_params)

    
    if max_runtime is not None:
        stop = MaxRuntime(max_runtime)
    else:
        assert max_iterations is not None
        stop = MaxIterations(max_iterations)
    
    result = algo.run(stop)
    return  result

In [4]:
def solve(
    instance_name: str,
    seed: int,
    max_runtime: Optional[float] = None,
    max_iterations: Optional[int] = None,
    **kwargs,
):
    
    data = read(f"C:/SREX_GNN/data/routes/{instance_name}.vrp", round_func="round")
    
    rng = RandomNumberGenerator(seed=seed)
    pen_manager = PenaltyManager()

    pop_params = PopulationParams()
    pop = Population(bpd, params=pop_params)

    nb_params = NeighbourhoodParams(nb_granular=20)
    neighbours = compute_neighbours(data, nb_params)
    ls = LocalSearch(data, rng, neighbours)

    for op in NODE_OPERATORS:
        ls.add_node_operator(op(data))

    for op in ROUTE_OPERATORS:
        ls.add_route_operator(op(data))


   
    init = [
        Solution.make_random(data, rng)
        for _ in range(pop_params.min_pop_size)
    ]
    ga_params = GeneticAlgorithmParams()
    algo = GeneticAlgorithm(data, pen_manager, rng, pop, ls, srex, initial_solutions=init, params=ga_params)

    
    if max_runtime is not None:
        stop = MaxRuntime(max_runtime)
    else:
        assert max_iterations is not None
        stop = MaxIterations(max_iterations)
    
    result = algo.run(stop)
    return  result

# Load model

In [5]:
model_dict = torch.load('C:\SREX_GNN\data\model_data\model_states\SrexGNN_2_0_0.36_0.32', map_location=torch.device('cpu'))
model_dict.keys()

dict_keys(['model_state', 'Metrics_train', 'Metrics_test'])

In [6]:
model = SREXmodel(num_node_features=8, hidden_dim=64, num_heads=8)
model.load_state_dict(state_dict=model_dict['model_state'])

<All keys matched successfully>

## Run HGS

In [7]:
instance_name = "ORTEC-n510-k23"
result = solve_GNN(instance_name, model=model, seed=22, max_iterations=300)

23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
5
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
7
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
6
tensor(1) tensor(0) tensor(0)
3
tensor(1) tensor(0) tensor(0)
4
tensor(1) tensor(0) tensor(0)
4
tensor(1) tensor(0) tensor(0)
8
tensor(1) tensor(0) tensor(0)
4
tensor(1) tensor(0) tensor(0)
6
tensor(1) tensor(0) tensor(0)
4
tensor(1) tensor(0) tensor(0)
7
tensor(1) tensor(0) tensor(0)
3
tensor(1) tensor(0) tensor(0)
7
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) tensor(0)
5
tensor(1) tensor(0) tensor(0)
7
tensor(1) tensor(0) tensor(0)
7
tensor(1) tensor(0) tensor(0)
23
tensor(1) tensor(0) ten

KeyboardInterrupt: 

In [None]:
instance_name = "ORTEC-VRPTW-ASYM-0bdff870-d1-n458-k35"
result = solve(instance_name, seed=22, max_iterations=300)