In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MLP, DeepSetsAggregation, conv
from torch_geometric.data import Batch, Data
from tqdm import tqdm

from geometric_governance.util import RangeOrValue, get_value
from geometric_governance.data import generate_synthetic_election, ElectionData
from geometric_governance.model import MessagePassingElectionModel

In [2]:
NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES_RANGE = 5
TRAIN_BATCH_SIZE = 128
TRAIN_NUM_EPOCHS = 1_000

In [3]:
def fully_connected_directed_edge_index(n):
    row, col = torch.meshgrid(torch.arange(n), torch.arange(n), indexing="ij")
    edge_index = torch.stack([row.flatten(), col.flatten()], dim=0)
    return edge_index


print(fully_connected_directed_edge_index(5))

tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
         4],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
         4]])


In [4]:
# StrategyModel is responsible for transforming voter utilities
class StrategyModel(nn.Module):
    def __init__(self, num_candidates: int, embedding_size: int = 128):
        super().__init__()
        self.reveal_utility = nn.Linear(num_candidates, embedding_size)
        self.gnn_communicate = nn.ModuleList(
            [
                conv.GATv2Conv(
                    in_channels=embedding_size,
                    out_channels=embedding_size,
                    add_self_loops=True,
                )
                for _ in range(5)
            ]
        )
        self.aggregate = nn.Linear(embedding_size, num_candidates)

    def forward(self, x):
        x = self.reveal_utility(x)
        revealed_utilities = x.clone()

        edge_index = fully_connected_directed_edge_index(x.size(-2))
        for conv in self.gnn_communicate:
            x = conv(x, edge_index)

        x = self.aggregate(revealed_utilities + x)
        x = F.softmax(x, dim=-1)

        return x

In [5]:
strategy_model = StrategyModel(NUM_CANDIDATES_RANGE)
strategy_model.train()
strategy_optim = torch.optim.Adam(strategy_model.parameters())

In [11]:
election_model = MessagePassingElectionModel(edge_dim=1)
# election_model.load_state_dict(torch.load("election_model", weights_only=True))
election_model.eval()

MessagePassingElectionModel(
  (lin_in_node): Linear(in_features=2, out_features=32, bias=True)
  (lin_out_node): Linear(in_features=32, out_features=1, bias=True)
  (lin_in_edge): Linear(in_features=1, out_features=8, bias=True)
  (convs): ModuleList(
    (0-3): 4 x MessagePassingElectionLayer()
  )
)

In [12]:
def utility_matrix_to_graph(U):
    """
    Converts a utility matrix U (Voters x Candidates) into a PyTorch Geometric Data object.

    Parameters:
        U (torch.Tensor): A tensor of shape (..., num_voters, num_candidates).

    Returns:
        Data: A PyTorch Geometric Data object.
    """
    num_voters, num_candidates = U.size(-2), U.size(-1)

    # Node features: one-hot encoding for voters and candidates
    x_voters = torch.tensor([[1, 0]] * num_voters, dtype=torch.float)
    x_candidates = torch.tensor([[0, 1]] * num_candidates, dtype=torch.float)
    x = torch.cat([x_voters, x_candidates], dim=0)

    # Create edges
    voter_indices = torch.arange(num_voters).repeat_interleave(num_candidates)
    candidate_indices = torch.arange(num_candidates).repeat(num_voters)

    # Shift candidate indices to match node indexing
    candidate_indices += num_voters

    edge_index = torch.stack([voter_indices, candidate_indices], dim=0)

    # Edge attributes (utility values)
    edge_attr = U.flatten().unsqueeze(-1)

    # Candidate indices
    candidate_idxs = x[:, 1] == 1

    return Data(
        x=x, edge_index=edge_index, edge_attr=edge_attr, candidate_idxs=candidate_idxs
    )

In [13]:
def print_graph_data(data: Data):
    print(data)
    print(data.x)
    print(data.edge_index)
    print(data.edge_attr)
    print(data.candidate_idxs)

In [None]:
rng = np.random.default_rng(seed=42)

