In [15]:
import pickle
import os
import numpy as np 
import tensorflow as tf

AUTOTUNE = tf.data.experimental.AUTOTUNE


def get_dataset(n_devices, batch_size, normalize, dtype):
    """Get DeFungi dataset splits."""
    if batch_size % n_devices:
        raise ValueError("Batch size %d isn't divided evenly by n_devices %d" %
                     (batch_size, n_devices))
    
    # Dataset that is all grayscale, or all RGB. Pulled from ../EDA/Dataset
    all_dataset = tf.data.Dataset.list_files(f'../EDA/Dataset/*/{dtype}_*.jpg')
    
    def decode(x):
        # Read and decode the image
        image = tf.io.read_file(x)
        image = tf.image.decode_jpeg(image, channels=3)
        
        # Get Label
        parts = tf.strings.split(x, '/')
        label_part = parts[-2]
        label = tf.strings.substr(label_part, pos=0, len=2)
        
        decoded = {
        'inputs':
            image,
        'targets':
            label
        }
        
        if normalize:
            decoded['inputs'] = decoded['inputs'] / 255
            
        return decoded
    
    # Maps the entire dataset to its corresponding label (supervised learning)
    # Note 'targets' are NOT integer-represented
    all_dataset = all_dataset.map(decode, num_parallel_calls=AUTOTUNE)
    
    # Shuffle, then split all_dataset into training, testing, validating
    all_size = all_dataset.cardinality().numpy()
    all_dataset = all_dataset.shuffle(buffer_size=all_size, reshuffle_each_iteration=True)
    
    # Calculate the sizes for train, val, and test sets
    train_ratio = 0.70; val_ratio = 0.15
    train_size = int(train_ratio * all_size)
    val_size = int(val_ratio * all_size)
    test_size = all_size - train_size - val_size
    
    # Split the dataset
    train_dataset = all_dataset.take(train_size)
    remaining_dataset = all_dataset.skip(train_size)
    val_dataset = remaining_dataset.take(val_size)
    test_dataset = remaining_dataset.skip(val_size)
    
    # Batch the datasets
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
    val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
    test_dataset = test_dataset.batch(batch_size, drop_remainder=True)
    
    # Shuffle the training dataset
    train_dataset = train_dataset.shuffle(buffer_size=256, reshuffle_each_iteration=True)

    return train_dataset, val_dataset, test_dataset, 5, 256, (batch_size, 32, 32,
                                                             1)

# grayscale or RGB dataset
dtype = input('Enter `rgb` if you want an rgb dataset, `grayscale` otherwise: ')

# Return trainset, evalset, testset
train_ds, eval_ds, test_ds, num_classes, vocab_size, input_shape = get_dataset(1, 32, True, dtype)


# Map the 3 sets, then encode in binary @ local directory
os.makedirs('../defungi_encoded/', exist_ok=True)
mapping = {"train": train_ds, "dev": eval_ds, "test": test_ds}
for component in mapping:
    ds_list = []
    for idx, inst in enumerate(iter(mapping[component])):
        ds_list.append({
            "input_ids_0": inst["inputs"].numpy()[0].reshape(-1),
            "label": inst["targets"].numpy()[0]
        })
        if idx % 100 == 0:
            print(f"{idx}\t\t", end="\r")
    with open(f"../defungi_encoded/image.{component}.pickle", "wb") as f:
        pickle.dump(ds_list, f)


Enter `rgb` if you want an rgb dataset, `grayscale` otherwise:  rgb


0		