In [4]:
import os
from typing import Literal
from functools import partial


import numpy as np
import torch
from torch.utils.data import DataLoader as TorchDataloader
from torch_geometric.loader import DataLoader as GraphDataloader
from torch_scatter import scatter_max
from tqdm import tqdm


from geometric_governance.util import (
    Logger,
    RangeOrValue,
    get_value,
    get_max,
    OUTPUT_DIR,
)
from geometric_governance.data import (
    SetDataset,
    generate_synthetic_election,
    get_scoring_function_winners,
)
from geometric_governance.model import (
    MessagePassingElectionModel,
    DeepSetElectionModel,
)

device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")

config = {
    "num_voters_range": (3, 50),
    "num_candidates_range": (2, 10),
    "train_dataset_size": 10_000,
    "train_batch_size": 128,
    "train_num_epochs": 200,
    "checkpoint_interval": 5,
    "val_num_voters": 75,
    "val_num_candidates": 15,
    "test_num_voters": 100,
    "test_num_candidates": 20,
    "train_iterations_per_epoch": 1,
    "eval_dataset_size": 1_000,
    "learning_rate": 0.001,
    "clip_grad_norm": 1.0,
    "use_monotonicity_loss": False,
    "monotonicity_loss_batch_size": 32,
    "voting_rule": "nash",
    "representation": "graph",
}


def generate_rule_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates_range: RangeOrValue,
    dataloader_batch_size: int,
    top_k_candidates: int | None,
    voting_rule: Literal["plurality", "borda", "copeland", "utilitarian", "nash", "rawlsian"],
    representation: Literal["set", "graph"],
    seed: int,
    recompute: bool = True,
):
    rng = np.random.default_rng(seed=seed)
    dataset_file = os.path.join(
        OUTPUT_DIR,
        f"rule_dataset_{dataset_size}_{num_voters_range}_{num_candidates_range}_{representation}_{voting_rule}_{seed}.pt",
    )
    if os.path.exists(dataset_file) and not recompute:
        with open(dataset_file, "rb") as f:
            dataset = torch.load(f, weights_only=False)
    else:
        dataset = []

        generated_count = 0
        with tqdm(range(dataset_size)) as pbar:
            while generated_count < dataset_size:
                num_voters = get_value(num_voters_range, rng)
                num_candidates = get_value(num_candidates_range, rng)

                election_data = generate_synthetic_election(
                    num_voters=num_voters, num_candidates=num_candidates, rng=rng
                )

                match voting_rule:
                    case "plurality":
                        scores = election_data.positional_ballots[0]
                        winners = get_scoring_function_winners(scores)
                    case "borda":
                        scoring = torch.tensor(
                            list(range(num_candidates, 0, -1)), dtype=torch.float32
                        )
                        scores = scoring @ election_data.positional_ballots
                        winners = get_scoring_function_winners(scores)
                    case "copeland":
                        scores = election_data.tournament_embedding.sum(dim=1)
                        winners = get_scoring_function_winners(scores)
                    case "utilitarian":
                        scores = election_data.voter_utilities.sum(dim=0)
                        winners = get_scoring_function_winners(scores)
                    case "nash":
                        # Use log utilities for stability
                        election_data.voter_utilities = election_data.voter_utilities.log()
                        scores = election_data.voter_utilities.sum(dim=0)
                        # scores = election_data.voter_utilities.prod(dim=0)
                        winners = get_scoring_function_winners(scores)
                    case "rawlsian":
                        scores = candidate_welfare = election_data.voter_utilities.min(dim=0)[0]
                        winners = get_scoring_function_winners(scores)
                    case _:
                        raise ValueError("Unknown voting rule.")

                if winners.max() < 1.0:
                    # Tie
                    continue

                if representation == "graph":
                    graph = election_data.to_bipartite_graph(
                        top_k_candidates, vote_data="utility"
                    )
                    graph.y = election_data.voter_utilities.sum(dim=0)
                    graph.winners = winners
                    dataset.append(graph)
                elif representation == "set":
                    pad_shape = get_max(num_candidates_range) - num_candidates
                    voter_preferences = election_data.voter_preferences_alt
                    voter_preferences = torch.nn.functional.pad(
                        voter_preferences, (0, pad_shape, 0, 0)
                    )
                    winners = torch.nn.functional.pad(winners, (0, pad_shape))
                    dataset.append((voter_preferences, winners))

                generated_count += 1
                pbar.update(1)

        with open(dataset_file, "wb") as f:
            torch.save(dataset, f)

    if representation == "graph":
        dataloader = GraphDataloader(
            dataset, batch_size=dataloader_batch_size, shuffle=True
        )
    elif representation == "set":
        voter_preferences_list = [x[0] for x in dataset]
        winner_list = [x[1] for x in dataset]
        set_dataset = SetDataset(voter_preferences_list, winner_list)
        dataloader = TorchDataloader(
            set_dataset,
            batch_size=dataloader_batch_size,
            shuffle=True,
            collate_fn=SetDataset.collate_fn,
        )
    return dataloader


