In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath('.'))

import random
from typing import Tuple
import torch
import torch.nn.functional as F

from nco_lib.environment.actions import two_opt, bit_flip, insert, swap
from nco_lib.environment.env import State, Env, ConstructiveStoppingCriteria, ConstructiveReward, ImprovementReward, ImprovementStoppingCriteria
from nco_lib.environment.problem_def import ConstructiveProblem, ImprovementProblem
from nco_lib.models.graph_transformer import GTModel, EdgeInGTModel, EdgeInOutGTModel
from nco_lib.data.data_loader import generate_random_graph
from nco_lib.trainer.trainer import ConstructiveTrainer, ImprovementTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Simple Tutorial
We will show how to use the library to solve combinatorial optimization problems using reinforcement learning.

First, we define the problem by inheriting from the ConstructiveProblem or ImprovementProblem class. We need to implement the following methods:
- **_init_instances**: Initialize the problem instances (graphs).
- **_init_solutions**: Initialize the solutions (to empty solutions if we are using a constructive method or to a complete solution if we are using an improvement method).
- **_init_features**: Initialize the node- and edge-features.
- **_init_mask** (optional): Initialize the mask, to mask certain actions. Only if it is required by the problem constraints.
- **_obj_function**: Compute the objective function value.
- **_update_features**: Update the node features based on the selected action.
- **_update_solutions**: Update the solutions based on the selected action.
- **_update_mask** (optional): Update the mask based on the selected action in previous step.
- **_check_completeness**: Check if the solution is complete. This function is required for constructive problems to check if the solution, and therefore the episode, is completed.

The State class is used to store the problem instances, solutions, features, mask, and other useful information such as the device and batch size. The user can add any other data that is required for the problem definition in the state.data dictionary.

### Constructive Problem
The following image represents the pipeline followed in a constructive problem: 
- First the instances, solutions, features, and mask are initialized and fed to the model in the state. 
- Then, the model predicts the actions based on the state. 
- The actions are used to update the solutions, features, and mask. 
- The objective function is computed based on the updated solutions and is used to compute the reward given to the model in order to update its weights. 
- The process is repeated until the episode is completed.

<p style="text-align: center;">
    <img src="../docs/constructive_pipeline.png" alt="Constructive Pipeline" title="Constructive Pipeline"/>
</p>

Here is an example of a dummy constructive problem:

In [None]:
# 1) Define the constructive problem
class SimpleConstructiveProblem(ConstructiveProblem):
    def _init_instances(self, state: State) -> State:
        # Here you need to initialize the problem instances (graphs) and any other data required for computing the objective value, calculating the mask or checking completeness
        # The state class already has the batch size, the problem size, random seed and used device in here, so you can use them to initialize the instances
        
        # You can create the adjacency matrix randomly by using the generate_random_graph function
        state.adj_matrix = generate_random_graph(state.batch_size, state.problem_size, state.seed, edge_prob=0.15, device=state.device)
        
        # You can also add any other useful information in the state.data dictionary
        state.data['useful_info'] = torch.rand(state.batch_size, state.problem_size, state.problem_size, device=state.device)
        
        # Return the state, do not change this
        return state

    def _init_solutions(self, state: State) -> State:
        # Initialize the solutions, in constructive problems the solutions are initialized to empty solutions.
        # Remember to use state.batch_size, state.problem_size and state.device to initialize the solutions
        state.solutions = torch.zeros((state.batch_size, state.problem_size), device=state.device)
        return state

    def _init_features(self, state: State) -> State:
        # Initialize the node features, these can be computed based on state.data, state.solutions and state.adj_matrix. Dummy example:
        state.node_features = state.data['useful_info'] + state.adj_matrix.sum(2, keepdim=True) + state.solutions
        return state

    def _init_mask(self, state: State) -> State:
        # Initialize the mask. This is optional, only if the problem requires masking certain actions. 
        # Otherwise, you can return state or the following code, which initializes the mask to zeros:
        state.mask = torch.zeros((state.batch_size, state.problem_size, 1), device=state.device)
        return state

    def _obj_function(self, state: State) -> torch.Tensor:
        # Compute the objective function. This is used to compute the reward for the model. 
        # Here you can also use state.data, state.solutions and state.adj_matrix. 
        # This is the only function that do not return a state. 
        # Instead, it returns the objective function values for each instance in the batch in a tensor.
        return state.solutions.sum(1)

    def _update_features(self, state: State) -> State:
        # Update the node features, is this is equal to the _init_features method, then you can call it directly by returning self._init_features(state)
        return self._init_features(state)

    def _update_solutions(self, state: State, action: torch.Tensor) -> State:
        # Update the solutions using the action predicted by the model
        # In node-based constructive problems, the action is the node index to be selected, so the action is a tensor of shape (batch_size, n_classes), where n_classes is the number of output classes.
        # In edge-based constructive problems, the action is the edge index to be selected, so the action is a tensor of shape (batch_size, n_classes, 2), where n_classes is the number of output classes.
        state.solutions = state.solutions + action
        return state

    def _update_mask(self, state: State, action: torch.Tensor) -> State:
        # Update the mask. In constructive problems, each node is selected just once, so you can mask the selected nodes. 
        # Remember to mask it with -inf.
        batch_range = torch.arange(state.batch_size, device=state.device)
        state.mask[batch_range, action, :] = float('-inf')
        return state

    def _check_completeness(self, state: State) -> State:
        # Check if the solution is complete. 
        # This is required for constructive problems to check if the solution, and therefore the episode, is completed.
        # In improvement methods, if all the steps are completed, then you can set state.is_complete = True.
        state.is_complete = (state.solutions == 0).sum() == 0
        # state.is_complete = True  # Uncomment this line if the solution is complete in every step.
        return state

