In [None]:
import numpy as np
import torch
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU, Module
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_log_softmax, scatter_max
from tqdm import tqdm

from geometric_governance.utils import RangeOrValue, get_value
from geometric_governance.data import generate_synthetic_election

NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES_RANGE = (2, 10)
DATASET_SIZE = 100_000
# DATASET_SIZE = 100
TRAIN_BATCH_SIZE = 128
TRAIN_NUM_EPOCHS = 20


def generate_plurality_bipartite_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates_range: RangeOrValue,
    dataloader_batch_size: int,
    top_k_candidates: int,
    rng: np.random.Generator,
):
    graphs = []
    for _ in tqdm(range(dataset_size)):
        num_voters = get_value(num_voters_range, rng)
        num_candidates = get_value(num_candidates_range, rng)

        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )

        graph = election_data.to_bipartite_graph(top_k_candidates, vote_data="ranking")

        #  Add plurality votes
        votes = graph.edge_index[1, (graph.edge_attr == 1).squeeze()]
        votes -= num_voters  # Index offset
        vote_count = torch.bincount(votes, minlength=num_candidates)
        # If multiple winners, equivalent to dice roll
        winners = torch.where(vote_count == torch.max(vote_count), 1, 0)
        winners = (winners / winners.sum()).to(torch.float32)

        graph.y = winners
        graphs.append(graph)
    dataloader = DataLoader(graphs, batch_size=dataloader_batch_size, shuffle=True)
    return dataloader


dataloader = generate_plurality_bipartite_dataset(
    dataset_size=DATASET_SIZE,
    num_voters_range=NUM_VOTERS_RANGE,
    num_candidates_range=NUM_CANDIDATES_RANGE,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=42),
)

In [None]:
class MessagePassingElectionLayer(MessagePassing):
    def __init__(self, edge_dim: int = 1, emb_dim: int = 32):
        super().__init__(aggr="add")
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim + edge_dim, emb_dim),
            BatchNorm1d(emb_dim),
            ReLU(),
            Linear(emb_dim, emb_dim),
            BatchNorm1d(emb_dim),
            ReLU(),
        )

    def forward(self, x, edge_index, edge_attr):
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        return out

    def message(self, x_i, x_j, edge_attr):
        msg = torch.cat([x_i, x_j, edge_attr], dim=-1)
        return self.mlp_msg(msg)


class MessagePassingElectionModel(Module):
    def __init__(
        self,
        edge_dim: int = 1,
        emb_dim: int = 32,
        num_layers: int = 4,
    ):
        super().__init__()
        in_dim = 2
        # Initial embedding
        self.lin_in = Linear(in_dim, emb_dim)
        # Convolution layers
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(MessagePassingElectionLayer(edge_dim, emb_dim))
        # Readout
        self.lin_out = Linear(emb_dim, 1)

    def forward(self, data: Data):
        x = self.lin_in(data.x.to(torch.float32))
        for conv in self.convs:
            x = x + conv(x, data.edge_index, data.edge_attr)
        candidate_idxs = data.x[:, 1] == 1
        logits = self.lin_out(x[candidate_idxs]).squeeze(dim=-1)
        out = scatter_log_softmax(logits, data.batch[candidate_idxs])
        return out


model = MessagePassingElectionModel(edge_dim=1)
optim = torch.optim.Adam(model.parameters())

for epoch in range(TRAIN_NUM_EPOCHS):
    total_loss = 0
    for data in dataloader:
        optim.zero_grad()
        out = model(data)
        loss = -(out * data.y).sum() / TRAIN_BATCH_SIZE
        loss.backward()
        optim.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} Loss {total_loss / len(dataloader)}")

In [None]:
EVAL_NUM_VOTERS = 100
EVAL_NUM_CANDIDATES = 20
EVAL_DATASET_SIZE = 1_000

eval_dataloader = generate_plurality_bipartite_dataset(
    dataset_size=EVAL_DATASET_SIZE,
    num_voters_range=EVAL_NUM_VOTERS,
    num_candidates_range=EVAL_NUM_CANDIDATES,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    top_k_candidates=5,
    rng=np.random.default_rng(seed=16180),
)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in eval_dataloader:
        out = model(data)
        candidate_idxs = data.x[:, 1] == 1
        batch_idxs = data.batch[candidate_idxs]
        _, predicted = scatter_max(out, batch_idxs)
        _, predicted_ground = scatter_max(data.y, batch_idxs)
        total += predicted_ground.shape[0]
        correct += (predicted == predicted_ground).sum().item()

accuracy = correct / total
print(f"Accuracy: {accuracy} with {EVAL_NUM_VOTERS} voters")