In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Dataset
import numpy as np
import scipy.io
import gzip
import wget
import h5py
import pickle
import urllib
import os
import skimage
import skimage.transform
from skimage.io import imread
import matplotlib.image as mpimg

import warnings

def LoadDataset(name, root, batch_size, split,shuffle=True, style=None, attr=None):
    if name == 'mnist':
        if split == 'train':
            return LoadMNIST(root+'mnist/', batch_size=batch_size, split='train', shuffle=shuffle, scale_32=True)
        elif split=='test':
            return LoadMNIST(root+'mnist/', batch_size=batch_size, split='test', shuffle=False, scale_32=True)
    elif name == 'usps':
        if split == 'train':
            return LoadUSPS(root+'usps/', batch_size=batch_size, split='train', shuffle=shuffle, scale_32=True)
        elif split=='test':
            return LoadUSPS(root+'usps/', batch_size=batch_size, split='test', shuffle=False, scale_32=True)
    elif name == 'svhn':
        if split == 'train':
            return LoadSVHN(root+'svhn/', batch_size=batch_size, split='extra', shuffle=shuffle)
        elif split=='test':
            return LoadSVHN(root+'svhn/', batch_size=batch_size, split='test', shuffle=False)
    elif name == 'face':
        assert style != None
        if split == 'train':
            return LoadFace(root, style=style, split='train', batch_size=batch_size,  shuffle=shuffle)
        elif split=='test':
            return LoadFace(root, style=style, split='test', batch_size=batch_size,  shuffle=False) 
    elif name=='artificial':
        if split == 'train':
            return LoadArtificial(root+name+'/', batch_size=batch_size,split='train',shuffle=shuffle)
        else:
            return LoadArtificial(root+name+'/', batch_size=batch_size,split='test',shuffle=False)
            
    else:
        warnings.warn('unknown dataset name %s'%name)


def LoadSVHN(data_root, batch_size=32, split='train', shuffle=True):
    if not os.path.exists(data_root):
        os.makedirs(data_root)
        
    filename_train = os.path.join(data_root,'extra_32x32.mat')
    filename_test = os.path.join(data_root,'test_32x32.mat')
    download = not (os.path.exists(filename_train) and os.path.exists(filename_test))
    svhn_dataset = datasets.SVHN(data_root, split=split, download=download,
                                   transform=transforms.ToTensor())
    return DataLoader(svhn_dataset,batch_size=batch_size, shuffle=shuffle, drop_last=True)

def LoadUSPS(data_root, batch_size=32, split='train', shuffle=True, scale_32 = False):
    if not os.path.exists(data_root):
        os.makedirs(data_root)

    usps_dataset = USPS(root=data_root,train=(split=='train'),download=True,scale_32=scale_32)
    return DataLoader(usps_dataset,batch_size=batch_size, shuffle=shuffle, drop_last=True)

def LoadMNIST(data_root, batch_size=32, split='train', shuffle=True, scale_32 = False):
    if not os.path.exists(data_root):
        os.makedirs(data_root)

    if scale_32:
        trans = transforms.Compose([transforms.Resize(size=[32, 32]),transforms.ToTensor()])
    else:
        trans = transforms.ToTensor()

    mnist_dataset = datasets.MNIST(data_root, train=(split=='train'), download=True,
                                   transform=trans)
    return DataLoader(mnist_dataset,batch_size=batch_size,shuffle=shuffle, drop_last=True)


def LoadArtificial(data_root, batch_size=32, split='train',shuffle=True, random_seed=0):
    if not os.path.exists(data_root):
        os.makedirs(data_root)
    
    train = (split == 'train')
    artificial_dataset = Artificial(
        data_root, dataset_size=[5000,1000], dim=16, n_clusters=[1,2,4,8], train=train, 
                 seed=0, name='artificial')
    return DataLoader(artificial_dataset,batch_size=batch_size,shuffle=shuffle,drop_last=True)
    

