# Learning plurality voting with DeepSets

## Dataset construction

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MLP, DeepSetsAggregation
from tqdm import tqdm

from geometric_governance.data import generate_synthetic_election

NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES = 3
DATASET_SIZE = 100_000

TRAIN_BATCH_SIZE = 128
TRAIN_NUM_EPOCHS = 20

# Construct dataset


class PluralityDataset(Dataset):
    def __init__(
        self,
        voter_preferences_list: list,
        winner_list: list,
    ):
        super().__init__()
        self.X = voter_preferences_list
        self.y = winner_list

        assert len(self.X) == len(self.y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return (self.X[index], self.y[index])


def collate_fn(data):
    voter_preferences_list = [x[0] for x in data]
    winners_list = [x[1] for x in data]

    X = torch.concat(voter_preferences_list, dim=0)
    index = [
        torch.full(size=(x.shape[0],), fill_value=n)
        for n, x in enumerate(voter_preferences_list)
    ]
    index = torch.concat(index, dim=0)

    y = torch.stack(winners_list)

    return X.to(torch.float32), index, y


def generate_plurality_dataset(
    dataset_size: int,
    num_voters_range: tuple[int, int] | int,
    num_candidates: int,
    dataloader_batch_size: int,
    rng: np.random.Generator,
):
    dataset = ([], [])
    for _ in tqdm(range(dataset_size)):
        if isinstance(num_voters_range, tuple):
            num_voters = rng.integers(low=num_voters_range[0], high=num_voters_range[1])
        else:
            num_voters = num_voters_range
        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )
        # X
        voter_preferences = election_data.voter_preferences
        vote_count = torch.bincount(voter_preferences[:, 0], minlength=num_candidates)
        # If multiple winners, equivalent to dice roll
        winners = torch.where(vote_count == torch.max(vote_count), 1, 0)
        winners = (winners / winners.sum()).to(torch.float32)
        # y
        dataset[0].append(voter_preferences)
        dataset[1].append(winners)

    dataset = PluralityDataset(
        voter_preferences_list=dataset[0], winner_list=dataset[1]
    )
    dataloader = DataLoader(
        dataset, batch_size=dataloader_batch_size, collate_fn=collate_fn
    )
    return dataset, dataloader


_, dataloader = generate_plurality_dataset(
    dataset_size=DATASET_SIZE,
    num_voters_range=NUM_VOTERS_RANGE,
    num_candidates=NUM_CANDIDATES,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    rng=np.random.default_rng(seed=42),
)


## Model Definition and Training

In [None]:
class DeepSetsScoreElectionModel(nn.Module):
    def __init__(self, num_candidates: int, embedding_size: int = 128):
        super().__init__()

        self.local_nn = MLP([num_candidates, embedding_size, embedding_size])
        self.global_nn = MLP([embedding_size, embedding_size, num_candidates])

        self.deepset = DeepSetsAggregation(
            local_nn=self.local_nn, global_nn=self.global_nn
        )

    def forward(self, x, index):
        scores = self.deepset(x, index=index)
        return scores


model = DeepSetsScoreElectionModel(num_candidates=NUM_CANDIDATES)
model.train()
optim = torch.optim.Adam(model.parameters())

for epoch in range(TRAIN_NUM_EPOCHS):
    total_loss = 0
    for X, index, y in dataloader:
        optim.zero_grad()
        p_y = model(X, index)
        loss = torch.nn.functional.cross_entropy(p_y, y)
        loss.backward()
        optim.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} Loss {total_loss / len(dataloader)}")

## Evaluation

In [None]:
EVAL_NUM_VOTERS = 100
EVAL_DATASET_SIZE = 10_000

# Construct dataset

_, eval_dataloader = generate_plurality_dataset(
    dataset_size=EVAL_DATASET_SIZE,
    num_voters_range=EVAL_NUM_VOTERS,
    num_candidates=NUM_CANDIDATES,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    rng=np.random.default_rng(seed=16180),
)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for X, index, y in eval_dataloader:
        p_y = model(X, index)
        _, predicted = torch.max(p_y, dim=1)
        _, predicted_ground = torch.max(y, dim=1)
        total += y.shape[0]
        correct += (predicted == predicted_ground).sum().item()

accuracy = correct / total
print(f"Accuracy: {accuracy} with {EVAL_NUM_VOTERS} voters")
