In [2]:
import sys
directory_paths = ["../../../", "/home/gathomp3/Deep_Learning/NeuralTangent/ntk-fed/notebooks/baselines/dfedsam/DP-FedSAM"]

for directory_path in directory_paths:
    if directory_path not in sys.path:
        # Add the directory to sys.path
        sys.path.append(directory_path)

import copy
import time
import time
import numpy as np
import argparse
import yaml
import networkx as nx
import matplotlib.pyplot as plt

from torch.utils import data
from torch import optim

from utils.utils import *
from utils import load_config
from utils.validate import *

from fedlearning.topology import *
from fedlearning.model import *
from fedlearning.dataset import *
from fedlearning.evolve import *
from fedlearning.optimizer import GlobalUpdater, LocalUpdater, get_omegas
from fedlearning.quantizer import SqcCompressor

from fedml_api.dpfedsam.sam import SAM

In [3]:
class NumpyDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        """
        Args:
            data (numpy array): Array of data samples.
            targets (numpy array): Array of labels corresponding to the data samples.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = data
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        target = self.targets[idx]
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, target

def numpy_to_tensor_transform(data):
    return torch.from_numpy(data)

In [4]:
config_file = "baseline_configs/config_dfedavgm.yaml"
config = load_config(config_file)

logger = init_logger(config)
logger.info("Loaded configuration from {}".format(config_file))
logger.info("Dataset path: {}".format(config.train_data_dir))

# Define a model to extract number of parameters for record
if config.record_path is not None:
    record = load_record(config.record_path)
    logger.info("Loaded record from {}".format(config.record_path))
    loaded_record = True
else:
    model = init_model(config, logger)
    record = init_record(config, model)
    loaded_record = False

if config.device == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

--------------------------------------------------------------------------------
Loaded configuration from baseline_configs/config_dfedavgm.yaml
Dataset path: ../../../data/mnist/train.dat


Creating config from filepath:  baseline_configs/config_dfedavgm.yaml
/home/gathomp3/Deep_Learning/NeuralTangent/ntk-fed/notebooks/baselines/dfedsam/../../../../records/baseline_trials/dfedavgm/trial_test/train.log


In [5]:
# Create user_ids
user_ids = np.arange(0, config.users)
# load the dataset
# dataset object is a dictionary with keys: train_data, test_data, user_with_data
# user_with_data is a dictionary with keys: userID:sampleID
# For example, in the IID setting ID's are just assigned like 0, 1, 2, 3, ...
dataset = assign_user_data(config, logger)
test_images = torch.from_numpy(dataset["test_data"]["images"]).to(config.device)
test_labels = torch.from_numpy(dataset["test_data"]["labels"]).to(config.device)

Non-IID data distribution
Load user_with_data from /home/gathomp3/Deep_Learning/NeuralTangent/ntk-fed/data/user_with_data/mnist300/a0.1/user_dataidx_map_0.10_0.dat


In [6]:
loaded_record = False
# Create a dictionary of models for each user
# Same initialization for all users
# If record/model_dict is passed, continue training from where it left off
if loaded_record == True:
    model_dict = record["models"]
else:
    if config.same_init:
        model = init_model(config, logger)
        model_dict = {model_id: copy.deepcopy(model) for model_id in user_ids}
    else:
        model_dict = {model_id: init_model(config, logger) for model_id in user_ids}

# Get zeroth round loss, acc

In [7]:
# Get zeroth round loss, acc
verbose = True
if record["epoch"] == 0:
    logger.info("Logging initial loss, acc")
    client_losses = []
    client_accs = []
    for client_id in user_ids:
        # Evaluate the client's model on the slice of the training data corresponding to the client's data
        user_images = torch.from_numpy(dataset["train_data"]["images"][dataset["user_with_data"][client_id]]).to(config.device)
        user_labels = torch.from_numpy(dataset["train_data"]["labels"][dataset["user_with_data"][client_id]]).to(config.device)
        
        # Get model outputs
        output_on_own_data = model_dict[client_id](user_images)
        output_on_test_set = model_dict[client_id](test_images)
        
        # Get losses/accs, and append to list
        loss = loss_with_output(output_on_own_data, user_labels, config.loss)
        acc = accuracy_with_output(output_on_test_set, test_labels)
        
        client_losses.append(loss)
        client_accs.append(acc)
        if verbose: logger.info("client {:d} loss {:.4f} acc {:.4f}".format(client_id, loss, acc))
    # Get rid of unnecessary variables to free up memory
    del user_images; del user_labels; del output_on_own_data; del output_on_test_set
    # Finally, append the initial losses, accs to the record
    record["loss"].append(client_losses)
    record["testing_accuracy"].append(client_accs)

Logging initial loss, acc
client 0 loss 3.0066 acc 0.1292
client 1 loss 4.4558 acc 0.1292
client 2 loss 2.8617 acc 0.1292
client 3 loss 3.3323 acc 0.1292
client 4 loss 3.1203 acc 0.1292
client 5 loss 2.4805 acc 0.1292
client 6 loss 3.9518 acc 0.1292
client 7 loss 2.5914 acc 0.1292
client 8 loss 2.7790 acc 0.1292
client 9 loss 2.7972 acc 0.1292
client 10 loss 4.6281 acc 0.1292
client 11 loss 3.2511 acc 0.1292
client 12 loss 3.0626 acc 0.1292
client 13 loss 3.6284 acc 0.1292
client 14 loss 1.3556 acc 0.1292
client 15 loss 2.5742 acc 0.1292
client 16 loss 4.3688 acc 0.1292
client 17 loss 2.3614 acc 0.1292
client 18 loss 2.4626 acc 0.1292
client 19 loss 2.0634 acc 0.1292
client 20 loss 2.5206 acc 0.1292
client 21 loss 2.9010 acc 0.1292
client 22 loss 2.4896 acc 0.1292
client 23 loss 1.8029 acc 0.1292
client 24 loss 2.2095 acc 0.1292
client 25 loss 3.1305 acc 0.1292
client 26 loss 3.7918 acc 0.1292
client 27 loss 2.6444 acc 0.1292
client 28 loss 3.0115 acc 0.1292
client 29 loss 2.2583 acc 0

## Temporary variables

In [13]:
def train_client_dfedsam(user_model, user_id, dataset, config, logger, loss_fn, 
    local_epochs,
    rho = 0.5,
    adaptive = True,
    lr = 0.1,
    lr_decay = 0.998,
    momentum = 0.5,
    wd = 5e-4,
    optimizer_batch_size = 32
    ): 
    # Get data corresponding to a certain user
    user_resource = assign_user_resource(config, user_id, 
                    dataset["train_data"], dataset["user_with_data"])

    # Define cross-entropy criterion
    loss_fn_pytorch = nn.CrossEntropyLoss()

    # Define the SAM optimizer
    base_optimizer = torch.optim.SGD
    optimizer = SAM(user_model.parameters(), base_optimizer, rho=rho, adaptive=adaptive, 
        lr=lr* (lr_decay**comm_round), momentum=momentum, weight_decay=wd)

    # Dataset stuff
    np_dataset = NumpyDataset(user_resource["images"], user_resource["labels"], transform=numpy_to_tensor_transform)
    user_data_loader = DataLoader(np_dataset, batch_size=optimizer_batch_size, shuffle=True)

    # Doing local_epochs number of local training rounds
    for epoch in range(local_epochs):
        # Iterate over the user's data
        epoch_loss, epoch_acc = [], []
        for batch_idx, (x, labels) in enumerate(user_data_loader):
            x, labels = x.to(config.device), labels.to(config.device)

            # From the SAM codebase
            # first forward-backward step
            pred = user_model(x)
            
            # Don't need enable_running_stats due to no batchnorm or any moving averages
            # enable_running_stats(model)
            # log_probs = model.forward(x)
            # loss = loss_fn_pytorch(user_model(x), labels.long())
            loss = loss_fn_pytorch(user_model(x), labels)
            loss.backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            # Don't need disable_running_stats due to no batchnorm or any moving averages
            # disable_running_stats(model)
            # loss_fn_pytorch(user_model(x), labels.long()).backward()
            loss_fn_pytorch(user_model(x), labels).backward()
            optimizer.second_step(zero_grad=True)
            
            epoch_loss.append(loss.item())
            
        if config.verbose: 
            print('Client Index = {}\tEpoch: {}\tLoss: {:.6f}'.format(
            user_id, epoch, sum(epoch_loss) / len(epoch_loss)))


In [14]:
# for comm_round in range(record["epoch"],record["epoch"]+config.rounds):
for comm_round in range(1):
    logger.info(f"Comm Round: {comm_round}")
    client_losses = []
    client_accs = []

    # Create the graph for this round
    if config.topology == "random":
        G = create_random_graph(config.users, config.p, config.graph_name)
    elif config.topology == "ring":
        G = create_ring_graph(config.users, config.graph_name)
    elif config.topology == "regular":
        if config.p is not None:
            raise ValueError("Regular graph requires d, not p")
        elif config.d is None:
            raise ValueError("Regular graph requires d")
        G = create_regular_graph(config.users, config.d, config.graph_name)
        if config.verbose: logger.info(f"Creating regular graph with d={config.d}")
    else: 
        raise ValueError("Invalid topology: {}".format(config.topology))

    for client_id in user_ids:
        # SGD w/ momentum
        loss_fn_pytorch = nn.CrossEntropyLoss()
        train_client_dfedsam(model_dict[client_id], client_id, dataset, config, logger, loss_fn_pytorch, 
        # Dfedsam parameters
        rho = 0.5,
        adaptive = True,
        lr = 0.1,
        lr_decay = 0.998,
        momentum = 0.5,
        wd = 5e-4,
        optimizer_batch_size = 32)

    # All clients average with neighbors
    # Must find the avged weights then load weights to mimic synchronous averaging
    new_avged_weights = {}
    # Get new weights for all clients
    for client_id in user_ids:
        neighbors = list(G.neighbors(client_id))
        new_avged_weights[client_id] = average_neighbor_weights(client_id, neighbors, model_dict)
    # Load new weights for all clients
    for client_id in user_ids:
        model_dict[client_id].load_state_dict(new_avged_weights[client_id])
    del new_avged_weights
    
    # Now, test individual client accs
    with torch.no_grad():
        for client_id in user_ids:
            # Get client accuracy
            output_on_test_set = model_dict[client_id](test_images)
            acc = accuracy_with_output(output_on_test_set, test_labels)
            client_accs.append(acc)

    # Test the global, aggregated model
    # Note: Weighted averaging is unnecessary since all clients have the same number of samples
    
    # Init model to load aggregated state dict
    temp_global_model = init_model(config, logger)
    temp_global_model.load_state_dict(average_neighbor_weights(0, user_ids[1:], model_dict))
    
    global_output = temp_global_model(test_images)
    global_loss = nn.CrossEntropyLoss()(global_output, test_labels)
    global_acc = accuracy_with_output(global_output, test_labels)
    
    if comm_round % 5 == 0: 
        logger.info(f"Round {comm_round}: Test Loss: {global_loss.item()}, Avg Client Acc: {np.mean(client_accs)}, Agg Acc: {global_acc}")

    # Record the results
    record["testing_accuracy"].append(client_accs)
    
    if 'aggregated_accs' in record:
        record['aggregated_accs'].append(global_acc)
    else:
        record['aggregated_accs'] = [global_acc]

    record["epoch"] += 1
    


Comm Round: 0
Round 0: Test Loss: 0.40996190905570984, Avg Client Acc: 0.8256143110990525, Agg Acc: 0.8858000040054321


Here, I write the local training function that takes INPUTS
- user_model
- user_id
- dataset
- config
- logger
- loss_fn 
- other optimizer parameters (IMPORTANT FOR DFEDSAM comparison)

Input variables corresponding to a certain user

In [19]:
user_id = 0
user_model = model_dict[user_id]