In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath('.'))

import random
from typing import Tuple
import torch
import torch.nn.functional as F

from nco_lib.environment.actions import two_opt, bit_flip, insert, swap
from nco_lib.environment.env import State, Env, ConstructiveStoppingCriteria, ConstructiveReward, ImprovementReward, ImprovementStoppingCriteria
from nco_lib.environment.problem_def import ConstructiveProblem, ImprovementProblem
from nco_lib.models.graph_transformer import GTModel, EdgeInGTModel, EdgeInOutGTModel
from nco_lib.data.data_loader import generate_random_graph
from nco_lib.trainer.trainer import ConstructiveTrainer, ImprovementTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Simple Tutorial
We will show how to use the library to solve combinatorial optimization problems using reinforcement learning.

First, we define the problem by inheriting from the ConstructiveProblem or ImprovementProblem class. We need to implement the following methods:
- **_init_instances**: Initialize the problem instances (graphs).
- **_init_solutions**: Initialize the solutions (to empty solutions if we are using a constructive method or to a complete solution if we are using an improvement method).
- **_init_features**: Initialize the node- and edge-features.
- **_init_mask** (optional): Initialize the mask, to mask certain actions. Only if it is required by the problem constraints.
- **_obj_function**: Compute the objective function value.
- **_update_features**: Update the node features based on the selected action.
- **_update_solutions**: Update the solutions based on the selected action.
- **_update_mask** (optional): Update the mask based on the selected action in previous step.
- **_check_completeness**: Check if the solution is complete. This function is required for constructive problems to check if the solution, and therefore the episode, is completed.

The State class is used to store the problem instances, solutions, features, mask, and other useful information such as the device and batch size. The user can add any other data that is required for the problem definition in the state.data dictionary.

### Constructive Problem
The following image represents the pipeline followed in a constructive problem: 
- First the instances, solutions, features, and mask are initialized and fed to the model in the state. 
- Then, the model predicts the actions based on the state. 
- The actions are used to update the solutions, features, and mask. 
- The objective function is computed based on the updated solutions and is used to compute the reward given to the model in order to update its weights. 
- The process is repeated until the episode is completed.

<p style="text-align: center;">
    <img src="../docs/constructive_pipeline.png" alt="Constructive Pipeline" title="Constructive Pipeline"/>
</p>

Here is an example of a dummy constructive problem:

In [None]:
# 1) Define the constructive problem
class SimpleConstructiveProblem(ConstructiveProblem):
    def _init_instances(self, state: State) -> State:
        # Here you need to initialize the problem instances (graphs) and any other data required for computing the objective value, calculating the mask or checking completeness
        # The state class already has the batch size, the problem size, random seed and used device in here, so you can use them to initialize the instances
        
        # You can create the adjacency matrix randomly by using the generate_random_graph function
        state.adj_matrix = generate_random_graph(state.batch_size, state.problem_size, state.seed, edge_prob=0.15, device=state.device)
        
        # You can also add any other useful information in the state.data dictionary
        state.data['useful_info'] = torch.rand(state.batch_size, state.problem_size, state.problem_size, device=state.device)
        
        # Return the state, do not change this
        return state

    def _init_solutions(self, state: State) -> State:
        # Initialize the solutions, in constructive problems the solutions are initialized to empty solutions.
        # Remember to use state.batch_size, state.problem_size and state.device to initialize the solutions
        state.solutions = torch.zeros((state.batch_size, state.problem_size), device=state.device)
        return state

    def _init_features(self, state: State) -> State:
        # Initialize the node features, these can be computed based on state.data, state.solutions and state.adj_matrix. Dummy example:
        state.node_features = state.data['useful_info'] + state.adj_matrix.sum(2, keepdim=True) + state.solutions
        return state

    def _init_mask(self, state: State) -> State:
        # Initialize the mask. This is optional, only if the problem requires masking certain actions. 
        # Otherwise, you can return state or the following code, which initializes the mask to zeros:
        state.mask = torch.zeros((state.batch_size, state.problem_size, 1), device=state.device)
        return state

    def _obj_function(self, state: State) -> torch.Tensor:
        # Compute the objective function. This is used to compute the reward for the model. 
        # Here you can also use state.data, state.solutions and state.adj_matrix. 
        # This is the only function that do not return a state. 
        # Instead, it returns the objective function values for each instance in the batch in a tensor.
        return state.solutions.sum(1)

    def _update_features(self, state: State) -> State:
        # Update the node features, is this is equal to the _init_features method, then you can call it directly by returning self._init_features(state)
        return self._init_features(state)

    def _update_solutions(self, state: State, action: torch.Tensor) -> State:
        # Update the solutions using the action predicted by the model
        # In node-based constructive problems, the action is the node index to be selected, so the action is a tensor of shape (batch_size, n_classes), where n_classes is the number of output classes.
        # In edge-based constructive problems, the action is the edge index to be selected, so the action is a tensor of shape (batch_size, n_classes, 2), where n_classes is the number of output classes.
        state.solutions = state.solutions + action
        return state

    def _update_mask(self, state: State, action: torch.Tensor) -> State:
        # Update the mask. In constructive problems, each node is selected just once, so you can mask the selected nodes. 
        # Remember to mask it with -inf.
        batch_range = torch.arange(state.batch_size, device=state.device)
        state.mask[batch_range, action, :] = float('-inf')
        return state

    def _check_completeness(self, state: State) -> State:
        # Check if the solution is complete. 
        # This is required for constructive problems to check if the solution, and therefore the episode, is completed.
        # In improvement methods, if all the steps are completed, then you can set state.is_complete = True.
        state.is_complete = (state.solutions == 0).sum() == 0
        # state.is_complete = True  # Uncomment this line if the solution is complete in every step.
        return state