generate_dataset = partial(
    generate_rule_dataset,
    voting_rule=config["voting_rule"],
    representation=config["representation"],
    dataloader_batch_size=config["train_batch_size"],
    top_k_candidates=None,
    recompute=False,
)

train_dataloader = generate_dataset(
    dataset_size=config["train_dataset_size"],
    num_voters_range=config["num_voters_range"],
    num_candidates_range=config["num_candidates_range"],
    seed=42,
)

val_dataloader = generate_dataset(
    dataset_size=config["eval_dataset_size"],
    num_voters_range=config["val_num_voters"],
    num_candidates_range=config["val_num_candidates"],
    seed=16180,
)

if config["representation"] == "graph":
    test_num_candidates = config["test_num_candidates"]
elif config["representation"] == "set":
    test_num_candidates = get_max(config["num_candidates_range"])

test_dataloader = generate_dataset(
    dataset_size=config["eval_dataset_size"],
    num_voters_range=config["test_num_voters"],
    num_candidates_range=test_num_candidates,
    seed=314159,
)

config["train_iterations_per_epoch"] = min(
    config["train_iterations_per_epoch"], len(iter(train_dataloader))
)

In [9]:
seed = 5
torch.manual_seed(seed)
np.random.seed(seed)

if config["representation"] == "graph":
    election_model = MessagePassingElectionModel(
        node_emb_dim=256, edge_emb_dim=64, num_layers=4, edge_dim=1
    )
elif config["representation"] == "set":
    election_model = DeepSetElectionModel(
        get_max(config["num_candidates_range"]), embedding_size=155
    )
parameter_count = sum(p.numel() for p in election_model.parameters() if p.requires_grad)
print(f"parameter_count: {parameter_count}")


