In [None]:
###############################################################
# ONE-CELL COMPLETE PIPELINE FOR SCOTTISH STV + PAIRWISE MLP
###############################################################

import os
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import nn
from tqdm import tqdm

from votekit.cvr_loaders import load_scottish
from votekit.elections import STV, Borda, Plurality


###################################################################
# 1. Pairwise Matrix Encoding from Scottish Profile
###################################################################

def normalize_name(x):
    """Convert frozenset({'John Doe'}) → 'John Doe'."""
    if isinstance(x, frozenset):
        return next(iter(x))
    return x

def profile_to_score_matrix(profile, max_cands=len(profile.candidates), max_positions=len(profile.candidates)):
    """
    Builds an n × m score matrix S[c,p] = number of voters
    who ranked candidate c in position p (0-indexed).
    Automatically handles partial ballots and weighted ballots.
    """
    # Candidate normalization
    cand_list = [normalize_name(c) for c in profile.candidates]
    n = len(cand_list)

    # Create empty score matrix
    S = np.zeros((max_cands, max_positions), dtype=np.float32)

    for ballot in profile.ballots:
        ranking = ballot.ranking  # tuple of frozensets
        weight = ballot.weight

        # For each ranked position in the ballot
        for pos, cand_set in enumerate(ranking):
            cname = normalize_name(cand_set)

            if cname in cand_list:
                cidx = cand_list.index(cname)

                # place in score matrix
                if cidx < max_cands and pos < max_positions:
                    S[cidx, pos] += weight

        # Unranked candidates → simply contribute 0 everywhere (implicitly)

    return S


###################################################################
# 2. Winner Vector using plurality
###################################################################

def plurality_winner_vector(profile, max_cands=14):
   
    winners = Plurality(profile, m = 1, tiebreak='random').get_elected()
    cand_list = list(profile.candidates)

    vec = np.zeros(max_cands, dtype=np.float32)
    for w in winners:
        idx = cand_list.index(next(iter(w)))  # w is frozenset
        vec[idx] = 1
    return vec

def borda_winner_vector(profile, max_cands=14):
   
    winners = Borda(profile, m = 1, tiebreak='random').get_elected()
    cand_list = list(profile.candidates)

    vec = np.zeros(max_cands, dtype=np.float32)
    for w in winners:
        idx = cand_list.index(next(iter(w)))  # w is frozenset
        vec[idx] = 1
    return vec

def stv_winner_vector(profile, max_cands=14):
   
    winners = STV(profile, m = 1, tiebreak='random').get_elected()
    cand_list = list(profile.candidates)

    vec = np.zeros(max_cands, dtype=np.float32)
    for w in winners:
        idx = cand_list.index(next(iter(w)))  # w is frozenset
        vec[idx] = 1
    return vec


###################################################################
# 3. Load entire Scottish Dataset (all folders)
###################################################################

def load_scottish_dataset(base_dir, max_cands=14):
    X_list, Y_list = [], []

    all_folders = sorted(
        d for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d))
    )

    for folder in tqdm(all_folders, desc="Folders"):
        folder_path = os.path.join(base_dir, folder)
        files = [f for f in os.listdir(folder_path) if f.endswith(".csv")]

        for file in tqdm(files, desc=f"{folder}", leave=False):
            profile = load_scottish(os.path.join(folder_path, file))[0]

            M = profile_to_score_matrix(profile)
            n = M.shape[0]

            padded = np.zeros((max_cands, max_cands), dtype=np.float32)
            padded[:n, :n] = M

            X_list.append(padded.flatten())
            Y_list.append(plurality_winner_vector(profile, max_cands=max_cands))

    X = torch.tensor(np.array(X_list), dtype=torch.float32)
    Y = torch.tensor(np.array(Y_list), dtype=torch.float32)
    return X, Y


###################################################################
# 4. Make DataLoaders
###################################################################

def make_dataloaders(X, Y, batch_size=64, split=0.8):
    N = len(X)
    idx = int(N * split)
    train_ds = TensorDataset(X[:idx], Y[:idx])
    test_ds  = TensorDataset(X[idx:], Y[idx:])

    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(test_ds, batch_size=batch_size, shuffle=False),
    )


###################################################################
# 5. MLP Model
###################################################################

class MLP(nn.Module):
    def __init__(self, max_cands=14):
        super().__init__()
        dim = max_cands * max_cands
        self.net = nn.Sequential(
            nn.Linear(dim, 256),
            nn.Sigmoid(),
            nn.Linear(256, 256),
            nn.Sigmoid(),
            nn.Linear(256, max_cands)
        )
    def forward(self, x):
        return self.net(x)


###################################################################
# 6. Train / Accuracy
###################################################################

def train_epoch(loader, model, loss_fn, optimizer):
    model.train()
    for x, y in loader:
        logits = model(x)
        loss = loss_fn(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            pred = torch.round(torch.sigmoid(model(x)))
            match = (pred == y).all(dim=1)
            correct += match.sum().item()
            total += len(y)
    return correct / total


###################################################################
# 7. Run Training
###################################################################

BASE_DIR = "/Users/ss2776/Downloads/scot-elex-main"   # <-- change this path

print("Loading Scottish elections...")
X, Y = load_scottish_dataset(BASE_DIR)

train_loader, test_loader = make_dataloaders(X, Y, batch_size=64)

model = MLP(max_cands=14)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.BCEWithLogitsLoss()

train_acc = []
test_acc = []
for epoch in range(20):
    train_epoch(train_loader, model, loss_fn, optimizer)
    tr = accuracy(model, train_loader)
    train_acc.append(tr)
    te = accuracy(model, test_loader)
    test_acc.append(te)
    print(f"Epoch {epoch+1}: Train Acc={tr:.3f} | Test Acc={te:.3f}")