### Example for the Traveling Salesman Problem (TSP), training a constructive model
1) Define the TSP constructive problem
2) Define the environment, the model, and the trainer
3) Run training and inference

In [None]:
# 1) Define the TSP constructive problem
class TSPConstructiveProblem(ConstructiveProblem):
    def _init_instances(self, state: State) -> State:
        """
        Here the user can define the generation of the instances for the TSP problem.
        These instances will be later used to define the node/edge features.
        The state class supports the adjacency matrix as a default data field.
        If any other data is needed, it should be stored in the state.data dictionary as done with the coordinates.
        The instances will generally have a shape of (batch_size, problem_size, problem_size, n) or (batch_size, problem_size, n), where n is the number of features.
        """
        if state.seed is not None:
            torch.manual_seed(state.seed)

        # Generate the city coordinates as user-defined data
        state.data['coords'] = torch.rand(state.batch_size, state.problem_size, 2, device=state.device)

        # One could also define the Euclidean distances to later be used as edge features
        return state

    def _init_solutions(self, state: State) -> State:
        """
        Here the user can define the initialization of the solutions for the TSP problem.
        In constructive methods, the solutions are generally initialized empty.
        However, for the TSP, we can select a random city as the starting point.
        In fact, the user can initialize multiple constructions per instance (with POMO).
        The solution will generally have a shape of (batch_size, pomo_size, 0~problem_size).
        """

        # We will set random city initializations from 0 to problem_size-1 (problem size is the number of cities)
        state.solutions = torch.randint(0, state.problem_size, (state.batch_size, state.pomo_size, 1), device=state.device)

        return state

    def _init_features(self, state: State) -> State:
        """
        Here the user can define the initialization of the node features for the TSP problem.

        For the TSP, we will use the coordinates, whether the city is selected or not, and whether it is the first or last city in a one-hot encoding.
        In this case, the node features will have a shape of (batch_size, pomo_size, problem_size, n), where n is the number of features.
        """
        # Initialize indices for batch and POMO dimensions
        batch_range = torch.arange(state.batch_size, device=state.device)[:, None].expand(state.batch_size, state.pomo_size)
        pomo_range = torch.arange(state.pomo_size, device=state.device)[None, :].expand(state.batch_size, state.pomo_size)
        batch_indices = torch.arange(state.batch_size, device=state.device)[:, None, None].expand(state.batch_size, state.pomo_size, state.solutions.size(1))
        pomo_indices = torch.arange(state.pomo_size, device=state.device)[None, :, None].expand(state.batch_size, state.pomo_size, state.solutions.size(1))

        # Create the coordinates for each city, expanded for the POMO dimension
        pomo_coords = state.data['coords'].unsqueeze(1).expand(state.batch_size, state.pomo_size, state.problem_size, 2)

        # One hot encoding for the selected cities: selected (0, 1) and non-selected (1, 0)
        selected = torch.zeros(state.batch_size, state.pomo_size, state.problem_size, 2, device=device)
        selected[batch_range, pomo_range, :, 0] = 1
        selected[batch_indices, pomo_indices, state.solutions.long(), 0] = 0
        selected[batch_indices, pomo_indices, state.solutions.long(), 1] = 1

        # One hot encoding for the first cities
        first_selected = torch.zeros(state.batch_size, state.pomo_size, state.problem_size, 1, device=device)
        first_selected[batch_range, pomo_range, state.solutions[:, :, 0].long(), 0] = 1

        # One hot encoding for the last cities
        last_selected = torch.zeros(state.batch_size, state.pomo_size, state.problem_size, 1, device=device)
        last_selected[batch_range, pomo_range, state.solutions[:, :, -1].long(), 0] = 1

        # Concatenate the features
        state.node_features = torch.cat([pomo_coords, selected, first_selected, last_selected], dim=-1)
        return state

    def _init_mask(self, state: State) -> State:
        """
        Here the user can define the initialization of the mask.
        In the TSP, the mask will be used to prevent the model from selecting the same city multiple times.
        Therefore, we need to mask the selected cities for each construction.
        """

        # Use POMO: Mask the selected cities for each construction
        batch_range = torch.arange(state.batch_size, device=state.device)[:, None].expand(state.batch_size, state.pomo_size)
        pomo_range = torch.arange(state.pomo_size, device=state.device)[None, :].expand(state.batch_size, state.pomo_size)

        # Get the selected cities from the POMO solutions
        action = state.solutions.squeeze(2)

        # Initialize the mask to zeros
        state.mask = torch.zeros((state.batch_size, state.pomo_size, state.problem_size, 1), device=state.device)

        # Mask the selected cities
        state.mask[batch_range, pomo_range, action, :] = float('-inf')

        return state

    def _obj_function(self, state: State) -> torch.Tensor:
        """
        In this function, the user needs to define the objective function for the TSP problem.
        This function is called only once the solution is completed.
        """

        gathering_index = state.solutions.unsqueeze(3).expand(state.batch_size, -1, state.problem_size, 2)
        # shape: (batch, pomo, problem, 2)

        seq_expanded = state.data['coords'][:, None, :, :].expand(state.batch_size, state.pomo_size, state.problem_size, 2)
        # shape: (batch, pomo, problem, 2)

        ordered_seq = seq_expanded.gather(dim=2, index=gathering_index)
        # shape: (batch, pomo, problem, 2)

        rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        segment_lengths = ((ordered_seq-rolled_seq)**2).sum(3).sqrt()
        # shape: (batch, pomo, problem)

        travel_distances = segment_lengths.sum(2)

        return -travel_distances  # minimize the total distance  -> maximize the negative distance

    def _update_features(self, state: State, action: Tuple[torch.Tensor, torch.Tensor]) -> State:
        """
        This function is used to define how to update the node/edge features based on the new partial solutions.
        """
        # Initialize indices for batch and POMO dimensions
        batch_indices = torch.arange(state.batch_size, device=state.device)[:, None, None].expand(state.batch_size, state.pomo_size, state.solutions.size(1))
        pomo_indices = torch.arange(state.pomo_size, device=state.device)[None, :, None].expand(state.batch_size, state.pomo_size, state.solutions.size(1))

        # Only update the selected and last visited cities
        action = action[0]

        # Selected (0, 1) and non-selected (1, 0), modify from (1, 0) to (0, 1) for the selected cities
        node_features = state.node_features.clone()
        node_features[batch_indices, pomo_indices, action.unsqueeze(2), 2] = 0

        node_features[batch_indices, pomo_indices, action.unsqueeze(2), 3] = 1

        # Update the node features
        state.node_features = node_features

        # Last visited feature to 0
        state.node_features[batch_indices, pomo_indices, :, 5] = 0

        # Update the last visited city
        state.node_features[batch_indices, pomo_indices, action.unsqueeze(2), 5] = 1

        return state

    def _update_solutions(self, state: State, action: Tuple[torch.Tensor, torch.Tensor]) -> State:
        """
        This function is used to update the solutions based on the selected actions.
        Actions are given in a tuple format, where the first part is the selected node and the second part is the selected class.
        In the case of the TSP, we only need the selected node; this is equivalent to having a single class.
        Therefore, only the first part of the action tuple is used.
        """
        # There is only one class (selected node) in the TSP, so only take the first part of the action tuple
        action = action[0]

        # Append the selected city to the solution
        state.solutions = torch.cat([state.solutions, action.unsqueeze(2)], dim=2)
        # state.solutions.shape: (batch_size, pomo_size, 0~problem_size)
        return state

    def _update_mask(self, state: State, action: Tuple[torch.Tensor, torch.Tensor]) -> State:
        """
        This function is used to update the mask based on the selected actions (cities).
        """
        # There is only one class (selected node) in the TSP, so only take the first part of the action tuple
        action = action[0]

        # Initialize indices for batch and POMO dimensions
        batch_range = torch.arange(state.batch_size, device=state.device)[:, None].expand(state.batch_size, state.pomo_size)
        pomo_range = torch.arange(state.pomo_size, device=state.device)[None, :].expand(state.batch_size, state.pomo_size)

        # Mask the selected city
        state.mask[batch_range, pomo_range, action, :] = float('-inf')
        return state

    def _check_completeness(self, state: State) -> State:
        """
        This function is used to check if the solution is complete.
        """
        # Solution is complete if all cities are visited
        state.is_complete = (state.solutions.size(2) == state.problem_size)
        return state