election_model.to(device=device)
optim = torch.optim.Adam(election_model.parameters(), lr=config["learning_rate"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, T_max=config["train_num_epochs"], eta_min=1e-4
)
experiment_name = f"{config['representation']}-election-{config['voting_rule']}"
with (
    Logger(
        experiment_name=experiment_name,
        config=config,
        mode="online",
    ) as logger,
    tqdm(range(config["train_num_epochs"])) as pbar,
):
    best_validation_accuracy: float = 0.0

    for epoch in range(config["train_num_epochs"]):
        # Train
        train_loss = 0
        train_rule_loss = 0
        train_monotonicity_loss = 0
        train_welfare = 0
        total, correct = 0, 0

        election_model.train()

        train_iter = iter(train_dataloader)

        for _ in range(config["train_iterations_per_epoch"]):
            optim.zero_grad()

            # Rule Loss
            if config["representation"] == "graph":
                data = next(train_iter).to(device=device)
                data.edge_attr.requires_grad = True
                out = election_model(data)
                winners = data.winners
                rule_loss = -(out * winners).sum() / config["train_batch_size"]
            elif config["representation"] == "set":
                X, index, y = next(train_iter)
                X = X.to(device=device)
                y = y.to(device=device)
                index = index.to(device=device)
                out = election_model(X, index=index)
                winners = y
                rule_loss = torch.nn.functional.cross_entropy(out, winners)

            # Monotonicity loss
            monotonicity_loss = 0

            if config["use_monotonicity_loss"] and config["representation"] == "graph":
                candidates = data.candidate_idxs.nonzero()
                perm = torch.randperm(candidates.size(0))[
                    : config["monotonicity_loss_batch_size"]
                ]
                for i in perm:
                    candidate_idx = candidates[i]
                    edge_idxs = data.edge_index[1] == candidate_idx
                    grad = torch.autograd.grad(
                        outputs=out[i], inputs=data.edge_attr, create_graph=True
                    )[0]
                    monotonicity_loss += torch.where(
                        grad[edge_idxs] < 0,
                        -grad[edge_idxs],
                        torch.zeros_like(grad[edge_idxs]),
                    ).mean()
            monotonicity_loss /= config["monotonicity_loss_batch_size"]

            loss = rule_loss + monotonicity_loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(
                election_model.parameters(), config["clip_grad_norm"]
            )
            optim.step()
            scheduler.step()

            if config["representation"] == "graph":
                batch_idxs = data.batch[data.candidate_idxs]
                _, predicted = scatter_max(out, batch_idxs)
                _, predicted_ground = scatter_max(winners, batch_idxs)
            elif config["representation"] == "set":
                _, predicted = torch.max(out, dim=1)
                _, predicted_ground = torch.max(winners, dim=1)

            total += predicted_ground.shape[0]
            correct += (predicted == predicted_ground).sum().item()
            if config["representation"] == "graph":
                welfare = data.y[predicted].mean()
                train_welfare += welfare.item()
            train_loss += loss.item()
            train_rule_loss += rule_loss.item()
            if config["use_monotonicity_loss"]:
                train_monotonicity_loss += monotonicity_loss.item()

        train_loss /= config["train_iterations_per_epoch"]
        train_rule_loss /= config["train_iterations_per_epoch"]
        train_monotonicity_loss /= config["train_iterations_per_epoch"]
        train_welfare /= config["train_iterations_per_epoch"]
        train_accuracy = correct / total

        if epoch % config["checkpoint_interval"] == 0:
            torch.save(
                election_model, os.path.join(logger.checkpoint_dir, f"model_{epoch}.pt")
            )

        logger.log(
            {
                "train/total_loss": train_loss,
                "train/rule_loss": train_rule_loss,
                "train/monotonicity_loss": train_monotonicity_loss,
                "train/welfare": train_welfare,
                "train/accuracy": train_accuracy,
            }
        )

        # Validation
        val_loss = 0
        val_accuracy = 0
        if config["representation"] == "graph":
            val_welfare = 0
            election_model.eval()
            total, correct = 0, 0
            with torch.no_grad():
                for data_ in val_dataloader:
                    data = data_.to(device=device)
                    out = election_model(data)
                    rule_loss = -(out * data.winners).sum() / config["train_batch_size"]
                    if config["representation"] == "graph":
                        batch_idxs = data.batch[data.candidate_idxs]
                        _, predicted = scatter_max(out, batch_idxs)
                        _, predicted_ground = scatter_max(data.winners, batch_idxs)
                    elif config["representation"] == "set":
                        _, predicted = torch.max(out, dim=1)
                        _, predicted_ground = torch.max(data.winners, dim=1)

                    total += predicted_ground.shape[0]
                    correct += (predicted == predicted_ground).sum().item()

                    welfare = data.y[predicted].mean()
                    val_loss += rule_loss.item()
                    val_welfare += welfare.item()

            val_loss /= len(val_dataloader)
            val_welfare /= len(val_dataloader)
            val_accuracy = correct / total

            if val_accuracy > best_validation_accuracy:
                print(f"New best accuracy: {val_accuracy}")
                torch.save(
                    election_model,
                    os.path.join(logger.checkpoint_dir, "model_best.pt"),
                )
                best_validation_accuracy = val_accuracy

            logger.log(
                {
                    "val/rule_loss": val_loss,
                    "val/accuracy": val_accuracy,
                    "val/welfare": val_welfare,
                }
            )
        logger.commit()

        pbar.set_postfix(
            {
                "train_rule_loss": train_rule_loss,
                "train_accuracy": train_accuracy,
                "val_rule_loss": val_loss,
                "val_accuracy": val_accuracy,
            }
        )
        pbar.update(1)

parameter_count: 1022081


  0%|▋                                                                                                                                         | 1/200 [00:00<01:04,  3.11it/s, train_rule_loss=1.56, train_accuracy=0.391, val_rule_loss=0.692, val_accuracy=0.866]

New best accuracy: 0.866


  1%|█▎                                                                                                                                       | 2/200 [00:00<01:03,  3.14it/s, train_rule_loss=0.129, train_accuracy=0.961, val_rule_loss=0.111, val_accuracy=0.958]

New best accuracy: 0.958


  2%|██▊                                                                                                                                        | 4/200 [00:01<01:01,  3.19it/s, train_rule_loss=0.0355, train_accuracy=1, val_rule_loss=0.0408, val_accuracy=0.988]

New best accuracy: 0.988


  8%|██████████                                                                                                                            | 15/200 [00:04<00:57,  3.22it/s, train_rule_loss=0.0406, train_accuracy=0.992, val_rule_loss=0.0287, val_accuracy=0.989]

New best accuracy: 0.989


  8%|███████████                                                                                                                               | 16/200 [00:04<00:57,  3.19it/s, train_rule_loss=0.0159, train_accuracy=1, val_rule_loss=0.0271, val_accuracy=0.994]

New best accuracy: 0.994


 13%|█████████████████▍                                                                                                                    | 26/200 [00:08<00:54,  3.20it/s, train_rule_loss=0.0275, train_accuracy=0.992, val_rule_loss=0.0179, val_accuracy=0.997]

New best accuracy: 0.997


 33%|████████████████████████████████████████████▉                                                                                           | 66/200 [00:20<00:41,  3.23it/s, train_rule_loss=0.00544, train_accuracy=1, val_rule_loss=0.00949, val_accuracy=0.999]

New best accuracy: 0.999


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:04<00:00,  3.10it/s, train_rule_loss=0.000685, train_accuracy=1, val_rule_loss=0.0234, val_accuracy=0.994]


