In [None]:
from typing import Literal

import numpy as np
import torch
from torch_geometric.loader import DataLoader
from torch_scatter import scatter_max
from tqdm import tqdm

from geometric_governance.util import RangeOrValue, get_value
from geometric_governance.data import (
    generate_synthetic_election,
    get_scoring_function_winners,
)
from geometric_governance.model import MessagePassingElectionModel


NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES_RANGE = (2, 10)
DATASET_SIZE = 100_000
# DATASET_SIZE = 100
TRAIN_BATCH_SIZE = 128
TRAIN_NUM_EPOCHS = 20


def generate_welfare_bipartite_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates_range: RangeOrValue,
    dataloader_batch_size: int,
    top_k_candidates: int,
    rng: np.random.Generator,
    welfare_fn: Literal["utilitarian", "nash", "rawlsian"] = "utilitarian",
):
    graphs = []
    for _ in tqdm(range(dataset_size)):
        num_voters = get_value(num_voters_range, rng)
        num_candidates = get_value(num_candidates_range, rng)

        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )

        graph = election_data.to_bipartite_graph(top_k_candidates, vote_data="ranking")

        candidate_welfare = election_data.voter_utilities
        match welfare_fn:
            case "utilitarian":
                candidate_welfare = election_data.voter_utilities.sum(dim=0)
            case "nash":
                candidate_welfare = election_data.voter_utilities.prod(dim=0)
            case "rawlsian":
                candidate_welfare = election_data.voter_utilities.min(dim=0)
            case _:
                raise ValueError("Unknown welfare function.")

        winners = get_scoring_function_winners(candidate_welfare)

        graph.y = winners
        graphs.append(graph)
    dataloader = DataLoader(graphs, batch_size=dataloader_batch_size, shuffle=True)
    return dataloader


dataloader = generate_welfare_bipartite_dataset(
    dataset_size=DATASET_SIZE,
    num_voters_range=NUM_VOTERS_RANGE,
    num_candidates_range=NUM_CANDIDATES_RANGE,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=42),
    welfare_fn="utilitarian",
)

In [None]:
model = MessagePassingElectionModel(edge_dim=1)
optim = torch.optim.Adam(model.parameters())

for epoch in range(TRAIN_NUM_EPOCHS):
    total_loss = 0
    for data in dataloader:
        optim.zero_grad()
        out = model(data)
        loss = -(out * data.y).sum() / TRAIN_BATCH_SIZE
        loss.backward()
        optim.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} Loss {total_loss / len(dataloader)}")

In [None]:
EVAL_NUM_VOTERS = 50
EVAL_NUM_CANDIDATES = 20
EVAL_DATASET_SIZE = 1_000

eval_dataloader = generate_welfare_bipartite_dataset(
    dataset_size=EVAL_DATASET_SIZE,
    num_voters_range=EVAL_NUM_VOTERS,
    num_candidates_range=EVAL_NUM_CANDIDATES,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=16180),
)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in eval_dataloader:
        out = model(data)
        candidate_idxs = data.x[:, 1] == 1
        batch_idxs = data.batch[candidate_idxs]
        _, predicted = scatter_max(out, batch_idxs)
        _, predicted_ground = scatter_max(data.y, batch_idxs)
        total += predicted_ground.shape[0]
        correct += (predicted == predicted_ground).sum().item()

accuracy = correct / total
print(f"Accuracy: {accuracy} with {EVAL_NUM_VOTERS} voters")