In [41]:
import sys
import os
import argparse

# 添加环境
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../MyExpr")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../FedML")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

print(sys.path)

['/home/guest', '/home/guest/FedML', '/home/guest/MyExpr', '/home/guest', '/home/guest/FedML', '/home/guest/MyExpr', '/home/guest/Fed_Expr/MyExpr', '/home/guest/miniconda/envs/fedml/lib/python37.zip', '/home/guest/miniconda/envs/fedml/lib/python3.7', '/home/guest/miniconda/envs/fedml/lib/python3.7/lib-dynload', '', '/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages', '/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages/IPython/extensions', '/home/guest/.ipython']


In [42]:
import torch
import os.path
from torchvision.datasets import utils, MNIST, CIFAR10
from torchvision import transforms
from torch.utils.data import Subset, DataLoader
from PIL import Image

num_workers = 2

class FEMNIST(MNIST):
    """
    This dataset is derived from the Leaf repository
    (https://github.com/TalwalkarLab/leaf) pre-processing of the Extended MNIST
    dataset, grouping examples by writer. Details about Leaf were published in
    "LEAF: A Benchmark for Federated Settings" https://arxiv.org/abs/1812.01097.
    """
    resources = [
        ('https://raw.githubusercontent.com/tao-shen/FEMNIST_pytorch/master/femnist.tar.gz',
         '59c65cec646fc57fe92d27d83afdf0ed')]

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False):
        super(MNIST, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
        self.train = train

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file

        self.data, self.targets, self.users_index = torch.load(os.path.join(self.processed_folder, data_file))

    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])
        img = Image.fromarray(img.numpy(), mode='F')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def download(self):
        """Download the FEMNIST data if it doesn't exist in processed_folder already."""
        import shutil

        if self._check_exists():
            return

        utils.makedir_exist_ok(self.raw_folder)
        utils.makedir_exist_ok(self.processed_folder)

        # download files
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            utils.download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)

        # process and save as torch files
        print('Processing...')
        shutil.move(os.path.join(self.raw_folder, self.training_file), self.processed_folder)
        shutil.move(os.path.join(self.raw_folder, self.test_file), self.processed_folder)


def Dataset(args):
    trainset, testset = None, None

    if args.dataset == 'cifar10':
        tra_trans = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        val_trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trainset = CIFAR10(root="./data", train=True, download=True, transform=tra_trans)
        testset = CIFAR10(root="./data", train=False, download=True, transform=val_trans)

    if args.dataset == 'femnist' or 'mnist':
        tra_trans = transforms.Compose([
            transforms.Pad(2, padding_mode='edge'),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        val_trans = transforms.Compose([
            transforms.Pad(2, padding_mode='edge'),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        if args.dataset == 'femnist':
            trainset = FEMNIST(root='./data', train=True, download=True, transform=tra_trans)
            testset = FEMNIST(root='./data', train=False, download=True, transform=val_trans)
        if args.dataset == 'mnist':
            trainset = MNIST(root='./data', train=True, download=True, transform=tra_trans)
            testset = MNIST(root='./data', train=False, download=True, transform=val_trans)

    return trainset, testset


class Data(object):

    def __init__(self, args):
        self.args = args
        self.trainset, self.testset = None, None

        trainset, testset = Dataset(args)
        # print(trainset, type(trainset))
        # print(testset, type(testset))
        # 标签
        # print(trainset.targets, testset.targets)
        
        
        # 均等分训练数据集
        num_train = [int(len(trainset) / args.split) for _ in range(args.split)]
        # print("num_train:", num_train)

        # 求前缀和
        cumsum_train = torch.tensor(list(num_train)).cumsum(dim=0).tolist()
        # print("cumsum_train:", cumsum_train)
        
        # # 将trainset的下标按标签进行排序，然后放到idx_train， 下同
        idx_train = sorted(range(len(trainset.targets)), key=lambda k: trainset.targets[k])  #split by class
        # 将trainset的下标抽出放到idx_train，下同
        # idx_train = range(len(trainset.targets))
        
        # 将划分好的下标，按照num_train划分成不同子集，【前缀和-num_train: 前缀和】
        splited_trainset = [Subset(trainset, idx_train[off - l:off]) for off, l in zip(cumsum_train, num_train)]
        
        num_test = [int(len(testset) / args.split) for _ in range(args.split)]
        cumsum_test = torch.tensor(list(num_test)).cumsum(dim=0).tolist()
        
        idx_test = sorted(range(len(testset.targets)), key=lambda k: testset.targets[k])  #split by class        
        
        # idx_test = range(len(testset.targets))
        
        splited_testset = [Subset(testset, idx_test[off - l:off]) for off, l in zip(cumsum_test, num_test)]
        # print(splited_testset)
        
        self.train_all = DataLoader(trainset, batch_size=args.batchsize, shuffle=False, num_workers=num_workers)
        self.test_all = DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=num_workers)
        
        self.train_loader = [DataLoader(splited_trainset[i], batch_size=args.batchsize, shuffle=True, num_workers=num_workers)
                             for i in range(args.node_num)]
        self.test_loader = [DataLoader(splited_testset[i], batch_size=args.batchsize, shuffle=False, num_workers=num_workers)
                            for i in range(args.node_num)]
        # 重复test_all>>>
        # self.test_loader = DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=num_workers)

In [44]:
parser = argparse.ArgumentParser()
parser.add_argument('--split', type=int, default=5,
                        help='data split')
parser.add_argument('--batchsize', type=int, default=128,
                        help='batchsize')
parser.add_argument('--node_num', type=int, default=5,
                    help='Number of nodes')
parser.add_argument('--dataset', type=str, default='cifar10',
                    help='datasets: {cifar100, cifar10, femnist, mnist}')
args = parser.parse_known_args()[0]

data = Data(args)

Files already downloaded and verified
Files already downloaded and verified


In [49]:
# 总体数据统计
keys = [i for i in range(10)]
# {0: 5000, 1: 5000, 2: 5000, 3: 5000, 4: 5000, 5: 5000, 6: 5000, 7: 5000, 8: 5000, 9: 5000}
ts = data.train_all
#{0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000}
# ts = data.test_all
dic = {}
for idx, (x, y) in enumerate(ts):
    for key in keys:
        key_sum = (y == key).sum().item()
        if key in dic:
            dic[key] += key_sum
        else:
            dic[key] = key_sum
print(dic)

{0: 5000, 1: 5000, 2: 5000, 3: 5000, 4: 5000, 5: 5000, 6: 5000, 7: 5000, 8: 5000, 9: 5000}


In [50]:
# 数据划分统计
ts = data.train_loader
for i in range(args.node_num):
    dic = {}
    for idx, (x, y) in enumerate(ts[i]):
        for key in keys:
            key_sum = (y == key).sum().item()
            # print(key, key_sum)
        # break
            if key in dic:
                dic[key] += key_sum
            else:
                dic[key] = key_sum
    print("client", i, ":", dic)

# 

client 0 : {0: 5000, 1: 5000, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0}
client 1 : {0: 0, 1: 0, 2: 5000, 3: 5000, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0}
client 2 : {0: 0, 1: 0, 2: 0, 3: 0, 4: 5000, 5: 5000, 6: 0, 7: 0, 8: 0, 9: 0}
client 3 : {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 5000, 7: 5000, 8: 0, 9: 0}
client 4 : {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 5000, 9: 5000}


In [None]:
s = torch.rand(1)
print(s)
print(s.item())

tensor([0.7148])
0.7147847414016724
