In [None]:
import numpy as np
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from tqdm import tqdm

from geometric_governance.utils import RangeOrValue, get_value
from geometric_governance.data import generate_synthetic_election

NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES_RANGE = (2, 10)
DATASET_SIZE = 1_000
# DATASET_SIZE = 100_000
TRAIN_BATCH_SIZE = 128


def generate_plurality_bipartite_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates_range: RangeOrValue,
    dataloader_batch_size: int,
    top_k_candidates: int,
    rng: np.random.Generator,
):
    graphs = []
    for _ in tqdm(range(dataset_size)):
        num_voters = get_value(num_voters_range, rng)
        num_candidates = get_value(num_candidates_range, rng)

        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )

        graph = election_data.to_bipartite_graph(top_k_candidates, vote_data="ranking")

        #  Add plurality votes
        votes = graph.edge_index[1, (graph.edge_attr == 1).squeeze()]
        votes -= num_voters  # Index offse
        vote_count = torch.bincount(votes, minlength=num_candidates)
        # If multiple winners, equivalent to dice roll
        winners = torch.where(vote_count == torch.max(vote_count), 1, 0)
        winners = (winners / winners.sum()).to(torch.float32)

        graph.y = winners
        graphs.append(graph)
    dataloader = DataLoader(graphs, batch_size=dataloader_batch_size, shuffle=True)
    return dataloader


dataloader = generate_plurality_bipartite_dataset(
    dataset_size=DATASET_SIZE,
    num_voters_range=NUM_VOTERS_RANGE,
    num_candidates_range=NUM_CANDIDATES_RANGE,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=42),
)

In [None]:
class MessagePassingElectionModel:
    pass

In [None]:
graph = generate_synthetic_election(50, 50).to_bipartite_graph(5, vote_data="ranking")
graph.edge_index[1, (graph.edge_attr == 1).squeeze()]