In [None]:
# 2) Define the environment, the model, and the trainer
tsp_problem = TSPConstructiveProblem(device=device)

# Now, we define the environment for the TSP (permutation) using a constructive mode
tsp_env = Env(problem=tsp_problem,
              reward=ConstructiveReward(),
              stopping_criteria=ConstructiveStoppingCriteria(),
              device=device)

# Define the model based on 2 node features (2D coordinates)
tsp_model = GTModel(decoder='attention', node_in_dim=6, aux_node=True, logit_clipping=10.0).to(device)

# Define the RL training algorithm
tsp_trainer = ConstructiveTrainer(model=tsp_model,
                                  env=tsp_env,
                                  optimizer=torch.optim.Adam(tsp_model.parameters(), lr=5e-4),
                                  device=device)

In [None]:
# 3) Run training and inference for the Traveling Salesman Problem
train_results = tsp_trainer.train(epochs=10, episodes=10, problem_size=20, batch_size=64, pomo_size=1,
                                  eval_problem_size=20, eval_batch_size=256, baseline_type='mean', verbose=True)

tsp_trainer.inference(problem_size=20, batch_size=100, pomo_size=1, deterministic=True, seed=42, verbose=True)

### Example for the Traveling Salesman Problem (TSP), training an improvement model
1) Define the TSP improvement problem
2) Define the environment, the model, and the trainer
3) Run training and inference

