In [14]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MLP, DeepSetsAggregation, conv
from torch_geometric.data import Batch, Data
from tqdm import tqdm

from geometric_governance.util import RangeOrValue, get_value
from geometric_governance.data import generate_synthetic_election, ElectionData
from geometric_governance.model import MessagePassingElectionModel

In [2]:
NUM_VOTERS_RANGE = (3, 50)
NUM_CANDIDATES_RANGE = 5
TRAIN_BATCH_SIZE = 128
TRAIN_NUM_EPOCHS = 1_000

In [3]:
def fully_connected_directed_edge_index(n):
    row, col = torch.meshgrid(torch.arange(n), torch.arange(n), indexing='ij')
    edge_index = torch.stack([row.flatten(), col.flatten()], dim=0)
    return edge_index

print(fully_connected_directed_edge_index(5))

tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
         4],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
         4]])


In [4]:
# StrategyModel is responsible for transforming voter utilities
class StrategyModel(nn.Module):
    def __init__(self, num_candidates: int, embedding_size: int = 128):
        super().__init__()
        self.reveal_utility = nn.Linear(num_candidates, embedding_size)
        self.gnn_communicate = nn.ModuleList([conv.GATv2Conv(in_channels=embedding_size, out_channels=embedding_size, add_self_loops=True) for _ in range(5)])
        self.aggregate = nn.Linear(embedding_size, num_candidates)

    def forward(self, x):
        x = self.reveal_utility(x)
        revealed_utilities = x.clone()

        edge_index = fully_connected_directed_edge_index(x.size(-2))
        for conv in self.gnn_communicate:
            x = conv(x, edge_index)

        x = self.aggregate(revealed_utilities + x)
        x = F.softmax(x, dim=-1)

        return x

In [5]:
strategy_model = StrategyModel(NUM_CANDIDATES_RANGE)
strategy_model.train()
strategy_optim = torch.optim.Adam(strategy_model.parameters())

In [6]:
election_model = MessagePassingElectionModel(edge_dim=1)
election_model.load_state_dict(torch.load("election_model", weights_only=True))
election_model.eval()

MessagePassingElectionModel(
  (convs): ModuleList(
    (0-3): 4 x MessagePassingElectionLayer()
  )
)

In [17]:
def utility_matrix_to_graph(U):
    """
    Converts a utility matrix U (Voters x Candidates) into a PyTorch Geometric Data object.
    
    Parameters:
        U (torch.Tensor): A tensor of shape (..., num_voters, num_candidates).
    
    Returns:
        Data: A PyTorch Geometric Data object.
    """
    num_voters, num_candidates = U.size(-2), U.size(-1)
    
    # Node features: one-hot encoding for voters and candidates
    x_voters = torch.tensor([[1, 0]] * num_voters, dtype=torch.float)
    x_candidates = torch.tensor([[0, 1]] * num_candidates, dtype=torch.float)
    x = torch.cat([x_voters, x_candidates], dim=0)
    
    # Create edges
    voter_indices = torch.arange(num_voters).repeat_interleave(num_candidates)
    candidate_indices = torch.arange(num_candidates).repeat(num_voters)
    
    # Shift candidate indices to match node indexing
    candidate_indices += num_voters
    
    edge_index = torch.stack([voter_indices, candidate_indices], dim=0)

    # Edge attributes (utility values)
    edge_attr = U.flatten()

    # Candidate indices
    candidate_idxs = x[:, 1] == 1

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, candidate_idxs=candidate_idxs)

In [21]:
rng = np.random.default_rng(seed=42)

epochs = tqdm(range(2))
for epoch in epochs:
    num_voters = get_value(NUM_VOTERS_RANGE, rng)
    num_candidates = get_value(NUM_CANDIDATES_RANGE, rng)

    election_data = generate_synthetic_election(
        num_voters=num_voters, num_candidates=num_candidates, rng=rng
    )
    test = election_data.to_bipartite_graph(5, vote_data="ranking")
    # print(test.x)
    # print(test.edge_index)
    # print(test.edge_attr)
    # print(test.candidate_idxs)
    election_data.voter_utilities = election_data.voter_utilities.float()

    transformed_utilities = strategy_model(election_data.voter_utilities)

    # Randomly disable gradients to all but one voter
    unmasked_voter = rng.integers(low=0, high=num_voters)
    gradient_mask = torch.ones(num_voters)
    gradient_mask[unmasked_voter] = 0
    gradient_mask = gradient_mask.bool()
    gradient_mask = gradient_mask.unsqueeze(-1)

    transformed_utilities = torch.where(gradient_mask, transformed_utilities.detach(), transformed_utilities)
    bipartite_graph = utility_matrix_to_graph(transformed_utilities)
    print(bipartite_graph)
    print(bipartite_graph.x)
    print(bipartite_graph.edge_index.size())
    print(bipartite_graph.edge_attr.size())
    print(bipartite_graph.candidate_idxs)

    # Pass through frozen election_model
    graph = Batch.from_data_list([bipartite_graph])

    print(election_model(graph))

  0%|                                                                                                                                                                | 0/2 [00:00<?, ?it/s]

Data(x=[12, 2], edge_index=[2, 35], edge_attr=[35], candidate_idxs=[12])
tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.]])
torch.Size([2, 35])
torch.Size([35])
tensor([False, False, False, False, False, False, False,  True,  True,  True,
         True,  True])





RuntimeError: Tensors must have same number of dimensions: got 2 and 1