In [1]:
import sys
import os
import argparse

# 添加环境
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../MyExpr")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../FedML")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

print(sys.path)

['/home/ubuntu', '/home/ubuntu/FedML', '/home/ubuntu/MyExpr', '/home/ubuntu/Fed_Expr/MyExpr', '/home/ubuntu/miniconda/envs/fedml/lib/python37.zip', '/home/ubuntu/miniconda/envs/fedml/lib/python3.7', '/home/ubuntu/miniconda/envs/fedml/lib/python3.7/lib-dynload', '', '/home/ubuntu/miniconda/envs/fedml/lib/python3.7/site-packages', '/home/ubuntu/miniconda/envs/fedml/lib/python3.7/site-packages/IPython/extensions', '/home/ubuntu/.ipython']


In [27]:
import torch
import os.path
from torchvision.datasets import utils, MNIST, CIFAR10
from torchvision import transforms
from torch.utils.data import Subset, DataLoader
from PIL import Image

num_workers = 2

class FEMNIST(MNIST):
    """
    This dataset is derived from the Leaf repository
    (https://github.com/TalwalkarLab/leaf) pre-processing of the Extended MNIST
    dataset, grouping examples by writer. Details about Leaf were published in
    "LEAF: A Benchmark for Federated Settings" https://arxiv.org/abs/1812.01097.
    """
    resources = [
        ('https://raw.githubusercontent.com/tao-shen/FEMNIST_pytorch/master/femnist.tar.gz',
         '59c65cec646fc57fe92d27d83afdf0ed')]

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False):
        super(MNIST, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
        self.train = train

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file

        self.data, self.targets, self.users_index = torch.load(os.path.join(self.processed_folder, data_file))

    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])
        img = Image.fromarray(img.numpy(), mode='F')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def download(self):
        """Download the FEMNIST data if it doesn't exist in processed_folder already."""
        import shutil

        if self._check_exists():
            return

        utils.makedir_exist_ok(self.raw_folder)
        utils.makedir_exist_ok(self.processed_folder)

        # download files
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            utils.download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)

        # process and save as torch files
        print('Processing...')
        shutil.move(os.path.join(self.raw_folder, self.training_file), self.processed_folder)
        shutil.move(os.path.join(self.raw_folder, self.test_file), self.processed_folder)


def Dataset(args):
    trainset, testset = None, None

    if args.dataset == 'cifar10':
        tra_trans = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        val_trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trainset = CIFAR10(root="./data", train=True, download=False, transform=tra_trans)
        testset = CIFAR10(root="./data", train=False, download=False, transform=val_trans)

    if args.dataset == 'femnist' or 'mnist':
        tra_trans = transforms.Compose([
            transforms.Pad(2, padding_mode='edge'),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        val_trans = transforms.Compose([
            transforms.Pad(2, padding_mode='edge'),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        if args.dataset == 'femnist':
            trainset = FEMNIST(root='./data', train=True, transform=tra_trans)
            testset = FEMNIST(root='./data', train=False, transform=val_trans)
        if args.dataset == 'mnist':
            trainset = MNIST(root='./data', train=True, transform=tra_trans)
            testset = MNIST(root='./data', train=False, transform=val_trans)

    return trainset, testset


class Data(object):

    def __init__(self, args):
        self.args = args
        self.trainset, self.testset = None, None

        trainset, testset = Dataset(args)
        # print(trainset, type(trainset))
        # print(testset, type(testset))
        # 标签
        # print(trainset.targets, testset.targets)
        
        
        # 均等分训练数据集
        num_train = [int(len(trainset) / args.split) for _ in range(args.split)]
        # print("num_train:", num_train)

        # 求前缀和
        cumsum_train = torch.tensor(list(num_train)).cumsum(dim=0).tolist()
        # print("cumsum_train:", cumsum_train)
        
        # # 将trainset的下标按标签进行排序，然后放到idx_train， 下同
        # idx_train = sorted(range(len(trainset.targets)), key=lambda k: trainset.targets[k])  #split by class
        # 将trainset的下标抽出放到idx_train，下同
        idx_train = range(len(trainset.targets))
        
        # 将划分好的下标，按照num_train划分成不同子集，【前缀和-num_train: 前缀和】
        splited_trainset = [Subset(trainset, idx_train[off - l:off]) for off, l in zip(cumsum_train, num_train)]
        
        num_test = [int(len(testset) / args.split) for _ in range(args.split)]
        cumsum_test = torch.tensor(list(num_test)).cumsum(dim=0).tolist()
        
        # idx_train = sorted(range(len(trainset.targets)), key=lambda k: trainset.targets[k])  #split by class
        
        
        idx_test = range(len(testset.targets))
        
        splited_testset = [Subset(testset, idx_test[off - l:off]) for off, l in zip(cumsum_test, num_test)]
        # print(splited_testset)
        
        self.test_all = DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=num_workers)
        self.train_loader = [DataLoader(splited_trainset[i], batch_size=args.batchsize, shuffle=True, num_workers=num_workers)
                             for i in range(args.node_num)]
        self.test_loader = [DataLoader(splited_testset[i], batch_size=args.batchsize, shuffle=False, num_workers=num_workers)
                            for i in range(args.node_num)]
        # 重复test_all>>>
        # self.test_loader = DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=num_workers)

In [28]:
parser = argparse.ArgumentParser()
parser.add_argument('--split', type=int, default=5,
                        help='data split')
parser.add_argument('--batchsize', type=int, default=128,
                        help='batchsize')
parser.add_argument('--node_num', type=int, default=5,
                    help='Number of nodes')
parser.add_argument('--dataset', type=str, default='cifar10',
                    help='datasets: {cifar100, cifar10, femnist, mnist}')
args = parser.parse_known_args()[0]

Data = Data(args)

2000 2000
4000 2000
6000 2000
8000 2000
10000 2000
[<torch.utils.data.dataset.Subset object at 0x7f589c42f810>, <torch.utils.data.dataset.Subset object at 0x7f589c42f8d0>, <torch.utils.data.dataset.Subset object at 0x7f589c42ff50>, <torch.utils.data.dataset.Subset object at 0x7f589c42f4d0>, <torch.utils.data.dataset.Subset object at 0x7f589c42ff10>]


In [None]:

# l = [1, 2, 3, 4]
# 求前缀和
# cumsum_train = torch.tensor(l).cumsum(dim=0).tolist()
# 打印为：[1, 3, 6, 10]
# print(cumsum_train) 