In [None]:
import copy
import numpy as np   
import torch
import torch.nn.functional as F
import os
import hnswlib

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils.sampling import partition_data_dataset
from utils.options import args_parser
from models.Update import DatasetSplit
from models.test import test_img
from models.resnet_client import resnet20, resnet16, resnet8
from torchvision.models import mobilenet_v3_small

In [None]:
if __name__ == '__main__':
    # parse args
    args = args_parser(args=['--dataset','cinic', '--momentum','0.9', '--alpha','10', 
                                '--epochs','50', '--gpu','0', '--lr','0.01'])

    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print('torch.cuda:',torch.cuda.is_available())
    print(args)

In [None]:
# load dataset and split users
# No Public Data Partition

if __name__ == '__main__':
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))])
        dataset_train = datasets.MNIST('data/mnist/', train = True, download = False, transform=trans_mnist)
        dataset_test = datasets.MNIST('data/mnist/', train = False, download = False, transform=trans_mnist)

        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'fashionmnist':
        trans_fashionmnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))])
        dataset_train = datasets.FashionMNIST('data/fashionmnist/', train = True, download = False, transform = trans_fashionmnist)
        dataset_test = datasets.FashionMNIST('data/fashionmnist/', train = False, download = False, transform = trans_fashionmnist)
    
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('data/cifar', train = True, download = False, transform = trans_cifar)
        dataset_test = datasets.CIFAR10('data/cifar', train = False, download = False, transform = trans_cifar)

        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    elif args.dataset == 'cinic':
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        transform_cinic = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean = cinic_mean, std = cinic_std)
        ])
        cinic_directory = 'data/cinic'
        dataset_train = datasets.ImageFolder(
            os.path.join(cinic_directory, 'train'),
            transform=transform_cinic
        )
        dataset_valid = datasets.ImageFolder(
            os.path.join(cinic_directory, 'valid'),
            transform=transform_cinic
        )
        dataset_test = datasets.ImageFolder(
            os.path.join(cinic_directory, 'test'),
            transform=transform_cinic
        )
        dataset_train = torch.utils.data.ConcatDataset([dataset_train, dataset_valid])


        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    print("num_users:", len(dict_users))
    img_size = dataset_train[0][0].shape
    print(img_size)

