In [2]:
import sys
import os
import tensorflow as tf

# Add the parent directory of the current notebook to sys.path
sys.path.append(os.path.abspath(".."))

import config

In [3]:
class NestedTFRecordsConverter:
    def __init__(self, base_dir, output_dir):
        """
        Initialize converter for nested directory structure
        
        Args:
            base_dir (str): Base directory containing 'art/fake' and 'art/real' subdirectories
            output_dir (str): Directory to save TFRecords files
        """
        self.base_dir = base_dir
        self.output_dir = output_dir

        # Paths to fake and real image directories
        self.fake_dir = os.path.join(base_dir,  'fake')
        self.real_dir = os.path.join(base_dir, 'real')

        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)

    def _list_images(self, directory):
        """
        Recursively list all image files in a directory
        
        Args:
            directory (str): Root directory to search for images
        
        Returns:
            list: Full paths to image files
        """
        image_paths = []
        for root, _, files in os.walk(directory):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    image_paths.append(os.path.join(root, file))
        return image_paths

    def _parse_image(self, filename):
        """
        Read and preprocess image
        
        Args:
            filename (str): Path to image file
        
        Returns:
            tf.Tensor: Preprocessed image
        """
        # Read image
        image = tf.io.read_file(filename)
        image = tf.image.decode_image(image, channels=3)

        # Resize to consistent size
        image = tf.image.resize(image, config.IMG_SIZE)

        # Normalize pixel values
        image = tf.cast(image, tf.float32) / 255.0

        return image

    def _create_tfrecord(self, image_paths, label, output_filename):
        """
        Create TFRecord file from image paths
        
        Args:
            image_paths (list): List of image file paths
            label (int): Label for these images (0 for fake, 1 for real)
            output_filename (str): Path to save TFRecord file
        """
        with tf.io.TFRecordWriter(output_filename) as writer:
            for path in image_paths:
                try:
                    # Parse and preprocess image
                    image = self._parse_image(path)

                    # Serialize image features
                    feature = {
                        'image': tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy().flatten())),
                        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                        'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[0]])),
                        'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[1]])),
                        'channels': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[2]]))
                    }

                    # Create example and write to TFRecord
                    example = tf.train.Example(
                        features=tf.train.Features(feature=feature))
                    writer.write(example.SerializeToString())

                except Exception as e:
                    print(f"Error processing {path}: {e}")

    def convert_to_tfrecords(self, max_images_per_class=None):
        """
        Convert images to TFRecords
        
        Args:
            max_images_per_class (int, optional): Limit number of images per class
        """
        # List images in fake and real directories
        fake_images = self._list_images(self.fake_dir)
        real_images = self._list_images(self.real_dir)

        # Optionally limit number of images
        if max_images_per_class:
            fake_images = fake_images[:max_images_per_class]
            real_images = real_images[:max_images_per_class]

        # Print image counts
        print(f"Fake images found: {len(fake_images)}")
        print(f"Real images found: {len(real_images)}")

        # Create TFRecords for fake images (label 0)
        fake_tfrecord_path = os.path.join(
            self.output_dir, 'fake_images.tfrecord')
        self._create_tfrecord(fake_images, 0, fake_tfrecord_path)

        # Create TFRecords for real images (label 1)
        real_tfrecord_path = os.path.join(
            self.output_dir, 'real_images.tfrecord')
        self._create_tfrecord(real_images, 1, real_tfrecord_path)

    def verify_tfrecords(self, tfrecord_path):
        """
        Verify TFRecords file contents
        
        Args:
            tfrecord_path (str): Path to TFRecord file
        """
        dataset = tf.data.TFRecordDataset(tfrecord_path)

        feature_description = {
            'image': tf.io.FixedLenFeature([100*100 * 3], tf.float32),  # Fixed length vector
            'label': tf.io.FixedLenFeature([], tf.int64),
            'height': tf.io.FixedLenFeature([], tf.int64),
            'width': tf.io.FixedLenFeature([], tf.int64),
            'channels': tf.io.FixedLenFeature([], tf.int64)
        }

        # Count images and print some details
        count = 0
        for record in dataset:
            parsed_record = tf.io.parse_single_example(
                record, feature_description)
            image = tf.reshape(parsed_record['image'],
                               [parsed_record['height'],
                                parsed_record['width'],
                                parsed_record['channels']])
            count += 1

            # Print first few images details
            if count <= 5:
                print(f"Image {count}:")
                print(f"  Label: {parsed_record['label'].numpy()}")
                print(f"  Shape: {image.shape}")

        print(f"Total images in TFRecord: {count}")


# Example usage
converter = NestedTFRecordsConverter(
    base_dir=config.DATA_DIR, 
    output_dir=config.OUTPUT_TFRECORD_DIR
)

# Convert images to TFRecords
converter.convert_to_tfrecords() 


Fake images found: 10000
Real images found: 10000
