In [1]:
import copy
import numpy as np   
import torch
import torch.nn.functional as F
import os

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils.sampling import partition_public, partition_data_dataset
from utils.options import args_parser
from models.Update import DatasetSplit, CustomDatasetSplit
from models.test import test_img
from DPKT import calculate_beta, KnowledgeBuffer, importance_sampling, predict, LDP, KnowledgeBuffer, storage
from models.resnet_client import resnet20, resnet16, resnet8

In [2]:
if __name__ == '__main__':
    # parse args
    args = args_parser(args=['--epsilon','100', '--m_num','2', '--cache_size','160', '--dataset','mnist', '--momentum','0.9', '--alpha','10', 
                             '--epochs','50', '--gpu','0', '--public_data_ratio','0.1', '--lr','0.01'])

    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print('torch.cuda:',torch.cuda.is_available())
    print(args)

torch.cuda: True
Namespace(alpha=10.0, bs=1024, cache_size=160, combined_ep=4, dataset='mnist', device=device(type='cuda', index=0), epochs=50, epsilon=100.0, gpu=0, local_ep=5, lr=0.01, m_num=2, momentum=0.9, num_classes=10, num_users=10, public_data_ratio=0.1, seed=1, train_bs=8)


In [3]:
if __name__ == '__main__':
    beta = calculate_beta(args.epsilon, args.m_num, args.num_classes, args.num_users)
    knowledge_buffer = KnowledgeBuffer(args.cache_size)
    print(beta)

0.9626406528772272


In [None]:
# load dataset and split users
if __name__ == '__main__':
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081))])
        dataset_train = datasets.MNIST('data/mnist/', train = True, download = False, transform = trans_mnist)
        dataset_test = datasets.MNIST('data/mnist/', train = False, download = False, transform = trans_mnist)

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'fashionmnist':
        trans_fashionmnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081))])
        dataset_train = datasets.FashionMNIST('data/fashionmnist/', train = True, download = False, transform = trans_fashionmnist)
        dataset_test = datasets.FashionMNIST('data/fashionmnist/', train = False, download = False, transform = trans_fashionmnist)
    
        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('data/cifar', train = True, download = False, transform = trans_cifar)
        dataset_test = datasets.CIFAR10('data/cifar', train = False, download = False, transform = trans_cifar)

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    elif args.dataset == 'cinic':
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        transform_cinic = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cinic_mean, std=cinic_std)
        ])
        cinic_directory = 'data/cinic'
        dataset_train = datasets.ImageFolder(
            os.path.join(cinic_directory, 'train'),
            transform=transform_cinic
        )
        dataset_valid = datasets.ImageFolder(
            os.path.join(cinic_directory, 'valid'),
            transform=transform_cinic
        )
        dataset_test = datasets.ImageFolder(
            os.path.join(cinic_directory, 'test'),
            transform=transform_cinic
        )
        dataset_train = torch.utils.data.ConcatDataset([dataset_train, dataset_valid])

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    print("num_users:", len(dict_users))
    img_size = dataset_train[0][0].shape
    print(img_size)

len(dataset_public):  6000
len(dataset_train):  54000
len(dataset_test):  10000
N = 54000
num_users: 10
torch.Size([3, 28, 28])


: 

In [None]:
# Data Distribution of Public Dataset
import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
label_distribution = [[] for _ in range(10)]
for c_id, (images,labels) in enumerate(dataset_public):
    label_distribution[labels].append(c_id)

list = []
for i in range(10):
    list.append(len(label_distribution[i]))
print(list)

x=np.arange(10)
y=np.array(list)

plt.title("Data Distribution of Public Dataset")
plt.bar(x,y,tick_label=[c_id for c_id in range(10)], width = 0.5)
plt.show()

In [None]:
# Label Distribution of Different Clients
import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize = (8, 6))
label_distribution = [[] for _ in range(10)]
for c_id in dict_users:
    for idx in dict_users[c_id]:
        label_distribution[dataset_train[idx][1]].append(c_id)

