In [1]:
import io
import os
import os.path as pt
from multiprocessing.dummy import Pool as ThreadPool

from PIL import Image
from simplejpeg import decode_jpeg

from datadings.writer import FileWriter
from datadings.tools import yield_threaded

from sklearn.model_selection import train_test_split


# __doc__ += document_keys(
#     Imagenette
#     )

def Imagenette(
        key,
        image,
        label
):
    """
Returns a dictionary::

    {
        'key': key,
        'image': image,
        'label': label
        ...
    }
    """
    return {
        'key': key,
        'image': image,
        'label': label  
    }

def __transform_image(im, size=64):
    return im.resize(
        (256, 256),
        Image.ANTIALIAS,
    )


def __decode(data):
    return Image.fromarray(decode_jpeg(
        data, fastupsample=False, fastdct=False
    ), 'RGB')


def __tobytes(im):
    bio = io.BytesIO()
    im.save(bio, 'PNG', optimize=True)
    return bio.getvalue()


def yield_samples(files, labels):
    for filename, label in zip(files, labels):
        
        temp = Image.open(filename)
        imagedata = temp.copy()
        imagedata = imagedata.convert('RGB')
        temp.close()
        
        yield filename, imagedata, label


def create_sample(item):
    filename, imagedata, label = item

    image = __transform_image(imagedata)
    image_binary = __tobytes(image)
    
    filename = filename.split(os.sep)[-1]
    return Imagenette(
        filename,
        image_binary,
        label,
    )

def write_set(partition, files, labels):
    outdir = '../datadings/'

    gen = yield_threaded(yield_samples(files, labels))

    outfile = pt.join(outdir, partition + '.msgpack')
    filelength = len(files)
    with FileWriter(outfile, total=filelength, overwrite=True) as writer:
        pool = ThreadPool(8)
        for sample in pool.imap_unordered(create_sample, gen):
            writer.write(sample)

def get_images_and_labels(split):
    destination = f'../original/imagenette2-320/{split}/'
    
    dct_folder2label = {
        "n02102040": 0, # tench
        "n03445777": 1, # English_springer
        "n03888257": 2, # cassette_player
        "n01440764": 3, # chain_saw
        "n03417042": 4, # church
        "n03425413": 5, # French_horn
        "n02979186": 6, # garbage_truck
        "n03000684": 7, # gas_pump
        "n03028079": 8, # golf_ball
        "n03394916": 9 # parachute
    }
    
    all_filenames = list()
    all_labels = list()
    
    for dirpath, dnames, fnames in os.walk(destination):
        for f in fnames:
            if f.endswith(".JPEG"):
                all_filenames.append(os.path.join(dirpath, f))
                label = dct_folder2label[dirpath.split('/')[-1]]
                all_labels.append(label)
            else:
                print(f)
                
    return all_filenames, all_labels

def write_sets():

    train_files, train_labels = get_images_and_labels('train')
    test_files, test_labels = get_images_and_labels('val') # Use val as testset and split train to get new validation set
    
    train_files, val_files, train_labels, val_labels = train_test_split(train_files, train_labels, stratify=train_labels, test_size=0.33, random_state=42)
    
    portions = ['train', 'val', 'test']
    files = [train_files, val_files, test_files]
    labels = [train_labels, val_labels, test_labels]

    try:
        for partition, file, label in zip(portions, files, labels):
            write_set(partition, file, label)
    except FileExistsError:
        pass

In [2]:
write_sets()

train.msgpack 100% 00:40<00:00, 155.05it/s
val.msgpack   0% 00:00<?, ?it/s

6344 samples written


val.msgpack 100% 00:19<00:00, 159.69it/s
test.msgpack   0% 00:00<?, ?it/s

3125 samples written


test.msgpack 100% 00:24<00:00, 159.97it/s

3925 samples written



