In [3]:
import numpy as np
import torch

from torch_geometric.loader import DataLoader
from torch_scatter import scatter_max
from tqdm import tqdm

from geometric_governance.util import RangeOrValue, get_value
from geometric_governance.data import (
    generate_synthetic_election,
    get_scoring_function_winners,
)
from geometric_governance.model import MessagePassingElectionModel

NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES_RANGE = (2, 10)
DATASET_SIZE = 100_000
# DATASET_SIZE = 100
TRAIN_BATCH_SIZE = 128
TRAIN_NUM_EPOCHS = 20


def generate_plurality_bipartite_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates_range: RangeOrValue,
    dataloader_batch_size: int,
    top_k_candidates: int,
    rng: np.random.Generator,
):
    graphs = []
    for _ in tqdm(range(dataset_size)):
        num_voters = get_value(num_voters_range, rng)
        num_candidates = get_value(num_candidates_range, rng)

        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )

        graph = election_data.to_bipartite_graph(top_k_candidates, vote_data="ranking")

        #  Add plurality votes
        votes = graph.edge_index[1, (graph.edge_attr == 1).squeeze()]
        votes -= num_voters  # Index offset
        vote_count = torch.bincount(votes, minlength=num_candidates)

        winners = get_scoring_function_winners(vote_count)

        graph.y = winners
        graphs.append(graph)
    dataloader = DataLoader(graphs, batch_size=dataloader_batch_size, shuffle=True)
    return dataloader


dataloader = generate_plurality_bipartite_dataset(
    dataset_size=DATASET_SIZE,
    num_voters_range=NUM_VOTERS_RANGE,
    num_candidates_range=NUM_CANDIDATES_RANGE,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=42),
)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [00:42<00:00, 2331.11it/s]


In [4]:
model = MessagePassingElectionModel(edge_dim=1)
optim = torch.optim.Adam(model.parameters())

for epoch in range(TRAIN_NUM_EPOCHS):
    total_loss = 0
    for data in dataloader:
        optim.zero_grad()
        out = model(data)
        loss = -(out * data.y).sum() / TRAIN_BATCH_SIZE
        loss.backward()
        optim.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} Loss {total_loss / len(dataloader)}")

Epoch 0 Loss 0.5428374072589228
Epoch 1 Loss 0.2259630191251826
Epoch 2 Loss 0.20822207604432502
Epoch 3 Loss 0.2014854596856305
Epoch 4 Loss 0.19901462398050235
Epoch 5 Loss 0.19678573989692857
Epoch 6 Loss 0.1954447927853793
Epoch 7 Loss 0.19457347472400768
Epoch 8 Loss 0.19362757311147802
Epoch 9 Loss 0.19327943839723496
Epoch 10 Loss 0.19288148615232972
Epoch 11 Loss 0.19269865960873606
Epoch 12 Loss 0.19210284243306847
Epoch 13 Loss 0.19202847542512752
Epoch 14 Loss 0.19165232108758234
Epoch 15 Loss 0.19155316286817994
Epoch 16 Loss 0.19143743665360125
Epoch 17 Loss 0.19152188348724408
Epoch 18 Loss 0.1915521417480067
Epoch 19 Loss 0.19131086083591137


In [5]:
# Save the model to be used in robust_voting.ipynb
torch.save(model.state_dict(), "election_model")

In [None]:
EVAL_NUM_VOTERS = 100
EVAL_NUM_CANDIDATES = 20
EVAL_DATASET_SIZE = 1_000

eval_dataloader = generate_plurality_bipartite_dataset(
    dataset_size=EVAL_DATASET_SIZE,
    num_voters_range=EVAL_NUM_VOTERS,
    num_candidates_range=EVAL_NUM_CANDIDATES,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=16180),
)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in eval_dataloader:
        out = model(data)
        candidate_idxs = data.x[:, 1] == 1
        batch_idxs = data.batch[candidate_idxs]
        _, predicted = scatter_max(out, batch_idxs)
        _, predicted_ground = scatter_max(data.y, batch_idxs)
        total += predicted_ground.shape[0]
        correct += (predicted == predicted_ground).sum().item()

accuracy = correct / total
print(f"Accuracy: {accuracy} with {EVAL_NUM_VOTERS} voters")