In [1]:
import numpy as np
import torch

from torch_geometric.loader import DataLoader
from torch_scatter import scatter_max
from tqdm import tqdm

from geometric_governance.util import RangeOrValue, get_value
from geometric_governance.data import (
    generate_synthetic_election,
    get_scoring_function_winners,
)
from geometric_governance.model import MessagePassingElectionModel

NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES_RANGE = (2, 10)
DATASET_SIZE = 100_000
# DATASET_SIZE = 100
TRAIN_BATCH_SIZE = 128
TRAIN_NUM_EPOCHS = 20


def generate_plurality_bipartite_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates_range: RangeOrValue,
    dataloader_batch_size: int,
    top_k_candidates: int,
    rng: np.random.Generator,
):
    graphs = []
    for _ in tqdm(range(dataset_size)):
        num_voters = get_value(num_voters_range, rng)
        num_candidates = get_value(num_candidates_range, rng)

        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )

        graph = election_data.to_bipartite_graph(top_k_candidates, vote_data="ranking")

        #  Add plurality votes
        votes = graph.edge_index[1, (graph.edge_attr == 1).squeeze()]
        votes -= num_voters  # Index offset
        vote_count = torch.bincount(votes, minlength=num_candidates)

        winners = get_scoring_function_winners(vote_count)

        graph.y = winners
        graphs.append(graph)
    dataloader = DataLoader(graphs, batch_size=dataloader_batch_size, shuffle=True)
    return dataloader


dataloader = generate_plurality_bipartite_dataset(
    dataset_size=DATASET_SIZE,
    num_voters_range=NUM_VOTERS_RANGE,
    num_candidates_range=NUM_CANDIDATES_RANGE,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=42),
)

100%|██████████| 100000/100000 [00:59<00:00, 1692.39it/s]


In [2]:
model = MessagePassingElectionModel(edge_dim=1)
optim = torch.optim.Adam(model.parameters())

for epoch in range(TRAIN_NUM_EPOCHS):
    total_loss = 0
    for data in dataloader:
        optim.zero_grad()
        out = model(data)
        loss = -(out * data.y).sum() / TRAIN_BATCH_SIZE
        loss.backward()
        optim.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} Loss {total_loss / len(dataloader)}")

Epoch 0 Loss 0.3541716648184735
Epoch 1 Loss 0.24235563782398659
Epoch 2 Loss 0.21295064097017888
Epoch 3 Loss 0.19801104890749507
Epoch 4 Loss 0.1952277642515157
Epoch 5 Loss 0.19525823213846025
Epoch 6 Loss 0.19314134979854003
Epoch 7 Loss 0.1927161123937048
Epoch 8 Loss 0.19230390915556636
Epoch 9 Loss 0.19221597720804573
Epoch 10 Loss 0.191641922487551
Epoch 11 Loss 0.19172178432726494
Epoch 12 Loss 0.19165738712033958
Epoch 13 Loss 0.19126175319694955
Epoch 14 Loss 0.19147559205341674
Epoch 15 Loss 0.19116606478534087
Epoch 16 Loss 0.1911609638863436
Epoch 17 Loss 0.19116424377102528
Epoch 18 Loss 0.19112650949574644
Epoch 19 Loss 0.1911894856663921


In [3]:
EVAL_NUM_VOTERS = 100
EVAL_NUM_CANDIDATES = 20
EVAL_DATASET_SIZE = 1_000

eval_dataloader = generate_plurality_bipartite_dataset(
    dataset_size=EVAL_DATASET_SIZE,
    num_voters_range=EVAL_NUM_VOTERS,
    num_candidates_range=EVAL_NUM_CANDIDATES,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=16180),
)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in eval_dataloader:
        out = model(data)
        candidate_idxs = data.x[:, 1] == 1
        batch_idxs = data.batch[candidate_idxs]
        _, predicted = scatter_max(out, batch_idxs)
        _, predicted_ground = scatter_max(data.y, batch_idxs)
        total += predicted_ground.shape[0]
        correct += (predicted == predicted_ground).sum().item()

accuracy = correct / total
print(f"Accuracy: {accuracy} with {EVAL_NUM_VOTERS} voters")

100%|██████████| 1000/1000 [00:01<00:00, 524.64it/s]


Accuracy: 0.846 with 100 voters
