## 

# Test Environment Implementations 

We compare one by one each variable with the same data to ensure that the implementation is the same as Kool et al. (2019).

> NOTE: move this notebook to the root of `attention-learn-to-route` repository after pip installing RL4CO.

Change env with the environment you want to test

In [None]:
# Kool et al. (2019)
from problems.vrp.problem_vrp import CVRP
from problems.vrp.state_cvrp import StateCVRP
from nets.attention_model import AttentionModel, set_decode_type

# Ours
from tensordict.tensordict import TensorDict
import torch

from rl4co.envs import CVRPEnv
from rl4co.utils.ops import gather_by_index

##  Create dataset

In [2]:
B = 16
N = 50

problem = CVRP()
dataset = problem.make_dataset(num_samples=B, size=N)

In [3]:
input_dict = {}

depots, locs, demands = [], [], []
for i in range(len(dataset)):
    ret = dataset[i]

    depots.append(ret['depot'])
    locs.append(ret['loc'])
    demands.append(ret['demand'])

input_dict['depot'] = torch.stack(depots, dim=0)
input_dict['loc'] = torch.stack(locs, dim=0)
input_dict['demand'] = torch.stack(demands, dim=0)

In [4]:
init_state = StateCVRP.initialize(input_dict)

In [5]:
am = AttentionModel(128, 128, problem)
set_decode_type(am, 'greedy')
cost, _, pi= am(input_dict, return_pi=True)

In [6]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False)

d = next(iter(dataloader))

costs_kool = problem.get_costs(d, pi)[0]
print(costs_kool)

tensor([23.5963, 39.9628, 35.2444, 40.9368, 23.6237, 22.1476, 20.5344, 26.2222,
        20.1624, 37.1704, 29.1447, 27.3736, 28.7453, 39.7343, 30.4520, 37.6757])


## Test one by one

In [8]:
# Kool
state_kool = StateCVRP.initialize(input_dict)

# Ours
cvrpenv = CVRPEnv(num_loc=N)

new_td = TensorDict(
    {
        "depot": input_dict['depot'],
        "locs": input_dict['loc'],
        "demand": input_dict['demand'] * 20,
        "capacity": torch.full((B,), 20),
    },
    batch_size=B,
)

td = cvrpenv.reset(new_td)

our_demands = []
kool_demands = []

our_mask = []
kool_mask = []
kool_used_capacity = []

our_dones = []


assert torch.allclose(td['demand'][...,1:], input_dict['demand']), "Demand should be the same"
assert torch.allclose(td['locs'][...,1:,:], input_dict['loc']), "Locs should be the same"
assert torch.allclose(td['depot'], input_dict['depot']), "Depot should be the same"

for p in pi.T:

    # Our step
    td.set("action", p)
    # print(p)
    # get td['action_mask"] index of non-zero elements
    # print((~td['action_mask']).nonzero())

    td = cvrpenv.step(td)["next"]

    our_demands.append(td['demand'].clone())
    our_mask.append(td['action_mask'].clone())
    our_dones.append(td['done'])
    

    # Kool step
    state_kool = state_kool.update(p)
    kool_demands.append(state_kool.demand)
    kool_mask.append(~state_kool.get_mask().squeeze(1)) # negation because of how the mask is defined
    kool_used_capacity.append(state_kool.used_capacity)

    # print("demand", td['demand'][...,1:])
    # print("koool", state_kool.demand)

# assert(td["done"].all())

In [9]:
i = 0
for our_m, kool_m in zip(our_mask, kool_mask):
    print("Step", i)
    i += 1

    print(torch.allclose(our_m, kool_m))

    if not torch.allclose(our_m, kool_m):
        print("PROBLEM DETECTED")

        print("our", our_m)
        print("kool", kool_m)

        print("ours")
        print(our_demands[i][:, 0]) # depot
        print("kool")
        print(kool_used_capacity[i])
        break