plt.hist(label_distribution, stacked = True,
            bins = np.arange(-0.5, 10 + 1.5, 1),
            label = np.arange(10), rwidth = 0.8)
plt.xticks(np.arange(10), ["Client %d" %
                                    c_id for c_id in range(10)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend()
plt.title("Label Distribution of Different Clients")
plt.show()

In [None]:
# Initialize model
model_init = {}
acc_init_test = []
for x in range(10):
    if x % 3 == 0:
        model_init[x] = resnet8(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    elif x % 3 == 1:
        model_init[x] = resnet16(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    else:
        model_init[x] = resnet20(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    acc_init_test.append(acc_test.item())
print("mean AccTop1 on all clients:",float(np.mean(np.array(acc_init_test))))

In [None]:
# copy init_model_parameters
model = {}
for i in range(10):
    model[i] = copy.deepcopy(model_init[i])
    print("---------------------------------model[", i, "]---------------------------------")
    print(model[i])

In [None]:
# calculate total_params
total_params = sum(p.numel() for p in model[0].parameters())

# calculate trainable_params
trainable_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)

print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

In [None]:
# pre-train
for uid in dict_users:
    model[uid].to(args.device)
    model[uid].train()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model[uid].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)
    ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[uid]), batch_size = 512, shuffle = True)

    for batch_idx, (images, labels) in enumerate(ldr_train):
        images, labels = images.to(args.device), labels.to(args.device)
        optimizer.zero_grad()
        log_probs = model[uid](images)
        loss = criterion(log_probs, labels)
        loss.backward()
        optimizer.step()

data_rounds = []

for round in range(args.epochs):

    acc_all_local=[]
    acc_all=[]

    '''
      Ensemble Knowledge Acquisition
    '''
    knowledge_transfer_data = []
    for uid in dict_users:
        transfer_data = importance_sampling(args.m_num, model[uid], dataset_public, args)
        knowledge_transfer_data.append(transfer_data)           # list
    knowledge_transfer_data = [item for sublist in knowledge_transfer_data for item in sublist]
    knowledge_transfer_data = [tensor.item() for tensor in knowledge_transfer_data]

    transfer_data_index = []
    for element in knowledge_transfer_data:
        transfer_data_index.append(element)
    transfer_data_index = np.array(transfer_data_index)           # np.array() [2010 2460 8597 9544 10616 3829 251 5537 6092 1105]

    # updata perturbed results and entropy information
    uploaded_perturbed_predictions = []
    uploaded_entropy_information = []

    # pre_model_parameters
    model_pre = {}
    for uid in dict_users:
        model_pre[uid] = copy.deepcopy(model[uid])

    '''
      Local_Training
    '''
    for uid in dict_users:
        model[uid].to(args.device)
        model[uid].train()
        criterionCE = nn.CrossEntropyLoss()
        criterionKL = nn.KLDivLoss(reduction='batchmean')
        optimizer = torch.optim.SGD(model[uid].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)
        ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[uid]), batch_size = 8, shuffle = True)
        for _ in range(args.local_ep):
            for batch_idx, (images, labels) in enumerate(ldr_train):
                images, labels = images.to(args.device), labels.to(args.device)
                probs = model[uid](images)
                log_probs = F.log_softmax(probs, dim = 1)
                pre_probs = F.softmax(model_pre[uid](images), dim = 1)
                loss_hard = criterionCE(probs, labels)
                loss_pri = criterionKL(log_probs, pre_probs)
                loss = loss_hard + (1 + 0.04 * (round + 1)) * loss_pri
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Computing local model predictions and information entropy
        knowledge_transfer_images = []
        for idx in transfer_data_index:
            image = dataset_public[idx]
            knowledge_transfer_images.append(image)
        local_predictions, entropy_information = predict(model[uid], knowledge_transfer_images, uid, args)                           # local_predictions、entropy_information

        # Perturb local model predictions
        perturbed_local_predictions = LDP(local_predictions, beta, args.num_classes, args).tolist()                   # one-hot
        uploaded_perturbed_predictions.append(perturbed_local_predictions)
        uploaded_entropy_information.append(entropy_information.tolist())

    uploaded_perturbed_predictions = np.array(uploaded_perturbed_predictions, dtype = float)
    
    
    '''
      Entropy-Aware Aggregation
    '''
    entropies_array = np.array(uploaded_entropy_information)
    sorted_indices = np.argsort(entropies_array, axis = 0)
    k = 7
    top_k_indices = sorted_indices[:k, :]
    average_predictions = []
    for i in range(len(uploaded_perturbed_predictions[0])):
        top_k_predictions = [uploaded_perturbed_predictions[j][i] for j in top_k_indices[:, i]]
        average_prediction = np.mean(top_k_predictions, axis = 0)
        average_predictions.append(average_prediction)
    average_predictions = np.array(average_predictions)

    # Knolwedge aggergation
    # uploaded_perturbed_predictions.mean(axis = 0)
    aggregated_predictions = (average_predictions - (1 - beta) / args.num_classes) / beta
    aggregated_predictions = torch.tensor(aggregated_predictions)
    aggregated_predictions = torch.argmax(aggregated_predictions, dim = 1)

    # Store distillation knowledge data in the distillation knowledge cache
    Fine_tuning_knowledge_transfer_data, Fine_tuning_aggregated_predictions = storage(transfer_data_index, aggregated_predictions, knowledge_buffer)
    Fine_tuning_knowledge_transfer_data = Fine_tuning_knowledge_transfer_data.tolist()
    Fine_tuning_aggregated_predictions = Fine_tuning_aggregated_predictions.tolist()

    # pre_model_parameters
    model_local = {}
    for uid in dict_users:
        model_local[uid] = copy.deepcopy(model[uid])
    

    labels = []
    for i in Fine_tuning_knowledge_transfer_data:
        labels.append(dataset_public[i][1])
    data_rounds.append(labels)

    '''
      Combined_Distillation_Training
    '''
    for uid in dict_users:
        # Local_Training
        model[uid].to(args.device)
        model[uid].train()
        Fine_criterion = nn.CrossEntropyLoss()
        criterionKL = nn.KLDivLoss(reduction = 'batchmean')
        Fine_optimizer = torch.optim.SGD(model[uid].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)
        Fine_train = DataLoader(CustomDatasetSplit(dataset_public, Fine_tuning_knowledge_transfer_data, Fine_tuning_aggregated_predictions), 
                                    batch_size = 8, shuffle = True)
        for _ in range(args.combined_ep):
            for batch_idx, (images, labels) in enumerate(Fine_train):
                images, labels = images.to(args.device), labels.to(args.device)
                probs = model[uid](images)
                log_probs = F.log_softmax(probs, dim = 1)
                local_probs = F.softmax(model_local[uid](images), dim = 1)
                loss_local_hard = Fine_criterion(probs, labels)
                loss_local_pri = criterionKL(log_probs, local_probs)
                loss_local = 1 * loss_local_hard + 3 * loss_local_pri
                Fine_optimizer.zero_grad()
                loss_local.backward()
                Fine_optimizer.step()                         
        
        model[uid].eval()
        acc_fine_test = test_img(model[uid], dataset_test, args)
        # print("round:",round,"uid:",uid,"Fine_Train Testing accuracy: {:.2f}".format(acc_fine_test))
        acc_all.append(acc_fine_test.item())
    
    args.lr = args.lr * (1 - round / args.epochs * 0.9)
    print(round + 1, ":", "mean AccTop1 on all clients:", float(np.mean(np.array(acc_all))))

In [None]:
# Draw the image of the n-th sample
image, label = dataset_public[3135]

image_np = image.numpy()
image_np = image_np.transpose(1, 2, 0)

plt.figure(figsize=(2, 2))
plt.imshow(image_np)
plt.title(f'Label: {label}')
plt.show()

In [None]:
# Draw the image of the n-th sample
image, label = dataset_public[2473]

image = image[0]
plt.figure(figsize=(2, 2))
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f'Label: {label}')
plt.show()