In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import gzip
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from functools import reduce
import math
from math import ceil


# -------------------------------------------------------------------
# 1. Data Loading (from loadDataSet.ipynb)
# -------------------------------------------------------------------
# Copy-paste or import this function from your notebook
def load_data_from_disk(partition_id: int, only_server_test_data: bool = False):
    import torchvision
    from torch.utils.data import DataLoader
    save_dir = "/content/drive/MyDrive/client_data_backup2"
    BATCH_SIZE = 64

    if only_server_test_data:
        testset = torchvision.datasets.CIFAR10(
            root='./data', train=False, download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        )
        return DataLoader(testset, batch_size=BATCH_SIZE, num_workers=8)

    client_filenames = [
        os.path.join(save_dir, f'client_{partition_id}.pt.gz'),
        os.path.join(save_dir, f'iid_clients_{partition_id}.pt.gz')
    ]
    for path in client_filenames:
        if os.path.exists(path):
            partition_data_path = path
            break
    else:
        raise FileNotFoundError(f"No data file for client {partition_id}")

    with gzip.open(partition_data_path, 'rb') as f:
        device_data = torch.load(f, map_location='cpu')
    device_data = [(x.to(torch.float32), y) for x,y in device_data]
    np.random.shuffle(device_data)
    split = int(len(device_data)*0.8)
    train_data, test_data = device_data[:split], device_data[split:]

    normalize = transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    train_data = [(normalize(x), y) for x,y in train_data]
    test_data  = [(normalize(x), y) for x,y in test_data]

    trainloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
    valloader  = DataLoader(test_data,  batch_size=BATCH_SIZE, num_workers=1)
    return trainloader, valloader








In [None]:
# [1]

# -------------------------------------------------------------------
# 2. QILSA Mask Generation Parameters & Utilities
# -------------------------------------------------------------------
# Ring dimension and modulus
N = 1024
q = 12289
sigma = 8/np.sqrt(2*np.pi)

def sample_A():
    """Sample public matrix A ∈ R_q^{N×N} uniformly."""
    return np.random.randint(0, q, size=(N, N), dtype=np.int64)

A = sample_A()

def sample_gaussian(size):
    """Discrete Gaussian sampler mod q."""
    return np.round(np.random.normal(scale=sigma, size=size)).astype(np.int64) % q

def ring_mul(a, b):
    """
    Multiply two polynomials a, b in R_q[x]/(x^N + 1).
    Naive convolution with wrap-around and sign flip.
    """
    res = np.zeros(N, dtype=np.int64)
    for i in range(N):
        ai = a[i]
        if ai == 0: continue
        # direct convolution
        res[:N-i] = (res[:N-i] + ai * b[i:]) % q
        # wrap-around term (x^N = -1)
        res[N-i:] = (res[N-i:] - ai * b[:i]) % q
    return res

In [None]:
# [2]

# -------------------------------------------------------------------
# 3. Verifiable Secret Sharing (Shamir over large prime)
# -------------------------------------------------------------------
P = 2**61 - 1  # prime for VSS

def _eval_poly(coeffs, x):
    res = 0
    for c in reversed(coeffs):
        res = (res * x + c) % P
    return res

def share_secret(secret, n, t):
    """
    Split integer `secret` into n Shamir shares with threshold t.
    Returns list of (i, share_i).
    """
    coeffs = [secret] + [random.randrange(0, P) for _ in range(t-1)]
    return [(i, _eval_poly(coeffs, i)) for i in range(1, n+1)]

def lagrange_interpolate_zero(shares):
    """
    Reconstruct f(0) from at least t shares via Lagrange interpolation.
    """
    total = 0
    for j, yj in shares:
        num, den = 1, 1
        for m, _ in shares:
            if m != j:
                num = (num * -m) % P
                den = (den * (j - m)) % P
        total = (total + yj * num * pow(den, P-2, P)) % P
    return total

In [None]:

