# Make into TFRecord

In [1]:
import os
import csv
import json
from pprint import pprint
import collections
import numpy as np
import pandas as pd
import tensorflow as tf
import glob
import pickle

data_dir = os.path.join("..", "data")
# danbooru_img_dir = os.path.join(data_dir, "danbooru", "38986_61452_upload_danbooru-images", "danbooru-images", "0059")
danbooru_meta_dir = os.path.join(data_dir, "danbooru", "danbooru2019", "metadata")
danbooru_project_name = "project-small"

# Tfrecord
tfrecord_dir = os.path.join(data_dir, "preprocessed", danbooru_project_name)

# Metadata
output_metadata_dir = os.path.join(data_dir, "preprocessed", danbooru_project_name)
output_metadata_filename = "metadata.pkl"
output_metadata_path = os.path.join(output_metadata_dir, output_metadata_filename)
print(output_metadata_path)


# Make project name subfolder
if not os.path.exists(tfrecord_dir):
    print("Making folder: {}".format(tfrecord_dir))
    os.makedirs(tfrecord_dir)

..\data\preprocessed\project-small\metadata.pkl


In [2]:
%%time


def read_metadata(file_path, top_n=None):
    lines = {}
    with open(file_path, 'r', encoding='utf-8') as file:
        line = file.readline()
        while line:
            line = json.loads(file.readline())
            lines[line['id']] = line
    return lines

            
def read_metadata_dict(file_path, top_n=None, filter_attributes=None, filter_tag=None):
    data = dict()
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            l = json.loads(line)
            if filter_attributes:
                l = {k: l.get(k, None) for k in filter_attributes}
                l['tags'] = [t['name'] for t in l['tags']]
                
            
            if filter_tag:
                if len([l for l in l['tags'] if l in filter_tag]):
                    data[l['id']] = l
            else:
                data[l['id']] = l
    return data

def read_metadata_list(file_path, top_n=None, filter_attributes=None, filter_tag=None):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            l = json.loads(line)
            if filter_attributes:
                l = {k: l.get(k, None) for k in filter_attributes}
                l['tags'] = [t['name'] for t in l['tags']]
                       
            if filter_tag:
                if len([l for l in l['tags'] if l in filter_tag]):
                    data.append(l)
            else:
                data.append(l)
    return data

def read_metadata_dir(file_dir, filter_attributes=None, filter_tag=None):
    # Walk through each
    json_path_list = [os.path.join(root, file) for root, subdirs, files in os.walk(file_dir) for file in files]
    json_dict = dict()
    for json_i, json_path in enumerate(json_path_list):
        print("Loading: {} of {}".format(str(json_i + 1).zfill(3), str(len(json_path_list)).zfill(3)))
        json_dict.update(read_metadata_dict(json_path, filter_attributes=filter_attributes, filter_tag=filter_tag))
        
    return json_dict
        
    

# data_meta_raw = read_metadata_list(danbooru_meta_path, filter_attributes=['id', 'file_ext', 'tags'], filter_tag=["transparent_background"])
# data_meta = read_metadata_list(danbooru_meta_path, filter_attributes=['id', 'file_ext', 'tags'], filter_tag=["transparent_background", ""])
# data_meta = read_metadata_dict(danbooru_meta_path, filter_attributes=['id', 'file_ext', 'tags'])

single_filter_tag = "dragon_ball"
# data_meta = read_metadata_dir(danbooru_meta_dir, filter_attributes=['id', 'file_ext', 'tags', 'is_rating_locked'], filter_tag=[single_filter_tag])


Wall time: 0 ns


In [3]:

# Read
with open(output_metadata_path, 'rb') as f:
    data_meta = pickle.load(f)

# Get tags from each image
tag_list = [tag for k, v in data_meta.items() for tag in v['tags']]

# Count each tags
tag_counts = collections.Counter(tag_list)

print("There are n = {} labels".format(len(tag_counts)))
tag_all_array = np.array(list(tag_counts))

tag_map = dict(zip(tag_all_array, list(range(len(tag_all_array)))))


There are n = 15096 labels


# Image List

In [4]:
image_dir = os.path.join(data_dir, "danbooru", "danbooru2019", "original")
print(image_dir)
print(os.path.join(image_dir, "*"))