def LoadFace(data_root, batch_size=32, split='train', style='photo', attr = None,
               shuffle=True, load_first_n = None):

    data_root = data_root+'face.h5'
    key = '/'.join(['CelebA',split,style])
    celeba_dataset = Face(data_root,key,load_first_n)
    return DataLoader(celeba_dataset,batch_size=batch_size,shuffle=shuffle,drop_last=True)


### USPS Reference : https://github.com/corenel/torchzoo/blob/master/torchzoo/datasets/usps.py
class USPS(Dataset):
    """USPS Dataset.
    Args:
        root (string): Root directory of dataset where dataset file exist.
        train (bool, optional): If True, resample from dataset randomly.
        download (bool, optional): If true, downloads the dataset
            from the internet and puts it in root directory.
            If dataset is already downloaded, it is not downloaded again.
        transform (callable, optional): A function/transform that takes in
            an PIL image and returns a transformed version.
            E.g, ``transforms.RandomCrop``
    """

    url = "https://github.com/mingyuliutw/CoGAN/raw/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"

    def __init__(self, root, train=True, scale_32=False, download=False):
        """Init USPS dataset."""
        # init params
        self.root = os.path.expanduser(root)

        if scale_32:
            self.filename = "usps_32x32.pkl"
        else:
            self.filename = "usps_28x28.pkl"
        self.train = train
        # Num of Train = 7438, Num ot Test 1860
        self.dataset_size = None

        # download dataset.
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        self.train_data, self.train_labels = self.load_samples()
        if self.train:
            total_num_samples = self.train_labels.shape[0]
            indices = np.arange(total_num_samples)
            np.random.shuffle(indices)
            self.train_data = self.train_data[indices[0:self.dataset_size], ::]
            self.train_labels = self.train_labels[indices[0:self.dataset_size]]

        #self.train_data *= 255.0
        #self.train_data = self.train_data.transpose(
        #    (0, 2, 3, 1))  # convert to HWC

    def __getitem__(self, index):
        """Get images and target for data loader.
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, label = self.train_data[index, ::], self.train_labels[index]
        label = torch.LongTensor([np.int64(label).item()])
        # label = torch.FloatTensor([label.item()])
        return torch.FloatTensor(img), label[0]

    def __len__(self):
        """Return size of dataset."""
        return self.dataset_size

    def _check_exists(self):
        """Check if dataset is download and in right place."""
        return os.path.exists(os.path.join(self.root, self.filename))

    def download(self):
        """Download dataset."""
        filename = os.path.join(self.root, 'usps_28x28.pkl')
        dirname = os.path.dirname(filename)
        if not os.path.isdir(dirname):
            os.makedirs(dirname)
        if not os.path.isfile(filename):
            print("Download %s to %s" % (self.url, os.path.abspath(filename)))
            #urllib.request.urlretrieve(self.url, filename)
            wget.download(self.url,out=os.path.join(self.root, 'usps_28x28.pkl'))
            print("[DONE]")
        if not os.path.isfile(os.path.join(self.root, 'usps_32x32.pkl')):
            print("Resizing USPS 28x28 to 32x32...")
            f = gzip.open(os.path.join(self.root, 'usps_28x28.pkl'), "rb")
            data_set = pickle.load(f, encoding="bytes")
            for d in [0,1]:
                tmp = []
                for img in range(data_set[d][0].shape[0]):
                    tmp.append(np.expand_dims(skimage.transform.resize(data_set[d][0][img].squeeze(),[32,32]),0))
                data_set[d][0] = np.array(tmp)
            fp=gzip.open(os.path.join(self.root, 'usps_32x32.pkl'),'wb')
            pickle.dump(data_set,fp)
            print("[DONE")
        return

    def load_samples(self):
        """Load sample images from dataset."""
        filename = os.path.join(self.root, self.filename)
        f = gzip.open(filename, "rb")
        data_set = pickle.load(f, encoding="bytes")
        f.close()
        if self.train:
            images = data_set[0][0]
            labels = data_set[0][1]
            self.dataset_size = labels.shape[0]
        else:
            images = data_set[1][0]
            labels = data_set[1][1]
            self.dataset_size = labels.shape[0]
        return images, labels

class Face(Dataset):
    def __init__(self, root, key, load_first_n = None):

        with h5py.File(root,'r') as f:
            data = f[key][()]
            if load_first_n:
                data = data[:load_first_n]
        self.imgs = (data/255.0)*2 -1

    def __getitem__(self, index):
        return self.imgs[index]

    def __len__(self):
        return len(self.imgs)

In [3]:
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np
import yaml
import h5py
from easydict import EasyDict as edict
from sklearn.utils import shuffle


class Artificial(Dataset):
    def __init__(self, root, dataset_size=[5000,1000], dim=16, n_clusters=[1,2,4,8], train=True, 
                 seed=0, name='artificial'):
        self.train = train
        self.seed = seed
        
        self.dataset_size = dataset_size
        self.dim = 16
        self.n_clusters = n_clusters
        
        
        self.filename = os.path.join(root,"%s.h5"%name)
        self.visualization = os.path.join(root,"%s.png"%name)
        if not self._check_exists():
            X,Y = self._create()
            with h5py.File(self.filename, "w") as f:
                f.create_group('train')
                f['train'].create_dataset('X',data=X[:dataset_size[0]])
                f['train'].create_dataset('Y',data=Y[:dataset_size[0]])
                f.create_group('test')
                f['test'].create_dataset('X',data=X[dataset_size[0]:])
                f['test'].create_dataset('Y',data=Y[dataset_size[0]:])
                f.attrs['dataset_size'] = self.dataset_size
                f.attrs['dim'] = self.dim
                f.attrs['n_clusters'] = self.n_clusters
                f.attrs['image']=self.visualization
            
        self.train_data,self.train_labels = self.load_samples()

    def load_samples(self,train=None):
        if train is None:
            if self.train:
                key = 'train'
            else:
                key = 'test'
        elif train:
            key = 'train'
        else:
            key = 'test'
            
        with h5py.File(self.filename, 'r') as f:
            conf = edict(f.attrs)
            data = f[key]['X'].value
            labels = f[key]['Y'].value
        assert(np.all(self.dataset_size == conf.dataset_size))
        assert(self.dim == conf.dim)
        assert(np.all(self.n_clusters == conf.n_clusters))
        return data,[ls for ls in labels.T]
        
    def __getitem__(self, index):
        """Get images and target for data loader.
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        data = torch.from_numpy(self.train_data[index]) 
        #label = [torch.from_numpy(l[index]) for l in self.train_labels]
        label = [l[index] for l in self.train_labels]
        return [data]+label

    def __len__(self):
        """Return size of dataset."""
        return self.dataset_size[not self.train]

    def _check_exists(self):
        """Check if dataset is download and in right place."""
        return os.path.exists(self.filename)
    
    
    def _create(self, do_shuffle=True):
        N = self.dataset_size[0]+self.dataset_size[1]
        
        X = []
        Y = []
        dim_raw = len(self.n_clusters)*2

        for c in self.n_clusters:
            data = make_blobs(n_samples = N, centers = c, cluster_std=.2, shuffle=False)
            X.append(data[0])
            Y.append(data[1][:, np.newaxis])
            print(c, np.max(data[1]),np.min(data[1]))
        X = np.hstack(X)
        Y = np.hstack(Y)

        print(X.shape)
        print(Y.shape)

        for i in range(len(self.n_clusters)):
            plt.figure(figsize=[dim_raw*2, 3])
            plt.subplot(1, 4, i + 1)
            plt.scatter(X[:, 2 * i], X[:, 2 * i + 1], 1, Y[:, i])
            plt.title('%d-%d-d' % (2 * i, 2 * i + 1))

        plt.savefig(self.visualization)

        W = np.random.uniform(0,1,(dim_raw,self.dim))
        X = np.dot(X,W)

        if do_shuffle:
            print('shuffle')
            X, Y = shuffle(X, Y, random_state=32) 
            print(Y[:,1])
            print(Y[:,2])
            print(Y[:,3])

        X_max = np.max(X,axis=0).T
        X_min = np.min(X,axis=0).T

        X -= X_min
        X /= (X_max-X_min)
        X= 2*(X-0.5) # X has range [1,-1] now.

        return X,Y