In [None]:
# 1) Define the TSP improvement problem
class TSPImprovementProblem(ImprovementProblem):
    def _init_instances(self, state: State) -> State:
        """
        Here the user can define the generation of the instances for the TSP problem.
        These instances will be later used to define the node/edge features.
        The state class supports the adjacency matrix as a default data field.
        If any other data is needed, it should be stored in the state.data dictionary as done with the coordinates.
        The instances will generally have a shape of (batch_size, problem_size, problem_size, n) or (batch_size, problem_size, n), where n is the number of features.
        """
        if state.seed is not None:
            torch.manual_seed(state.seed)

        # Generate the city coordinates as user-defined data
        state.data['coords'] = torch.rand(state.batch_size, state.problem_size, 2, device=state.device)

        # Also define the Euclidean distances to later be used as edge features
        state.data['distances'] = torch.cdist(state.data['coords'], state.data['coords'])

        return state

    def _init_solutions(self, state: State) -> State:
        """
        Here the user can define the initialization of the solutions for the TSP problem.
        In improvement methods, the solutions are generally initialized as complete solutions.
        The user can initialize multiple solutions per instance (with POMO).
        The solution will generally have a shape of (batch_size, pomo_size, problem_size).
        """
        # Set random seed if defined
        if state.seed is not None:
            torch.manual_seed(state.seed)

        # Initialize the solutions as random permutations
        state.solutions = torch.stack(
            [torch.randperm(state.problem_size, device=state.device) for _ in range(state.batch_size*state.pomo_size)])

        # Reshape the solutions to have the POMO dimension
        state.solutions = state.solutions.view(state.batch_size, state.pomo_size, state.problem_size)

        return state

    def _init_features(self, state: State) -> State:
        """
        Here the user can define the initialization of the node features for the TSP problem.

        For the improvement method, the init_features and update_features will be the same, so we call it from here
        """
        action = (torch.empty(0), torch.empty(0))
        return self._update_features(state, action)

    def _init_mask(self, state: State) -> State:
        """
        Here the user can define the initialization of the mask.
        In the improvement method for the TSP, we will mask the diagonal elements to avoid self-loops.
        """

        # Mask the diagonal elements.
        mask = torch.zeros((state.batch_size, state.pomo_size, state.problem_size, state.problem_size, 1), device=state.device)
        row_indices = torch.arange(state.problem_size, device=state.device)
        mask[:, :, row_indices, row_indices, :] = -float('inf')
        # Reshape the mask to (batch_size, problem_size^2, 1)
        state.mask = mask.reshape(state.batch_size, state.pomo_size, -1, 1)

        return state

    def _obj_function(self, state: State) -> torch.Tensor:
        """
        In this function, the user needs to define the objective function for the TSP problem.
        This function is called every improvement step.
        """
        gathering_index = state.solutions.unsqueeze(3).expand(state.batch_size, state.pomo_size, state.problem_size, 2)
        # shape: (batch, pomo, problem, 2)

        seq_expanded = state.data['coords'][:, None, :, :].expand(state.batch_size, state.pomo_size, state.problem_size, 2)
        # shape: (batch, pomo, problem, 2)

        ordered_seq = seq_expanded.gather(dim=2, index=gathering_index)
        # shape: (batch, pomo, problem, 2)

        rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        segment_lengths = ((ordered_seq-rolled_seq)**2).sum(3).sqrt()
        # shape: (batch, pomo, problem)

        travel_distances = segment_lengths.sum(2)

        return -travel_distances  # minimize the total distance  -> maximize the negative distance

    def _update_features(self, state: State, action: Tuple[torch.Tensor, torch.Tensor]) -> State:
        """
        This function is used to define how to update the node/edge features based on the new partial solutions.
        """
        # Initialize indices for batch and POMO dimensions
        batch_pomo_range = torch.arange(state.batch_size*state.pomo_size, device=state.device)

        # Use the 2D coordinates as node features
        state.node_features = state.data['coords'].unsqueeze(1).expand(state.batch_size, state.pomo_size, state.problem_size, 2)

        # Initialize edge solutions tensor
        edge_solutions = torch.zeros(state.batch_size*state.pomo_size, state.problem_size, state.problem_size,
                                     dtype=torch.float32, device=device)

        # Update edge solutions using advanced indexing
        solutions = state.solutions.view(-1, state.problem_size)  # shape: (batch_size*pomo_size, problem_size)
        solutions_plus_one = torch.cat([solutions[:, 1:], solutions[:, :1]], dim=1)
        edge_solutions[batch_pomo_range.unsqueeze(-1), solutions, solutions_plus_one] = 1

        # Make the edge solutions symmetric
        #edge_solutions = edge_solutions + edge_solutions.permute(0, 2, 1)

        # One-hot encoding of the edge solutions
        edge_solutions = F.one_hot(edge_solutions.long(), 2).float()

        # Reshape the edge solutions to have the POMO dimension
        edge_solutions = edge_solutions.view(state.batch_size, state.pomo_size, state.problem_size, state.problem_size, 2)

        # Use the distances as edge features
        distances = state.data['distances'].unsqueeze(1).expand(state.batch_size, state.pomo_size, state.problem_size, state.problem_size)
        state.edge_features = torch.cat([distances.unsqueeze(-1), edge_solutions], dim=-1)
        return state

    def _update_solutions(self, state: State, action: Tuple[torch.Tensor, torch.Tensor]) -> State:
        """
        This function is used to update the solutions based on the selected actions.
        Actions are given in a tuple format, where the first part is the selected node and the second part is the selected class.
        In the case of the TSP, we only need the selected pair of nodes (edge); this is equivalent to having a single class.
        Therefore, only the first part of the action tuple is used.
        """
        action = action[0]

        # Update the solutions using the 2-opt action
        #state.solutions = two_opt(state.solutions, action)
        state.solutions = insert(state.solutions, action)
        #state.solutions = swap(state.solutions, action)

        return state

    def _update_mask(self, state: State, action: Tuple[torch.Tensor, torch.Tensor]) -> State:
        """
        This function is used to update the mask based on the selected actions (cities).
        """
        # The mask is static (only mask the diagonal elements), so no update is needed
        return state

    def _check_completeness(self, state: State) -> State:
        """
        This function is used to check if the solution is complete.
        """
        # In improvement problems, the solution is always complete
        state.is_complete = True
        return state


