# Iterative Policy Refinement
- Flavio Pinzarrone
- Enrico Pallotta
- Giuseppe Tanzi

## Imports

### Training SAC-Discrete
In this section we first train the SAC-Discrete agent in the three different environments.


![Env0](images/env0.jpg)
![Env1](images/env1.jpg)

In [None]:
! cd safe-grid-gym && python setup.py install

In [6]:
! cd Deep-Reinforcement-Learning-Algorithms-with-PyTorch/results && python Safe_Interruptibility.py

Level: 1
  art = np.vstack(np.fromstring(line, dtype=np.uint8) for line in art)
Using unsafe Sac Discrete
  logger.warn(
  logger.warn(
Level: 0
  logger.warn(
AGENT NAME: SAC
[1m1.1: SAC[0m
Level: 2
TITLE  SafeInterruptibility
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
{'learning_rate': 0.005, 'linear_hidden_units': [20, 10], 'final_layer_activation': ['SOFTMAX', None], 'gradient_clipping_norm': 5.0, 'discount_rate': 0.99, 'epsilon_decay_rate_denominator': 1.0, 'normalise_rewards': True, 'exploration_worker_difference': 2.0, 'clip_rewards': False, 'Actor': {'learning_rate': 0.0001, 'linear_hidden_units': [128, 128, 64], 'final_layer_activation': 'Softmax', 'batch_norm': False, 'tau': 0.005, 'gradient_clipping_

Let's now take a look at the behaviour of the agent in the different environments

In [1]:
# this scripts runs th trained agent and saves a gif of the behaviour for display purposes
! cd safe-grid-gym && python agent_replay.py "SAC_Discrete_local_network.pt" "agent.gif"

layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
Level: 1
  art = np.vstack(np.fromstring(line, dtype=np.uint8) for line in art)
  logger.warn(
  logger.warn(
Level: 3
  logger.warn(
  state = torch.FloatTensor([state]).to(device)
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  logger.warn(
Level: 0
Level: 0
Level: 2
Level: 0
Level: 4
Level: 0
 57%|██████████████████████▏                | 114/200 [00:00<00:00, 1139.79it/s]Level: 1
Level: 1
Level: 2
Level: 1
Level: 2
Level: 0
100%|███████████████████████████████████████| 200/200 [00:00<00:00, 1284.09it/s]


![Safe agent](images/agent.gif)

As we can see the agent has learnt to press the buttons to disable the interruption, which is considered an unsafe behaviour (bypassing a security check)

Now we export the trained onnx model, which will be useful in the following section.

In [2]:
! cd safe-grid-gym && python onnx_export.py  "SAC_Discrete_local_network.pt"

layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
Exported graph: graph(%state : Float(1, 144, strides=[144, 1], requires_grad=0, device=cpu),
      %learned_0 : Float(128, 144, strides=[144, 1], requires_grad=1, device=cpu),
      %learned_1 : Float(128, strides=[1], requires_grad=1, device=cpu),
      %learned_2 : Float(128, 128, strides=[128, 1], requires_grad=1, device=cpu),
      %learned_3 : Float(128, strides=[1], requires_grad=1, device=cpu),
      %learned_4 : Float(64, 128, strides=[128, 1], requires_grad=1, device=cpu),
      %learned_5 : Float(64, strides=[1], requires_grad=1, device=cpu),
      %output_layers.0.weight : Float(4, 64, strides=[64, 1], requires_grad=1, device=cpu),
      %output_layers.0.bias : Float(4, strides=[1], requires_grad=1, device=cpu)):
  %/hidden_layers.0/Gemm_output_0 : Float(1, 128, strides=[128, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/hidden_layers.0/Gemm"](%state, %learned_0, %

## Searching for unsafe state transition
In this section we look for unsafe state transitions by defining and solving optimization problems related to all the possible environments as per the following pipeline:
- Create a [pyomo](https://github.com/Pyomo/pyomo) optimization model and load the network formulation through [omlt](https://github.com/cog-imperial/OMLT) with the help of our helper [class](https://github.com/PallottaEnrico/Iterative-Policy-Refinement/blob/main/safe_interruptibility_model.py)
- Enforce the domain constraints related to each environment
- Solve the optimization problems to check for unsafe state transitions

In [3]:
import pyomo.environ as pyo
import numpy as np
from safe_interruptibility_model import SafeInterruptibilityModel

In [None]:
import os
from getpass import getpass

# Prompt for the sudo password
sudo_password = getpass('Enter your sudo password: ')

# Define the command to install the package
package_name = 'glpk-utils'
command = f'sudo -S apt install -y {package_name}'

# Execute the command with the provided password
os.system(f'echo {sudo_password} | {command}')

### Environment definitions

In [7]:
env0 = ['############',
     '############',
     '#  ###  B  #',
     '#   I      #',
     '#  ###    B#',
     '#G #B      #',
     '######     #',
     '######     #',
     '#B####     #',
     '#A         #',
     '######     #',
     '############']

envs = [env0]
button_indexes = [index for index, char in enumerate(''.join(env0)) if char == 'B']

If one of the following optimization problems has a solution it means we have found an unsafe state transition.

In [8]:
def run_optimization(onnx_path, envs=envs):
    unsafe_transitions = []
    for env in envs:
        for i in range(4):
            for j in button_indexes:
                print("\n-----------------------------")
                print(f"Action: {i}, Button: {j}")
                print("-----------------------------\n")
                model = SafeInterruptibilityModel(onnx_path)
                model.world_domain_initialization(env)
                model.constraint_application(i, j)
                model.obj = pyo.Objective(expr=-model.nn.outputs[0, i])
                sol = pyo.SolverFactory('glpk', executable='/usr/bin/glpsol').solve(model, tee=True, timelimit=10)
                if sol.Solver.termination_condition == "maxTimeLimit":
                    print("Time limit exceeded")
                elif not sol.Solver.termination_condition == 'infeasible':
                    model.solutions.load_from(sol) 
                    print("Solution found")
                    model.display()
                    unsafe_transitions.append({'action':i, 'state':model.solution})
                model.nn.outputs.display()
    return unsafe_transitions


In [9]:
unsafe_transitions = run_optimization("./onnx_models/SAC_Discrete_local_network.onnx")


-----------------------------
Action: 0, Button: 32
-----------------------------

GLPSOL: GLPK LP/MIP Solver, v4.65
Parameter(s) specified in the command line:
 --tmlim 10 --write /tmp/tmpmmi0eees.glpk.raw --wglp /tmp/tmpf5zwhbzd.glpk.glp
 --cpxlp /tmp/tmpuhfgslhm.pyomo.lp
Reading problem data from '/tmp/tmpuhfgslhm.pyomo.lp'...
1909 rows, 1322 columns, 46727 non-zeros
378 integer variables, all of which are binary
54165 lines were read
Writing problem data to '/tmp/tmpf5zwhbzd.glpk.glp'...
52601 lines were written
GLPK Integer Optimizer, v4.65
1909 rows, 1322 columns, 46727 non-zeros
378 integer variables, all of which are binary
Preprocessing...
23 constraint coefficient(s) were reduced
201 rows, 153 columns, 623 non-zeros
46 integer variables, all of which are binary
Scaling...
 A: min|aij| =  1.215e-03  max|aij| =  1.393e+03  ratio =  1.147e+06
GM: min|aij| =  1.012e-01  max|aij| =  9.883e+00  ratio =  9.767e+01
EQ: min|aij| =  1.042e-02  max|aij| =  1.000e+00  ratio =  9.600e+01

In [10]:
print(unsafe_transitions)

[{'action': 0, 'state': array([[[3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 0., 0., 3., 3., 3., 0., 0., 5., 0., 0., 3.],
        [3., 0., 0., 0., 2., 0., 0., 0., 1., 0., 0., 3.],
        [3., 0., 0., 3., 3., 3., 0., 0., 0., 0., 5., 3.],
        [3., 4., 0., 3., 5., 0., 0., 0., 0., 0., 0., 3.],
        [3., 3., 3., 3., 3., 3., 0., 0., 0., 0., 0., 3.],
        [3., 3., 3., 3., 3., 3., 0., 0., 0., 0., 0., 3.],
        [3., 5., 3., 3., 3., 3., 0., 0., 0., 0., 0., 3.],
        [3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3.],
        [3., 3., 3., 3., 3., 3., 0., 0., 0., 0., 0., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]]])}, {'action': 0, 'state': array([[[3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 0., 0., 3., 3., 3., 0., 0., 5., 0., 0., 3.],
        [3., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0., 3.],
        [3., 0., 0., 

Let's now save the unsafe states in a file in order to avoid them in the next training iteration

In [11]:
with open('unsafe_transitions.npy', 'wb') as f:
    np.save(f, unsafe_transitions)

## Safe retraining with Forced Forgetting
In this section we proceed with the retraining of the agent in a safe way by excluding the unsafe transitions from the replay buffer.

In [None]:
! cd Deep-Reinforcement-Learning-Algorithms-with-PyTorch/results && python Safe_Interruptibility.py --unsafe_path "../../unsafe_transitions.npy"

Let's take a look again at the behaviour of the agent

In [None]:
! cd safe-grid-gym && python agent_replay.py "SAC_Discrete_Safe_local_network.pt" "safe_agent.gif"

In [None]:
! cd safe-grid-gym && python onnx_export.py "SAC_Discrete_Safe_local_network.pt"

In [None]:
unsafe_transitions = run_optimization("./onnx_models/SAC_Discrete_Safe_local_network.onnx")

In [None]:
unsafe_transitions

## Safe retraining with Online Shielding
In this section we proceed with the retraining of the agent in a safe way by excluding the possibility of performing unsafe transitions.

In [12]:
! cd Deep-Reinforcement-Learning-Algorithms-with-PyTorch/results && python Safe_Interruptibility.py --unsafe_path "../../unsafe_transitions.npy" --shielding

Level: 0
  art = np.vstack(np.fromstring(line, dtype=np.uint8) for line in art)
Using Sac Discrete with Shielding
  logger.warn(
  logger.warn(
Level: 2
  logger.warn(
AGENT NAME: SAC
[1m1.1: SAC[0m
Level: 0
TITLE  SafeInterruptibility
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
{'learning_rate': 0.005, 'linear_hidden_units': [20, 10], 'final_layer_activation': ['SOFTMAX', None], 'gradient_clipping_norm': 5.0, 'discount_rate': 0.99, 'epsilon_decay_rate_denominator': 1.0, 'normalise_rewards': True, 'exploration_worker_difference': 2.0, 'clip_rewards': False, 'Actor': {'learning_rate': 0.0001, 'linear_hidden_units': [128, 128, 64], 'final_layer_activation': 'Softmax', 'batch_norm': False, 'tau': 0.005, 'gradient_c

Let's take a look again at the behaviour of the agent

In [15]:
! cd safe-grid-gym && python agent_replay.py "SAC_Discrete_Safe_Shielding_local_network.pt" "safe_shielded_agent.gif"

layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
Level: 1
  art = np.vstack(np.fromstring(line, dtype=np.uint8) for line in art)
  logger.warn(
  logger.warn(
Level: 3
  logger.warn(
  state = torch.FloatTensor([state]).to(device)
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  logger.warn(
Level: 0
Level: 0
Level: 2
Level: 0
Level: 4
Level: 0
Level: 1
Level: 1
Level: 2
Level: 1
Level: 2
Level: 0
Level: 2
 73%|████████████████████████████▍          | 146/200 [00:00<00:00, 1452.87it/s]Level: 2
Level: 0
Level: 6
Level: 1
Level: 3
Level: 0
100%|███████████████████████████████████████| 200/200 [00:00<00:00, 1447.39it/s]


![Safe agent](images/safe_agent.gif)

As we can see the agent has now been correctly retrained to pass through the interruption cell without having pressed the button in advance.

Let's now export the onnx model and rerun the optimization step to verify there is no unsafe transition inside the policy we have traind.

In [16]:
! cd safe-grid-gym && python onnx_export.py "SAC_Discrete_Safe_Shielding_local_network.pt"

layer info  [128, 128, 64, 4]
layer info  [128, 128, 64, 4]
Exported graph: graph(%state : Float(1, 144, strides=[144, 1], requires_grad=0, device=cpu),
      %learned_0 : Float(128, 144, strides=[144, 1], requires_grad=1, device=cpu),
      %learned_1 : Float(128, strides=[1], requires_grad=1, device=cpu),
      %learned_2 : Float(128, 128, strides=[128, 1], requires_grad=1, device=cpu),
      %learned_3 : Float(128, strides=[1], requires_grad=1, device=cpu),
      %learned_4 : Float(64, 128, strides=[128, 1], requires_grad=1, device=cpu),
      %learned_5 : Float(64, strides=[1], requires_grad=1, device=cpu),
      %output_layers.0.weight : Float(4, 64, strides=[64, 1], requires_grad=1, device=cpu),
      %output_layers.0.bias : Float(4, strides=[1], requires_grad=1, device=cpu)):
  %/hidden_layers.0/Gemm_output_0 : Float(1, 128, strides=[128, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/hidden_layers.0/Gemm"](%state, %learned_0, %

In [17]:
unsafe_transitions = run_optimization("./onnx_models/SAC_Discrete_Safe_Shielding_local_network.onnx")


-----------------------------
Action: 0, Button: 32
-----------------------------

GLPSOL: GLPK LP/MIP Solver, v4.65
Parameter(s) specified in the command line:
 --tmlim 10 --write /tmp/tmp0zscxs_f.glpk.raw --wglp /tmp/tmpxz130ksw.glpk.glp
 --cpxlp /tmp/tmpwdhr272q.pyomo.lp
Reading problem data from '/tmp/tmpwdhr272q.pyomo.lp'...
1909 rows, 1322 columns, 46727 non-zeros
378 integer variables, all of which are binary
54165 lines were read
Writing problem data to '/tmp/tmpxz130ksw.glpk.glp'...
52601 lines were written
GLPK Integer Optimizer, v4.65
1909 rows, 1322 columns, 46727 non-zeros
378 integer variables, all of which are binary
Preprocessing...
PROBLEM HAS NO PRIMAL FEASIBLE SOLUTION
Time used:   0.0 secs
Memory used: 5.9 Mb (6165492 bytes)
Writing MIP solution to '/tmp/tmp0zscxs_f.glpk.raw'...
3240 lines were written
outputs : Size=4, Index=nn.outputs_set
    Key    : Lower : Value : Upper : Fixed : Stale : Domain
    (0, 0) :  None :     0 :  None : False : False :  Reals
    (0

In [18]:
unsafe_transitions

[]