# New Hypernetwork & Model

In [5]:
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset, SubsetRandomSampler

In [1]:
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm

class CNNHyper(nn.Module):
    def __init__(
            self, n_nodes, embedding_dim, in_channels=3, out_dim=100, n_kernels=64, hidden_dim=100,
            spec_norm=False, n_hidden=1):
        super().__init__()

        self.in_channels = in_channels
        self.out_dim = out_dim
        self.n_kernels = n_kernels
        self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim)

        layers = [
            spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim),
        ]
        for _ in range(n_hidden):
            layers.append(nn.ReLU(inplace=True))
            layers.append(
                spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim),
            )

        self.mlp = nn.Sequential(*layers)

        self.c1_weights = nn.Linear(hidden_dim, self.n_kernels * self.in_channels * 5 * 5)
        self.c1_bias = nn.Linear(hidden_dim, self.n_kernels)
        self.c2_weights = nn.Linear(hidden_dim, self.n_kernels * self.n_kernels * 5 * 5)
        self.c2_bias = nn.Linear(hidden_dim, self.n_kernels)
        self.l1_weights = nn.Linear(hidden_dim, 384 * 64 * 5 * 5)
        self.l1_bias = nn.Linear(hidden_dim, 384)
        self.l2_weights = nn.Linear(hidden_dim, 192 * 384)
        self.l2_bias = nn.Linear(hidden_dim, 192)
        self.l3_weights = nn.Linear(hidden_dim, self.out_dim * 192)
        self.l3_bias = nn.Linear(hidden_dim, self.out_dim)

        if spec_norm:
            self.c1_weights = spectral_norm(self.c1_weights)
            self.c1_bias = spectral_norm(self.c1_bias)
            self.c2_weights = spectral_norm(self.c2_weights)
            self.c2_bias = spectral_norm(self.c2_bias)
            self.l1_weights = spectral_norm(self.l1_weights)
            self.l1_bias = spectral_norm(self.l1_bias)
            self.l2_weights = spectral_norm(self.l2_weights)
            self.l2_bias = spectral_norm(self.l2_bias)
            self.l3_weights = spectral_norm(self.l3_weights)
            self.l3_bias = spectral_norm(self.l3_bias)

    def forward(self, idx):
        emd = self.embeddings(idx)
        features = self.mlp(emd)

        weights = OrderedDict({
            "conv1.weight": self.c1_weights(features).view(self.n_kernels, self.in_channels, 5, 5),
            "conv1.bias": self.c1_bias(features).view(-1),
            "conv2.weight": self.c2_weights(features).view(self.n_kernels, self.n_kernels, 5, 5),
            "conv2.bias": self.c2_bias(features).view(-1),
            "fc1.weight": self.l1_weights(features).view(384, self.n_kernels * 5 * 5),
            "fc1.bias": self.l1_bias(features).view(-1),
            "fc2.weight": self.l2_weights(features).view(192, 384),
            "fc2.bias": self.l2_bias(features).view(-1),
            "fc3.weight": self.l3_weights(features).view(self.out_dim, 192),
            "fc3.bias": self.l3_bias(features).view(-1),
        })
        return weights


class CIFARLeNet(nn.Module):
    """
    A neural network model inspired by LeNet5, designed for CIFAR-100 dataset.

    Attributes:
    ----------
    flatten : nn.Module
        A layer to flatten the input tensor.
    conv1 : nn.Module
        First convolutional layer with 3 input channels and 64 output channels.
    conv2 : nn.Module
        Second convolutional layer with 64 input channels and 64 output channels.
    pool : nn.Module
        Max pooling layer with kernel size of 2.
    fc1 : nn.Module
        Fully connected layer with input size 64*5*5 and output size 384.
    fc2 : nn.Module
        Fully connected layer with input size 384 and output size 192.
    fc3 : nn.Module
        Fully connected layer with input size 192 and output size 100.
    """

    def __init__(self):
        """
        Initialize the CIFARLeNet model with its layers.
        """
        super(CIFARLeNet, self).__init__()
        # self.flatten = nn.Flatten()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)
    def produce_feature(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x
    def forward(self, x):
        """
        Defines the forward pass of the model.

        Parameters:
        ----------
        x : torch.Tensor
            The input tensor.

        Returns:
        -------
        torch.Tensor
            The output tensor after applying all the layers.
        """
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        #x = x.view(-1, 64 * 5 * 5)
        x = x.view(x.shape[0], -1)
        x.shape[0], -1
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        # x = F.log_softmax(x, dim=1)
        return x

In [2]:
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm

class CNNHyper_old(nn.Module):
    def __init__(
            self, n_nodes, embedding_dim, in_channels=3, out_dim=100, n_kernels=64, hidden_dim=100,
            spec_norm=False, n_hidden=1):
        super().__init__()

        self.in_channels = in_channels
        self.out_dim = out_dim
        self.n_kernels = n_kernels
        self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim)

        layers = [
            spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim),
        ]
        for _ in range(n_hidden):
            layers.append(nn.ReLU(inplace=True))
            layers.append(
                spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim),
            )

        self.mlp = nn.Sequential(*layers)

        self.c1_weights = nn.Linear(hidden_dim, self.n_kernels * self.in_channels * 5 * 5)
        self.c1_bias = nn.Linear(hidden_dim, self.n_kernels)
        self.c2_weights = nn.Linear(hidden_dim, self.n_kernels * self.n_kernels * 5 * 5)
        self.c2_bias = nn.Linear(hidden_dim, self.n_kernels)
        self.l1_weights = nn.Linear(hidden_dim, 384 * 64 * 5 * 5)
        self.l1_bias = nn.Linear(hidden_dim, 384)
        self.l2_weights = nn.Linear(hidden_dim, 192 * 384)
        self.l2_bias = nn.Linear(hidden_dim, 192)
        self.l3_weights = nn.Linear(hidden_dim, self.out_dim * 192)
        self.l3_bias = nn.Linear(hidden_dim, self.out_dim)

        if spec_norm:
            self.c1_weights = spectral_norm(self.c1_weights)
            self.c1_bias = spectral_norm(self.c1_bias)
            self.c2_weights = spectral_norm(self.c2_weights)
            self.c2_bias = spectral_norm(self.c2_bias)
            self.l1_weights = spectral_norm(self.l1_weights)
            self.l1_bias = spectral_norm(self.l1_bias)
            self.l2_weights = spectral_norm(self.l2_weights)
            self.l2_bias = spectral_norm(self.l2_bias)
            self.l3_weights = spectral_norm(self.l3_weights)
            self.l3_bias = spectral_norm(self.l3_bias)

    def forward(self, idx):
        emd = self.embeddings(idx)
        features = self.mlp(emd)

        weights = OrderedDict({
            "conv1.weight": self.c1_weights(features).view(self.n_kernels, self.in_channels, 5, 5),
            "conv1.bias": self.c1_bias(features).view(-1),
            "conv2.weight": self.c2_weights(features).view(self.n_kernels, self.n_kernels, 5, 5),
            "conv2.bias": self.c2_bias(features).view(-1),
            "fc1.weight": self.l1_weights(features).view(384, self.n_kernels * 5 * 5),
            "fc1.bias": self.l1_bias(features).view(-1),
            "fc2.weight": self.l2_weights(features).view(192, 384),
            "fc2.bias": self.l2_bias(features).view(-1),
            "fc3.weight": self.l3_weights(features).view(self.out_dim, 192),
            "fc3.bias": self.l3_bias(features).view(-1),
        })
        return weights