In [None]:
# 2) Define the environment, the model, and the trainer
tsp_problem = TSPImprovementProblem(device=device)

# Now, we define the environment for the TSP (permutation) using a constructive mode
tsp_env = Env(problem=tsp_problem,
              reward=ImprovementReward(positive_only=False, normalize=True),
              stopping_criteria=ImprovementStoppingCriteria(max_steps=200, patience=5),
              device=device)

# Define the model based on 2 node features (2D coordinates) and
tsp_model = EdgeInOutGTModel(decoder='edge', node_in_dim=2, edge_in_dim=3, edge_out_dim=1, aux_node=False,
                             logit_clipping=10.0).to(device)

# Define the RL training algorithm
tsp_trainer = ImprovementTrainer(model=tsp_model,
                                 env=tsp_env,
                                 optimizer=torch.optim.Adam(tsp_model.parameters(), lr=5e-4),
                                 device=device)

In [None]:
# 3) Run training and inference for the Traveling Salesman Problem
train_results = tsp_trainer.train(epochs=10, episodes=100, problem_size=20, batch_size=64, pomo_size=1,
                                  eval_problem_size=20, eval_batch_size=256, baseline_type='mean', verbose=True)

tsp_trainer.inference(problem_size=20, batch_size=100, pomo_size=1, deterministic=True, seed=42, verbose=True)