0,1
train/accuracy,▁███████████████████████████████████████
train/monotonicity_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/rule_loss,▅▆▃▂▃▃▃▄▂▄▄█▂▃▂▂▂▁▁▂▂▃▁▂▂▂▁▁▁▁▁▂▁▁▂▁▁▂▁▁
train/total_loss,▆▆▃█▃▃▅▂▃▄▂▃▂▁▃▁▂▃▃▂▂▃▂▂▂▂▂▁▁▂▂▁▁▁▁▁▂▂▁▁
train/welfare,▅▆▁▆▆▄▅▃▄▆▅▃▄▂▆▄▇▆▄▃▅▃▆▇▆▆▃▇▆▇▄▄▄▅▆▆▇█▄▄
val/accuracy,▃▂▁▄▄▃▆▇▅▃▆▇▇▇█▇▇▇▇▇▇▇▇█▇▇▇█▇██████▇▇█▇▇
val/rule_loss,▄█▅▃▅▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁
val/welfare,▄▅▅▃▅▃▃▄▆▃▅▆▆▆▄▇▆▇▅█▅▇▇▅▁█▆▂▂▆█▃▆▅▄▅▆▅▅▅

0,1
train/accuracy,1.0
train/monotonicity_loss,0.0
train/rule_loss,0.00068
train/total_loss,0.00068
train/welfare,-49.51244
val/accuracy,0.994
val/rule_loss,0.02342
val/welfare,-225.17649


