In [1]:
import sys
directory_path = "../"
if directory_path not in sys.path:
    # Add the directory to sys.path
    sys.path.append(directory_path)

import copy
import time
import time
import numpy as np
import argparse
import yaml
import networkx as nx
import matplotlib.pyplot as plt

from torch.utils import data

from utils.utils import *
from utils import load_config
from utils.validate import *
from fedlearning.model import *
from fedlearning.dataset import *
from fedlearning.evolve import *
from fedlearning.optimizer import GlobalUpdater, LocalUpdater, get_omegas

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [2]:
class NumpyDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        """
        Args:
            data (numpy array): Array of data samples.
            targets (numpy array): Array of labels corresponding to the data samples.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = data
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        target = self.targets[idx]
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, target

def numpy_to_tensor_transform(data):
    return torch.from_numpy(data)

def self_train(user_model, user_id, dataset, config, logger, loss_fn, batch_size=32, epochs=1, lr = 0.001, verbose=False): 
    # Get data corresponding to a certain user
    user_resource = assign_user_resource(config, user_id, 
                        dataset["train_data"], dataset["user_with_data"])
    
    # Define the optimizer
    optimizer = optim.SGD(user_model.parameters(), lr=lr)
    dataset = NumpyDataset(user_resource["images"], user_resource["labels"], transform=numpy_to_tensor_transform)

    user_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
   
    for epoch in range(epochs):
        # Iterate over the user's data
        for batch_idx, (data, target) in enumerate(user_data_loader):
            data, target = data.to(config.device), target.to(config.device)
            # Clear the gradients
            optimizer.zero_grad()
            
            # Forward pass
            output = user_model(data)
            
            # Compute the loss
            loss = loss_fn(output, target)
            
            # Backward pass
            loss.backward()
            
            # Update the model parameters
            optimizer.step()

            if batch_idx % 100 == 0:
                if verbose: logger.info(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(user_data_loader.dataset)} ({100. * batch_idx / len(user_data_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    if verbose: print()

def create_random_graph(n, p, graph_name=None):
    # Generate the graph
    G = nx.erdos_renyi_graph(n, p)

    if graph_name != None:
        # Draw the graph
        nx.draw(G, with_labels=True)
        plt.savefig(graph_name)
    return G

def create_ring_graph(n, graph_name=None):
    G = nx.cycle_graph(n)
    if graph_name != None:
        # Draw the graph
        nx.draw(G, with_labels=True)
        plt.savefig(graph_name)
    return G

def average_neighbor_weights(client_id, neighbor_ids, model_dict):
    # Average the weights of the models in the cluster
    weight_dict = copy.deepcopy(model_dict[client_id].state_dict())
    weight_aggregator = WeightMod(weight_dict)
    for user_id in neighbor_ids:
        weight_aggregator.add(copy.deepcopy(model_dict[user_id].state_dict()))
    # Add one for the client itself
    weight_aggregator.mul(1.0/ (len(neighbor_ids)+1) )
    return weight_aggregator.state_dict()


def load_and_deload_neighbor_weights(neighbor_ids, model_dict, avg_weight_dict):
    # Save the weights of the neighbors
    older_weight_dicts = [copy.deepcopy(model_dict[user_id].state_dict()) for user_id in neighbor_ids]
    # Load the average weights
    for user_id in neighbor_ids:
        model_dict[user_id].load_state_dict(avg_weight_dict)
    return older_weight_dicts

def reload_neighbor_weights(neighbor_ids, model_dict, old_weight_dicts):
    for i, user_id in enumerate(neighbor_ids):
        model_dict[user_id].load_state_dict(old_weight_dicts[i])