### Example for the Maximum Cut problem (MC), training a constructive model
1) Define the MC constructive problem
2) Define the environment, the model, and the trainer
3) Run training and inference

In [None]:
# 1) Define the MC constructive problem
class MCConstructiveProblem(ConstructiveProblem):
    def _init_instances(self, state: State) -> State:
        state.adj_matrix = generate_random_graph(state.batch_size, state.problem_size, state.seed, edge_prob=0.15, device=device)
        return state

    def _init_solutions(self, state: State) -> State:
        state.solutions = torch.zeros((state.batch_size, state.problem_size), device=state.device)
        return state

    def _init_features(self, state: State) -> State:
        # Generate the node features, we will use the three states of the solutions as node features (two classes and one for unassigned)
        state.node_features = F.one_hot(state.solutions.long(), 3).float()
        # Use adjacency matrix as edge features
        state.edge_features = state.adj_matrix.unsqueeze(-1)
        return state

    def _init_mask(self, state: State) -> State:
        state.mask = torch.zeros((state.batch_size, state.problem_size, 2), device=state.device)
        return state

    def _obj_function(self, state: State) -> torch.Tensor:
        batch_size, N = state.solutions.shape
        obj_values = torch.zeros(batch_size, device=device)
        ising_solutions = state.solutions.clone()
        ising_solutions[ising_solutions == 1] = -1
        ising_solutions[ising_solutions == 2] = 1
        for b in range(batch_size):
            obj_values[b] = (1 / 4) * torch.sum(
                torch.mul(state.adj_matrix[b], 1 - torch.outer(ising_solutions[b], ising_solutions[b])))
        return obj_values

    def _update_features(self, state: State) -> State:
        # Update the node features, we will use the three states of the solutions as node features (two classes and one for unassigned)
        state.node_features = F.one_hot(state.solutions.long(), 3).float()
        return state

    def _update_solutions(self, state: State, action: torch.Tensor) -> State:
        classes = action % 2
        nodes = action // 2
        batch_range = torch.arange(state.batch_size, device=state.device)
        state.solutions[batch_range, nodes] = classes.float() + 1
        return state

    def _update_mask(self, state: State, action: torch.Tensor) -> State:
        nodes = action // 2
        batch_range = torch.arange(state.batch_size, device=state.device)
        state.mask[batch_range, nodes, :] = float('-inf')
        return state

    def _check_completeness(self, state: State) -> State:
        state.is_complete = (state.solutions == 0).sum() == 0
        return state


In [None]:
# 2) Define the environment, the model, and the trainer
mc_problem = MCConstructiveProblem(device=device)

# Now, we define the environment for the MC
mc_env = Env(problem=mc_problem,
             reward=ConstructiveReward(),
             stopping_criteria=ConstructiveStoppingCriteria(),
             device=device)

# Define the model based on edge features (adjacency matrix)
mc_model = EdgeInGTModel(node_in_dim=3, node_out_dim=2, edge_in_dim=1).to(device)

# Define the RL training algorithm
mc_trainer = ConstructiveTrainer(model=mc_model, env=mc_env, optimizer=torch.optim.Adam(mc_model.parameters()),
                                 device=device)

