In [3]:
import numpy as np
import torch
import pickle
import os
from torchvision import datasets, transforms
import re

class SemiSupervisedDataset(torch.utils.data.Dataset):
    """
    A dataset with auxiliary pseudo-labeled data.
    """
    def __init__(self, base_dataset='cifar10', take_amount=None, take_amount_seed=13, aux_data_filename=None, 
                 add_aux_labels=False, aux_take_amount=None, train=False, validation=False, **kwargs):

        self.base_dataset = base_dataset
        self.load_base_dataset(train, **kwargs)


        if validation:
            self.dataset.data = self.dataset.data[1024:]
            self.dataset.targets = self.dataset.targets[1024:]
        
        self.train = train

        if self.train:
            if take_amount is not None:
                rng_state = np.random.get_state()
                np.random.seed(take_amount_seed)
                take_inds = np.random.choice(len(self.sup_indices), take_amount, replace=False)
                np.random.set_state(rng_state)

                self.targets = self.targets[take_inds]
                self.data = self.data[take_inds]

            self.sup_indices = list(range(len(self.targets)))
            self.unsup_indices = []

            if aux_data_filename is not None:
                aux_path = aux_data_filename
                print('Loading data from %s' % aux_path)
                if os.path.splitext(aux_path)[1] == '.pickle':
                    # for data from Carmon et al, 2019.
                    with open(aux_path, 'rb') as f:
                        aux = pickle.load(f)
                    aux_data = aux['data']
                    aux_targets = aux['extrapolated_targets']
                else:
                    # for data from Rebuffi et al, 2021.
                    aux = np.load(aux_path)
                    aux_data = aux['image']
                    print(aux_data.shape)
                    aux_targets = aux['label']
                
                orig_len = len(self.data)

                if aux_take_amount is not None:
                    rng_state = np.random.get_state()
                    np.random.seed(take_amount_seed)
                    take_inds = np.random.choice(len(aux_data), aux_take_amount, replace=False)
                    np.random.set_state(rng_state)

                    aux_data = aux_data[take_inds]
                    aux_targets = aux_targets[take_inds]

                self.data = np.concatenate((self.data, aux_data), axis=0)

                if not add_aux_labels:
                    self.targets.extend([-1] * len(aux_data))
                else:
                    self.targets.extend(aux_targets)
                self.unsup_indices.extend(range(orig_len, orig_len+len(aux_data)))

        else:
            self.sup_indices = list(range(len(self.targets)))
            self.unsup_indices = []
    
    def load_base_dataset(self, **kwargs):
        raise NotImplementedError()
    
    @property
    def data(self):
        return self.dataset.data

    @data.setter
    def data(self, value):
        self.dataset.data = value

    @property
    def targets(self):
        return self.dataset.targets

    @targets.setter
    def targets(self, value):
        self.dataset.targets = value

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        self.dataset.labels = self.targets
        return self.dataset[item]
    

class SemiSupervisedSampler(torch.utils.data.Sampler):
    """
    Balanced sampling from the labeled and unlabeled data.
    """
    def __init__(self, sup_inds, unsup_inds, batch_size, unsup_fraction=0.5, num_batches=None):
        if unsup_fraction is None or unsup_fraction < 0:
            self.sup_inds = sup_inds + unsup_inds
            unsup_fraction = 0.0
        else:
            self.sup_inds = sup_inds
            self.unsup_inds = unsup_inds

        self.batch_size = batch_size
        unsup_batch_size = int(batch_size * unsup_fraction)
        self.sup_batch_size = batch_size - unsup_batch_size

        if num_batches is not None:
            self.num_batches = num_batches
        else:
            self.num_batches = int(np.ceil(len(self.sup_inds) / self.sup_batch_size))
        super().__init__(None)

    def __iter__(self):
        batch_counter = 0
        while batch_counter < self.num_batches:
            sup_inds_shuffled = [self.sup_inds[i]
                                 for i in torch.randperm(len(self.sup_inds))]
            for sup_k in range(0, len(self.sup_inds), self.sup_batch_size):
                if batch_counter == self.num_batches:
                    break
                batch = sup_inds_shuffled[sup_k:(sup_k + self.sup_batch_size)]
                if self.sup_batch_size < self.batch_size:
                    batch.extend([self.unsup_inds[i] for i in torch.randint(high=len(self.unsup_inds), 
                                                                            size=(self.batch_size - len(batch),), 
                                                                            dtype=torch.int64)])
                np.random.shuffle(batch)
                yield batch
                batch_counter += 1

    def __len__(self):
        return self.num_batches

