In [37]:
import torch
from torch_geometric.loader import DataLoader

In [31]:
from thesis_floor_halkes.batch.graph_data_batch import (
    GraphGenerator,
    RandomGraphDataset,
)

In [32]:
dataset = RandomGraphDataset(
    num_graphs=50,
    min_nodes=10,
    max_nodes=35,
    min_prob=0.1,
    max_prob=0.5,
    max_wait=30,
    min_length=100.0,
    max_length=1000.0,
    min_speed=30.0,
    max_speed=100.0,
)

In [34]:
dataset[0]

Data(x=[17, 3], edge_index=[2, 82], edge_attr=[82, 2], graph_id=0, start_node=12, end_node=5)

In [17]:
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

In [21]:
for data in dataloader:
    print(data)

DataBatch(x=[231, 3], edge_index=[2, 1230], edge_attr=[1230, 2], graph_id=[10], start_node=[10], end_node=[10], batch=[231], ptr=[11])
DataBatch(x=[256, 3], edge_index=[2, 1934], edge_attr=[1934, 2], graph_id=[10], start_node=[10], end_node=[10], batch=[256], ptr=[11])
DataBatch(x=[210, 3], edge_index=[2, 1236], edge_attr=[1236, 2], graph_id=[10], start_node=[10], end_node=[10], batch=[210], ptr=[11])
DataBatch(x=[203, 3], edge_index=[2, 1252], edge_attr=[1252, 2], graph_id=[10], start_node=[10], end_node=[10], batch=[203], ptr=[11])
DataBatch(x=[238, 3], edge_index=[2, 2128], edge_attr=[2128, 2], graph_id=[10], start_node=[10], end_node=[10], batch=[238], ptr=[11])


In [22]:
data.batch

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])

In [5]:
import torch

In [12]:
start_node = torch.randint(0, 10, (1,)).item()
end_node = torch.randint(0, 10, (1,)).item()
print(start_node, end_node)
while start_node == end_node:
    end_node = torch.randint(0, 10, (1,)).item()

print(start_node, end_node)

6 1
6 1


In [None]:
from thesis_floor_halkes.model_dynamic_attention import (
    DynamicGATEncoder,
    GATModelEncoderStatic,
    AttentionDecoderChat,
)
from thesis_floor_halkes.environment.dynamic_ambulance import DynamicEnvironment
from thesis_floor_halkes.batch.graph_data_batch import (
    GraphGenerator,
    RandomGraphPytorchDataset,
)
from thesis_floor_halkes.features.dynamic.getter import (
    DynamicFeatureGetter,
    RandomDynamicFeatureGetter,
)
from thesis_floor_halkes.penalties.calculator import PenaltyCalculator
from thesis_floor_halkes.utils.adj_matrix import build_adjecency_matrix

data = GraphGenerator(
    num_nodes=15,
    edge_prob=0.5,
    max_wait=10.0,
).generate()

dataset = RandomGraphPytorchDataset(
    num_graphs=2,
    min_nodes=5,
    max_nodes=5,
    min_prob=1,
    max_prob=1,
)

env = DynamicEnvironment(
    static_dataset=dataset,
    dynamic_feature_getter=RandomDynamicFeatureGetter(),
    penalty_calculator=PenaltyCalculator,
    max_steps=30,
)
hidden_size = 64

static_encoder = GATModelEncoderStatic(
    in_channels=1, hidden_size=hidden_size, edge_attr_dim=2
)
dynamic_encoder = DynamicGATEncoder(in_channels=2, hidden_size=hidden_size)
decoder = AttentionDecoderChat(embed_dim=hidden_size * 2, num_heads=4)


def embed_graph(data, type="static"):
    if type == "static":
        return static_encoder(data)
    elif type == "dynamic":
        return dynamic_encoder(data)
    else:
        raise ValueError("Invalid graph type. Use 'static' or 'dynamic'.")


def select_action(state):
    # if self.static_context is None:
    #     static_embedding = self._embed_graph(state.static_data, type="static")
    # else:
    #     static_embedding = self.static_context
    # dynamic_embedding = self._embed_graph(state.dynamic_data, type="dynamic")
    # final_embedding = torch.cat((static_embedding, dynamic_embedding), dim=1)
    # action, action_log_prob = self.decoder(final_embedding)
    static_embedding = embed_graph(state.static_data, type="static")
    dynamic_embedding = embed_graph(state.dynamic_data, type="dynamic")
    final_embedding = torch.cat(
        (static_embedding, dynamic_embedding), dim=1
    )  # overwegen om naar + ipv cat te doen
    action, action_log_prob, _ = AttentionDecoderChat(
        final_embedding, current_node=current_node, invalid_action_mask=visited
    )

    return action, action_log_prob

 action is None
visited_nodes = [4]


In [15]:
for graph in dataset:
    state = env.reset()
    for step in range(2):
        action, action_log_prob = select_action(state)
        print(action, action_log_prob)

 action is None
visited_nodes = [1]


TypeError: select_action() missing 1 required positional argument: 'state'

In [21]:
num_nodes = 15
valid_actions = torch.arange(num_nodes)  # Example: all nodes are valid actions
action_mask = torch.ones(num_nodes, dtype=torch.bool)
action_mask[valid_actions] = 1

print(action_mask)
print(action_mask.unsqueeze(0))
print(action_mask.shape)
print(action_mask.unsqueeze(0).shape)

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True])
tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True]])
torch.Size([15])
torch.Size([1, 15])


In [2]:
emb = [{"node": 3, "embedding": 5}, {"node": 4, "embedding": 10}]

In [3]:
emb["embedding"]

TypeError: list indices must be integers or slices, not str