class CIFARLeNet(nn.Module):
    """
    A neural network model inspired by LeNet5, designed for CIFAR-100 dataset.

    Attributes:
    ----------
    flatten : nn.Module
        A layer to flatten the input tensor.
    conv1 : nn.Module
        First convolutional layer with 3 input channels and 64 output channels.
    conv2 : nn.Module
        Second convolutional layer with 64 input channels and 64 output channels.
    pool : nn.Module
        Max pooling layer with kernel size of 2.
    fc1 : nn.Module
        Fully connected layer with input size 64*5*5 and output size 384.
    fc2 : nn.Module
        Fully connected layer with input size 384 and output size 192.
    fc3 : nn.Module
        Fully connected layer with input size 192 and output size 100.
    """

    def __init__(self):
        """
        Initialize the CIFARLeNet model with its layers.
        """
        super(CIFARLeNet, self).__init__()
        # self.flatten = nn.Flatten()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)
    def produce_feature(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x
    def forward(self, x):
        """
        Defines the forward pass of the model.

        Parameters:
        ----------
        x : torch.Tensor
            The input tensor.

        Returns:
        -------
        torch.Tensor
            The output tensor after applying all the layers.
        """
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        #x = x.view(-1, 64 * 5 * 5)
        x = x.view(x.shape[0], -1)
        x.shape[0], -1
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        # x = F.log_softmax(x, dim=1)
        return x

# Args Parser


In [21]:
import argparse
import sys

def str2bool(v): # this is an utils function
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def args_parser():
    parser = argparse.ArgumentParser()
    ## Wandb parameters
    parser.add_argument('--wandb_key', type=str, default='', help='wandb key')
    parser.add_argument('--wandb_username', type=str, default='', help='wandb userna,,e')
    parser.add_argument('--wandb_project', type=str, default='charBert', help='wandb project')
    parser.add_argument('--wandb_run_name', type=str, default='charBert', help='wandb run name')
    parser.add_argument('--logfile', type=str, default='/content/logger.log', help='log file name')
    parser.add_argument('--data_dir', type=str, default='./data', help='data directory')

    # federated arguments (Notation for the arguments followed from paper)
    parser.add_argument('--epochs', type=int, default=2000,
                        help="number of rounds of training")
    parser.add_argument('--last_epoch', type=int, default=1900,
                        help="number of rounds of old clients training")
    parser.add_argument('--num_users', type=int, default=100,
                        help="number of total users: K")
    parser.add_argument('--n_nodes', type=int, default=90,
                        help="number of already seen users")
    parser.add_argument('--Nc', type=int, default=5,
                        help='number of class each client in non iid')
    parser.add_argument('--frac', type=float, default=0.1,
                        help='the fraction of clients: C')
    parser.add_argument('--local_ep', type=int, default=4,
                        help="the number of local rounds: J")
    parser.add_argument('--local_bs', type=int, default=64,
                        help="local batch size: B")
    parser.add_argument('--lr', type=float, default=0.01,
                        help='learning rate')
    parser.add_argument('--inner_lr', type=float, default=0.01,
                        help='local learning rate')
    parser.add_argument("--wd", type=float, default=4e-4,
                        help="weight decay")
    parser.add_argument('--momentum', type=float, default=0.5,
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--tol', type=float, default=1,
                        help='the maximum allowed difference between \
                        old and new clients accuracy')
    parser.add_argument('--bias', type=float, default=1,
                        help='A bias value between 0 and 1 that specifies the\
                         likelihood of selecting new clients during training')
    parser.add_argument('--step_iter', type=int, default=1,
                        help='to decide when the gen training should stop')
    parser.add_argument('--average', type=int, default=1,
                        help='initialization of the embedding vectors of the \
                        new clients as the average of the embedding vectors\
                        of the old clients')
    parser.add_argument('--update', type=int, default=1,
                        help='to decide if only the embedding weights should \
                        be updated')
    parser.add_argument('--extra', type=float, default=0,
                        help='The number of extra rounds of \
                        new clients fine tuning')


    # model arguments
    parser.add_argument('--model', type=str, default='cnn', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9,
                        help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to \
                        use for convolution')
    parser.add_argument('--num_channels', type=int, default=1, help="number \
                        of channels of imgs")
    parser.add_argument('--norm', type=str, default='batch_norm',
                        help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32,
                        help="number of filters for conv nets -- 32 for \
                        mini-imagenet, 64 for omiglot.")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than \
                        strided convolutions")
    parser.add_argument('--checkpoint_resume', type=int, default=0,
                        help='resume from checkpoint, 0 for False, 1 for True')
    parser.add_argument("--inner-wd", type=float, default=5e-5,
                        help="inner weight decay")
    parser.add_argument("--embed-dim", type=int, default=-1,
                        help="embedding dim")
    parser.add_argument("--embed-lr", type=float, default=None,
                        help="embedding learning rate")
    parser.add_argument("--hyper-hid", type=int, default=100,
                        help="hypernet hidden dim")
    parser.add_argument("--spec-norm", type=str2bool, default=False,
                        help="hypernet hidden dim")
    parser.add_argument('--algorithm', type=str, default='pFedHN', help='Default set to FedAvg.')


    # other arguments
    parser.add_argument('--dataset', type=str, default='cifar', help="name \
                        of dataset")
    parser.add_argument('--val_split', type=float, default=0.2,
                        help="train-validation split")
    parser.add_argument('--num_classes', type=int, default=100, help="number \
                        of classes")
    parser.add_argument('--gpu', default=None, help="To use cuda, set \
                        to a specific GPU ID. Default set to use CPU.")
    parser.add_argument('--optimizer', type=str, default='sgd', help="type \
                        of optimizer")
    parser.add_argument('--iid', type=int, default=0,
                        help='Default set to IID. Set to 0 for non-IID.')
    parser.add_argument('--participation', type=int, default=1,
                        help='Default set to Uniform Participation. Set to 0 for Skewed')
    parser.add_argument('--backup', type=int, default=500,
                        help='How often an old backup should be preserved')
    parser.add_argument('--checkpoint_path', type=str, default=".",
                        help='Saved models location')
    parser.add_argument('--stopping_rounds', type=int, default=10,
                        help='rounds of early stopping')
    parser.add_argument('--print_every', type=int, default=10,
                        help='how often the train_accuracy is computed, and \
                        how often a new checkpoint is saved')
    parser.add_argument('--verbose', type=int, default=0, help='verbose')
    parser.add_argument('--gamma', type=float, default=0.1, help='gamma')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--metrics_dir', type=str, default='/content/drive/MyDrive/MLDL/cifar/metrics', help='metrics directory')

    # KNN arguments
    parser.add_argument(
            '--interpolate_logits',
            help='if selected logits are interpolated instead of probabilities',
            default="store_true"
        )


    # If running in a notebook, ignore the first argument which is the script name
    args = parser.parse_args(args=sys.argv[1:] if "__file__" in globals() else [])
    return args


# Client Class

In [4]:
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from collections import Counter

class Client:
    def __init__(self, args, client_id, train_dataset, test_dataset, train_indices, val_indices, test_indices):
        self.client_id = client_id
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.train_indices = train_indices
        self.val_indices = val_indices
        self.test_indices = test_indices
        self.batch_size = args.local_bs
        self.train_dataloader = self.create_dataloader("train")
        self.val_dataloader = self.create_dataloader("val")
        self.test_dataloader = self.create_dataloader("test")
    def get_class_distribution(self, indices, dataset):
        targets = [dataset.targets[idx] for idx in indices]
        return dict(Counter(targets))

    def get_distributions(self):
        train_dist = self.get_class_distribution(self.train_indices, self.train_dataset)
        val_dist = self.get_class_distribution(self.val_indices, self.train_dataset)
        test_dist = self.get_class_distribution(self.test_indices, self.test_dataset)

        return {
            'train': train_dist,
            'val': val_dist,
            'test': test_dist
        }

    def create_dataloader(self, dataset_type):
        dataset_dict = {
            "train": (self.train_dataset, self.train_indices),
            "val": (self.train_dataset, self.val_indices),
            "test": (self.test_dataset, self.test_indices)
        }

        dataset, indices = dataset_dict[dataset_type]
        subset = Subset(dataset, indices)
        dataloader = DataLoader(subset, batch_size=self.batch_size, shuffle=True)
        return dataloader

    def print_class_distribution(self):
        def get_class_distribution(indices, dataset):
            targets = [dataset.targets[idx] for idx in indices]
            return dict(Counter(targets))

        train_dist = get_class_distribution(self.train_indices, self.train_dataset)
        val_dist = get_class_distribution(self.val_indices, self.train_dataset)
        test_dist = get_class_distribution(self.test_indices, self.test_dataset)

        print(f"Client {self.client_id} class distribution:")
        print(f"  Train: {train_dist}")
        print(f"  Val: {val_dist}")
        print(f"  Test: {test_dist}")

    def print_class_distribution(self):
        def get_class_distribution(indices, dataset):
            targets = [dataset.targets[idx] for idx in indices]
            return dict(Counter(targets))

        train_dist = get_class_distribution(self.train_indices, self.train_dataset)
        val_dist = get_class_distribution(self.val_indices, self.train_dataset)
        test_dist = get_class_distribution(self.test_indices, self.test_dataset)

        print(f"Client {self.client_id} class distribution:")
        print(f"  Train: {train_dist}")
        print(f"  Val: {val_dist}")
        print(f"  Test: {test_dist}")

    def check_indices(self):
        def has_duplicates(lst):
            return len(lst) != len(set(lst))

        if has_duplicates(self.train_indices):
            raise ValueError("Duplicate entries found in train_indices")
        if has_duplicates(self.val_indices):
            raise ValueError("Duplicate entries found in val_indices")
        if has_duplicates(self.test_indices):
            raise ValueError("Duplicate entries found in test_indices")

        train_indices_set = set(self.train_indices)
        val_indices_set = set(self.val_indices)

        if not train_indices_set.isdisjoint(val_indices_set):
            raise ValueError("Overlap found between train_indices and val_indices")
        if not val_indices_set.isdisjoint(train_indices_set):
            raise ValueError("Overlap found between val_indices and train_indices")

    def train(self, model, criterion, optimizer, args):
        self.train_dataloader = self.create_dataloader('train')  # Recreate dataloader to shuffle data

        model.train()
        step_count = 0  # Initialize step counter
        while step_count < args.local_ep:  # Loop until local steps are reached
            for inputs, labels in self.train_dataloader:
                if args.device == 'cuda':
                    inputs, labels = inputs.cuda(), labels.cuda()
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                step_count += 1
                if step_count >= args.local_ep:  # Exit if local steps are reached
                    break
        return model

    def inference(self, model, criterion, args, loader_type='test'):
        model.eval()
        correct, total, test_loss = 0.0, 0.0, 0.0
        testloader = self.test_dataloader if loader_type == 'test' else self.val_dataloader
        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(testloader):
                if args.device == 'cuda':
                    inputs, labels = inputs.cuda(), labels.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                test_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = test_loss / len(testloader)
        accuracy = correct / total
        return accuracy, test_loss

    def single_batch_inference(self, model, criterion, args): # Ahmad his function is not necessary atm
        model.eval()
        testloader = iter(self.test_dataloader)
        with torch.no_grad():
            inputs, labels = next(testloader)
            if args.device == 'cuda':
                inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs, 1)
            correct = (predicted == labels).sum().item()
            accuracy = correct / len(labels)
        return accuracy, loss.item()


def cifar_iid(args, train_dataset, test_dataset):
    val_split = args.val_split
    num_clients = args.num_users

    # Number of classes in the dataset
    num_classes = len(train_dataset.classes)

    # Create a list to store indices for each class
    class_indices = [[] for _ in range(num_classes)]

    # Populate class_indices with the indices of each class
    for idx, target in enumerate(train_dataset.targets):
        class_indices[target].append(idx)

    # Shuffle indices within each class
    for indices in class_indices:
        np.random.shuffle(indices)

    # Create lists for train and validation class indices
    train_class_indices = [[] for _ in range(num_classes)]
    val_class_indices = [[] for _ in range(num_classes)]

    # Split the indices into 80% for train and 20% for validation
    for i, indices in enumerate(class_indices):
        split_idx = int(len(indices) * val_split)
        val_class_indices[i] = indices[:split_idx]
        train_class_indices[i] = indices[split_idx:]

    # Prepare test_class_indices
    test_class_indices = [[] for _ in range(num_classes)]
    for idx, target in enumerate(test_dataset.targets):
        test_class_indices[target].append(idx)
    for indices in test_class_indices:
        np.random.shuffle(indices)

    # Calculate the number of samples per client per class
    train_samples_per_client_per_class = int(len(train_dataset) * (1-val_split) // (num_clients * num_classes))
    val_samples_per_client_per_class = int(len(train_dataset) * val_split // (num_clients * num_classes))
    test_samples_per_client_per_class = len(test_dataset) // (num_clients * num_classes)

    # Initialize the list of client objects
    clients = []

    # Distribute the samples uniformly to the clients
    for client_id in range(num_clients):
        train_client_indices = []
        val_client_indices = []
        test_client_indices = []

        for train_class_indices_for_class in train_class_indices:
            train_client_indices.extend(train_class_indices_for_class[client_id * train_samples_per_client_per_class : (client_id + 1) * train_samples_per_client_per_class])
        for val_class_indices_for_class in val_class_indices:
            val_client_indices.extend(val_class_indices_for_class[client_id * val_samples_per_client_per_class : (client_id + 1) * val_samples_per_client_per_class])
        for test_class_indices_for_class in test_class_indices:
            test_client_indices.extend(test_class_indices_for_class[client_id * test_samples_per_client_per_class : (client_id + 1) * test_samples_per_client_per_class])

        client = Client(args, client_id, train_dataset, test_dataset, train_client_indices, val_client_indices, test_client_indices)
        clients.append(client)

    return clients

def cifar_noniid(args, train_dataset, test_dataset):
    def class_clients_sharding(num_classes, Nc):
        class_clients = {key: set() for key in range(num_classes)}
        first_clients = list(range(num_classes))
        clients_list = [num // (Nc-1) for num in range((Nc-1)*100)]
        random.shuffle(first_clients)
        for i in range(num_classes):
            class_clients[i].add(first_clients[i])

        for j in range(1,Nc):
            class_list = list(range(num_classes))
            for i in range(num_classes):
                random_class = random.choice(class_list)
                class_list.remove(random_class)

                clients_list_cleaned = [client for client in clients_list if client not in class_clients[random_class]]

                random_client = random.choice(clients_list_cleaned)
                class_clients[random_class].add(random_client)
                clients_list.remove(random_client)

        return class_clients

    val_split = args.val_split
    num_clients = args.num_users
    Nc = args.Nc
    num_classes = len(train_dataset.classes)

    error = True
    while error:
        try:
            class_clients = class_clients_sharding(num_classes, Nc)
            error = False
        except Exception as e:
            print("Sharding Invalid, trying again...")

    # Create a list to store indices for each class
    class_indices = [[] for _ in range(num_classes)]

    # Populate class_indices with the indices of each class
    for idx, target in enumerate(train_dataset.targets):
        class_indices[target].append(idx)

    # Shuffle indices within each class
    for indices in class_indices:
        np.random.shuffle(indices)

    # Create lists for train and validation class indices
    train_class_indices = [[] for _ in range(num_classes)]
    val_class_indices = [[] for _ in range(num_classes)]

    # Split the indices into 80% for train and 20% for validation
    for i, indices in enumerate(class_indices):
        split_idx = int(len(indices) * val_split)
        val_class_indices[i] = indices[:split_idx]
        train_class_indices[i] = indices[split_idx:]

    # Prepare test_class_indices
    test_class_indices = [[] for _ in range(num_classes)]
    for idx, target in enumerate(test_dataset.targets):
        test_class_indices[target].append(idx)
    for indices in test_class_indices:
        np.random.shuffle(indices)

    # Initialize the list of client objects
    clients_list = []

    # Calculate the number of samples per client per class
    train_samples_per_client_per_class = int(len(train_dataset) * (1-val_split) // (Nc * num_classes))
    val_samples_per_client_per_class = int(len(train_dataset) * val_split // (Nc * num_classes))
    test_samples_per_client_per_class = len(test_dataset) // (Nc * num_classes)


    train_shards_indices = [[] for clients in range(num_clients)]
    val_shards_indices = [[] for clients in range(num_clients)]
    test_shards_indices = [[] for clients in range(num_clients)]

    for class_idx in range(num_classes):
        train_class_indices_for_class = train_class_indices[class_idx]
        val_class_indices_for_class = val_class_indices[class_idx]
        test_class_indices_for_class = test_class_indices[class_idx]
        clients = class_clients[class_idx].copy()
        for client_idx in range(Nc):
            client = random.choice(list(clients))
            clients.remove(client)

            train_start_idx = client_idx * int(train_samples_per_client_per_class)
            val_start_idx = client_idx * int(val_samples_per_client_per_class)
            test_start_idx = client_idx * int(test_samples_per_client_per_class)

            train_end_idx = (client_idx + 1) * int(train_samples_per_client_per_class)
            val_end_idx = (client_idx + 1) * int(val_samples_per_client_per_class)
            test_end_idx = (client_idx + 1) * int(test_samples_per_client_per_class)

            train_shards_indices[client].extend(train_class_indices_for_class[train_start_idx:train_end_idx])
            val_shards_indices[client].extend(val_class_indices_for_class[val_start_idx:val_end_idx])
            test_shards_indices[client].extend(test_class_indices_for_class[test_start_idx:test_end_idx])

    for client_id in range(num_clients):
        client = Client(args, client_id, train_dataset, test_dataset, train_shards_indices[client_id], val_shards_indices[client_id], test_shards_indices[client_id])
        clients_list.append(client)

    return clients_list





# Drive And Get Dataset

In [7]:
from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


In [8]:
import os
import torch

def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)

def load_checkpoint(filename):
    if os.path.isfile(filename):
        #checkpoint = torch.load(filename)
        checkpoint = torch.load(filename, map_location=torch.device('cpu'))
        print(f"Loading checkpoint '{filename}' (epoch {checkpoint['epoch']})")
        return checkpoint
    else:
        print(f"No checkpoint found at '{filename}'")
        return None

In [6]:
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    # data_dir = args.data_dir
    data_dir = "/content/drive/MyDrive/MLDL/"
    if args.dataset == 'cifar':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        train_dataset = datasets.CIFAR100(data_dir, train=True, download=True,
                                       transform=transform_train)
        test_dataset = datasets.CIFAR100(data_dir, train=False, download=True,
                                      transform=transform_test)
        if args.iid:
            clients = cifar_iid(args, train_dataset, test_dataset)
        else:
            clients = cifar_noniid(args, train_dataset, test_dataset)

    return train_dataset, test_dataset, clients


# Inference

In [9]:
def eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type):

    @torch.no_grad()
    def evaluate(nodes, num_nodes, hnet, net, criterion, device, loader_type):
        hnet.eval()
        results = defaultdict(lambda: defaultdict(list))

        for node_id in range(num_nodes):  # iterating over nodes
            running_loss, running_correct, running_samples = 0., 0., 0.
            if loader_type == 'test':
                curr_data = nodes[node_id].test_dataloader
            elif loader_type == 'val':
                curr_data = nodes[node_id].val_dataloader
            else:
                curr_data = nodes[node_id].train_dataloader

            for batch_count, batch in enumerate(curr_data):
                img, label = tuple(t.to(device) for t in batch)

                weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
                net.load_state_dict(weights)
                pred = net(img)
                running_loss += criterion(pred, label).item()
                running_correct += pred.argmax(1).eq(label).sum().item()
                running_samples += len(label)

            results[node_id]['loss'] = running_loss / (batch_count + 1)
            results[node_id]['correct'] = running_correct
            results[node_id]['total'] = running_samples
            results[node_id]['accuracy'] = running_correct / running_samples

        return results

    curr_results = evaluate(nodes, num_nodes, hnet, net, criterion, device, loader_type=loader_type)
    total_correct = sum([val['correct'] for val in curr_results.values()])
    total_samples = sum([val['total'] for val in curr_results.values()])

    avg_loss = np.mean([val['loss'] for val in curr_results.values()])
    avg_acc = np.mean([val['accuracy'] for val in curr_results.values()])
    acc = total_correct / total_samples

    return curr_results, avg_loss, avg_acc, acc

In [10]:
def eval_pfedhn_gen(nodes, num_nodes, num_users, hnet, net, criterion, device, loader_type):

    @torch.no_grad()
    def evaluate(nodes, num_nodes, num_users, hnet, net, criterion, device, loader_type):
        hnet.eval()
        results = defaultdict(lambda: defaultdict(list))

        for node_id in range(num_users):  # iterating over nodes
            running_loss, running_correct, running_samples = 0., 0., 0.
            if loader_type == 'test':
                curr_data = nodes[node_id].test_dataloader
            elif loader_type == 'val':
                curr_data = nodes[node_id].val_dataloader
            else:
                curr_data = nodes[node_id].train_dataloader

            for batch_count, batch in enumerate(curr_data):
                img, label = tuple(t.to(device) for t in batch)

                weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
                net.load_state_dict(weights)
                pred = net(img)
                running_loss += criterion(pred, label).item()
                running_correct += pred.argmax(1).eq(label).sum().item()
                running_samples += len(label)

            results[node_id]['loss'] = running_loss / (batch_count + 1)
            results[node_id]['correct'] = running_correct
            results[node_id]['total'] = running_samples
            results[node_id]['accuracy'] = running_correct / running_samples

        return results

    curr_results = evaluate(nodes, num_nodes, num_users, hnet, net, criterion, device, loader_type=loader_type)
    total_correct = sum([val['correct'] for val in curr_results.values()])
    total_samples = sum([val['total'] for val in curr_results.values()])

    avg_loss_new = np.mean([curr_results[node_id]['loss'] for node_id in range(num_nodes, num_users)])
    avg_acc_new = np.mean([curr_results[node_id]['accuracy'] for node_id in range(num_nodes, num_users)])
    loss = np.mean([val['loss'] for val in curr_results.values()])
    acc = np.mean([val['accuracy'] for val in curr_results.values()])

    return curr_results, avg_loss_new, avg_acc_new, loss, acc

# Hypernetwork

In [None]:
def pFedHN(global_model, clients, criterion, args, metrics, device, test_set):
    nodes = clients
    clients_distribs = {client.client_id: 0 for client in clients}
    embed_dim = args.embed_dim
    num_nodes = args.n_nodes

    if embed_dim == -1:
        embed_dim = int(1 + num_nodes / 4)

    # Ahmad add a check if there is a saving checkpoint if it there is then load the state dict on the global_model and the hypernetwork
    hnet = CNNHyper(num_nodes, embed_dim).to(device)
    net = global_model

    ##################
    # init optimizer #
    ##################
    lr = args.lr
    embed_lr = args.embed_lr
    wd = args.wd

    embed_lr = embed_lr if embed_lr is not None else lr
    optimizers = {
        'sgd': torch.optim.SGD(
            [
                {'params': [p for n, p in hnet.named_parameters() if 'embed' not in n]},
                {'params': [p for n, p in hnet.named_parameters() if 'embed' in n], 'lr': embed_lr}
            ], lr=lr, weight_decay=wd
        )
    }
    optimizer = optimizers[args.optimizer]

    ################
    # init metrics #
    ################
    results = defaultdict(list)
    dirichlet_probs = np.random.dirichlet([args.gamma] * num_nodes)

    step_iter = trange(args.epochs)
    for step in step_iter:
        hnet.train()

        if args.participation:
            # Uniform participation
            node_id = np.random.choice(range(num_nodes))
        else:
            # Skewed participation
            node_id = np.random.choice(range(num_nodes), p=dirichlet_probs)

        clients_distribs[node_id] = 1

        # produce & load local network weights
        weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
        net.load_state_dict(weights)

        # init inner optimizer
        inner_optim = torch.optim.SGD(
            net.parameters(), lr=args.lr, weight_decay=args.inner_wd
        )

        # storing theta_i for later calculating delta theta
        inner_state = OrderedDict({k: tensor.data for k, tensor in weights.items()})

        # inner updates -> obtaining theta_tilda
        inner_steps = args.local_ep * 10
        for i in range(inner_steps):
            net.train()
            inner_optim.zero_grad()
            optimizer.zero_grad()

            batch = next(iter(nodes[node_id].train_dataloader))
            img, label = tuple(t.to(device) for t in batch)

            pred = net(img)

            loss = criterion(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 50)

            inner_optim.step()

        optimizer.zero_grad()

        final_state = net.state_dict()

        # calculating delta theta
        delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in weights.keys()})

        # calculating phi gradient
        hnet_grads = torch.autograd.grad(
            list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
        )

        # update hnet weights
        for p, g in zip(hnet.parameters(), hnet_grads):
            p.grad = g
        torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
        optimizer.step()

        # logger.info(f"\n\nStep: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}")
        if (step +1) % args.print_every == 0:
            filename = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}_epoch_{step+1}.pth.tar"

            last_eval = step
            step_results, test_avg_loss, test_avg_acc, test_acc = eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type="test")

            print(f"\nStep: {step+1}, AVG Test Loss: {test_avg_loss:.4f},  AVG Test Acc: {test_avg_acc:.4f}, Test Acc: {test_acc:.4f}")

            results['test_avg_loss'].append(test_avg_loss)
            results['test_avg_acc'].append(test_avg_acc)
            results['test_acc'].append(test_acc)

            _, val_avg_loss, val_avg_acc, val_acc  = eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type="val")

            results['val_avg_loss'].append(val_avg_loss)
            results['val_avg_acc'].append(val_avg_acc)
            results['val_acc'].append(val_acc)

            for key in results:
                print(f"{key}: {results[key][-1]}")

            checkpoint = {
                'epoch': step + 1,
                'model_state_dict': global_model.state_dict(),
                'hn_state_dict': hnet.state_dict(),
                'user_input': (args.iid, args.participation, args.Nc, args.local_ep),
                'test_accuracy': results['test_acc'][-1],
                'test_avg_loss': results['test_avg_loss'][-1],
                'test_avg_acc': results['test_avg_acc'][-1],
                'val_accuracy': results['val_acc'][-1],
                'val_avg_loss': results['val_avg_loss'][-1],
                'val_avg_acc': results['val_avg_acc'][-1]
            }
            save_checkpoint(checkpoint, filename=filename)

            # Remove the previous checkpoint unless it's a multiple of the backup parameter
            prev_epoch = step + 1 - args.print_every
            if (step + 1) > args.print_every and prev_epoch != 1900:
                if (step + 1 -10) % args.backup != 0:
                    prev_filename = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}_epoch_{prev_epoch}.pth.tar"

                    if os.path.exists(prev_filename):
                        os.remove(prev_filename)
            # metrics = pd.DataFrame(columns=['Round', 'Test Accuracy', 'Test Loss', 'Avg Test Accuracy', 'Avg Test Loss', 'Avg Validation Accuracy', 'Avg Validation Loss'])

            metrics.loc[len(metrics)] = [step + 1, results['test_acc'][-1], results['test_avg_loss'][-1], results['test_avg_acc'][-1], results['val_acc'][-1], results['val_avg_loss'][-1], results['val_avg_acc'][-1]]
    if step != last_eval:
        _, val_avg_loss, val_avg_acc, val_acc  = eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type="val")

        results['test_avg_loss'].append(test_avg_loss)
        results['test_avg_acc'].append(test_avg_acc)
        results['test_acc'].append(test_acc)

        step_results, test_avg_loss, test_avg_acc, test_acc = eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type="test")
        print(f"\nStep: {step+1}, AVG Test Loss: {test_avg_loss:.4f},  AVG Test Acc: {test_avg_acc:.4f}, Test Acc: {test_acc:.4f}")

        results['val_avg_loss'].append(val_avg_loss)
        results['val_avg_acc'].append(val_avg_acc)
        results['val_acc'].append(val_acc)

    # Save the plot as a PDF file
    if args.participation:
        pickle_file = f"{args.metrics_dir}/metrics_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}.pkl"
    else:
        pickle_file = f"{args.metrics_dir}/metrics_{args.algorithm}_{args.iid}_{args.participation}_{args.gamma}_{args.Nc}_{args.local_ep}.pkl"


    # Optionally, clear the figure to free up memory


    metrics.to_pickle(pickle_file)
    logger.info(f"Metrics saved at {pickle_file}")
    logger.info(f"Plots saved at {plot_location}")
    logger.info("Training Done!")