def get_semisup_dataloaders(train_dataset, test_dataset, val_dataset=None, batch_size=256, batch_size_test=256, num_workers=4, 
                            unsup_fraction=0.5):
    """
    Return dataloaders with custom sampling of pseudo-labeled data.
    """
    dataset_size = train_dataset.dataset_size
    train_batch_sampler = SemiSupervisedSampler(train_dataset.sup_indices, train_dataset.unsup_indices, batch_size, 
                                                unsup_fraction, num_batches=int(np.ceil(dataset_size/batch_size)))
    
    epoch_size = len(train_batch_sampler) * batch_size

    # kwargs = {'num_workers': num_workers, 'pin_memory': torch.cuda.is_available() }
    kwargs = {'num_workers': num_workers, 'pin_memory': False}    
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler, **kwargs)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, **kwargs)
    
    if val_dataset:
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False, **kwargs)
        return train_dataloader, test_dataloader, val_dataloader
    return train_dataloader, test_dataloader




def load_cifar10s(data_dir, use_augmentation='base', use_consistency=False, aux_take_amount=None, 
                  aux_data_filename=None, 
                  validation=False):
    """
    Returns semisupervised CIFAR10 train, test datasets and dataloaders (with Tiny Images).
    Arguments:
        data_dir (str): path to data directory.
        use_augmentation: use different augmentations for training set.
        aux_take_amount (int): number of semi-supervised examples to use (if None, use all).
        aux_data_filename (str): path to additional data pickle file.
    Returns:
        train dataset, test dataset. 
    """
    data_dir = re.sub('cifar10s', 'cifar10', data_dir)
    test_transform = transforms.Compose([transforms.ToTensor()])
    if use_augmentation == 'none':
        train_transform = test_transform
    elif use_augmentation == 'base':
        train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 
                                              transforms.ToTensor()])
    # elif use_augmentation == 'cutout':
    #     train_transform = transforms.Compose([
    #         transforms.RandomCrop(32, padding=4),
    #         transforms.RandomHorizontalFlip(0.5),
    #         transforms.ToTensor(),
    #     ])
    #     train_transform.transforms.append(CutoutDefault(18))
    # elif use_augmentation == 'autoaugment':
    #     train_transform = transforms.Compose([
    #         transforms.RandomCrop(32, padding=4),
    #         transforms.RandomHorizontalFlip(0.5),
    #         CIFAR10Policy(),
    #         transforms.ToTensor(),
    #     ])
    #     train_transform.transforms.append(CutoutDefault(18))
    # elif use_augmentation == 'randaugment':
    #     train_transform = transforms.Compose([
    #         transforms.RandomCrop(32, padding=4),
    #         transforms.RandomHorizontalFlip(0.5),
    #         transforms.ToTensor(),
    #     ])
    #     # Add RandAugment with N, M(hyperparameter), N=2, M=14 for wdn-28-10
    #     train_transform.transforms.insert(0, RandAugment(2, 14))
    # elif use_augmentation == 'idbh':
    #     train_transform = IDBH('cifar10-weak')
    
    if use_consistency:
        pass
        # train_transform = MultiDataTransform(train_transform)

    train_dataset = SemiSupervisedCIFAR10(base_dataset='cifar10', root=data_dir, train=True, download=True, 
                                          transform=train_transform, aux_data_filename=aux_data_filename, 
                                          add_aux_labels=True, aux_take_amount=aux_take_amount, validation=validation)
    
    test_dataset = SemiSupervisedCIFAR10(base_dataset='cifar10', root=data_dir, train=False, download=True, 
                                         transform=test_transform)
    if validation:
        val_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=test_transform)
        val_dataset = torch.utils.data.Subset(val_dataset, np.arange(0, 1024))  # split from training set
        return train_dataset, test_dataset, val_dataset
    return train_dataset, test_dataset


