In [None]:
from typing import Literal

import numpy as np
import torch
import torch.optim as o
from torch_geometric.loader import DataLoader
from torch_scatter import scatter_softmax, scatter_add
from tqdm import tqdm


from geometric_governance.util import Logger, RangeOrValue, get_value
from geometric_governance.data import (
    generate_synthetic_election,
    get_scoring_function_winners,
)
from geometric_governance.model import DeepSetStrategyModel, MLPStrategyModel

device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")

config = {
    "num_voters_range": 10,
    "num_candidates_range": 3,
    "train_dataset_size": 10_000,
    "train_batch_size": 64,
    "train_num_epochs": 200,
    "welfare_fn": "utilitarian",
    "top_k_candidates": None,
    "learning_rate": 0.001,
    "use_welfare_loss": True,
    "use_monotonicity_loss": True,
    "monotonicity_loss_batch_size": 32,
}


def generate_welfare_bipartite_dataset(
    dataset_size: int,
    num_voters_range: RangeOrValue,
    num_candidates_range: RangeOrValue,
    dataloader_batch_size: int,
    top_k_candidates: int,
    welfare_fn: Literal["utilitarian", "nash", "rawlsian"],
    rng: np.random.Generator,
):
    graphs = []
    for _ in tqdm(range(dataset_size)):
        num_voters = get_value(num_voters_range, rng)
        num_candidates = get_value(num_candidates_range, rng)

        election_data = generate_synthetic_election(
            num_voters=num_voters, num_candidates=num_candidates, rng=rng
        )

        graph = election_data.to_bipartite_graph(top_k_candidates, vote_data="utility")

        candidate_welfare = election_data.voter_utilities
        match welfare_fn:
            case "utilitarian":
                candidate_welfare = election_data.voter_utilities.sum(dim=0)
            case "nash":
                # print(election_data.voter_utilities[:, 0])
                log_voter_utilities = torch.log(election_data.voter_utilities)
                # print(log_voter_utilities[:, 0])
                candidate_welfare = log_voter_utilities.sum(dim=0)
                # print(candidate_welfare[0])
            case "rawlsian":
                candidate_welfare = election_data.voter_utilities.min(dim=0)[0]
            case _:
                raise ValueError("Unknown welfare function.")

        winners = get_scoring_function_winners(candidate_welfare)

        graph.y = candidate_welfare
        graph.winners = winners
        graph.voter_utilities = election_data.voter_utilities
        graphs.append(graph)
    dataloader = DataLoader(graphs, batch_size=dataloader_batch_size, shuffle=True)
    return dataloader


train_dataloader = generate_welfare_bipartite_dataset(
    dataset_size=config["train_dataset_size"],
    num_voters_range=config["num_voters_range"],
    num_candidates_range=config["num_candidates_range"],
    dataloader_batch_size=config["train_batch_size"],
    top_k_candidates=config["top_k_candidates"],
    welfare_fn=config["welfare_fn"],
    rng=np.random.default_rng(seed=42),
)


In [None]:
def manual_election_model(data):
    vote_sum = scatter_add(data.edge_attr, index=data.edge_index[0], dim=0)
    vote_sum = vote_sum[data.edge_index[0]]
    normalised_votes = data.edge_attr / torch.maximum(
        vote_sum, torch.ones_like(vote_sum)
    )

    logits = scatter_add(
        src=normalised_votes, index=data.edge_index[1].unsqueeze(-1), dim=0
    )[data.candidate_idxs.nonzero()].squeeze()
    out = scatter_softmax(logits, data.batch[data.candidate_idxs])
    return out