In [None]:
# Initialize model
model_init = {}
acc_init_test = []
for x in range(10):
    if x % 3 == 0:
        model_init[x] = resnet8(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    elif x % 3 == 1:
        model_init[x] = resnet16(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    else:
        model_init[x] = resnet20(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    acc_init_test.append(acc_test.item())
print("mean AccTop1 on all clients:",float(np.mean(np.array(acc_init_test))))

In [None]:
# copy init_model_parameters
model = {}
for i in range(10):
    model[i] = copy.deepcopy(model_init[i])
    print("---------------------------------model[", i, "]---------------------------------")
    print(model[i])

In [None]:
class KnowledgeCache:
    def __init__(self,n_classes, R):
        self.n_classes=n_classes
        self.cache={}
        self.idx_to_hash={}
        self.relation={}
        for i in range(n_classes):
            self.cache[i]={}
        self.R = R
        pass

    def add_hash(self, hash, label,idx):
        for k_,l_,i_ in zip(hash, label, idx):
            self.add_hash_single(k_, l_, i_)

    def add_hash_single(self,hash,label,idx):
        self.cache[int(label)][idx]=torch.Tensor(np.array([0.0 for _ in range(self.n_classes)]))
        self.idx_to_hash[idx] = hash

    # Approximate nearest neighbor search (ANN) finds semantically similar neighbors for each data sample
    def build_relation(self):
        hnsw_sim = 0
        for c in range(self.n_classes):
            idx_vectors=[key for key in self.cache[c].keys()]
            data = list()
            data=np.array([self.idx_to_hash[key].numpy() for key in idx_vectors])
            num_elements = data.shape[0]
            dim = data.shape[1]
            data_labels = np.arange(num_elements)
            index = hnswlib.Index(space='cosine', dim=dim)
            index.init_index(max_elements=num_elements, ef_construction = 1000, M = 64)
            index.add_items(data, data_labels)
            index.set_ef(1000)
            labels, distances = index.knn_query(data, self.R+1)
            for idx,ele in enumerate(labels):
                self.relation[idx_vectors[int(idx)]]=[]
                for x in ele[1:]:
                    self.relation[idx_vectors[int(idx)]].append(idx_vectors[x])

    def set_knowledge(self,knowledge,label,idx):
        for k_, l_, i_ in zip(knowledge,label,idx):
            self.set_knowledge_single(k_, l_, i_)

    def set_knowledge_single(self,knowledge,label,idx):
        self.cache[int(label)][idx]=knowledge

    def fetch_knowledge(self,label,idx):
        result = []
        for l_,i_ in zip(label,idx):
            result.append(self.fetch_knowledge_single(l_, i_))
        return result

    def fetch_knowledge_single(self, label, idx):
        result = []
        pairs=self.relation[idx]
        for pair in pairs:
            result.append(self.cache[int(label)][pair])
        return result

def knowledge_avg_single(knowledge, weights):
    result=torch.zeros_like(knowledge[0]).cpu()
    sum = 0
    for _k,_w in zip(knowledge,weights):
        result.add_(_k.cpu()*_w)
        sum = sum+_w
    result = result / sum
    return torch.tensor(np.array(result.detach().cpu()))

class KL_Loss(nn.Module):
    def __init__(self, temperature = 3.0):
        super(KL_Loss, self).__init__()
        self.T = temperature

    def forward(self, output_batch, teacher_outputs):
        output_batch = F.log_softmax(output_batch / self.T, dim = 1)
        teacher_outputs = F.softmax(teacher_outputs / self.T, dim = 1) + 10 ** (-7)
        loss = self.T * self.T * nn.KLDivLoss(reduction = 'batchmean')(output_batch, teacher_outputs)
        return loss

image_scaler = transforms.Compose([transforms.Resize(224),])
criterion_KL = KL_Loss()
print("*********start training with FedCache***************")
train_data_local_dict_seq = {}
for client_index in range(len(dict_users)):
    train_data_local_dict_seq[client_index] = []
    for batch_idx, (images, labels) in enumerate(DataLoader(DatasetSplit(dataset_train, dict_users[client_index]), batch_size = 256)):
        train_data_local_dict_seq[client_index].append((images, labels))
knowledge_cache = KnowledgeCache(10, 16)                # class_num = 10
encoder = mobilenet_v3_small(weights='IMAGENET1K_V1').to(args.device)
encoder = torch.nn.Sequential( *( list(encoder.children())[:-1] ) )
encoder.eval()
for client_index in range(len(dict_users)):
    cur_idx = 0
    for batch_idx, (images, labels) in enumerate(train_data_local_dict_seq[client_index]):
        images, labels = images.to(args.device), labels.to(args.device)
        hash_code = encoder(image_scaler(images)).detach().cpu()
        hash_code = torch.tensor(hash_code.reshape((hash_code.shape[0],hash_code.shape[1])))
        for img, hash, label in zip(images, hash_code, labels):
            knowledge_cache.add_hash_single(hash, label, (client_index, cur_idx))
            cur_idx = cur_idx + 1
knowledge_cache.build_relation()
print("*********knowledge cache initialized successfully***************")
for global_epoch in range(args.epochs):
    acc_all=[]
    print("*********communication round", global_epoch, "***************")
    for client_index in range(len(dict_users)):
        model[client_index].to(args.device)
        model[client_index].train()
        optim=torch.optim.SGD(model[client_index].parameters(), lr = args.lr, momentum = 0.9, weight_decay = 5e-4)
        cur_idx = 0
        for batch_idx, (images, labels) in enumerate(train_data_local_dict_seq[client_index]):
            labels = torch.tensor(labels, dtype = torch.long)
            images, labels = images.to(args.device), labels.to(args.device)
            
            log_probs = model[client_index](images)
            loss_true = F.cross_entropy(log_probs, labels)
            loss = None
            
            teacher_knowledge=[]
            for img, logit, label in zip(images, log_probs, labels):
                fetched_knowledge_single = knowledge_cache.fetch_knowledge_single(label, (client_index, cur_idx))
                knowledge_cache.set_knowledge_single(logit, label, (client_index, cur_idx))
                cur_idx = cur_idx + 1
                avg_knowledge_single = knowledge_avg_single(fetched_knowledge_single, [1 for _ in range(16)])
                teacher_knowledge.append(avg_knowledge_single.detach().cpu().numpy())
            teacher_knowledge = torch.tensor(np.array(teacher_knowledge)).to(args.device)
            loss_kd = criterion_KL(log_probs, teacher_knowledge / 1.0)
            loss = loss_true + 1.5 * loss_kd
            optim.zero_grad()
            loss.backward()
            optim.step()
        model[client_index].eval()
        acc_fine_test = test_img(model[client_index], dataset_test, args)
        acc_all.append(acc_fine_test.item())
    print("communication round:", global_epoch, "mean Fine_Test/AccTop1 on all clients:", float(np.mean(np.array(acc_all))))