Step 0
True
Step 1
True
Step 2
True
Step 3
True
Step 4
True
Step 5
False
PROBLEM DETECTED
our tensor([[ True, False, False,  True,  True,  True, False,  True, False,  True,
          True,  True,  True, False,  True,  True, False,  True,  True,  True,
         False,  True,  True,  True,  True,  True,  True,  True, False,  True,
         False, False, False,  True, False, False,  True, False,  True,  True,
         False,  True, False,  True, False, False, False, False,  True,  True,
          True],
        [False,  True,  True,  True,  True,  True,  True,  True, False,  True,
          True,  True,  True,  True,  True,  True, False,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [ True,  True,  True,  True, False,  True,  True,  True, False,  True,
   

In [10]:
for i, (our_d, kool_d) in enumerate(zip(our_demands, kool_demands)):

    if not (our_d[:, 1:].all() == kool_d.all()):
        # print diff
        print(f"Diff at {i}")
        print(our_d[:, 1:] - kool_d)

Diff at 0
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0750,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1250,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,

In [11]:
locs = td["locs"]
depot = locs[..., 0:1, :]
loc_gathered = torch.cat([depot, gather_by_index(locs, pi)], dim=1)
loc_gathered_next = torch.roll(loc_gathered, 1, dims=1)
dist = ((loc_gathered_next - loc_gathered).norm(p=2, dim=2).sum(1))
print(dist)

tensor([23.5963, 39.9627, 35.2444, 40.9368, 23.6237, 22.1476, 20.5344, 26.2222,
        20.1624, 37.1704, 29.1447, 27.3736, 28.7453, 39.7343, 30.4520, 37.6757])


In [12]:
print(dist)
print(costs_kool)

tensor([23.5963, 39.9627, 35.2444, 40.9368, 23.6237, 22.1476, 20.5344, 26.2222,
        20.1624, 37.1704, 29.1447, 27.3736, 28.7453, 39.7343, 30.4520, 37.6757])
tensor([23.5963, 39.9628, 35.2444, 40.9368, 23.6237, 22.1476, 20.5344, 26.2222,
        20.1624, 37.1704, 29.1447, 27.3736, 28.7453, 39.7343, 30.4520, 37.6757])


In [13]:
print(f'AM   cost: {cost}')
print(f'Ours (JY): {cost}')
print(f'Ours (CB): {-cvrpenv.get_reward(td, pi)}')

AM   cost: tensor([23.5963, 39.9628, 35.2444, 40.9368, 23.6237, 22.1476, 20.5344, 26.2222,
        20.1624, 37.1704, 29.1447, 27.3736, 28.7453, 39.7343, 30.4520, 37.6757])
Ours (JY): tensor([23.5963, 39.9628, 35.2444, 40.9368, 23.6237, 22.1476, 20.5344, 26.2222,
        20.1624, 37.1704, 29.1447, 27.3736, 28.7453, 39.7343, 30.4520, 37.6757])
Ours (CB): tensor([23.5963, 39.9627, 35.2444, 40.9368, 23.6237, 22.1476, 20.5344, 26.2222,
        20.1624, 37.1704, 29.1447, 27.3736, 28.7453, 39.7343, 30.4520, 37.6757])


In [14]:
print("Difference between cost and get_reward: \n", cost + cvrpenv.get_reward(td, pi))

Difference between cost and get_reward: 
 tensor([ 0.0000e+00,  3.8147e-06,  3.8147e-06,  0.0000e+00, -1.9073e-06,
         1.9073e-06, -1.9073e-06,  0.0000e+00,  1.9073e-06, -3.8147e-06,
         0.0000e+00,  0.0000e+00, -1.9073e-06,  0.0000e+00,  0.0000e+00,
         0.0000e+00])


In [15]:
print("Difference between cost and get_reward: \n", cvrpenv.get_reward(td, pi) + costs_kool)

Difference between cost and get_reward: 
 tensor([ 0.0000e+00,  3.8147e-06,  3.8147e-06,  0.0000e+00, -1.9073e-06,
         1.9073e-06, -1.9073e-06,  0.0000e+00,  1.9073e-06, -3.8147e-06,
         0.0000e+00,  0.0000e+00, -1.9073e-06,  0.0000e+00,  0.0000e+00,
         0.0000e+00])


In [16]:
assert torch.allclose(cost,dist)
assert torch.allclose(cost,costs_kool)
assert torch.allclose(costs_kool,-cvrpenv.get_reward(td, pi))
assert torch.allclose(cost,-cvrpenv.get_reward(td, pi))