In [None]:
from typing import Literal

import numpy as np
import torch
from torch_geometric.loader import DataLoader
from torch_scatter import scatter_max
from tqdm import tqdm


from geometric_governance.util import Logger, RangeOrValue, get_value
from geometric_governance.data import (
    generate_synthetic_election,
    get_scoring_function_winners,
)
from geometric_governance.model import MessagePassingElectionModel

device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")

config = {
    "num_voters_range": (3, 50),
    "num_candidates_range": (2, 10),
    "train_dataset_size": 100_000,
    "train_batch_size": 128,
    "train_num_epochs": 20,
    "eval_num_voters": 75,
    "eval_num_candidates": 15,
    "eval_dataset_size": 1_000,
    "welfare_fn": "utilitarian",
    "top_k_candidates": None,
    "learning_rate": 3e-4,
    "use_welfare_loss": True,
    "monotonicity_loss_batch_size": 32,
}


def generate_welfare_bipartite_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates_range: RangeOrValue,
    dataloader_batch_size: int,
    top_k_candidates: int,
    welfare_fn: Literal["utilitarian", "nash", "rawlsian"],
    rng: np.random.Generator,
):
    graphs = []
    for _ in tqdm(range(dataset_size)):
        num_voters = get_value(num_voters_range, rng)
        num_candidates = get_value(num_candidates_range, rng)

        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )

        graph = election_data.to_bipartite_graph(top_k_candidates, vote_data="ranking")

        candidate_welfare = election_data.voter_utilities
        match welfare_fn:
            case "utilitarian":
                candidate_welfare = election_data.voter_utilities.sum(dim=0)
            case "nash":
                candidate_welfare = election_data.voter_utilities.prod(dim=0)
            case "rawlsian":
                candidate_welfare = election_data.voter_utilities.min(dim=0)
            case _:
                raise ValueError("Unknown welfare function.")

        winners = get_scoring_function_winners(candidate_welfare)

        graph.y = candidate_welfare
        graph.winners = winners
        graphs.append(graph)
    dataloader = DataLoader(graphs, batch_size=dataloader_batch_size, shuffle=True)
    return dataloader


train_dataloader = generate_welfare_bipartite_dataset(
    dataset_size=config["train_dataset_size"],
    num_voters_range=config["num_voters_range"],
    num_candidates_range=config["num_candidates_range"],
    dataloader_batch_size=config["train_batch_size"],
    top_k_candidates=config["top_k_candidates"],
    welfare_fn=config["welfare_fn"],
    rng=np.random.default_rng(seed=42),
)

eval_dataloader = generate_welfare_bipartite_dataset(
    dataset_size=config["eval_dataset_size"],
    num_voters_range=config["eval_num_voters"],
    num_candidates_range=config["eval_num_candidates"],
    dataloader_batch_size=config["train_batch_size"],
    top_k_candidates=config["top_k_candidates"],
    welfare_fn=config["welfare_fn"],
    rng=np.random.default_rng(seed=16180),
)

In [None]:
model = MessagePassingElectionModel(edge_dim=1)
model.to(device=device)
optim = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

experiment_name = "monotonicity-criterion-disabled"
logger = Logger(
    experiment_name=experiment_name,
    config=config,
    mode="online",
)

with tqdm(range(config["train_num_epochs"])) as pbar:
    for epoch in range(config["train_num_epochs"]):
        # Train
        train_loss = 0
        train_welfare_loss = 0
        train_monotonicity_loss = 0
        train_welfare = 0

        model.train()
        for data_ in train_dataloader:
            optim.zero_grad()

            data = data_.to(device=device)
            data.edge_attr.requires_grad = True
            out = model(data)

            if config["use_welfare_loss"]:
                welfare_loss = (
                    -(torch.exp(out) * data.y).sum() / config["train_batch_size"]
                )
            else:
                welfare_loss = -(out * data.winners).sum() / config["train_batch_size"]

            # Monotonicity loss
            monotonicity_loss = 0
            candidates = data.candidate_idxs.nonzero()
            perm = torch.randperm(candidates.size(0))[
                : config["monotonicity_loss_batch_size"]
            ]
            for i in perm:
                candidate_idx = candidates[i]
                edge_idxs = data.edge_index[1] == candidate_idx
                grad = torch.autograd.grad(
                    outputs=out[i], inputs=data.edge_attr, create_graph=True
                )[0]
                monotonicity_loss += torch.where(
                    grad[edge_idxs] < 0,
                    -grad[edge_idxs],
                    torch.zeros_like(grad[edge_idxs]),
                ).mean()
            monotonicity_loss /= config["monotonicity_loss_batch_size"]

            # loss = welfare_loss + monotonicity_loss
            loss = welfare_loss
            loss.backward()
            optim.step()

            batch_idxs = data.batch[data.candidate_idxs]
            _, predicted = scatter_max(out, batch_idxs)
            welfare = data.y[predicted].mean()
            train_welfare += welfare.item()

            train_loss += loss.item()
            train_welfare_loss += welfare_loss.item()
            train_monotonicity_loss += monotonicity_loss.item()

        train_loss /= len(train_dataloader)
        train_welfare_loss /= len(train_dataloader)
        train_monotonicity_loss /= len(train_dataloader)
        train_welfare /= len(train_dataloader)

        logger.log(
            {
                "train/total_loss": train_loss,
                "train/welfare_loss": train_welfare_loss,
                "train/monotonicity_loss": train_monotonicity_loss,
                "train/welfare": train_welfare,
            }
        )

        # Eval
        model.eval()
        eval_loss = 0
        eval_welfare = 0
        total, correct = 0, 0
        with torch.no_grad():
            for data_ in eval_dataloader:
                data = data_.to(device=device)
                out = model(data)
                if config["use_welfare_loss"]:
                    welfare_loss = (
                        -(torch.exp(out) * data.y).sum() / config["train_batch_size"]
                    )
                else:
                    welfare_loss = (
                        -(out * data.winners).sum() / config["train_batch_size"]
                    )
                batch_idxs = data.batch[data.candidate_idxs]
                _, predicted = scatter_max(out, batch_idxs)
                _, predicted_ground = scatter_max(data.y, batch_idxs)
                total += predicted_ground.shape[0]
                correct += (predicted == predicted_ground).sum().item()

                welfare = data.y[predicted].mean()
                eval_loss += loss.item()
                eval_welfare += welfare.item()

        eval_loss /= len(eval_dataloader)
        eval_welfare /= len(eval_dataloader)
        eval_accuracy = correct / total
        logger.log(
            {
                "eval/loss": eval_loss,
                "eval/accuracy": eval_accuracy,
                "eval/welfare": eval_welfare,
            }
        )
        logger.commit()

        pbar.set_postfix(
            {
                "train_welfare_loss": train_welfare_loss,
                "train_monotonicity_loss": train_monotonicity_loss,
                "train_welfare": train_welfare,
                "eval_loss": eval_loss,
                "eval_accuracy": eval_accuracy,
                "eval_welfare": eval_welfare,
            }
        )
        pbar.update(1)