In [18]:
if config["representation"] == "graph":
    election_model = torch.load(
        os.path.join(logger.checkpoint_dir, "model_best.pt"), weights_only=False
    )
    election_model.eval()

    test_loss = 0
    total, correct = 0, 0
    with torch.no_grad():
        for data_ in test_dataloader:
            data = data_.to(device=device)
            out = election_model(data)
            rule_loss = -(out * data.winners).sum() / config["train_batch_size"]
            batch_idxs = data.batch[data.candidate_idxs]
            _, predicted = scatter_max(out, batch_idxs)
            _, predicted_ground = scatter_max(data.winners, batch_idxs)
            total += predicted_ground.shape[0]
            correct += (predicted == predicted_ground).sum().item()
            test_loss += rule_loss.item()

    test_loss /= len(test_dataloader)
    test_accuracy = correct / total

    print(f"Test | Accuracy {test_accuracy} | Loss {test_loss}")
elif config["representation"] == "set":
    election_model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for X, index, y in test_dataloader:
            X = X.to(device)
            index = index.to(device)
            y = y.to(device)
            p_y = election_model(X, index)
            _, predicted = torch.max(p_y, dim=1)
            _, predicted_ground = torch.max(y, dim=1)
            total += y.shape[0]
            correct += (predicted == predicted_ground).sum().item()

    accuracy = correct / total
    print(f"Accuracy: {accuracy}")

Test | Accuracy 1.0 | Loss 0.00241581993032014


In [19]:
data.winners

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')

In [20]:
predicted_ground

tensor([   3,   25,   46,   62,   85,  117,  133,  148,  173,  189,  204,  230,
         246,  272,  285,  314,  328,  347,  372,  383,  400,  429,  454,  471,
         494,  519,  522,  550,  565,  593,  611,  620,  642,  664,  685,  716,
         727,  759,  769,  792,  803,  828,  850,  869,  892,  914,  937,  951,
         973,  997, 1015, 1034, 1051, 1060, 1092, 1108, 1125, 1154, 1172, 1196,
        1200, 1228, 1249, 1260, 1291, 1302, 1329, 1351, 1369, 1385, 1405, 1428,
        1444, 1478, 1495, 1506, 1533, 1549, 1560, 1582, 1608, 1632, 1657, 1676,
        1697, 1711, 1738, 1740, 1769, 1788, 1819, 1832, 1849, 1861, 1886, 1917,
        1936, 1942, 1965, 1981, 2018, 2029, 2047, 2069], device='cuda:0')

In [21]:
predicted

tensor([   3,   25,   46,   62,   85,  117,  133,  148,  173,  189,  204,  230,
         246,  272,  285,  314,  328,  347,  372,  383,  400,  429,  454,  471,
         494,  519,  522,  550,  565,  593,  611,  620,  642,  664,  685,  716,
         727,  759,  769,  792,  803,  828,  850,  869,  892,  914,  937,  951,
         973,  997, 1015, 1034, 1051, 1060, 1092, 1108, 1125, 1154, 1172, 1196,
        1200, 1228, 1249, 1260, 1291, 1302, 1329, 1351, 1369, 1385, 1405, 1428,
        1444, 1478, 1495, 1506, 1533, 1549, 1560, 1582, 1608, 1632, 1657, 1676,
        1697, 1711, 1738, 1740, 1769, 1788, 1819, 1832, 1849, 1861, 1886, 1917,
        1936, 1942, 1965, 1981, 2018, 2029, 2047, 2069], device='cuda:0')

In [22]:
out

tensor([ -294.8145, -1092.3062,  -696.4897,  ...,  -237.2563,   -19.7446,
        -1796.7393], device='cuda:0')