image_list = []
accepted_file_ext = ['jpg', 'jpeg', 'png']
for file_ext in accepted_file_ext:
    image_list.extend(glob.glob(os.path.join(image_dir, "*", "*.{file_ext}".format(file_ext=file_ext))))
    
print(len(image_list))
pprint(image_list[:10])

..\data\danbooru\danbooru2019\original
..\data\danbooru\danbooru2019\original\*
8055
['..\\data\\danbooru\\danbooru2019\\original\\0000\\162000.jpg',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\2066000.jpg',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\2767000.png',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\2881000.jpg',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\2963000.jpg',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\3191000.jpg',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\3253000.png',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\3436000.png',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\3493000.png',
 '..\\data\\danbooru\\danbooru2019\\original\\0000\\3605000.png']


# TFRecord

In [8]:
import pathlib
import contextlib2

class DataUtil(object):
    def __init__(self):
        pass

    def int64_feature(self, value):
      return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


    def int64_list_feature(self, value):
      return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


    def bytes_feature(self, value):
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


    def bytes_list_feature(self, value):
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


    def float_list_feature(self, value):
      return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def open_sharded_output_tfrecords(exit_stack, base_path, num_shards, extension="tfrecord"):
    """Opens all TFRecord shards for writing and adds them to an exit stack.
    Args:
    exit_stack: A context2.ExitStack used to automatically closed the TFRecords
      opened in this function.
    base_path: The base path for all shards
    num_shards: The number of shards
    Returns:
    The list of opened TFRecords. Position k in the list corresponds to shard k.
    """
    tf_record_output_filenames = [
        '{}-{:05d}-of-{:05d}.{}'.format(base_path, idx, num_shards, extension)
        for idx in range(num_shards)
    ]

    tfrecords = [
        exit_stack.enter_context(tf.io.TFRecordWriter(file_name))
        for file_name in tf_record_output_filenames
    ]

    return tfrecords


# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(img_path, height=512, width=512, resize=True):
    
    filename = os.path.basename(img_path)
    img_id, img_ext = os.path.splitext(filename)
    
    parts = tf.strings.split(img_path, os.sep)

    image = tf.io.read_file(img_path)
    if img_ext in [".jpg", ".jpeg"]:
        image = tf.image.decode_jpeg(image, channels=3)
    if img_ext in [".png"]:
        image = tf.image.decode_png(image, channels=3, dtype=tf.dtypes.uint8)
#     image = tf.image.resize(image, [height, width], preserve_aspect_ratio=True)
    if resize:
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize_with_pad(image=image, target_height=height, target_width=width)
    image = tf.image.convert_image_dtype(image, tf.uint8)
    return image


def create_dataset(img, img_path, labels_text, label_id, file_ext):
    
    if tf.is_tensor(img):
        img = img.numpy()
        
    # Flatten
    img_flat = img.astype(dtype=np.uint8).flatten()
        
    labels_text = [tag.encode("utf8") for tag in labels_text]
    
    image_shape = img.shape

    feature_dict = {
        'image/height': dataset_util.int64_feature(image_shape[0]),
        'image/width': dataset_util.int64_feature(image_shape[1]),
        'image/depth': dataset_util.int64_feature(image_shape[2]),
        'image/object/class/text': dataset_util.bytes_list_feature(labels_text),
        'image/object/class/label': dataset_util.int64_list_feature(label_id),
        'image/encoded': dataset_util.int64_list_feature(img_flat.tolist()),
#         'image/encoded': dataset_util.bytes_list_feature(img_flat.encode('utf8')),
#         'image/encoded': dataset_util.bytes_feature(img.tobytes()),
        'image/filename': dataset_util.bytes_feature(img_path.encode('utf8')),
        'image/format': dataset_util.bytes_feature(file_ext.encode('utf8')),
    }

    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))

    return example


dataset_util = DataUtil()


In [9]:


# TFRecord
def image_to_tfrecord_shards(image_path_list, data_meta, tag_map, tfrecord_dir, tfrecord_prefix, n_shards=8):
    
    tfrecord_output_path = os.path.join(tfrecord_dir, tfrecord_prefix)
    data_info_json_path = os.path.join(tfrecord_dir, "data_info.json")

    img_counter = 0
    
    print_per_n = 10
    
    tag_list = []
    
    # Start shards
    with contextlib2.ExitStack() as tf_record_close_stack:
        print("image_path_list n = {}".format(len(image_path_list)))
        
        # List of shards
        output_tfrecords = open_sharded_output_tfrecords(
            tf_record_close_stack, tfrecord_output_path, n_shards)
        
        # Loop through image
        for img_i, img_path in enumerate(image_path_list):
            
            # img path
            img_filename = os.path.basename(img_path)
            img_id, img_ext = os.path.splitext(img_filename)

            # If not found in metadata continue
            if img_id not in data_meta.keys():
                print("img_id = {} not in data_met".format(img_id))
                continue
            
            if img_counter % print_per_n == 0:
                print("Working on image i = {}".format(img_counter))
            # Get information
            img_meta = data_meta[img_id]
            
            if img_meta is None:
                print("img_id = {} does not exist in img_meta".format(img_id))
                continue

            # Parse Image
            try:
                img = parse_image(img_path, resize=True)
            except:
                print("img_id = {} failed to load image {}".format(img_id, img_path))
                continue
                
                
            try:
                tags = img_meta['tags']
            except:
                print("img_id = {} tags does not exist in img_meta".format(img_id))
                print(img_meta)
                print(img_id)
                break
                
            try:
                tags = img_meta['tags']
            except:
                print("img_id = {} tags does not exist in img_meta".format(img_id))
                print(img_meta)
                break
                
            try:
                file_ext = img_meta['file_ext']
            except:
                print("img_id = {} file_ext does not exist in img_meta".format(img_id))
                print(img_meta)
                break
                
            try:
                file_ext = img_meta['file_ext']
            except:
                print("img_id = {} file_ext does not exist in img_meta".format(img_id))
                print(img_meta)
                break
                
            try:
                label_id = [tag_map[tag] for tag in tags]
            except:
                print("img_id = {} label_id error".format(img_id))
                print(img_meta)
                break
                
            

            # Tags
            tag_list.extend(tags)
                
            # Make an example
            example = create_dataset(
                img=img, 
                img_path=img_path, 
                labels_text=tags, 
                label_id=label_id, 
                file_ext=file_ext)

            # Write tfrecord to shard
            shard_idx = img_counter % tfrecord_n_shards
            output_tfrecords[shard_idx].write(example.SerializeToString())
            img_counter += 1
        print("Made tfrecord with n = {}".format(img_counter))

    # Dataset Info
    tag_counts = dict(collections.Counter(tag_list))
    n_classes = len(tag_counts)
    tag_map = dict(zip(list(tag_counts), list(range(n_classes))))
    data_info = {
        'tag_counts': tag_counts,
        'n_classes': n_classes,
        'tag_map': tag_map,
        'n_records': img_counter
    }
    with open(data_info_json_path, 'w') as outfile:
        json.dump(data_info, outfile)



# Output
dataset_name = "train"  # train, valid, test
tfrecord_dir_test = os.path.join(tfrecord_dir, dataset_name)
tfrecord_prefix_test = dataset_name
tfrecord_n_shards = 8

# Image List
image_path_list = image_list[512:]
print("n images = {}".format(len(image_path_list)))

image_to_tfrecord_shards(
    image_path_list=image_path_list, 
    data_meta=data_meta,
    tag_map=tag_map,
    tfrecord_dir=tfrecord_dir_test, 
    tfrecord_prefix=dataset_name, 
    n_shards=8)

n images = 7543
image_path_list n = 7543
Working on image i = 0
Working on image i = 10
Working on image i = 20
Working on image i = 30
img_id = 2652067 failed to load image ..\data\danbooru\danbooru2019\original\0067\2652067.zip
Working on image i = 30
Working on image i = 40
img_id = 1619069 failed to load image ..\data\danbooru\danbooru2019\original\0069\1619069.gif
Working on image i = 50
Working on image i = 60
Working on image i = 70
Working on image i = 80
Working on image i = 90
Working on image i = 100
Working on image i = 110
Working on image i = 120
img_id = 665081 failed to load image ..\data\danbooru\danbooru2019\original\0081\665081.gif
Working on image i = 130
Working on image i = 140
Working on image i = 150
Working on image i = 160
Working on image i = 170
Working on image i = 180
Working on image i = 190
Working on image i = 200
Working on image i = 210
Working on image i = 220
Working on image i = 230
Working on image i = 240
Working on image i = 250
Working on image