epochs = tqdm(range(TRAIN_NUM_EPOCHS))
for epoch in epochs:
    num_voters = get_value(NUM_VOTERS_RANGE, rng)
    num_candidates = get_value(NUM_CANDIDATES_RANGE, rng)

    election_data = generate_synthetic_election(
        num_voters=num_voters, num_candidates=num_candidates, rng=rng
    )
    election_data.voter_utilities = election_data.voter_utilities.float()

    transformed_utilities = strategy_model(election_data.voter_utilities)

    # Randomly disable gradients to all but one voter
    unmasked_voter = rng.integers(low=0, high=num_voters)
    gradient_mask = torch.ones(num_voters)
    gradient_mask[unmasked_voter] = 0
    gradient_mask = gradient_mask.bool()
    gradient_mask = gradient_mask.unsqueeze(-1)
    print(transformed_utilities.shape)

    transformed_utilities = torch.where(
        gradient_mask, transformed_utilities.detach(), transformed_utilities
    )
    bipartite_graph = utility_matrix_to_graph(transformed_utilities)

    # Pass through frozen election_model
    graph = Batch.from_data_list([bipartite_graph])
    out = election_model(graph)

    # Calculate loss for one voter, and update strategy_model's parameters
    strategy_optim.zero_grad()
    loss = -(torch.dot(out, election_data.voter_utilities[unmasked_voter]))
    loss.backward()
    strategy_optim.step()

    epochs.set_postfix({"strategy_loss": loss})

 19%|█▉        | 192/1000 [00:02<00:10, 79.72it/s, strategy_loss=tensor(1.5586, grad_fn=<NegBackward0>)] 


KeyboardInterrupt: 

In [36]:
election_data.voter_utilities

tensor([[0.1405, 0.4738, 0.0879, 0.2872, 0.0106],
        [0.5978, 0.0204, 0.0803, 0.0819, 0.2196],
        [0.2841, 0.1888, 0.1146, 0.0995, 0.3130],
        [0.0034, 0.0031, 0.3444, 0.4339, 0.2153],
        [0.2872, 0.1333, 0.0162, 0.3592, 0.2041],
        [0.1664, 0.3263, 0.0749, 0.3745, 0.0580],
        [0.4832, 0.0330, 0.0294, 0.3734, 0.0810],
        [0.1755, 0.4521, 0.1116, 0.1439, 0.1169],
        [0.3825, 0.1553, 0.1202, 0.1468, 0.1953],
        [0.0146, 0.1700, 0.4109, 0.3609, 0.0436],
        [0.1149, 0.0395, 0.0899, 0.4052, 0.3505],
        [0.1472, 0.3564, 0.0787, 0.0016, 0.4161],
        [0.0769, 0.0420, 0.2877, 0.5511, 0.0422],
        [0.0560, 0.0257, 0.2364, 0.1114, 0.5705],
        [0.3400, 0.1143, 0.2718, 0.1970, 0.0769],
        [0.0435, 0.6513, 0.1572, 0.0260, 0.1220],
        [0.2119, 0.5596, 0.1107, 0.0915, 0.0263],
        [0.1625, 0.1933, 0.0555, 0.0032, 0.5855],
        [0.0946, 0.2732, 0.0327, 0.1573, 0.4422],
        [0.1480, 0.1979, 0.2313, 0.3369, 0.0858],


In [37]:
strategy_model(election_data.voter_utilities)

tensor([[0.2053, 0.0586, 0.2920, 0.1286, 0.3155],
        [0.0327, 0.2981, 0.2028, 0.2782, 0.1882],
        [0.1406, 0.1929, 0.2148, 0.3222, 0.1294],
        [0.3066, 0.3657, 0.1072, 0.0879, 0.1327],
        [0.1206, 0.2023, 0.3823, 0.1103, 0.1845],
        [0.1863, 0.1019, 0.3309, 0.0988, 0.2820],
        [0.0453, 0.2569, 0.3266, 0.0876, 0.2836],
        [0.2045, 0.0709, 0.2436, 0.2450, 0.2360],
        [0.0896, 0.2169, 0.2249, 0.2590, 0.2095],
        [0.3069, 0.2245, 0.0903, 0.1168, 0.2616],
        [0.2348, 0.2902, 0.2798, 0.0997, 0.0955],
        [0.2321, 0.0901, 0.1961, 0.4103, 0.0714],
        [0.2179, 0.3057, 0.1586, 0.0534, 0.2645],
        [0.2829, 0.2992, 0.0986, 0.2842, 0.0350],
        [0.0934, 0.2666, 0.1321, 0.2104, 0.2974],
        [0.3086, 0.0302, 0.1611, 0.3219, 0.1782],
        [0.1632, 0.0443, 0.2240, 0.2621, 0.3064],
        [0.2129, 0.1522, 0.1875, 0.4101, 0.0373],
        [0.2905, 0.1217, 0.2768, 0.2443, 0.0667],
        [0.2103, 0.1979, 0.1879, 0.1325, 0.2715],
