# Learning plurality voting with DeepSets

## Dataset construction

In [10]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MLP, DeepSetsAggregation
from tqdm import tqdm

from geometric_governance.utils import RangeOrValue, get_value
from geometric_governance.data import generate_synthetic_election

NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES = 3
DATASET_SIZE = 100_000
TRAIN_BATCH_SIZE = 128
TRAIN_NUM_EPOCHS = 20


class TopTwoDataset(Dataset):
    def __init__(
        self,
        voter_preferences_list: list,
        winner_list: list,
    ):
        super().__init__()
        self.X = voter_preferences_list
        self.y = winner_list

        assert len(self.X) == len(self.y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return (self.X[index], self.y[index])


def collate_fn(data):
    voter_preferences_list = [x[0] for x in data]
    winners_list = [x[1] for x in data]

    X = torch.concat(voter_preferences_list, dim=0)
    index = [
        torch.full(size=(x.shape[0],), fill_value=n)
        for n, x in enumerate(voter_preferences_list)
    ]
    index = torch.concat(index, dim=0)

    y = torch.stack(winners_list)

    return X.to(torch.float32), index, y


def generate_top_two_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates: int,
    dataloader_batch_size: int,
    rng: np.random.Generator,
):
    assert num_candidates >= 2, "Require at least two candidates for top-two voting."
    dataset = ([], [])
    for _ in tqdm(range(dataset_size)):
        num_voters = get_value(num_voters_range, rng)
        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )
        # X
        voter_preferences = election_data.voter_preferences
        vote_count = torch.bincount(voter_preferences[:, 0], minlength=num_candidates)

        max_votes = torch.max(vote_count)
        first_round_winners = torch.where(vote_count == max_votes, 1, 0)
        num_first_round_winners = first_round_winners.sum()

        # Multiple next-round possibilities
        # Case 1: There are > 1 candidates with the maximum -> all candidates with the maximum advance to the next round
        # Case 2: Or there is 1 with the maximum, and multiple with the second-highest -> maximal candidate and all second-highest candidates advance to the next round
        # Case 3: Or if there is 1 with the maximum, and one with the second-highest -> only these two candidates advance to the next round
        if num_first_round_winners > 1:
            # Case 1
            second_round_candidates = first_round_winners
        else:
            # Identify second place candidates in the first round
            second_highest_votes = torch.max(vote_count[vote_count < max_votes])
            second_place_candidates = torch.where(vote_count == second_highest_votes, 1, 0)
            
            # Case 2 + Case 3
            second_round_candidates = first_round_winners | second_place_candidates

        # Plurality vote on the second_round_candidates
        # For each voter's preference, we start from their most prefered candidate
        # and make our way down their preferences, until we find a candidate
        # present in second_round_candidates and this candidate gets a vote
        advancing_candidates = torch.nonzero(second_round_candidates, as_tuple=True)[0]
        second_round_votes = torch.zeros(num_candidates, dtype=torch.int)
        
        for voter_preference in voter_preferences:
            for candidate in voter_preference:
                if candidate in advancing_candidates:
                    second_round_votes[candidate] += 1
                    break

        winners = torch.where(second_round_votes == torch.max(second_round_votes), 1, 0)
        winners = (winners / winners.sum()).to(torch.float32)
        
        # y
        dataset[0].append(voter_preferences)
        dataset[1].append(winners)

    dataset = PluralityDataset(
        voter_preferences_list=dataset[0], winner_list=dataset[1]
    )
    dataloader = DataLoader(
        dataset, batch_size=dataloader_batch_size, collate_fn=collate_fn
    )
    return dataloader


dataloader = generate_top_two_dataset(
    dataset_size=DATASET_SIZE,
    num_voters_range=NUM_VOTERS_RANGE,
    num_candidates=NUM_CANDIDATES,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    rng=np.random.default_rng(seed=42),
)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [00:20<00:00, 4860.36it/s]


## Model Definition and Training

In [11]:
class DeepSetsElectionModel(nn.Module):
    def __init__(self, num_candidates: int, embedding_size: int = 128):
        super().__init__()

        self.local_nn = MLP([num_candidates, embedding_size, embedding_size])
        self.global_nn = MLP([embedding_size, embedding_size, num_candidates])

        self.deepset = DeepSetsAggregation(
            local_nn=self.local_nn, global_nn=self.global_nn
        )

    def forward(self, x, index):
        scores = self.deepset(x, index=index)
        return scores


model = DeepSetsElectionModel(num_candidates=NUM_CANDIDATES)
model.train()
optim = torch.optim.Adam(model.parameters())

for epoch in range(TRAIN_NUM_EPOCHS):
    total_loss = 0
    for X, index, y in dataloader:
        optim.zero_grad()
        p_y = model(X, index)
        loss = torch.nn.functional.cross_entropy(p_y, y)
        loss.backward()
        optim.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} Loss {total_loss / len(dataloader)}")

Epoch 0 Loss 0.36437921894861913
Epoch 1 Loss 0.33083327019306097
Epoch 2 Loss 0.24674158686262262
Epoch 3 Loss 0.20644089891134626
Epoch 4 Loss 0.19681299434941443
Epoch 5 Loss 0.19103577671110478
Epoch 6 Loss 0.18711953982710838
Epoch 7 Loss 0.18374267572065447
Epoch 8 Loss 0.18112010474476364
Epoch 9 Loss 0.17890771534627356
Epoch 10 Loss 0.1770576547707438
Epoch 11 Loss 0.17557669850185398
Epoch 12 Loss 0.17424641030333232
Epoch 13 Loss 0.17309364504026026
Epoch 14 Loss 0.17220086638656112
Epoch 15 Loss 0.17136707265511195
Epoch 16 Loss 0.1705850150979236
Epoch 17 Loss 0.16991801897202002
Epoch 18 Loss 0.16907167310833626
Epoch 19 Loss 0.16838723882232481


## Evaluation

In [14]:
EVAL_NUM_VOTERS = 100
EVAL_DATASET_SIZE = 10_000

# Construct dataset

eval_dataloader = generate_top_two_dataset(
    dataset_size=EVAL_DATASET_SIZE,
    num_voters_range=EVAL_NUM_VOTERS,
    num_candidates=NUM_CANDIDATES,
    dataloader_batch_size=TRAIN_BATCH_SIZE,
    rng=np.random.default_rng(seed=16180),
)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for X, index, y in eval_dataloader:
        p_y = model(X, index)
        _, predicted = torch.max(p_y, dim=1)
        _, predicted_ground = torch.max(y, dim=1)
        total += y.shape[0]
        correct += (predicted == predicted_ground).sum().item()

accuracy = correct / total
print(f"Accuracy: {accuracy} with {EVAL_NUM_VOTERS} voters")


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:07<00:00, 1413.96it/s]


Accuracy: 0.8977 with 100 voters