# Hypernetwork With Checkpoint

In [None]:
def pFedHN(global_model, clients, criterion, args, metrics, device, test_set):
    nodes = clients
    clients_distribs = {client.client_id: 0 for client in clients}
    embed_dim = args.embed_dim
    num_nodes = args.n_nodes

    # Print the ID and indices of the first 5 clients
    for i in range(min(5, len(clients))):  # To ensure we only print if there are at least 5 clients
        print(f"Client ID: {clients[i].client_id}")
        print(f"Train indices: {clients[i].train_indices[:10]}")  # Print only the first 10 indices for brevity
        print(f"Val indices: {clients[i].val_indices[:10]}")  # Print only the first 10 indices
        print(f"Test indices: {clients[i].test_indices[:10]}")  # Print only the first 10 indices
        print('-' * 50)

    if embed_dim == -1:
        embed_dim = int(1 + num_nodes / 4)

    # Ahmad add a check if there is a saving checkpoint if it there is then load the state dict on the global_model and the hypernetwork
    hnet = CNNHyper(num_nodes, embed_dim).to(device)
    net = global_model

    last_epoch = args.last_epoch
    #last_epoch = 1900
    if args.participation:
        checkpoint_pattern = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}_epoch_{args.last_epoch}.pth.tar"
    else:
        checkpoint_pattern = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.iid}_{args.participation}_{args.gamma}_{args.Nc}_{args.local_ep}_epoch_{args.last_epoch}.pth.tar"
    checkpoint_files = sorted(glob.glob(checkpoint_pattern))
    print(checkpoint_pattern)
    print("len:",len(checkpoint_files))

    if len(checkpoint_files)+1:
        latest_checkpoint = checkpoint_files[-1]
        checkpoint = load_checkpoint(latest_checkpoint)
        print("\nciao mamma")
        if checkpoint:
            print("\n ciao papa")
            start_epoch = checkpoint['epoch']
            last_user_input = checkpoint['user_input']
            test_acc = checkpoint['test_accuracy']
            test_avg_loss = checkpoint['test_avg_loss']
            test_avg_acc = checkpoint['test_avg_acc']
            val_acc = checkpoint['val_accuracy']
            val_avg_loss = checkpoint['val_avg_loss']
            val_avg_acc = checkpoint['val_avg_acc']
            # Print the status of the last checkpoint
            participation_status = 'uniform' if last_user_input[1] == 1 else 'skewed'
            user_input_string = f"IID: {last_user_input[0]}, Participation: {participation_status}, Nc: {last_user_input[2]}, J: {last_user_input[3]}"

            #logger.info(f"\nA saving checkpoint with these parameters exists:\n"
            print(f"\nA saving checkpoint with these parameters exists:\n"
                f"Last checkpoint details:\n"
                f"Epoch reached: {start_epoch}\n"
                f"User input variables: {user_input_string}\n"
                f'Test Accuracy: {100*test_acc}%\n'
                f'Test Avg Loss: {test_avg_loss}\n'
                f'Test Avg Acc: {100*test_avg_acc}%\n'
                f'Val Accuracy: {100*val_acc}%\n'
                f'Val Avg Loss {val_avg_loss}\n'
                f'Val Avg Acc {100*val_avg_acc}%\n')


            # print("\nitems:\n",checkpoint['hn_state_dict'].items())
            if args.checkpoint_resume == 1:

                global_model.load_state_dict(checkpoint['model_state_dict'])
                net = global_model

                hnet = CNNHyper(num_nodes, embed_dim).to(device)
                hnet.load_state_dict(checkpoint['hn_state_dict'])



    ##################
    # init optimizer #
    ##################
    lr = args.lr
    embed_lr = args.embed_lr
    wd = args.wd

    embed_lr = embed_lr if embed_lr is not None else lr
    optimizers = {
        'sgd': torch.optim.SGD(
            [
                {'params': [p for n, p in hnet.named_parameters() if 'embed' not in n]},
                {'params': [p for n, p in hnet.named_parameters() if 'embed' in n], 'lr': embed_lr}
            ], lr=lr, weight_decay=wd
        )
    }
    optimizer = optimizers[args.optimizer]

    ################
    # init metrics #
    ################
    results = defaultdict(list)
    dirichlet_probs = np.random.dirichlet([args.gamma] * num_nodes)

    step_iter = trange(args.epochs - args.last_epoch + args.extra)

    for step in step_iter:
        hnet.train()

        if args.participation:
            # Uniform participation
            node_id = np.random.choice(range(num_nodes))
        else:
            # Skewed participation
            node_id = np.random.choice(range(num_nodes), p=dirichlet_probs)

        clients_distribs[node_id] = 1

        # produce & load local network weights
        weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
        net.load_state_dict(weights)

        # init inner optimizer
        inner_optim = torch.optim.SGD(
            net.parameters(), lr=args.lr, weight_decay=args.inner_wd
        )

        # storing theta_i for later calculating delta theta
        inner_state = OrderedDict({k: tensor.data for k, tensor in weights.items()})

        # inner updates -> obtaining theta_tilda
        inner_steps = args.local_ep * 10
        for i in range(inner_steps):
            net.train()
            inner_optim.zero_grad()
            optimizer.zero_grad()

            batch = next(iter(nodes[node_id].train_dataloader))
            img, label = tuple(t.to(device) for t in batch)

            pred = net(img)

            loss = criterion(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 50)

            inner_optim.step()

        optimizer.zero_grad()

        final_state = net.state_dict()

        # calculating delta theta
        delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in weights.keys()})

        # calculating phi gradient
        hnet_grads = torch.autograd.grad(
            list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
        )

        # update hnet weights
        for p, g in zip(hnet.parameters(), hnet_grads):
            p.grad = g
        torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
        optimizer.step()

        # logger.info(f"\n\nStep: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}")
        if (step +1) % args.print_every == 0:
            filename = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}_epoch_{args.last_epoch+step+1}.pth.tar"

            last_eval = step
            step_results, test_avg_loss, test_avg_acc, test_acc = eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type="test")

            print(f"\nStep: {step+1}, AVG Test Loss: {test_avg_loss:.4f},  AVG Test Acc: {test_avg_acc:.4f}, Test Acc: {test_acc:.4f}")

            results['test_avg_loss'].append(test_avg_loss)
            results['test_avg_acc'].append(test_avg_acc)
            results['test_acc'].append(test_acc)

            _, val_avg_loss, val_avg_acc, val_acc  = eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type="val")

            results['val_avg_loss'].append(val_avg_loss)
            results['val_avg_acc'].append(val_avg_acc)
            results['val_acc'].append(val_acc)

            for key in results:
                print(f"{key}: {results[key][-1]}")

            checkpoint = {
                'epoch': step + 1,
                'model_state_dict': global_model.state_dict(),
                'hn_state_dict': hnet.state_dict(),
                'user_input': (args.iid, args.participation, args.Nc, args.local_ep),
                'test_accuracy': results['test_acc'][-1],
                'test_avg_loss': results['test_avg_loss'][-1],
                'test_avg_acc': results['test_avg_acc'][-1],
                'val_accuracy': results['val_acc'][-1],
                'val_avg_loss': results['val_avg_loss'][-1],
                'val_avg_acc': results['val_avg_acc'][-1]
            }
            save_checkpoint(checkpoint, filename=filename)

            hnet_old = CNNHyper_old(num_nodes, embed_dim).to(device)

            # Load all other parameters except the embedding layer
            hnet_state_dict = hnet.state_dict()
            hnet_old.load_state_dict(checkpoint['hn_state_dict'])


            print("\n LOADED MODEL\n")
            step_results, test_avg_loss, test_avg_acc, test_acc = eval_pfedhn(nodes, num_nodes, hnet_old, net, criterion, device, loader_type="test")

            print(f"\nStep: {step+1}, AVG Test Loss: {test_avg_loss:.4f},  AVG Test Acc: {test_avg_acc:.4f}, Test Acc: {test_acc:.4f}")




            # Remove the previous checkpoint unless it's a multiple of the backup parameter
            prev_epoch = step + 1 - args.print_every
            if (step + 1) > args.print_every and prev_epoch != 2000:
                if (step + 1 -10) % args.backup != 0:
                    prev_filename = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}_epoch_{prev_epoch}.pth.tar"

                    if os.path.exists(prev_filename):
                        os.remove(prev_filename)
            # metrics = pd.DataFrame(columns=['Round', 'Test Accuracy', 'Test Loss', 'Avg Test Accuracy', 'Avg Test Loss', 'Avg Validation Accuracy', 'Avg Validation Loss'])

            metrics.loc[len(metrics)] = [step + 1, results['test_acc'][-1], results['test_avg_loss'][-1], results['test_avg_acc'][-1], results['val_acc'][-1], results['val_avg_loss'][-1], results['val_avg_acc'][-1]]
    if step != last_eval:
        _, val_avg_loss, val_avg_acc, val_acc  = eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type="val")

        results['test_avg_loss'].append(test_avg_loss)
        results['test_avg_acc'].append(test_avg_acc)
        results['test_acc'].append(test_acc)

        step_results, test_avg_loss, test_avg_acc, test_acc = eval_pfedhn(nodes, num_nodes, hnet, net, criterion, device, loader_type="test")
        print(f"\nStep: {step+1}, AVG Test Loss: {test_avg_loss:.4f},  AVG Test Acc: {test_avg_acc:.4f}, Test Acc: {test_acc:.4f}")

        results['val_avg_loss'].append(val_avg_loss)
        results['val_avg_acc'].append(val_avg_acc)
        results['val_acc'].append(val_acc)

    # Save the plot as a PDF file
    if args.participation:
        pickle_file = f"{args.metrics_dir}/metrics_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}.pkl"
    else:
        pickle_file = f"{args.metrics_dir}/metrics_{args.algorithm}_{args.iid}_{args.participation}_{args.gamma}_{args.Nc}_{args.local_ep}.pkl"


    # Optionally, clear the figure to free up memory


    metrics.to_pickle(pickle_file)
    logger.info(f"Metrics saved at {pickle_file}")
    logger.info(f"Plots saved at {plot_location}")
    logger.info("Training Done!")