# -------------------------------------------------------------------
# 4. Model & Parameter Vector Helpers
# -------------------------------------------------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*8*8,128), nn.ReLU(),
            nn.Linear(128,10)
        )
    def forward(self, x):
        return self.fc(self.conv(x))

def get_model_vector(model):
    return torch.cat([p.data.view(-1) for p in model.parameters()]).cpu().numpy().astype(np.int64) % q

def set_model_vector(model, vec):
    pointer = 0
    for p in model.parameters():
        numel = p.numel()
        part = vec[pointer:pointer+numel].reshape(p.shape)
        p.data.copy_(torch.from_numpy(part).to(p.dtype))
        pointer += numel

In [None]:

# -------------------------------------------------------------------
# 5. Federated Client Implementation
# -------------------------------------------------------------------
class FedClient:
    def __init__(self, client_id, global_model):
        self.id = client_id
        self.device = torch.device('cpu')
        self.trainloader, _ = load_data_from_disk(client_id, False)
        self.valloader = load_data_from_disk(0, True)
        self.model = SimpleCNN().to(self.device)
        self.model.load_state_dict(global_model.state_dict())


    def validate(self):
        """Compute accuracy on this client’s validation set."""
        correct, total = 0, 0
        self.model.eval()
        with torch.no_grad():
            for x, y in self.valloader:
                out = self.model(x)
                _, pred = torch.max(out, dim=1)
                correct += (pred == y).sum().item()
                total   += y.size(0)
        acc = 100 * correct / total
        print(f" Client {self.id} validation accuracy: {acc:.2f}%")
        return acc

    def local_train(self, epochs=6, lr=0.01):
        optim_ = optim.SGD(self.model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        print(f"Client {self.id} training...")
        self.model.train()
        for _ in range(epochs):
            for x,y in self.trainloader:
                x,y = x.to(self.device), y.to(self.device)
                optim_.zero_grad()
                loss_fn(self.model(x), y).backward()
                optim_.step()

    def get_update(self, global_vec):
        local_vec = get_model_vector(self.model)
        return (local_vec - global_vec) % q

    def generate_mask(self):
        """
        QILSA mask: sample s,e ∈ R_q^N; compute m = A s + e mod q.
        """
        self.s = sample_gaussian(N)
        self.e = sample_gaussian(N)
        As = np.zeros(N, dtype=np.int64)
        for row in range(N):
            As = (As + ring_mul(A[row], self.s)) % q
        self.mask_poly = (As + self.e) % q

    def generate_vss_shares(self, selected_ids, t):
        """
        Shamir-share a random scalar r_i among selected clients.
        """
        self.r = random.randrange(0, P)
        shares = share_secret(self.r, len(selected_ids), t)
        self.shares = { selected_ids[i]: shares[i][1]
                        for i in range(len(selected_ids)) }

    def mask_update(self, update_vec):
        """
        Mask the flattened update: break into blocks of size N,
        add mask_poly to each block, then add Shamir r across all coords.
        """
        d = update_vec.shape[0]
        blocks = int(np.ceil(d / N))
        padded = np.zeros(blocks * N, dtype=np.int64)
        padded[:d] = update_vec % q

        for b in range(blocks):
            start,end = b*N, (b+1)*N
            padded[start:end] = (padded[start:end] + self.mask_poly) % q

        # also tile the scalar r (mod P) into the same shape for server VSS
        return padded[:d], self.r

In [None]:

# -------------------------------------------------------------------
# 6. Federated Server Implementation
# -------------------------------------------------------------------
class FedServer:
    def __init__(self, num_clients=50, local_epochs=6,
                 rounds=10, threshold=None):
        self.global_model = SimpleCNN()
        self.num_clients = num_clients
        self.local_epochs = local_epochs
        self.rounds = rounds
        self.clients = list(range(num_clients))
        self.threshold = threshold or num_clients
        self.testloader = load_data_from_disk(0, True)

    def select_clients(self):
        # full participation
        return random.sample(self.clients, k=5)

    def aggregate_round(self):
        print("Aggregating round...")
        selected = self.select_clients()
        t = self.threshold

        # 1) Each client trains, masks, and Shamir-shares
        globals_vec = get_model_vector(self.global_model)
        client_objs = []
        masked_updates = []
        sum_shares = {cid: 0 for cid in selected}

        for cid in selected:
            c = FedClient(cid, self.global_model)
            c.local_train(self.local_epochs)
            c.validate()
            c.generate_mask()
            c.generate_vss_shares(selected, t)

            update = c.get_update(globals_vec)
            masked_u, _ = c.mask_update(update)
            masked_updates.append(masked_u)

            # accumulate each client's share
            for partner, share_val in c.shares.items():
                sum_shares[partner] = (sum_shares[partner] + share_val) % P

            client_objs.append(c)

        # 2) Server sums masked updates (mod q)
        summed_masked = reduce(lambda a,b: (a+b)%q, masked_updates)

        # 3) Server collects t Shamir shares to reconstruct R_sum
        share_items = list(sum_shares.items())[:t]
        R_sum = lagrange_interpolate_zero(share_items)

        # 4) Unmask scalar part (mod P), then remove QILSA mask
        #    first remove Shamir sum (tile removal)
        unsharded = (summed_masked - (R_sum % q)) % q
        # 2) Tile the QILSA mask to length d before subtraction
        total_mask = sum(c.mask_poly for c in client_objs) % q  # shape: (N,)
        d           = unsharded.shape[0]
        blocks      = math.ceil(d / N)
        padded_mask = np.tile(total_mask, blocks)[:d]         # shape: (d,)

        #    next, subtract sum of all QILSA masks:
        #    ∑m_i = A ∑s_i + ∑e_i  (we don't implement lattice decode here;
        #    instead we assume server knows ∑mask_poly if needed)
        # For simplicity, assume error small enough to round:
        #    total_mask = sum(c.mask_poly for c in client_objs) % q
        # 3) Unmask element‐wise
        unmasked = (unsharded - padded_mask) % q

        # 5) Update global model
        set_model_vector(self.global_model, unmasked)
    def evaluate_global(self):
        """Compute accuracy of self.global_model on server test set."""
        correct, total = 0, 0
        self.global_model.eval()
        with torch.no_grad():
            for x, y in self.testloader:
                out = self.global_model(x)
                _, pred = torch.max(out, dim=1)
                correct += (pred == y).sum().item()
                total   += y.size(0)
        acc = 100 * correct / total
        print(f" Global model test accuracy: {acc:.2f}%")
        return acc

    def train(self):
        for rnd in range(1, self.rounds+1):
            print(f"--- Global Round {rnd} ---")
            self.aggregate_round()
            self.evaluate_global()
        print("Federated training complete.")


In [None]:
# -------------------------------------------------------------------
# 7. Run
# -------------------------------------------------------------------
if __name__ == "__main__":
    server = FedServer(
        num_clients=50,
        local_epochs=6,
        rounds=2,
        threshold=2
    )
    server.train()

--- Global Round 1 ---
Aggregating round...
Client 17 training...
 Client 17 validation accuracy: 16.95%
Client 40 training...
 Client 40 validation accuracy: 24.05%
Client 2 training...
 Client 2 validation accuracy: 11.05%
Client 16 training...
 Client 16 validation accuracy: 16.14%
Client 42 training...
 Client 42 validation accuracy: 22.44%
 Global model test accuracy: 10.00%
--- Global Round 2 ---
Aggregating round...
Client 36 training...
 Client 36 validation accuracy: 10.00%


  return torch.cat([p.data.view(-1) for p in model.parameters()]).cpu().numpy().astype(np.int64) % q


Client 4 training...
 Client 4 validation accuracy: 10.00%
Client 33 training...
 Client 33 validation accuracy: 10.00%
Client 34 training...
 Client 34 validation accuracy: 10.00%
Client 41 training...
 Client 41 validation accuracy: 10.00%
 Global model test accuracy: 10.00%
Federated training complete.
