In [9]:
# Import packages
import torch
from problems.pctsp.problem_pctsp import PCTSPDet
from problems.pctsp.state_pctsp import StatePCTSP
from rl4co.envs.pctsp import PCTSPEnv
from tensordict.tensordict import TensorDict
from nets.attention_model import AttentionModel, set_decode_type

In [10]:
openv = PCTSPEnv()
td = openv.reset(batch_size=[32])
print(td)

TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([32, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        current_node: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([32, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        penalty: Tensor(shape=torch.Size([32, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        prize: Tensor(shape=torch.Size([32, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        prize_collect: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        prize_require: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([32]),
    device=cpu,
    is_shared=False)


In [11]:
# Init problems for ours and kool
B = 16
N = 50

# SECTION Kool's
problem = PCTSPDet()
dataset = problem.make_dataset(num_samples=B, size=N)

# SECTION Ours
openv = PCTSPEnv()

In [12]:
# Create dataset
# Collect oritinal data
input_dict = {}
depots, locs, penalty, deterministic_prize, stochastic_prize = [], [], [], [], []
for i in range(len(dataset)):
    ret = dataset[i]
    depots.append(ret['depot'])
    locs.append(ret['loc'])
    penalty.append(ret['penalty'])
    deterministic_prize.append(ret['deterministic_prize'])
    stochastic_prize.append(ret['stochastic_prize'])

input_dict['depot'] = torch.stack(depots, dim=0)
input_dict['loc'] = torch.stack(locs, dim=0)
input_dict['penalty'] = torch.stack(penalty, dim=0)
input_dict['deterministic_prize'] = torch.stack(deterministic_prize, dim=0)
input_dict['stochastic_prize'] = torch.stack(stochastic_prize, dim=0)

print(input_dict['loc'].size())
print(input_dict['depot'].size())
print(input_dict['penalty'].size())
print(input_dict['deterministic_prize'].size())
print(input_dict['stochastic_prize'].size())

# Create loc with depot
# loc_with_depot = torch.cat((input_dict['depot'][:, None, :], input_dict['loc']), -2)
# print(loc_with_depot.size())

# Create 
print('------')

deterministic_prize_with_depot = torch.cat((
    torch.zeros_like(input_dict['deterministic_prize'][:, :1]), 
    input_dict['deterministic_prize']), -1)
print(deterministic_prize_with_depot.size())

penalty_with_depot = torch.cat((torch.zeros_like(input_dict['penalty'][:, :1]), input_dict['penalty']), -1)
print(penalty_with_depot.size())


torch.Size([16, 50, 2])
torch.Size([16, 2])
torch.Size([16, 50])
torch.Size([16, 50])
torch.Size([16, 50])
------
torch.Size([16, 51])
torch.Size([16, 51])


In [13]:
# Init data
# Kool
state_kool = StatePCTSP.initialize(input_dict)

# Ours
openv = PCTSPEnv(num_loc=N)
new_td = TensorDict(
    {
        "observation": input_dict['loc'],
        "depot": input_dict['depot'],
        "prize": deterministic_prize_with_depot,
        "penalty": penalty_with_depot,
    },
    batch_size=B,
)
td = openv.reset(new_td)

assert torch.allclose(td['observation'][..., 1:, :], input_dict['loc']), "Locs should be the same"
assert torch.allclose(td['depot'], input_dict['depot']), "Depot should be the same"
assert torch.allclose(td['penalty'][..., 1:], input_dict['penalty']), "Demand should be the same"
assert torch.allclose(td['prize'][..., 1:], input_dict['deterministic_prize']), "Demand should be the same"

In [15]:
# Run the model to get actions and mask
# Init the model
pctsp = AttentionModel(128, 128, problem)
set_decode_type(pctsp, 'greedy')
cost, _, pi= pctsp(input_dict, return_pi=True)

# Record demands
our_prize_collect = []; kool_prize_collect = []
# Record mask
our_mask = []; kool_mask = []
# Record used capacity
our_dones = []

for p in pi.T:
    # Our step
    td.set("action", p)
    td = openv.step(td)["next"]

    our_prize_collect.append(td['prize_collect'].clone())
    our_mask.append(td['action_mask'].clone().squeeze(1))

    # Kool step
    state_kool = state_kool.update(p)

    kool_prize_collect.append(state_kool.cur_total_prize)
    kool_mask.append(state_kool.get_mask().squeeze(1)) # negation because of how the mask is defined

tensor([[ True, False, False, False, False, False, False, False, False, False,
         False, False, False, False,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False]])
tensor([[ True, False, False, False, False,  True, False, False, False, False,
         False, False, False, False,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False]])
tensor([[ True, False, False, False, False,  True, False, False, False, False,
         False, False, False, False,  True, False, False, False, False, False,
         False, 

In [17]:
print("--- Prize ---")
print(our_prize_collect[0].size())
print(kool_prize_collect[0].size())

print("--- Mask ---")
print(our_mask[0].size())
print(kool_mask[0].size())

# print("--- Dones ---")
# print(our_dones[0].size())

--- Prize ---
torch.Size([16, 1])
torch.Size([16, 1])
--- Mask ---
torch.Size([16, 51])
torch.Size([16, 51])


In [18]:
# Check the prize
for i, (our_p, kool_p) in enumerate(zip(our_prize_collect, kool_prize_collect)):
    if not torch.allclose(our_p, kool_p):
        print(f"Prize diff at {i}")
        print(our_p)
        print(kool_p)
        break

# Check the mask
for i, (our_m, kool_m) in enumerate(zip(our_mask, kool_mask)):
    our_m = ~our_m
    if not torch.allclose(our_m, kool_m):
        print(f"Mask diff at {i}")
        print(~our_m)
        print(kool_m)
        break

# Check the used capacity
# for i, (our_c, kool_c) in enumerate(zip(our_used_capacity, kool_used_capacity)):
#     if not torch.allclose(our_c, kool_c):
#         print(f"Used capacity diff at {i}")
#         print(our_c)
#         print(kool_c)
#         break

print('PASS')

PASS