# Hypernetwork Main

In [15]:
import argparse
import json
import logging
import random
from collections import OrderedDict, defaultdict
from pathlib import Path

import numpy as np
import torch
import torch.utils.data
from tqdm import trange
import pickle
import glob
import pandas as pd




# TO DO: modify how things of args are set
args = args_parser()
args.epochs = 2000
args.last_epoch = 1900
args.iid = 0
args.participation = 1
args.algorithm = "pfedhn"
args.Nc = 1
args.local_ep = 4
args.checkpoint_resume = 1
args.checkpoint_path = "/content/drive/MyDrive/MLDL/Cifar-100/Checkpoints"
if args.gpu:
    torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'

args.device = device

criterion = nn.CrossEntropyLoss().to(device)

if args.dataset == 'cifar':
    if args.algorithm == 'fedavg':
        global_model = CIFARLeNet().to(device)
    else:
        global_model = CIFARLeNet().to(device)
else:
    global_model = CharLSTM().to(device)

#train_set, test_set, clients = get_dataset(args)

# TO DO: logger info


if args.iid:
    if args.participation:
        pickle_file = f"{args.metrics_dir}/clients_classes_dist_{args.algorithm}_{args.iid}_{args.participation}.pkl"
    else:
        pickle_file = f"{args.metrics_dir}/clients_classes_dist_{args.algorithm}_{args.iid}_{args.participation}_{args.gamma}.pkl"
