This code is for preparing training dataset based on https://github.com/NVlabs/edm

In [1]:
import functools
import gzip
import io
import json
import os
import pickle
import re
import sys
import tarfile
import zipfile
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
import click
import numpy as np
import PIL.Image
from tqdm import tqdm

In [2]:
def save_bytes(fname: str, data: Union[bytes, str]):
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    with open(fname, 'wb') as fout:
        if isinstance(data, str):
            data = data.encode('utf8')
        fout.write(data)

In [3]:
import torch
from torchvision import datasets
import numpy as np

cifar10 = datasets.CIFAR10(
    "./data", download=True, train=True)
target_cifar = torch.tensor(cifar10.targets)



random_index = torch.randperm(len(target_cifar))
random_target = target_cifar[random_index]


sampled_index = torch.tensor([], dtype = int)
sampled_target = torch.tensor([], dtype = int)

for i in range(10):
    sampled_index = torch.cat((sampled_index, random_index[torch.where(random_target==i)[0][:200]]))
    sampled_target = torch.cat((sampled_target, target_cifar[random_index[torch.where(random_target==i)[0][:200]]]))
    sampled_data = cifar10.data[sampled_index]

torch.save(sampled_index, 'sampled_index.pt')
samples = sampled_data
label = sampled_target

np.savez('sampled_data_with_label.npz', samples=samples, label=label)

Files already downloaded and verified


In [4]:
images = {'img': samples, "label" : label}

In [5]:
def open_dest(dest: str):
    dest_ext = file_ext(dest)
    if dest_ext == 'zip':
        if os.path.dirname(dest) != '':
            os.makedirs(os.path.dirname(dest), exist_ok=True)
        zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
        def zip_write_bytes(fname: str, data: Union[bytes, str]):
            zf.writestr(fname, data)
        return '', zip_write_bytes, zf.close
    else:

        if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
            raise click.ClickException('--dest folder must be empty')
        os.makedirs(dest, exist_ok=True)

        def folder_write_bytes(fname: str, data: Union[bytes, str]):
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            with open(fname, 'wb') as fout:
                if isinstance(data, str):
                    data = data.encode('utf8')
                fout.write(data)
        return dest, folder_write_bytes, lambda: None


In [6]:
def file_ext(name: Union[str, Path]):
    return str(name).split('.')[-1]

In [7]:
archive_root_dir, save_bytes, close_dest = open_dest('./data')

In [9]:
labels = []
dataset_attrs = None
channels = 3
for idx, data in enumerate(zip(label, samples)):
    image = {'img': data[1], "label" : int(data[0])}
    idx_str = f'{idx:08d}'
    archive_fname = f'{idx_str[:5]}/img{idx_str}.png'

    # Apply crop and resize.
    img = image['img']
    if img is None:
        continue

    # Error check to require uniform image attributes across
    # the whole dataset.
#         channels = img.shape[2] if img.ndim == 3 else 1
    cur_image_attrs = {'width': 32, 'height': 32, 'channels': 3}
    if dataset_attrs is None:
        dataset_attrs = cur_image_attrs
        width = dataset_attrs['width']
        height = dataset_attrs['height']
        if width != height:
            raise click.ClickException(f'Image dimensions after scale and crop are required to be square.  Got {width}x{height}')
        if dataset_attrs['channels'] not in [1, 3]:
            raise click.ClickException('Input images must be stored as RGB or grayscale')
        if width != 2 ** int(np.floor(np.log2(width))):
            raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
    elif dataset_attrs != cur_image_attrs:
        err = [f'  dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
        raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset.  Got:\n' + '\n'.join(err))

    # Save the image as an uncompressed PNG.
    img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels])
    image_bits = io.BytesIO()
    img.save(image_bits, format='png', compress_level=0, optimize=False)
    save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
    labels.append([archive_fname, image['label']] if image['label'] is not None else None)

metadata = {'labels': labels if all(x is not None for x in labels) else None}
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))