In [1]:
from config import EnvConfig
from environments.poachers import PoachersEnv, PoachersMap
import torch

NUM_NODES = 10
NUM_STEPS = 20
SEED = 42

config = EnvConfig(
    num_nodes=NUM_NODES,
    num_steps=NUM_STEPS,
    seed=SEED,
    env_name="poachers",
)
device = torch.device("cuda:0")

env_map = PoachersMap(config, device)
env = PoachersEnv(config, env_map, device)
env.reset()

TensorDict(
    fields={
        actions_mask: Tensor(shape=torch.Size([2, 7]), device=cuda:0, dtype=torch.bool, is_shared=True),
        actions_seq: Tensor(shape=torch.Size([2, 21]), device=cuda:0, dtype=torch.int32, is_shared=True),
        available_moves: Tensor(shape=torch.Size([2, 21, 4]), device=cuda:0, dtype=torch.int32, is_shared=True),
        done: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        game_id: Tensor(shape=torch.Size([16]), device=cuda:0, dtype=torch.uint8, is_shared=True),
        graph_edge_index: Tensor(shape=torch.Size([2, 40]), device=cuda:0, dtype=torch.int64, is_shared=True),
        graph_x: Tensor(shape=torch.Size([2, 10, 3]), device=cuda:0, dtype=torch.float32, is_shared=True),
        node_reward_info: Tensor(shape=torch.Size([2, 21, 2]), device=cuda:0, dtype=torch.int32, is_shared=True),
        position_seq: Tensor(shape=torch.Size([2, 21]), device=cuda:0, dtype=torch.int64, is_shared=True),
        step_count:

In [2]:
def distances_to_nearest_reward(curr_env):
    final_distances = torch.full((curr_env.map.num_nodes,), float('inf'), dtype=torch.float32, device=curr_env.device)

    nodes = torch.where(~curr_env.nodes_collected & curr_env.map.reward_nodes)[0].tolist()
    distances = [0] * len(nodes)
    visited: set[int] = set()

    while nodes:
        current_node = nodes.pop(0)
        distance = distances.pop(0)
        if current_node in visited or distance >= final_distances[current_node].item():
            continue
        visited.add(current_node)
        final_distances[current_node] = distance

        neighbors = curr_env.map.get_neighbors(
            torch.tensor([current_node], dtype=torch.int32, device=curr_env.device)).squeeze(0).cpu().tolist()
        for neighbor in neighbors:
            if neighbor != -1 and neighbor not in visited:
                distances.append(distance + 1)
                nodes.append(neighbor)

    return final_distances

In [3]:
import tqdm

generated_data = []
targets = []
positions = []

generator = torch.Generator().manual_seed(500)
for i in tqdm.tqdm(range(5000)):
    #if torch.rand(torch.Size(()), generator=generator).item() < 0.5:
    #    position = 3 #(torch.where(env_map.reward_nodes)[0])[torch.randint(0, torch.sum(env_map.entry_nodes), torch.Size(()), generator=generator).item()]
    #else:
    position = torch.randint(0, NUM_NODES, torch.Size(()), generator=generator).item()
    positions.append(position)
    env.position = torch.tensor([(position+1) % NUM_NODES, position], dtype=torch.int64, device=device)
    track_value = torch.full((2, NUM_NODES), -NUM_STEPS, dtype=torch.int32, device=device)
    if torch.rand(torch.Size(()), generator=generator).item() < 0.2:
        track_value[1, position] = torch.randint(0, NUM_STEPS+1, torch.Size(()), dtype=torch.int32, generator=generator).item()
    
    env.nodes_prepared = torch.zeros_like(env.nodes_prepared, dtype=torch.bool, device=device)
    env.nodes_collected = torch.zeros_like(env.nodes_collected, dtype=torch.bool, device=device)
    if env_map.reward_nodes[position].item() and not env.nodes_collected[position].item():
        env.nodes_prepared[position] = torch.rand(torch.Size(()), generator=generator).item() < 0.5
        env.nodes_collected[position] = torch.rand(torch.Size(()), generator=generator).item() < 0.5
    
    generated_data.append(env._get_graph_x(track_value))
    
    neighbors = env_map.get_neighbors(torch.tensor([position], dtype=torch.int64, device=device)).squeeze(0)
    distances = torch.where(neighbors != -1, distances_to_nearest_reward(env)[neighbors], float("inf"))
    #distances = distances_to_nearest_reward(env)
    #targets.append(distances[position].item()) # env.map.reward_nodes[position].float().item())
    targets.append(distances)
    # if ~env.nodes_collected[position] & env_map.reward_nodes[position]:
    #     if env.nodes_prepared[position].item():
    #         targets.append(6)
    #     else:
    #         targets.append(5)
    # else:
    #     min_dist = torch.min(distances).item()
    #     min_idxs = torch.where(distances == min_dist)[0]
    #     targets.append(min_idxs[0].cpu().item())
    #     #targets.append(min_idxs[torch.randint(0, min_idxs.numel(), torch.Size(()), generator=generator).item()].cpu().item())

100%|██████████| 5000/5000 [00:20<00:00, 245.37it/s]


In [4]:
data_X = torch.stack(generated_data, dim=0)
# data_y = torch.tensor(targets, dtype=torch.int64, device=device)
data_y = torch.stack(targets, dim=0)
#data_y = torch.tensor(targets, dtype=torch.float32, device=device)
data_position = torch.tensor(positions, dtype=torch.int64, device=device)

In [5]:
mask = data_y == 0
idxs = torch.where(mask)[0]
leave_idxs = torch.cat([idxs[torch.randperm(idxs.numel(), generator=generator)[:900]], torch.where(~mask)[0]], dim=0)
data_X = data_X[leave_idxs]
data_y = data_y[leave_idxs]
data_position = data_position[leave_idxs]

In [6]:
mask = data_y == 1
idxs = torch.where(mask)[0]
leave_idxs = torch.cat([idxs[torch.randperm(idxs.numel(), generator=generator)[:900]], torch.where(~mask)[0]], dim=0)
data_X = data_X[leave_idxs]
data_y = data_y[leave_idxs]
data_position = data_position[leave_idxs]

In [7]:
mask = data_y == 5
idxs = torch.where(mask)[0]
leave_idxs = torch.cat([idxs[torch.randperm(idxs.numel(), generator=generator)[:900]], torch.where(~mask)[0]], dim=0)
data_X = data_X[leave_idxs]
data_y = data_y[leave_idxs]
data_position = data_position[leave_idxs]

In [8]:
mask = data_y == 6
idxs = torch.where(mask)[0]
leave_idxs = torch.cat([idxs[torch.randperm(idxs.numel(), generator=generator)[:900]], torch.where(~mask)[0]], dim=0)
data_X = data_X[leave_idxs]
data_y = data_y[leave_idxs]
data_position = data_position[leave_idxs]

In [5]:
mask = data_y == 0
idxs = torch.where(mask)[0]
leave_idxs = torch.cat([idxs[torch.randperm(idxs.numel(), generator=generator)[:900]], torch.where(~mask)[0]], dim=0)
data_X = data_X[leave_idxs]
data_y = data_y[leave_idxs]
data_position = data_position[leave_idxs]

mask = data_y == 1
idxs = torch.where(mask)[0]
leave_idxs = torch.cat([idxs[torch.randperm(idxs.numel(), generator=generator)[:900]], torch.where(~mask)[0]], dim=0)
data_X = data_X[leave_idxs]
data_y = data_y[leave_idxs]
data_position = data_position[leave_idxs]
# 
# mask = data_y == 2
# idxs = torch.where(mask)[0]
# leave_idxs = torch.cat([idxs[torch.randperm(idxs.numel(), generator=generator)[:200]], torch.where(~mask)[0]], dim=0)
# data_X = data_X[leave_idxs]
# data_y = data_y[leave_idxs]
# data_position = data_position[leave_idxs]
# 
# mask = data_y == 3
# idxs = torch.where(mask)[0]
# leave_idxs = torch.cat([idxs[torch.randperm(idxs.numel(), generator=generator)[:200]], torch.where(~mask)[0]], dim=0)
# data_X = data_X[leave_idxs]
# data_y = data_y[leave_idxs]
# data_position = data_position[leave_idxs]

In [5]:
from torchrl.data import Bounded

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class DistanceGNN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels):
        super(DistanceGNN, self).__init__()
        # Input layer
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        # Hidden layer
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        # Output layer - predicts a single value (distance)
        #self.out = torch.nn.Linear(hidden_channels, 128)
        self.out2 = torch.nn.Linear(128, 4)

    def forward(self, x):#, edge_index, pos):
        # 1. First GCN layer + non-linearity
        # batch_size = x.shape[0]
        # addition = torch.arange(batch_size, device=x.device) * x.shape[1]
        # edge_index = edge_index.clone() + addition.reshape(-1, 1, 1)
        # edge_index = edge_index.transpose(1, 0).reshape(2, -1)
        # pos_mod = pos.clone() + addition
        # x = x.clone().reshape(-1, *x.shape[2:])
        # 
        # x = self.conv1(x, edge_index)
        # x = F.relu(x)
        # 
        # # 2. Second GCN layer + non-linearity
        # x = self.conv2(x, edge_index)
        # x = F.relu(x)
        # poses = env_map.get_neighbors(pos) + addition.reshape(-1, 1)  # Get neighbors of current positions
        # poses = torch.cat([poses, pos_mod.unsqueeze(1)], dim=1)  # Add current position to neighbors
        # x = x[poses].reshape(batch_size, -1)
        #x = self.lin1(x)
        #x = F.relu(x)
        #x = self.lin2(x)
        # 3. Readout/Output layer to get the final regression value
        #x = self.out(x)
        #x = F.relu(x)
        x = self.out2(x)
        
        return x # Remove last dimension to get a vector of N distances

from algorithms.simple_nn import GNNBackbone, ActorHead
from algorithms.keys_processors import CombinedExtractor

# attacker_extractor = CombinedExtractor(player_type=1, env=env, actions=[])
# backbone = GNNBackbone(
#     extractor=attacker_extractor,
#     embedding_size=32,
#     hidden_size=32,
# ).to(device)
# 
# regression = torch.nn.Sequential(
#     torch.nn.Linear(32, 1),
# ).to(device)
# 
# model = torch.nn.Sequential(
#     backbone,
#     regression,
# ).to(device)

model1 = GNNBackbone(
    extractor=CombinedExtractor(player_type=1, env=env, actions=[]),
    embedding_size=128,
    hidden_size=32,
)
model2 = DistanceGNN(3, 128).to(device)
model = torch.nn.Sequential(
    model1,
    model2,
).to(device)

# action_spec = Bounded(
#     shape=torch.Size((1,)),
#     low=0,
#     high=env.action_size - 1,
#     dtype=torch.int32,
# )
# 
# actor = ActorHead(
#     embedding_size=32,
#     player_type=1,
#     device=device,
#     action_spec=action_spec,
#     hidden_size=64,
# )

# model = torch.nn.Sequential(
#     backbone,
#     actor,
# )

In [9]:
(data_y == 2).sum()

tensor(798, device='cuda:0')

In [6]:
loss = torch.nn.MSELoss()
#loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

In [19]:
(data_y == 0).sum()

tensor(900, device='cuda:0')

In [25]:
data_X.shape

torch.Size([1800, 2, 10, 1])

In [7]:
batch_size = 5000
for epoch in range(100):
    losses = []
    for i in range(0, len(data_X), batch_size):
        edges = env_map.edge_index.repeat(min(batch_size, len(data_X) - i), 1, 1)
        #edges = env_map.edge_index
        #pos = data_position[i:i+batch_size]
        batch_X = data_X[i:i+batch_size]#[i:i+batch_size]
        batch_y = data_y[i:i+batch_size]
        pos = data_position[i:i+batch_size] # .reshape(-1, 1, 1).repeat(1, 2, 1)

        optimizer.zero_grad()
        output = model1(
            batch_X, 
            edges, 
            pos.reshape(-1, 1, 1).repeat(1, 2, 1), 
            env_map.get_neighbors(pos).reshape(-1, 1, 4).repeat(1, 2, 1),
        )
        output = model2(output)
        #output = regression(output)
        #output = actor(output, torch.ones((2, 7), dtype=torch.bool, device=device))
        loss_value = loss(output, batch_y)
        losses.append(loss_value.item())
        loss_value.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {sum(losses) / len(losses)}")

Epoch 1, Loss: 1.846469521522522
Epoch 2, Loss: 1.559464931488037
Epoch 3, Loss: 1.2004756927490234
Epoch 4, Loss: 0.7113326191902161
Epoch 5, Loss: 0.3671416938304901
Epoch 6, Loss: 0.974967360496521
Epoch 7, Loss: 0.6127938628196716
Epoch 8, Loss: 0.3742115795612335
Epoch 9, Loss: 0.4210272431373596
Epoch 10, Loss: 0.5304561257362366
Epoch 11, Loss: 0.5923111438751221
Epoch 12, Loss: 0.5936586856842041
Epoch 13, Loss: 0.5489507913589478
Epoch 14, Loss: 0.47851723432540894
Epoch 15, Loss: 0.407805860042572
Epoch 16, Loss: 0.36903926730155945
Epoch 17, Loss: 0.3842388093471527
Epoch 18, Loss: 0.43329960107803345
Epoch 19, Loss: 0.4562920331954956
Epoch 20, Loss: 0.4280635714530945
Epoch 21, Loss: 0.3846243619918823
Epoch 22, Loss: 0.36360979080200195
Epoch 23, Loss: 0.369406133890152
Epoch 24, Loss: 0.3867608606815338
Epoch 25, Loss: 0.40048471093177795
Epoch 26, Loss: 0.40288540720939636
Epoch 27, Loss: 0.3935428261756897
Epoch 28, Loss: 0.3773391544818878
Epoch 29, Loss: 0.3621900677

In [8]:
torch.save(model1.state_dict(), "test.pth")

In [38]:
torch.sum(data_position == 3)

tensor(2699, device='cuda:0')

In [10]:
data_y[data_position == 4]

tensor([0., 0., 1.,  ..., 0., 0., 0.], device='cuda:0')

In [13]:
data_X[data_position == 4, 1, 4]

tensor([[1., 1., 0.],
        [1., 0., 0.],
        [1., 0., 1.],
        ...,
        [1., 1., 0.],
        [1., 1., 0.],
        [1., 0., 0.]], device='cuda:0')

In [9]:
torch.sum(data_y == 1)

tensor(695, device='cuda:0')

In [17]:
env._get_graph_x(torch.full((2, NUM_NODES), -NUM_STEPS, dtype=torch.int32, device=device))[:1].shape

torch.Size([1, 10, 1])

In [23]:
data_X[0, 1]

tensor([[0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')

In [19]:
env._get_graph_x(torch.full((2, NUM_NODES), -NUM_STEPS, dtype=torch.int32, device=device))[1:]

tensor([[[ 0., -1., -1.],
         [ 0., -1., -1.],
         [ 1., -1., -1.],
         [ 0., -1., -1.],
         [ 1., -1., -1.],
         [ 0.,  0.,  0.],
         [ 0., -1., -1.],
         [ 0., -1., -1.],
         [ 0., -1., -1.],
         [ 0., -1., -1.]]], device='cuda:0')

In [12]:
pos = torch.tensor([[[0], [2]]], dtype=torch.int64, device=device)
print(env_map.get_neighbors(pos[0][1]).squeeze(0))
env.nodes_prepared = torch.zeros_like(env.nodes_prepared, dtype=torch.bool, device=device)
#env.nodes_prepared[pos[0][1][0]] = True
env.nodes_collected = torch.zeros_like(env.nodes_collected, dtype=torch.bool, device=device)
env.nodes_collected[4] = False
env.position = pos.reshape(2)
with torch.no_grad():
    print(model(env._get_graph_x(torch.full((2, NUM_NODES), -NUM_STEPS, dtype=torch.int32, device=device))[1:], env_map.edge_index, pos[0, 1]))
    #embeddings = backbone(env._get_graph_x(torch.full((2, NUM_NODES), -NUM_STEPS, dtype=torch.int32, device=device)), env_map.edge_index, pos)
    #print(torch.softmax(actor(embeddings, torch.ones((2, 7), dtype=torch.bool, device=device)).cpu(), dim=1))
    #print(regression(embeddings))

tensor([0, 1, 3, 4], device='cuda:0')
tensor([[1.0478, 1.3608, 0.7249, 0.0132]], device='cuda:0')


In [35]:
torch.where(torch.arange(10, device=device) == env.position[1].item(), env.nodes_prepared.float(), -1)

tensor([-1., -1.,  1., -1., -1., -1., -1., -1., -1., -1.], device='cuda:0')

In [4]:
pos = torch.tensor([[[0], [2]]], dtype=torch.int64, device=device)
env.position = pos.reshape(2)

In [29]:
env.nodes_prepared[2] = True
env.nodes_prepared

tensor([False, False,  True, False, False, False, False, False, False, False],
       device='cuda:0')

In [30]:
env._get_graph_x(torch.full((2, NUM_NODES), -NUM_STEPS, dtype=torch.int32, device=device))[1,2]

tensor([1., 1., 1.], device='cuda:0')

In [3]:
env.nodes_prepared

tensor([False, False, False, False, False, False, False, False, False, False],
       device='cuda:0')

In [2]:
env_map.reward_nodes

NameError: name 'env_map' is not defined