else:
    if args.participation:
        pickle_file = f"{args.metrics_dir}/clients_classes_dist_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}.pkl"
    else:
        pickle_file = f"{args.metrics_dir}/clients_classes_dist_{args.algorithm}_{args.iid}_{args.participation}_{args.gamma}_{args.Nc}_{args.local_ep}.pkl"

checkpoint_path = pickle_file
try:
    with open(checkpoint_path, 'rb') as file:
        clients_classes_df = pickle.load(file)

    # Print the loaded DataFrame to verify
    print("Checkpoint data loaded successfully:")
except FileNotFoundError:
    print(f"Checkpoint file not found: {checkpoint_path}")

#data_dir = args.data_dir
data_dir = "/content/drive/MyDrive/MLDL/"

if args.dataset == 'cifar':
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])
    train_dataset = datasets.CIFAR100(data_dir, train=True, download=True,
                                    transform=transform_train)
    test_dataset = datasets.CIFAR100(data_dir, train=False, download=True,
                                  transform=transform_test)


display(clients_classes_df)
if args.dataset == 'cifar':
    metrics = pd.DataFrame(columns=['Round', 'Test Accuracy', 'Test Loss', 'Avg Test Accuracy', 'Avg Test Loss', 'Avg Validation Accuracy', 'Avg Validation Loss'])
else:
    metrics = pd.DataFrame(columns=['Round', 'Test Accuracy', 'Test Loss', 'Avg Test Accuracy', 'Avg Test Loss'])



#logger.info(f"Saved clients classes distribution to {pickle_file}")

if args.n_nodes == args.num_users:
    if args.algorithm == 'fedavg':
        fedAVG(global_model, clients, criterion, args, temp, metrics, temp, device, test_set)
    elif args.algorithm == 'pfedhn':
        pFedHN(global_model, clients, criterion, args, metrics, device, test_set)
else:
    # Assuming df is your DataFrame with a column named 'new_id'
    df_sorted = clients_classes_df.sort_values(by='new_id')

    # Optionally, reset the index if you want to have the default index after sorting
    df_sorted.reset_index(drop=True, inplace=True)
    display(df_sorted)


    clients = []
    for client_id in range(args.num_users):
        client = Client(args, df_sorted['client_id'][client_id], train_dataset, test_dataset, df_sorted['train_indices'][client_id], df_sorted['val_indices'][client_id], df_sorted['test_indices'][client_id])

        #client = Client(args, clients_classes_df['new_id'][client_id], train_dataset, test_dataset, clients_classes_df['train_indices'][client_id], clients_classes_df['val_indices'][client_id], clients_classes_df['test_indices'][client_id])
        clients.append(client)

    if args.algorithm == 'fedavg':
        fedAVG(global_model, clients, criterion, args, temp, metrics, temp, device, test_set)
    elif args.algorithm == 'pfedhn':
        # TO DO: add the wandb_logger to the pFedHN
        pFedHN(global_model, clients, criterion, args, metrics, device, test_dataset)


Checkpoint data loaded successfully:
Files already downloaded and verified
Files already downloaded and verified


Unnamed: 0,client_id,train,val,test,new_id,train_indices,val_indices,test_indices
0,0,{66: 400},{66: 100},{66: 100},52,"[21798, 9842, 16198, 36577, 32338, 34686, 3768...","[4812, 13010, 9025, 37095, 39183, 268, 21995, ...","[4075, 2619, 6441, 415, 3660, 2546, 6335, 4618..."
1,1,{52: 400},{52: 100},{52: 100},57,"[38624, 5427, 28266, 41757, 6122, 3113, 9973, ...","[42526, 23261, 13141, 32989, 44141, 41089, 231...","[1803, 7734, 6137, 2283, 7438, 7666, 5265, 114..."
2,2,{32: 400},{32: 100},{32: 100},33,"[23117, 33712, 37542, 31746, 49253, 36127, 552...","[33040, 38597, 30277, 43613, 26454, 40693, 361...","[3366, 7480, 2465, 9699, 8860, 4954, 689, 6158..."
3,3,{76: 400},{76: 100},{76: 100},3,"[35788, 42003, 28492, 39416, 11844, 35000, 498...","[49443, 10426, 11318, 15173, 36490, 35418, 532...","[5062, 5355, 5477, 3097, 6744, 1580, 5547, 524..."
4,4,{80: 400},{80: 100},{80: 100},15,"[19346, 34367, 26146, 41023, 15974, 41326, 383...","[6440, 45794, 28508, 24698, 3026, 1081, 47146,...","[5089, 1883, 2769, 208, 3051, 9590, 260, 7760,..."
...,...,...,...,...,...,...,...,...
95,95,{54: 400},{54: 100},{54: 100},37,"[33406, 33249, 40826, 1577, 19205, 29793, 1580...","[5071, 27780, 1178, 456, 13428, 36758, 26599, ...","[7598, 7997, 1255, 3632, 9430, 8158, 3903, 322..."
96,96,{51: 400},{51: 100},{51: 100},32,"[2882, 40895, 48859, 13023, 35381, 44138, 1830...","[16786, 29546, 33842, 39553, 31485, 33216, 982...","[9847, 7619, 7356, 8788, 5484, 3337, 7425, 350..."
97,97,{92: 400},{92: 100},{92: 100},56,"[8165, 37937, 8080, 27736, 13140, 29705, 11549...","[25644, 27435, 8485, 11055, 25695, 32326, 4371...","[4613, 6896, 1504, 2000, 9938, 9254, 6944, 689..."
98,98,{63: 400},{63: 100},{63: 100},85,"[10625, 6989, 20417, 21242, 6953, 42918, 2716,...","[34607, 147, 8330, 2887, 27506, 41920, 41902, ...","[564, 2046, 8403, 7185, 7449, 5055, 6195, 971,..."


Unnamed: 0,client_id,train,val,test,new_id,train_indices,val_indices,test_indices
0,11,{61: 400},{61: 100},{61: 100},0,"[43031, 42082, 30244, 3309, 18975, 38580, 4852...","[38599, 25239, 37441, 29374, 49725, 20029, 492...","[4562, 9064, 5318, 1492, 8195, 7663, 1625, 197..."
1,81,{4: 400},{4: 100},{4: 100},1,"[30665, 38204, 27342, 39748, 35882, 19181, 135...","[46937, 31311, 32117, 954, 10099, 10243, 21109...","[8400, 4951, 2720, 4749, 3380, 1379, 9158, 885..."
2,77,{2: 400},{2: 100},{2: 100},2,"[9964, 580, 382, 16332, 27661, 12363, 47261, 3...","[46974, 4031, 31300, 13524, 47390, 36345, 8681...","[2777, 3397, 2804, 4071, 6301, 4056, 573, 1047..."
3,3,{76: 400},{76: 100},{76: 100},3,"[35788, 42003, 28492, 39416, 11844, 35000, 498...","[49443, 10426, 11318, 15173, 36490, 35418, 532...","[5062, 5355, 5477, 3097, 6744, 1580, 5547, 524..."
4,65,{53: 400},{53: 100},{53: 100},4,"[49422, 26610, 18645, 28859, 44041, 20464, 302...","[11550, 22135, 49432, 8198, 46248, 13245, 1948...","[1457, 9449, 245, 793, 9956, 2467, 1706, 8903,..."
...,...,...,...,...,...,...,...,...
95,60,{29: 400},{29: 100},{29: 100},95,"[31711, 30207, 14744, 31805, 24054, 27887, 158...","[36208, 9684, 28788, 29096, 37652, 10355, 3919...","[5598, 4632, 8888, 3459, 8984, 9772, 386, 2050..."
96,12,{98: 400},{98: 100},{98: 100},96,"[45304, 740, 29937, 42097, 47458, 5459, 46462,...","[10317, 22963, 11504, 20033, 25848, 34236, 388...","[3879, 2020, 1468, 5061, 2498, 7114, 8601, 892..."
97,50,{7: 400},{7: 100},{7: 100},97,"[31534, 25681, 5516, 38429, 3810, 7173, 30265,...","[39476, 43315, 23862, 26164, 32032, 14553, 214...","[2201, 64, 7684, 8294, 9652, 5518, 3103, 198, ..."
98,71,{12: 400},{12: 100},{12: 100},98,"[38619, 23655, 14684, 19663, 25921, 6031, 4694...","[10034, 48415, 6835, 20126, 17653, 35516, 2545...","[6905, 5423, 324, 5814, 7116, 1194, 6861, 1649..."


