In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset


from typing import Sequence
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

import networkx as nx
import pycxsimulator
from parse_args import parse_arguments

Code to create a dataset given the labels it should contain:

In [31]:
def create_ds_from_labels(primary_label:int,
                          secondary_labels:Sequence[int],
                          label_indexes_list,
                          total_ds_len:int,
                          primary_label_fraction:float,
                          original_ds:Dataset
                          ):
    
    
    # Find the primary label elements for dataset
    primary_label_elements_num = int(primary_label_fraction * total_ds_len)
    # print("num prim elems: ", primary_label_elements_num)
    selected_primary_indexes = np.random.randint(low=0, high=len(label_indexes_list[primary_label]), size=primary_label_elements_num)
    # print(selected_primary_indexes)
    primary_label_idxs = label_indexes_list[primary_label][selected_primary_indexes]
    label_indexes_list[primary_label] = label_indexes_list[primary_label][~np.isin(label_indexes_list[primary_label], primary_label_idxs)]
    primary_label_subset = Subset(original_ds, primary_label_idxs)
    # print("len subset prim: ", len(primary_label_subset))
    
    # Find the secondary label(s) elements for the dataset
    secondary_label_elements_frac = ((1-primary_label_fraction)/len(secondary_labels))
    secondary_label_elements_num = int(secondary_label_elements_frac*total_ds_len)
    secondary_labels_subsets = []
    for label in secondary_labels:
        selected_label_indexes = np.random.randint(low=0, high=len(label_indexes_list[label]), size=secondary_label_elements_num)
        selected_indexes = label_indexes_list[label][selected_label_indexes]
        secondary_labels_subsets.append(Subset(original_ds, selected_indexes))
        label_indexes_list[label] = label_indexes_list[label][~np.isin(label_indexes_list[label], selected_indexes)]
    
    # for item in secondary_labels_subsets:
    #     print(len(item))

    secondary_labels_subsets += [primary_label_subset]

    return ConcatDataset(secondary_labels_subsets), label_indexes_list  

creating pseudo-random uniform primary and secondary labels:

In [10]:
def has_enough_samples(label_idx_list, n_labels_needed, label_chosen):
    if len(label_idx_list[label_chosen]) >= n_labels_needed:
        return True
    else:
        return False

def generate_random_label_set(label_idx_list, primary_dataset_len, secondary_dataset_len, num_secondaries):
    selected_labels = []
    while True:
        sample_label = np.random.randint(low=0, high=10, dtype=int)
        if has_enough_samples(label_idx_list, primary_dataset_len, sample_label):
            selected_labels.append(sample_label)
            break
    
    for i in range(num_secondaries):
        while True:
            sample_label = np.random.randint(low=0, high=10, dtype=int)
            if (has_enough_samples(label_idx_list, primary_dataset_len, sample_label)) and (sample_label not in selected_labels):
                selected_labels.append(sample_label)
                break
        
    return selected_labels


Creating a client's dataset by:
1. Generating a pseudo-random set of labels for it to contain as primary and secondary labels
2. Generating a random subset of the original MNIST dataset given the labels and the subset size

In [29]:
def create_client_ds(original_ds, label_idx_list, total_ds_len, primary_label_fraction, num_secondaries):
    primary_dataset_len = int(primary_label_fraction * total_ds_len)
    secondary_dataset_len = int(((1-primary_label_fraction)/num_secondaries) * total_ds_len)
    label_set = generate_random_label_set(label_idx_list, primary_dataset_len, secondary_dataset_len, num_secondaries)
    client_ds, label_idx_list = create_ds_from_labels(label_set[0], label_set[1:], label_idx_list,
                                total_ds_len, primary_label_fraction, original_ds)
    
    return client_ds, label_idx_list

creating a simple model

In [43]:
class simpleModel(nn.Module):

    def __init__(self) -> None:

        super(simpleModel, self).__init__()
        
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
    
    def forward(self, x):

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.log_softmax(self.fc3(x), dim=1)

        return x

Defining a client that
- contains a dataset and a model
- has a method defined for one epoch of local training
- has a method to validate on a given dataloader

In [None]:
class MNIST_client():
    def __init__(self, original_ds, client_id, label_idx_list, args):
        self.client_id = client_id
        self.dataset, label_idx_list = create_client_ds(original_ds, label_idx_list, args.total_ds_len, args.primary_label_fraction, args.num_secondaries)
        self.dataloader = DataLoader(dataset=self.dataset, batch_size=args.batch_size, shuffle=args.shuffle_dataset)
        self.model = simpleModel()
        self.model.load_state_dict(torch.load(os.path.join(args.saved_model_path, 'initial_model_weights.pth')))
        self.history = {"train_acc":[], "train_loss":[], "val_acc":[], "val_loss":[]}
        self.sgd_per_round = args.sgd_per_round
        self.device = args.device
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.5)
        self.criterion = nn.NLLLoss()
    
    def train_one_round(self):
        
        self.model.train()
        train_correct = 0
        train_total = 0
        running_loss = 0
        for idx, (data, labels) in enumerate(self.dataloader):

            if idx > self.sgd_per_round:
                break

            data = data.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()

            output = self.model(data.view(data.shape[0], -1))
            loss = self.criterion(output, labels)

            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            _, pred = torch.max(output, dim=1)
            train_correct += torch.sum(pred==labels).item()
            train_total += labels.size(0)
        
        self.history['train_loss'].append(running_loss/len(self.dataloader))
        self.history['train_acc'].append(100*train_correct/train_total)
    
    def validate_model(self, val_loader):
        self.model.eval()
        
        with torch.no_grad():
            running_loss = 0
            correct = 0
            total = 0
            for idx, (data, labels) in enumerate(val_loader):
                data = data.to(self.device)
                labels = labels.to(self.device)

                output = self.model(data.view(data.shape[0], -1))
                loss = self.criterion(output, labels)

                running_loss += loss.item()
                _, pred = torch.max(output, dim=1)
                correct += torch.sum(pred==labels).item()
                total += labels.size(0)
            
            self.history['val_loss'].append(running_loss/len(val_loader))
            self.history['val_acc'].append(100*correct/total)
    
    
        

In [None]:
class Network():
    
    def __init__(self, args) -> None:
        '''
        Initializing the network parameters given arguments 'args'
        See parse_args.py for further information on the available options.
        '''

        mnist_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        self.train_set = torchvision.datasets.MNIST(args.data_path, download=True, train=True, transform=mnist_transforms)
        self.test_set = torchvision.datasets.MNIST(args.data_path, download=True, train=False, transform=mnist_transforms)

        self.label_idxs = [np.array([], dtype=int) for i in range(10)]
        for i, datapoint in enumerate(self.train_set):
            self.label_idxs[datapoint[1]] = np.append(self.label_idxs[datapoint[1]], int(i))


        # Initializing the global model
        self.global_model = simpleModel()
        self.global_model.load_state_dict(torch.load(os.path.join(args.saved_model_path, 'initial_model_weights.pth')))

        # Initializing the clients
        self.clients = []
        for i in range(args.num_clients):
            self.clients.append(MNIST_client(self.train_set, i, self.label_idxs, args))
        
        for client in self.clients:
            print(len(client.dataset))
        
        
            

            


In [None]:
running_arguments = parse_arguments()