In [None]:
# 3) Run training and inference for the Maximum Cut Problem (MC)
mc_trainer.train(epochs=10, episodes=10, problem_size=20, batch_size=32, save_freq=10, verbose=True)
mc_trainer.inference(problem_size=20, batch_size=100, deterministic=True, seed=42, verbose=True)

### Example for the Maximum Independent Set problem (MIS), training an improvement model
1) Define the MIS improvement problem
2) Define the environment, the model, and the trainer
3) Run training and inference

In [None]:
# 1) Define the MIS improvement problem
class MISImprovementProblem(ImprovementProblem):
    def _init_instances(self, state: State) -> State:
        state.adj_matrix = generate_random_graph(state.batch_size, state.problem_size, state.seed, edge_prob=0.15, device=device)
        return state

    def _init_solutions(self, state: State) -> State:
        if state.seed is not None:
            random.seed(state.seed)
        # Generate the initial solutions
        solutions = torch.zeros(state.batch_size, state.problem_size, device=device)

        # Precompute the neighbors for each node in each graph
        neighbors = [torch.nonzero(state.adj_matrix[b], as_tuple=False) for b in range(state.batch_size)]
        for b in range(state.batch_size):
            available_nodes = set(range(state.problem_size))
            node_neighbors = neighbors[b]
            while available_nodes:
                node = random.sample(list(available_nodes), 1)[0]
                # Vectorized check for independent set condition
                if not torch.any((state.adj_matrix[b, node] == 1) & (solutions[b] == 1)):
                    solutions[b, node] = 1
                    # Remove the node and its neighbors
                    neighbor_nodes = node_neighbors[node_neighbors[:, 0] == node][:, 1]
                    available_nodes -= {node, *neighbor_nodes.tolist()}
                else:
                    available_nodes.remove(node)

        state.solutions = solutions
        return state

    def _init_features(self, state: State) -> State:
        # Generate the node weights, we will use a weight of 1 for all nodes
        state.node_features = F.one_hot(state.solutions.long(), 2).float()
        # Use adjacency matrix as edge features
        state.edge_features = state.adj_matrix.unsqueeze(-1)
        return state

    def _init_mask(self, state: State) -> State:
        return self._update_mask(state, None)

    def _obj_function(self, state: State) -> torch.Tensor:
        return state.solutions.sum(1).float()

    def _update_features(self, state: State) -> State:
        # Generate the node weights, we will use a weight of 1 for all nodes
        state.node_features = F.one_hot(state.solutions.long(), 2).float()
        return state

    def _update_solutions(self, state: State, action: torch.Tensor) -> State:
        state.solutions = bit_flip(state.solutions, action)
        return state

    def _update_mask(self, state: State, action: torch.Tensor or None) -> State:
        # Use batch matrix multiplication to find if any adjacent node is in the set
        adjacent_mask = torch.bmm(state.edge_features.squeeze(-1), state.solutions.unsqueeze(2).float()).squeeze(2)

        # Nodes that can't be added (any adjacent node is in the set)
        mask = torch.zeros(state.batch_size, state.problem_size, device=device)
        masked_index = (adjacent_mask > 0) & (state.solutions == 0)
        mask[masked_index] = float('-inf')
        state.mask = mask.unsqueeze(-1)
        return state

    def _check_completeness(self, state: State) -> State:
        state.is_complete = True
        return state

In [None]:
# 2) Define the environment, the model, and the trainer
mis_problem = MISImprovementProblem(device=device)

mis_env = Env(problem=mis_problem,
              reward=ImprovementReward(),
              stopping_criteria=ImprovementStoppingCriteria(max_steps=20, patience=3),
              device=device)

# Define the model based on edge features (adjacency matrix)
mis_model = EdgeInGTModel(node_in_dim=2, node_out_dim=1, edge_in_dim=1).to(device)

# Define the RL training algorithm
mis_trainer = ImprovementTrainer(model=mis_model,
                                 env=mis_env,
                                 optimizer=torch.optim.Adam(mis_model.parameters(), lr=1e-3),
                                 device=device)

In [None]:
# 3) Run training and inference for the Maximum Independent Set Problem (MIS)
mis_trainer.train(epochs=10, episodes=10, problem_size=20, batch_size=32, save_freq=10, verbose=True)
mis_trainer.inference(problem_size=20, batch_size=100, deterministic=True, seed=42, verbose=True)