NameError: name 'pFedHN' is not defined

# Generalization Hypernetwork

In [29]:
import glob
import matplotlib.pyplot as plt

# def pFedHN_Gen(global_model, clients, criterion, args, logger, metrics, wandb_logger, device, test_set):
def pFedHN_Gen(global_model, clients, criterion, args, metrics, device, test_set):

    nodes = clients

    clients_distribs = {client.client_id: 0 for client in clients}
    embed_dim = args.embed_dim
    num_nodes = args.n_nodes
    num_users = args.num_users

    # Print the ID and indices of the first 5 clients
    for i in range(min(5, len(clients))):  # To ensure we only print if there are at least 5 clients
        print(f"Client ID: {clients[i].client_id}")
        print(f"Train indices: {clients[i].train_indices[:10]}")  # Print only the first 10 indices for brevity
        print(f"Val indices: {clients[i].val_indices[:10]}")  # Print only the first 10 indices
        print(f"Test indices: {clients[i].test_indices[:10]}")  # Print only the first 10 indices
        print('-' * 50)

    if embed_dim == -1:
        embed_dim = int(1 + num_nodes / 4)

    last_epoch = args.last_epoch
    #last_epoch = 1900
    if args.participation:
        checkpoint_pattern = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}_epoch_{args.last_epoch}.pth.tar"
    else:
        checkpoint_pattern = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.iid}_{args.participation}_{args.gamma}_{args.Nc}_{args.local_ep}_epoch_{args.last_epoch}.pth.tar"
    checkpoint_files = sorted(glob.glob(checkpoint_pattern))
    print(checkpoint_pattern)
    print("len:",len(checkpoint_files))

    if len(checkpoint_files):
        latest_checkpoint = checkpoint_files[-1]
        checkpoint = load_checkpoint(latest_checkpoint)

        if checkpoint:
            start_epoch = checkpoint['epoch']
            last_user_input = checkpoint['user_input']
            test_acc = checkpoint['test_accuracy']
            test_avg_loss = checkpoint['test_avg_loss']
            test_avg_acc = checkpoint['test_avg_acc']
            val_acc = checkpoint['val_accuracy']
            val_avg_loss = checkpoint['val_avg_loss']
            val_avg_acc = checkpoint['val_avg_acc']
            # Print the status of the last checkpoint
            participation_status = 'uniform' if last_user_input[1] == 1 else 'skewed'
            user_input_string = f"IID: {last_user_input[0]}, Participation: {participation_status}, Nc: {last_user_input[2]}, J: {last_user_input[3]}"

            #logger.info(f"\nA saving checkpoint with these parameters exists:\n"
            print(f"\nA saving checkpoint with these parameters exists:\n"
                f"Last checkpoint details:\n"
                f"Epoch reached: {start_epoch}\n"
                f"User input variables: {user_input_string}\n"
                f'Test Accuracy: {100*test_acc}%\n'
                f'Test Avg Loss: {test_avg_loss}\n'
                f'Test Avg Acc: {100*test_avg_acc}%\n'
                f'Val Accuracy: {100*val_acc}%\n'
                f'Val Avg Loss {val_avg_loss}\n'
                f'Val Avg Acc {100*val_avg_acc}%\n')


            # print("\nitems:\n",checkpoint['hn_state_dict'].items())
            if args.checkpoint_resume == 1:

                global_model.load_state_dict(checkpoint['model_state_dict'])
                net = global_model

                hnet = CNNHyper(num_users, embed_dim).to(device)
                hnet_old = CNNHyper_old(num_nodes, embed_dim).to(device)

                # Load all other parameters except the embedding layer
                hnet_state_dict = hnet.state_dict()
                hnet_old.load_state_dict(checkpoint['hn_state_dict'])
                checkpoint_hnet_state_dict = checkpoint['hn_state_dict']

                # Copy weights for all layers except the embeddings
                for name, param in checkpoint_hnet_state_dict.items():
                    if 'embeddings' not in name:
                        hnet_state_dict[name] = param

                # Manually load the embeddings, adjusting for the new size
                embedding_key = 'embeddings.weight'
                if embedding_key in checkpoint_hnet_state_dict:
                    old_embeddings = checkpoint_hnet_state_dict[embedding_key]
                    new_embeddings = hnet_state_dict[embedding_key]

                    # Copy the first part of the old embeddings to the new embeddings
                    new_embeddings[:old_embeddings.size(0)] = old_embeddings
                    if args.average:
                        # Compute the average vector of the old embeddings
                        average_vector = old_embeddings.mean(dim=0)

                        # Fill the rest of the new_embeddings with the average vector
                        new_embeddings[old_embeddings.size(0):] = average_vector
                        display(new_embeddings)
                    hnet_state_dict[embedding_key] = new_embeddings

                hnet.load_state_dict(hnet_state_dict)

                #logger.info(f"Starting the vector embedding tuning")
                """
                if args.update:
                    for name, param in hnet.named_parameters():
                        if 'embed' not in name:
                            param.requires_grad = False
                        else:
                            param_to_compute_grad = param"""

    ##################
    # init optimizer #
    ##################
    lr = args.lr
    embed_lr = args.embed_lr
    wd = args.wd

    embed_lr = embed_lr if embed_lr is not None else lr
    optimizers = {
        'sgd': torch.optim.SGD(
            [
                {'params': [p for n, p in hnet.named_parameters() if 'embed' not in n]},
                {'params': [p for n, p in hnet.named_parameters() if 'embed' in n], 'lr': embed_lr}
            ], lr=lr, weight_decay=wd
        )
    }
    optimizer = optimizers[args.optimizer]

    ################
    # init metrics #
    ################
    results = defaultdict(list)

    dirichlet_probs = np.random.dirichlet([args.gamma] * num_users)

    old_test_accuracy = checkpoint['test_accuracy']
    old_val_accuracy = checkpoint['val_accuracy']

    step_results_old, avg_loss_old, avg_acc_old, acc_old = eval_pfedhn(nodes, num_nodes, hnet_old, net, criterion, device, loader_type="test")
    print("\n the old model accuracy on old client is", acc_old)
    step_results, test_avg_loss_new, test_avg_acc_new, test_loss, test_acc = eval_pfedhn_gen(nodes, num_nodes, num_users , hnet, net, criterion, device, loader_type="test")
    newmodel_avg_acc_old = np.mean([step_results[node_id]['accuracy'] for node_id in range(0, num_nodes)])
    print("\nthe new model test accuracy on old client is: ", newmodel_avg_acc_old)

    #step_iter = trange(args.epochs - last_epoch + args.extra)
    step_iter = trange(args.step_iter)
    for step in step_iter:
        hnet.train()

        # Compute the weights for old clients and new clients based on the bias factor
        num_old_clients = num_nodes
        num_new_clients = num_users - num_nodes
        bias_factor = args.bias

        if bias_factor <= 1:
            weight_old_clients = 1.0 - bias_factor  # decreases from 1 to 0
            weight_new_clients = bias_factor        # increases from 0 to 1

        # Create the weights array
        weights = np.array([weight_old_clients] * num_old_clients + [weight_new_clients] * num_new_clients)


        # Normalize the weights to create a probability distribution
        prob_distribution = weights / weights.sum()

        # Randomly select a client based on the probability distribution
        node_id = np.random.choice(range(num_users), p=prob_distribution)

        #clients_distribs[node_id] = 1
        # produce & load local network weights
        weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
        net.load_state_dict(weights)

        # init inner optimizer
        inner_optim = torch.optim.SGD(
            net.parameters(), lr=args.lr, weight_decay=args.inner_wd
        )

        # storing theta_i for later calculating delta theta
        inner_state = OrderedDict({k: tensor.data for k, tensor in weights.items()})

        # inner updates -> obtaining theta_tilda
        inner_steps = args.local_ep * 10
        for i in range(inner_steps):
            net.train()
            inner_optim.zero_grad()
            optimizer.zero_grad()

            batch = next(iter(nodes[node_id].train_dataloader))
            img, label = tuple(t.to(device) for t in batch)

            pred = net(img)

            loss = criterion(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 50)

            inner_optim.step()

        optimizer.zero_grad()

        final_state = net.state_dict()

        # calculating delta theta
        delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in weights.keys()})

        if args.update:
            """hnet_grad = torch.autograd.grad(
                outputs=list(weights.values())[0],
                inputs=param_to_compute_grad,
                grad_outputs=list(delta_theta.values())[0]
            )


            # update hnet weights
            for p, g in zip(hnet.parameters(), hnet_grad):
              if p.requires_grad:  # Only assign gradients if requires_grad is True
                  p.grad = g"""
            # calculating phi gradient
            hnet_grads = torch.autograd.grad(
                list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
            )


            next(hnet.parameters()).grad = hnet_grads[0]

        else:
            # calculating phi gradient
            hnet_grads = torch.autograd.grad(
                list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
            )

            # update hnet weights
            for p, g in zip(hnet.parameters(), hnet_grads):
                # Only assign gradients if requires_grad is True
                  p.grad = g

        torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
        optimizer.step()

        # logger.info(f"\n\nStep: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}")
        if (step +1) % args.print_every == 0:

            filename = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.bias}_gen_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}_epoch_{args.last_epoch+step+1}.pth.tar"

            last_eval = step
            step_results, test_avg_loss_new, test_avg_acc_new, test_loss, test_acc = eval_pfedhn_gen(nodes, num_nodes, num_users , hnet, net, criterion, device, loader_type="test")
            print(f"\nStep: {step+1}, New Clients Test Loss: {test_avg_loss_new:.4f},  New Clients Test Acc: {test_avg_acc_new:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

            # logger.info(f"\nStep: {step+1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}")
            """wandb_logger.log({
                'Step': step + 1,
                'Test Avg Loss': avg_loss,
                'Test Avg Acc': avg_acc,
            })"""
            results['test_avg_loss_new'].append(float(f"{test_avg_loss_new:.4f}"))
            results['test_avg_acc_new'].append(float(f"{test_avg_acc_new:.4f}"))
            results['test_loss'].append(float(f"{test_loss:.4f}"))
            results['test_acc'].append(float(f"{test_acc:.4f}"))

            _, val_avg_loss_new, val_avg_acc_new, val_loss, val_acc  = eval_pfedhn_gen(nodes, num_nodes, num_users, hnet, net, criterion, device, loader_type="val")
            results['val_avg_loss_new'].append(float(f"{val_avg_loss_new:.4f}"))
            results['val_avg_acc_new'].append(float(f"{val_avg_acc_new:.4f}"))
            results['val_loss'].append(float(f"{val_loss:.4f}"))
            results['val_acc'].append(float(f"{val_acc:.4f}"))

            print("\ncheck acc:", old_test_accuracy)
            print("\nnew client acc:", test_avg_acc_new)
            flag = 0
            if (old_test_accuracy - results['test_avg_acc_new'][-1]) < 0:

                flag = 1

            """wandb_logger.log({
                'Test Loss': loss,
                'Test Accuracy': acc * 100,
                'Round': step + 1
            })"""
            for key in results:
                print(f"{key}: {results[key][-1]}")

                #logger.info(f"{key}: {results[key][-1]}")
                #wandb_logger.log({key: results[key][-1]})


            checkpoint = {
                'epoch': step + 1,
                'epoch_start': last_epoch,
                'model_state_dict': global_model.state_dict(),
                'hn_state_dict': hnet.state_dict(),
                'user_input': (args.iid, args.participation, args.Nc, args.local_ep),
                'test_avg_loss_new': results['test_avg_loss_new'][-1],
                'test_avg_acc_new': results['test_avg_acc_new'][-1],
                'test_loss': results['test_loss'][-1],
                'test_acc': results['test_acc'][-1],
                'val_avg_loss_new': results['val_avg_loss_new'][-1],
                'val_avg_acc_new': results['val_avg_acc_new'][-1],
                'val_loss': results['val_loss'][-1],
                'val_acc': results['val_acc'][-1],
                'old test accuracy': old_test_accuracy,
                'old val accuracy': old_val_accuracy
            }
            save_checkpoint(checkpoint, filename=filename)

            # Remove the previous checkpoint unless it's a multiple of the backup parameter
            prev_epoch = step + 1 - args.print_every
            if (step + 1) > args.print_every and prev_epoch != 100:
                if (step + 1 -10) % args.backup != 0:
                    prev_filename = f"{args.checkpoint_path}/checkpoint_{args.algorithm}_{args.bias}_gen_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}_epoch_{1900+prev_epoch}.pth.tar"
                    if os.path.exists(prev_filename):
                        os.remove(prev_filename)
            # metrics = pd.DataFrame(columns=['Round', 'Test Accuracy', 'Test Loss', 'Avg Test Accuracy', 'Avg Test Loss', 'Avg Validation Accuracy', 'Avg Validation Loss'])

            metrics.loc[len(metrics)] = [
                step + 1,
                last_epoch,
                results['test_avg_loss_new'][-1],
                results['test_avg_acc_new'][-1],
                results['test_loss'][-1],
                results['test_acc'][-1],
                results['val_avg_loss_new'][-1],
                results['val_avg_acc_new'][-1],
                results['val_loss'][-1],
                results['val_acc'][-1],
                old_test_accuracy,
                old_val_accuracy
            ]

            if flag == 1:
                print("convergence found")
                print("Global Test Acc Diff:", old_test_accuracy - test_acc)
                print("Global Val Acc Diff:", old_val_accuracy - val_acc)
                print("New Clients Test Acc:", results['test_avg_acc_new'][-1])
                print("New Clients Val Acc:", results['val_avg_acc_new'][-1])
                break

    """if step != last_eval:
        _, val_avg_loss, val_avg_acc, val_acc  = eval_pfedhn_gen(nodes, num_nodes, num_users, hnet, net, criterion, device, loader_type="val")
        results['val_avg_loss_new'].append(float(f"{val_avg_loss_new:.4f}"))
        results['val_avg_acc_new'].append(float(f"{val_avg_acc_new:.4f}"))
        results['val_loss'].append(float(f"{val_loss:.4f}"))
        results['val_acc'].append(float(f"{val_acc:.4f}"))

        step_results, test_avg_loss, test_avg_acc, test_acc = eval_pfedhn_gen(nodes, num_users, num_nodes, hnet, net, criterion, device, loader_type="test")
        print(f"\nStep: {step+1}, New Clients Test Loss: {test_avg_loss_new:.4f},  New Clients Test Acc: {test_avg_acc_new:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
        results['test_avg_loss_new'].append(float(f"{test_avg_loss_new:.4f}"))
        results['test_avg_acc_new'].append(float(f"{test_avg_acc_new:.4f}"))
        results['test_loss'].append(float(f"{test_loss:.4f}"))
        results['test_acc'].append(float(f"{test_acc:.4f}"))"""
    """# Plot the frequency of client selection
    plt.figure(figsize=(10, 6))

    # Normalize the selection counts
    normalized_counts = [count / sum(clients_distribs.values()) for count in clients_distribs.values()]

    # Create the bar plot
    plt.bar(clients_distribs.keys(), normalized_counts)
    plt.xlabel('Client ID')
    plt.ylabel('Relative frequency')
    if args.participation:
        plt.title(f'Clients distribution (random selection)')

    else:
        plt.title(f'Clients distribution (gamma={args.gamma})')"""

    # Save the plot as a PDF file
    if args.participation:
        pickle_file = f"{args.metrics_dir}/metrics_{args.algorithm}_{args.bias}_gen_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}.pkl"
        #plot_location = f'{args.metrics_dir}/client_selection_frequency_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}.pdf'
    else:
        pickle_file = f"{args.metrics_dir}/metrics_{args.algorithm}_{args.bias}_gen_{args.iid}_{args.participation}_{args.gamma}_{args.Nc}_{args.local_ep}.pkl"
        #plot_location = f'{args.metrics_dir}/client_selection_frequency_{args.algorithm}_{args.iid}_{args.participation}_{args.gamma}_{args.Nc}_{args.local_ep}.pdf'
    #plt.savefig(plot_location)

    # Optionally, clear the figure to free up memory
    #plt.clf()

    metrics.to_pickle(pickle_file)
    #logger.info(f"Metrics saved at {pickle_file}")
    #logger.info(f"Plots saved at {plot_location}")
    #logger.info("Training Done!")

# Generalization Main

In [33]:
import argparse
import json
import logging
import random
from collections import OrderedDict, defaultdict
from pathlib import Path

import numpy as np
import torch
import torch.utils.data
from tqdm import trange
import pickle
import glob
import pandas as pd

args = args_parser()
args.iid = 0
args.participation = 1
args.algorithm = "pfedhn"
args.Nc = 1
args.local_ep = 4
args.checkpoint_resume = 1
args.checkpoint_path = "/content/drive/MyDrive/MLDL/Cifar-100/Checkpoints"
args.metrics_dir =  '/content/drive/MyDrive/MLDL/cifar/metrics'
args.bias = 1
args.step_iter = 100
args.extra = 0
args.average = 0
args.update = 0
args.last_epoch = 500

if args.gpu:
    torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'
device = 'cpu'
#train_set, test_set, clients = get_dataset(args)

prev_epoch = 240
if args.participation:
    pickle_file = f"{args.metrics_dir}/clients_classes_dist_{args.algorithm}_{args.iid}_{args.participation}_{args.Nc}_{args.local_ep}.pkl"
else:
    pickle_file = f"{args.metrics_dir}/clients_classes_dist_{args.algorithm}_{args.iid}_{args.participation}_{args.gamma}_{args.Nc}_{args.local_ep}.pkl"
# Load the checkpoint
checkpoint_path = pickle_file
try:
    with open(checkpoint_path, 'rb') as file:
        clients_classes_df = pickle.load(file)

    # Print the loaded DataFrame to verify
    print("Checkpoint data loaded successfully:")
except FileNotFoundError:
    print(f"Checkpoint file not found: {checkpoint_path}")

#data_dir = args.data_dir
data_dir = "/content/drive/MyDrive/MLDL/"

if args.dataset == 'cifar':
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])
    train_dataset = datasets.CIFAR100(data_dir, train=True, download=True,
                                    transform=transform_train)
    test_dataset = datasets.CIFAR100(data_dir, train=False, download=True,
                                  transform=transform_test)