class SemiSupervisedCIFAR10(SemiSupervisedDataset):
    """
    A dataset with auxiliary pseudo-labeled data for CIFAR10.
    """
    def load_base_dataset(self, train=False, **kwargs):
        assert self.base_dataset == 'cifar10', 'Only semi-supervised cifar10 is supported. Please use correct dataset!'
        self.dataset = datasets.CIFAR10(train=train, **kwargs)
        self.dataset_size = len(self.dataset)


def load_data(data_dir, batch_size=256, batch_size_test=256, num_workers=4, use_augmentation='base', use_consistency=False, shuffle_train=True, 
              aux_data_filename=None, unsup_fraction=None, validation=False):
    """
    Returns train, test datasets and dataloaders.
    Arguments:
        data_dir (str): path to data directory.
        batch_size (int): batch size for training.
        batch_size_test (int): batch size for validation.
        num_workers (int): number of workers for loading the data.
        use_augmentation (base/none): whether to use augmentations for training set.
        shuffle_train (bool): whether to shuffle training set.
        aux_data_filename (str): path to unlabelled data.
        unsup_fraction (float): fraction of unlabelled data per batch.
        validation (bool): if True, also returns a validation dataloader for unspervised cifar10 (as in Gowal et al, 2020).
    """

    dataset = os.path.basename(os.path.normpath(data_dir))
    # load_dataset_fn = _LOAD_DATASET_FN[dataset]
    
    if validation:
        # assert dataset in SEMISUP_DATASETS, 'Only semi-supervised datasets allow a validation set.'
        train_dataset, test_dataset, val_dataset = load_cifar10s(data_dir=data_dir, use_augmentation=use_augmentation, use_consistency=use_consistency,
                                                                   aux_data_filename=aux_data_filename, validation=True)
    else:
        train_dataset, test_dataset = load_cifar10s(data_dir=data_dir, use_augmentation=use_augmentation)
       
    # if dataset in SEMISUP_DATASETS:
    if validation:
        train_dataloader, test_dataloader, val_dataloader = get_semisup_dataloaders(
                train_dataset, test_dataset, val_dataset, batch_size, batch_size_test, num_workers, unsup_fraction )
    else:
        train_dataloader, test_dataloader = get_semisup_dataloaders(
                train_dataset, test_dataset, None, batch_size, batch_size_test, num_workers, unsup_fraction )
    # else:
    #     #pin_memory = torch.cuda.is_available()
    #     pin_memory = False
    #     train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_train, 
    #                                                    num_workers=num_workers, pin_memory=pin_memory)
    #     test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, 
    #                                                   num_workers=num_workers, pin_memory=pin_memory)
    if validation:
        return train_dataset, test_dataset, val_dataset, train_dataloader, test_dataloader, val_dataloader
    return train_dataset, test_dataset, train_dataloader, test_dataloader

data_dir = './data'
train_dataset, test_dataset, train_dataloader, test_dataloader = load_data(data_dir, 
          batch_size=256, 
          batch_size_test=256, 
          num_workers=1, 
          use_augmentation='base', 
          use_consistency=False, 
          shuffle_train=True, 
          aux_data_filename='./data/1m.npz', 
          unsup_fraction=None, 
          validation=False)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
train_dataset.dataset_size

50000