In [None]:
experiment_name = "strategy-module-debugging"
# strategy_model = DeepSetStrategyModel(edge_dim=1, emb_dim=256).to(device)
strategy_model = MLPStrategyModel(
    num_candidates=config["num_candidates_range"], emb_dim=32
).to(device)
optim = o.Adam(strategy_model.parameters(), lr=config["learning_rate"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, T_max=config["train_num_epochs"]
)
warmup_epochs = 5
warmup_scheduler = o.lr_scheduler.LinearLR(
    optim, start_factor=0.1, end_factor=1, total_iters=warmup_epochs
)
main_scheduler = o.lr_scheduler.CosineAnnealingWarmRestarts(optim, T_0=5, T_mult=2)
scheduler = o.lr_scheduler.SequentialLR(
    optim,
    schedulers=[warmup_scheduler, main_scheduler],
    milestones=[warmup_epochs],
)

with (
    Logger(
        experiment_name=experiment_name,
        config=config,
        mode="disabled",
    ) as logger,
    tqdm(range(config["train_num_epochs"])) as pbar,
):
    for epoch in range(config["train_num_epochs"]):
        # Train
        train_loss = 0
        train_welfare_loss = 0
        train_monotonicity_loss = 0
        train_welfare = 0

        for i, data_ in enumerate(train_dataloader):
            data = data_.to(device=device)

            # Train strategy model
            strategy_loss = 0

            voters = (~data.candidate_idxs).nonzero()
            candidate_idxs_nonzero = data.candidate_idxs.nonzero()
            candidates_to_batch = data.batch[
                candidate_idxs_nonzero
            ]  # batch each candidate belongs to
            voter_to_batch = data.batch[voters]

            truthful_votes = data.edge_attr
            strategic_votes = strategy_model(
                data.edge_attr, data.edge_index, data.candidate_idxs
            )
            strategic_votes_detached = strategic_votes.detach()

            # 20% of voters are strategic
            p = 0.2
            strategic_voters = torch.rand_like(voters, dtype=torch.float) < p

            # Select a train voter from each batch
            rand_indices = torch.stack(
                [
                    torch.where(voter_to_batch == batch)[0][
                        torch.randint(0, (voter_to_batch == batch).sum(), (1,))
                    ].squeeze()
                    for batch in range(data.batch[-1] + 1)
                ]
            )
            sampled_voters = voters[rand_indices]
            strategic_voters[rand_indices] = True

            # Create gradient mask
            strategy_train_mask = torch.isin(data.edge_index[0], sampled_voters)
            strategy_train_mask = strategy_train_mask.unsqueeze(-1)
            strategic_votes_mask = torch.isin(
                data.edge_index[0], strategic_voters.nonzero()
            ).unsqueeze(-1)

            # Cut gradients
            resulting_votes = torch.where(
                strategic_votes_mask, strategic_votes_detached, truthful_votes
            )
            gradient_cut_strategic_votes = torch.where(
                strategy_train_mask, strategic_votes, resulting_votes
            )

            # strategic_votes.retain_grad()

            # Clone and modify votes
            data_strategy = data.clone()
            data_strategy.edge_attr = gradient_cut_strategic_votes

            vote_probabilities = manual_election_model(data_strategy)

            # vote_probabilities.retain_grad()

            # Calculate welfare
            voter_welfare = data.edge_attr[strategy_train_mask].squeeze()

            voter_candidate_order = data.edge_index[1][strategy_train_mask.squeeze()]
            original_candidate_order = candidate_idxs_nonzero.squeeze()

            vote_probabilities_expanded = torch.zeros((data.x.shape[0],), device=device)
            welfare_expanded = torch.zeros((data.x.shape[0],), device=device)
            vote_probabilities_expanded[original_candidate_order] = vote_probabilities
            welfare_expanded[voter_candidate_order] = voter_welfare

            # Calculate loss
            strategy_loss = -(vote_probabilities_expanded * welfare_expanded).sum()
            strategy_loss = strategy_loss / len(sampled_voters)
            optim.zero_grad()
            strategy_loss.backward()

            # if epoch % 10 == 0 and i == 0:
            #     print("==")
            #     print(sampled_voters[0])
            #     print("Voters and candidates")
            #     print(data.edge_index[0][strategy_train_mask.squeeze()][:3])
            #     print(data.edge_index[1][strategy_train_mask.squeeze()][:3])
            #     print("Strategic votes")
            #     print(strategic_votes[strategy_train_mask][:3])
            #     print(strategic_votes.grad[strategy_train_mask][:3])
            #     print("Vote probabilities")
            #     print(vote_probabilities[:3])
            #     print(vote_probabilities.grad[:3])
            #     print("Welfare")
            #     print(voter_welfare[:3])

            torch.nn.utils.clip_grad_norm_(strategy_model.parameters(), max_norm=1.0)
            optim.step()

            # if epoch % 10 == 0 and i == 0:
            #     # Check updated votes
            #     print("Update")
            #     strategic_votes_new = strategy_model(
            #         data.edge_attr, data.edge_index, data.candidate_idxs
            #     )
            #     print(strategic_votes_new[strategy_train_mask][:3])

            # assert False
            train_loss += strategy_loss.item()

            scheduler.step()

        logger.commit()

        pbar.set_postfix(
            {
                "train_strategy_loss": train_loss,
            }
        )
        pbar.update(1)