display(clients_classes_df)
# Assuming df is your DataFrame with a column named 'new_id'
df_sorted = clients_classes_df.sort_values(by='new_id')

# Optionally, reset the index if you want to have the default index after sorting
df_sorted.reset_index(drop=True, inplace=True)
display(df_sorted)


clients = []
for client_id in range(args.num_users):
    client = Client(args, df_sorted['client_id'][client_id], train_dataset, test_dataset, df_sorted['train_indices'][client_id], df_sorted['val_indices'][client_id], df_sorted['test_indices'][client_id])

    #client = Client(args, clients_classes_df['new_id'][client_id], train_dataset, test_dataset, clients_classes_df['train_indices'][client_id], clients_classes_df['val_indices'][client_id], clients_classes_df['test_indices'][client_id])
    clients.append(client)

# Print the loaded data to verify
"""for client in clients:
    client.print_class_distribution()"""

# Ahmad this is the part that will go in federated.py,
# you will add checks that based on args.algorithm will select which algorithm we want to run
criterion = nn.CrossEntropyLoss().to(device)
global_model = CIFARLeNet().to(device)

# cleints

# this will be added in the algorithm.py
"""if args.generalization
    random select 90 clients form clients and create the two clients vectors
     """
metrics = pd.DataFrame(columns=[
                'Round',
                'Start',
                'New Clients Test Loss',
                'New Clients Test Accuracy',
                'Test Loss',
                'Test Accuracy',
                'New Clients Val Loss',
                'New Clients Val Accuracy',
                'Val Loss',
                'Val Accuracy',
                'Old Test Accuracy',
                'Old Val Accuracy'
])
pFedHN_Gen(global_model, clients, criterion, args, metrics, device, test_dataset) #Ahmad  add your missing argument
"""pFedHN_embedtuning ()
      just take the last check pint of the pFedHN run don on 1900 rounds and with 90 clients"""


Checkpoint data loaded successfully:
Files already downloaded and verified
Files already downloaded and verified


