In [30]:
import numpy as np
import torch
from tqdm import tqdm
from geometric_governance.model import DeepSetStrategyModel
from geometric_governance.util import get_value

device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")

In [31]:
BATCH_SIZE = (1, 128)
NUM_EPOCHS = 10_000

In [32]:
model = DeepSetStrategyModel()
model.train()
optim = torch.optim.Adam(model.parameters())

In [33]:
rng = np.random.default_rng(seed=42)

epochs = tqdm(range(NUM_EPOCHS))
for epoch in epochs:
    num_nodes = get_value((3, 50), rng).item()
    batch_size = get_value(BATCH_SIZE, rng).item()
    edge_from = torch.randint(low=0, high=num_nodes, size=(batch_size,))
    edge_to = torch.randint(low=0, high=num_nodes, size=(batch_size,))
    edge_index = torch.stack([edge_from, edge_to])
    edge_attr = torch.normal(mean=torch.zeros(batch_size, 1))
    candidate_idxs = torch.zeros(num_nodes)

    optim.zero_grad()
    out = model(edge_attr, edge_index, candidate_idxs)
    loss = torch.nn.functional.mse_loss(out, edge_attr)
    loss.backward()
    optim.step()

    epochs.set_postfix({
        "loss": loss,
    })

100%|██████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:11<00:00, 905.92it/s, loss=tensor(5.9629e-06, grad_fn=<MseLossBackward0>)]


In [34]:
edge_attr, edge_index, candidate_idxs

(tensor([[-0.2150],
         [-0.9205],
         [-0.8790],
         [-1.1981],
         [-0.9081],
         [-0.8153],
         [ 0.8131],
         [-2.2566],
         [ 1.8169],
         [-0.2819],
         [ 0.9922],
         [-1.4609],
         [ 1.1899],
         [-0.0370],
         [-0.3088],
         [-0.5088],
         [ 0.9589],
         [-0.9516],
         [-0.7196],
         [-0.5934],
         [ 2.5510],
         [ 0.5920],
         [-0.0805],
         [ 0.5214],
         [ 0.7690],
         [-0.9838],
         [-1.2154],
         [-0.4463],
         [ 0.9495],
         [ 0.3588],
         [-1.5629],
         [-0.1639],
         [-0.7884],
         [ 1.2625],
         [ 1.9150],
         [-0.3832],
         [-0.2996],
         [-0.8946],
         [-1.8050],
         [-0.2715],
         [ 0.7061],
         [-0.3571],
         [ 0.8235],
         [ 1.5512],
         [-1.0858],
         [ 0.9513],
         [-0.4984],
         [ 0.7268],
         [-0.0591],
         [ 1.2847],


In [35]:
model(edge_attr, edge_index, candidate_idxs)

tensor([[-0.2146],
        [-0.9191],
        [-0.8782],
        [-1.1970],
        [-0.9092],
        [-0.8147],
        [ 0.8120],
        [-2.2541],
        [ 1.8128],
        [-0.2835],
        [ 0.9891],
        [-1.4591],
        [ 1.1879],
        [-0.0368],
        [-0.3077],
        [-0.5080],
        [ 0.9565],
        [-0.9494],
        [-0.7184],
        [-0.5932],
        [ 2.5476],
        [ 0.5895],
        [-0.0840],
        [ 0.5191],
        [ 0.7681],
        [-0.9822],
        [-1.2138],
        [-0.4468],
        [ 0.9486],
        [ 0.3619],
        [-1.5611],
        [-0.1652],
        [-0.7870],
        [ 1.2597],
        [ 1.9105],
        [-0.3842],
        [-0.2997],
        [-0.8937],
        [-1.8049],
        [-0.2713],
        [ 0.7053],
        [-0.3551],
        [ 0.8199],
        [ 1.5494],
        [-1.0841],
        [ 0.9501],
        [-0.4977],
        [ 0.7267],
        [-0.0595],
        [ 1.2819],
        [ 0.4624],
        [ 0.4901],
        [-1.