In [2]:
import numpy as np
from PIL import Image

from torchvision import datasets, transforms

class IMBALANCECIFAR10(datasets.CIFAR10):
    cls_num = 10

    def __init__(self, phase, imbalance_ratio, root='data/cifar10_lt/', imb_type='exp', train_aug=True):
        train = True if phase == 'train' else False
        super(IMBALANCECIFAR10, self).__init__(root, train, transform=None, target_transform=None, download=True)
        self.train = train
        if self.train:
            img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio)
            self.gen_imbalanced_data(img_num_list)
            if train_aug:
                self.transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                    ])
            else:
                self.transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])

        self.labels = self.targets

        print('{} Mode: Contain {} images'.format(phase, len(self.data)))

    def _get_class_dict(self):
        class_dict = dict()
        for i, anno in enumerate(self.get_annotations()):
            cat_id = anno["category_id"]
            if not cat_id in class_dict:
                class_dict[cat_id] = []
            class_dict[cat_id].append(i)
        return class_dict

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)

        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets

    def __getitem__(self, index):
        img, label = self.data[index], self.labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        return img, label #, index
    
    def __len__(self):
        return len(self.labels)

    def get_num_classes(self):
        return self.cls_num

    def get_annotations(self):
        annos = []
        for label in self.labels:
            annos.append({'category_id': int(label)})
        return annos
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.9

import os
import sys
import json
import random
import copy
import pickle
import numpy as np
import pandas as pd
import medmnist
from medmnist import INFO

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

from util.args import args_parser
from util.path import set_result_dir, set_dict_user_path
from util.data_simulator import shard_balance, dir_balance
from util.longtail_dataset import IMBALANCECIFAR10, IMBALANCECIFAR100
from util.misc import adjust_learning_rate


def get_dataset(args):
    MEAN = {'mnist': (0.1307,), 'fmnist': (0.5,), 'emnist': (0.5,), 'svhn': [0.4376821, 0.4437697, 0.47280442], 
            'cifar10': [0.485, 0.456, 0.406], 'cifar100': [0.507, 0.487, 0.441], 'pathmnist': (0.5,), 
            'octmnist': (0.5,), 'organamnist': (0.5,), 'dermamnist': (0.5,), 'bloodmnist': (0.5,)}
    STD = {'mnist': (0.3081,), 'fmnist': (0.5,), 'emnist': (0.5,), 'svhn': [0.19803012, 0.20101562, 0.19703614], 
           'cifar10': [0.229, 0.224, 0.225], 'cifar100': [0.267, 0.256, 0.276], 'pathmnist': (0.5,),
           'octmnist': (0.5,), 'organamnist': (0.5,), 'dermamnist': (0.5,), 'bloodmnist': (0.5,)}
    
    if 'lt' not in args.dataset:
        noaug = [transforms.ToTensor(),
                 transforms.Normalize(mean=MEAN[args.dataset], std=STD[args.dataset])]
        
        weakaug = [transforms.RandomHorizontalFlip(),
                   transforms.ToTensor(),
                   transforms.Normalize(mean=MEAN[args.dataset], std=STD[args.dataset])]
        
        trans_noaug = transforms.Compose(noaug)
        trans_weakaug = transforms.Compose(weakaug)
        
    # standard benchmarks
    print('Load Dataset {}'.format(args.dataset))
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST(args.data_dir, train=True, download=True, transform=trans_weakaug)
        dataset_query = datasets.MNIST(args.data_dir, train=True, download=True, transform=trans_noaug)
        dataset_test = datasets.MNIST(args.data_dir, train=False, download=True, transform=trans_noaug)
    
    elif args.dataset == "fmnist":
        dataset_train = datasets.FashionMNIST(args.data_dir, download=True, train=True, transform=trans_weakaug)
        dataset_query = datasets.FashionMNIST(args.data_dir, download=True, train=True, transform=trans_noaug)
        dataset_test = datasets.FashionMNIST(args.data_dir, download=True, train=False, transform=trans_noaug)

    elif args.dataset == 'emnist':
        dataset_train = datasets.EMNIST(args.data_dir, split='byclass', train=True, download=True, transform=trans_weakaug)
        dataset_query = datasets.EMNIST(args.data_dir, split='byclass', train=True, download=True, transform=trans_noaug)
        dataset_test = datasets.EMNIST(args.data_dir, split='byclass', train=False, download=True, transform=trans_noaug)

    elif args.dataset == 'svhn':
        dataset_train = datasets.SVHN(args.data_dir, 'train', download=True, transform=trans_weakaug)
        dataset_query = datasets.SVHN(args.data_dir, 'train', download=True, transform=trans_noaug)
        dataset_test = datasets.SVHN(args.data_dir, 'test', download=True, transform=trans_noaug)
            
    elif args.dataset == 'cifar10':
        dataset_train = datasets.CIFAR10(args.data_dir, train=True, download=True, transform=trans_weakaug)
        dataset_query = datasets.CIFAR10(args.data_dir, train=True, download=True, transform=trans_noaug)
        dataset_test = datasets.CIFAR10(args.data_dir, train=False, download=True, transform=trans_noaug)
            
    elif args.dataset == 'cifar10_lt':
        dataset_train = IMBALANCECIFAR10('train', args.imb_ratio, args.data_dir)
        dataset_query = IMBALANCECIFAR10('train', args.imb_ratio, args.data_dir, train_aug=False)
        dataset_test = IMBALANCECIFAR10('test', args.imb_ratio, args.data_dir)
        
    elif args.dataset == 'cifar100':
        dataset_train = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=trans_weakaug)
        dataset_query = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=trans_noaug)
        dataset_test = datasets.CIFAR100(args.data_dir, train=False, download=True, transform=trans_noaug)
            
    elif args.dataset == 'cifar10_lt':
        dataset_train = IMBALANCECIFAR100('train', args.imb_ratio, args.data_dir)
        dataset_query = IMBALANCECIFAR100('train', args.imb_ratio, args.data_dir, train_aug=False)
        dataset_test = IMBALANCECIFAR100('test', args.imb_ratio, args.data_dir)

    # medical benchmarks
    elif args.dataset in ['pathmnist', 'octmnist', 'organamnist', 'dermamnist', 'bloodmnist']:
        DataClass = getattr(medmnist, INFO[args.dataset]['python_class'])
        
        dataset_train = DataClass(download=True, split='train', transform=trans_weakaug)
        dataset_query = DataClass(download=True, split='train', transform=trans_noaug)
        dataset_test = DataClass(download=True, split='test', transform=trans_noaug)
        
    else:
        exit('Error: unrecognized dataset')
        
    args.dataset_train = dataset_train
    args.total_data = len(dataset_train)

    if args.partition == "shard_balance":
        dict_users_train_total = shard_balance(dataset_train, args)
        dict_users_test_total = shard_balance(dataset_test, args)
    elif args.partition == "dir_balance":
        dict_users_train_total, sample = dir_balance(dataset_train, args)
        dict_users_test_total, _ = dir_balance(dataset_test, args, sample)
    
    args.n_query = round(args.total_data, -2) * args.query_ratio
    args.n_data = round(args.total_data, -2) * args.current_ratio
    
    return dataset_train, dataset_query, dataset_test, dict_users_train_total, dict_users_test_total, args

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'models'