Unnamed: 0,client_id,train,val,test,new_id,train_indices,val_indices,test_indices
0,0,{66: 400},{66: 100},{66: 100},52,"[21798, 9842, 16198, 36577, 32338, 34686, 3768...","[4812, 13010, 9025, 37095, 39183, 268, 21995, ...","[4075, 2619, 6441, 415, 3660, 2546, 6335, 4618..."
1,1,{52: 400},{52: 100},{52: 100},57,"[38624, 5427, 28266, 41757, 6122, 3113, 9973, ...","[42526, 23261, 13141, 32989, 44141, 41089, 231...","[1803, 7734, 6137, 2283, 7438, 7666, 5265, 114..."
2,2,{32: 400},{32: 100},{32: 100},33,"[23117, 33712, 37542, 31746, 49253, 36127, 552...","[33040, 38597, 30277, 43613, 26454, 40693, 361...","[3366, 7480, 2465, 9699, 8860, 4954, 689, 6158..."
3,3,{76: 400},{76: 100},{76: 100},3,"[35788, 42003, 28492, 39416, 11844, 35000, 498...","[49443, 10426, 11318, 15173, 36490, 35418, 532...","[5062, 5355, 5477, 3097, 6744, 1580, 5547, 524..."
4,4,{80: 400},{80: 100},{80: 100},15,"[19346, 34367, 26146, 41023, 15974, 41326, 383...","[6440, 45794, 28508, 24698, 3026, 1081, 47146,...","[5089, 1883, 2769, 208, 3051, 9590, 260, 7760,..."
...,...,...,...,...,...,...,...,...
95,95,{54: 400},{54: 100},{54: 100},37,"[33406, 33249, 40826, 1577, 19205, 29793, 1580...","[5071, 27780, 1178, 456, 13428, 36758, 26599, ...","[7598, 7997, 1255, 3632, 9430, 8158, 3903, 322..."
96,96,{51: 400},{51: 100},{51: 100},32,"[2882, 40895, 48859, 13023, 35381, 44138, 1830...","[16786, 29546, 33842, 39553, 31485, 33216, 982...","[9847, 7619, 7356, 8788, 5484, 3337, 7425, 350..."
97,97,{92: 400},{92: 100},{92: 100},56,"[8165, 37937, 8080, 27736, 13140, 29705, 11549...","[25644, 27435, 8485, 11055, 25695, 32326, 4371...","[4613, 6896, 1504, 2000, 9938, 9254, 6944, 689..."
98,98,{63: 400},{63: 100},{63: 100},85,"[10625, 6989, 20417, 21242, 6953, 42918, 2716,...","[34607, 147, 8330, 2887, 27506, 41920, 41902, ...","[564, 2046, 8403, 7185, 7449, 5055, 6195, 971,..."


Unnamed: 0,client_id,train,val,test,new_id,train_indices,val_indices,test_indices
0,11,{61: 400},{61: 100},{61: 100},0,"[43031, 42082, 30244, 3309, 18975, 38580, 4852...","[38599, 25239, 37441, 29374, 49725, 20029, 492...","[4562, 9064, 5318, 1492, 8195, 7663, 1625, 197..."
1,81,{4: 400},{4: 100},{4: 100},1,"[30665, 38204, 27342, 39748, 35882, 19181, 135...","[46937, 31311, 32117, 954, 10099, 10243, 21109...","[8400, 4951, 2720, 4749, 3380, 1379, 9158, 885..."
2,77,{2: 400},{2: 100},{2: 100},2,"[9964, 580, 382, 16332, 27661, 12363, 47261, 3...","[46974, 4031, 31300, 13524, 47390, 36345, 8681...","[2777, 3397, 2804, 4071, 6301, 4056, 573, 1047..."
3,3,{76: 400},{76: 100},{76: 100},3,"[35788, 42003, 28492, 39416, 11844, 35000, 498...","[49443, 10426, 11318, 15173, 36490, 35418, 532...","[5062, 5355, 5477, 3097, 6744, 1580, 5547, 524..."
4,65,{53: 400},{53: 100},{53: 100},4,"[49422, 26610, 18645, 28859, 44041, 20464, 302...","[11550, 22135, 49432, 8198, 46248, 13245, 1948...","[1457, 9449, 245, 793, 9956, 2467, 1706, 8903,..."
...,...,...,...,...,...,...,...,...
95,60,{29: 400},{29: 100},{29: 100},95,"[31711, 30207, 14744, 31805, 24054, 27887, 158...","[36208, 9684, 28788, 29096, 37652, 10355, 3919...","[5598, 4632, 8888, 3459, 8984, 9772, 386, 2050..."
96,12,{98: 400},{98: 100},{98: 100},96,"[45304, 740, 29937, 42097, 47458, 5459, 46462,...","[10317, 22963, 11504, 20033, 25848, 34236, 388...","[3879, 2020, 1468, 5061, 2498, 7114, 8601, 892..."
97,50,{7: 400},{7: 100},{7: 100},97,"[31534, 25681, 5516, 38429, 3810, 7173, 30265,...","[39476, 43315, 23862, 26164, 32032, 14553, 214...","[2201, 64, 7684, 8294, 9652, 5518, 3103, 198, ..."
98,71,{12: 400},{12: 100},{12: 100},98,"[38619, 23655, 14684, 19663, 25921, 6031, 4694...","[10034, 48415, 6835, 20126, 17653, 35516, 2545...","[6905, 5423, 324, 5814, 7116, 1194, 6861, 1649..."


Client ID: 11
Train indices: [43031, 42082, 30244, 3309, 18975, 38580, 48521, 34818, 32360, 39908]
Val indices: [38599, 25239, 37441, 29374, 49725, 20029, 49205, 37139, 33925, 11216]
Test indices: [4562, 9064, 5318, 1492, 8195, 7663, 1625, 1970, 8705, 4814]
--------------------------------------------------
Client ID: 81
Train indices: [30665, 38204, 27342, 39748, 35882, 19181, 13503, 43247, 11798, 2769]
Val indices: [46937, 31311, 32117, 954, 10099, 10243, 21109, 14996, 27761, 27678]
Test indices: [8400, 4951, 2720, 4749, 3380, 1379, 9158, 8853, 3916, 6109]
--------------------------------------------------
Client ID: 77
Train indices: [9964, 580, 382, 16332, 27661, 12363, 47261, 33073, 31638, 23970]
Val indices: [46974, 4031, 31300, 13524, 47390, 36345, 8681, 41812, 22766, 37530]
Test indices: [2777, 3397, 2804, 4071, 6301, 4056, 573, 1047, 4257, 7968]
--------------------------------------------------
Client ID: 3
Train indices: [35788, 42003, 28492, 39416, 11844, 35000, 49828, 3527

  checkpoint = torch.load(filename, map_location=torch.device('cpu'))


Loading checkpoint '/content/drive/MyDrive/MLDL/Cifar-100/Checkpoints/checkpoint_pfedhn_0_1_1_4_epoch_500.pth.tar' (epoch 500)

A saving checkpoint with these parameters exists:
Last checkpoint details:
Epoch reached: 500
User input variables: IID: 0, Participation: uniform, Nc: 1, J: 4
Test Accuracy: 55.266666666666666%
Test Avg Loss: 4.403396789838751
Test Avg Acc: 55.266666666666666%
Val Accuracy: 55.57777777777778%
Val Avg Loss 4.501711083676956
Val Avg Acc 55.57777777777777%


 the old model accuracy on old client is 0.5526666666666666

the new model test accuracy on old client is:  0.5526666666666666


  9%|▉         | 9/100 [02:18<25:45, 16.98s/it]


Step: 10, New Clients Test Loss: 15.3295,  New Clients Test Acc: 0.0780, Test Loss: 5.5477, Test Acc: 0.4916

check acc: 0.5526666666666666

new client acc: 0.07799999999999999
test_avg_loss_new: 15.3295
test_avg_acc_new: 0.078
test_loss: 5.5477
test_acc: 0.4916
val_avg_loss_new: 15.9175
val_avg_acc_new: 0.069
val_loss: 5.7548
val_acc: 0.4908


 19%|█▉        | 19/100 [06:30<28:23, 21.04s/it]


Step: 20, New Clients Test Loss: 9.7928,  New Clients Test Acc: 0.1320, Test Loss: 5.0320, Test Acc: 0.4847

check acc: 0.5526666666666666

new client acc: 0.132
test_avg_loss_new: 9.7928
test_avg_acc_new: 0.132
test_loss: 5.032
test_acc: 0.4847
val_avg_loss_new: 10.1745
val_avg_acc_new: 0.128
val_loss: 5.1478
val_acc: 0.4882


 29%|██▉       | 29/100 [11:10<34:48, 29.41s/it]


Step: 30, New Clients Test Loss: 5.3170,  New Clients Test Acc: 0.3480, Test Loss: 4.6924, Test Acc: 0.4926

check acc: 0.5526666666666666

new client acc: 0.348
test_avg_loss_new: 5.317
test_avg_acc_new: 0.348
test_loss: 4.6924
test_acc: 0.4926
val_avg_loss_new: 5.4367
val_avg_acc_new: 0.367
val_loss: 4.7947
val_acc: 0.4994


 39%|███▉      | 39/100 [15:58<13:19, 13.10s/it]


Step: 40, New Clients Test Loss: 4.6805,  New Clients Test Acc: 0.4100, Test Loss: 4.7765, Test Acc: 0.4845

check acc: 0.5526666666666666

new client acc: 0.41
test_avg_loss_new: 4.6805
test_avg_acc_new: 0.41
test_loss: 4.7765
test_acc: 0.4845
val_avg_loss_new: 4.7664
val_avg_acc_new: 0.4
val_loss: 4.7633
val_acc: 0.4931


 49%|████▉     | 49/100 [19:31<10:41, 12.57s/it]


Step: 50, New Clients Test Loss: 2.7460,  New Clients Test Acc: 0.5490, Test Loss: 4.7114, Test Acc: 0.4854

check acc: 0.5526666666666666

new client acc: 0.549
test_avg_loss_new: 2.746
test_avg_acc_new: 0.549
test_loss: 4.7114
test_acc: 0.4854
val_avg_loss_new: 2.7009
val_avg_acc_new: 0.56
val_loss: 4.7927
val_acc: 0.4969


 59%|█████▉    | 59/100 [23:02<10:23, 15.22s/it]


Step: 60, New Clients Test Loss: 1.9918,  New Clients Test Acc: 0.6400, Test Loss: 4.7545, Test Acc: 0.4841

check acc: 0.5526666666666666

new client acc: 0.64
test_avg_loss_new: 1.9918
test_avg_acc_new: 0.64
test_loss: 4.7545
test_acc: 0.4841
val_avg_loss_new: 1.8842
val_avg_acc_new: 0.66
val_loss: 4.7816
val_acc: 0.498


 59%|█████▉    | 59/100 [24:05<16:44, 24.51s/it]

convergence found
Global Test Acc Diff: 0.06856666666666666
Global Val Acc Diff: 0.05777777777777776
New Clients Test Acc: 0.64
New Clients Val Acc: 0.66





'pFedHN_embedtuning ()\n      just take the last check pint of the pFedHN run don on 1900 rounds and with 90 clients'

In [None]:
import argparse
import json
import logging
import random
from collections import OrderedDict, defaultdict
from pathlib import Path

import numpy as np
import torch
import torch.utils.data
from tqdm import trange
import pickle
import glob
import pandas as pd

args = args_parser()
args.iid = 0
args.participation = 1
args.algorithm = "pfedhn"
args.Nc = 1
args.local_ep = 4
args.checkpoint_resume = 1
args.checkpoint_path = "/content/drive/MyDrive/MLDL/Cifar-100/Checkpoints"

if args.gpu:
    torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'

train_set, test_set, clients = get_dataset(args)

def get_class_distribution(indices, dataset):
            targets = [dataset.targets[idx] for idx in indices]
            return dict(Counter(targets))

for client in clients:

    train_dist = get_class_distribution(client.train_indices, client.train_dataset)
    val_dist = get_class_distribution(client.val_indices, client.train_dataset)
    test_dist = get_class_distribution(client.test_indices, client.test_dataset)

    print(f"Client {client.client_id} class distribution:")
    print(f"  Train: {train_dist}")
    print(f"  Val: {val_dist}")
    print(f"  Test: {test_dist}")

KeyboardInterrupt: 