In [4]:
import os
import random
import shutil
import tensorflow as tf
from PIL import Image
import numpy as np

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def convert_to_tfrecords_split(input_dir, output_dir, train_split=0.8):
    for root, dirs, files in os.walk(input_dir):
        for file in files:
            if file.endswith('.jpg') or file.endswith('.png'):
                img_path = os.path.join(root, file)
                label = os.path.basename(root)

                img = Image.open(img_path)
                img_array = np.array(img)
                shape = img_array.shape

                if len(shape) == 3 and shape[2] == 3:  # Check if the image is RGB
                    img = img.convert('RGB')  # Convert image to RGB
                    img_bytes = tf.io.encode_jpeg(img_array).numpy()

                    # Create a tf.Example containing JUST image data and label
                    example = tf.train.Example(features=tf.train.Features(feature={
                        'label': _bytes_feature(label.encode('utf-8')),
                        'image_raw': _bytes_feature(img_bytes),
                    }))

                    # Determine whether to save the example to the train or eval folder
                    split_folder = 'train' if random.random() < train_split else 'eval'
                    relative_path = os.path.relpath(root, input_dir)
                    output_subdir = os.path.join(output_dir, relative_path, split_folder)
                    os.makedirs(output_subdir, exist_ok=True)

                    # Save the tf.Example to a TFRecord file
                    tfrecord_file = os.path.join(output_subdir, file + '.tfrecord')
                    with tf.io.TFRecordWriter(tfrecord_file) as writer:
                        writer.write(example.SerializeToString())
                else:
                    # Display the image with its shape and skip it
                    print(f"Skipping image {img_path} with shape {shape}")
                    img.show()

# Convert images in the Felidae directory to TFRecords and split them
input_dir = 'C:/Users/Guillermo/Desktop/My_projects/Data/Felidae'
output_dir = 'felidae_tfrecords_split'
convert_to_tfrecords_split(input_dir, output_dir)


Skipping image C:/Users/Guillermo/Desktop/My_projects/Data/Felidae\Lion\Lion_037.jpg with shape (598, 479, 4)
Skipping image C:/Users/Guillermo/Desktop/My_projects/Data/Felidae\Puma\Puma_003.jpg with shape (200, 340)
