In [None]:
import os
import shutil
import tensorflow as tf
from zipfile import ZipFile, BadZipFile


class DataPipeline:
    def __init__(self, path_rd, extract_to, img_size=(200, 200), batch_size=32):
        """
        Initialize the data pipeline.

        Args:
            path_rd (str): Path to the compressed (ZIP) dataset file.
            extract_to (str): Destination directory for extracted dataset.
            img_size (tuple): Target image dimensions (height, width).
            batch_size (int): Size of training batches.
        """
        self.path_rd = path_rd
        self.extract_to = extract_to
        self.img_size = img_size
        self.batch_size = batch_size
        self.class_names = [
            "crazing", 
            "inclusion", 
            "patches", 
            "pitted_surface", 
            "rolled-in_scale", 
            "scratches"
        ]

    def running_engine(self):
        """Execute the full pipeline and return prepared datasets."""
        self.extract_data()
        self.organize_flat_structure()
        return self.create_datasets()

    def extract_data(self):
        """Step 1: Extract ZIP dataset if not already extracted."""
        if not os.path.exists(self.extract_to) or len(os.listdir(self.extract_to)) == 0:
            print(f"Extracting data from {self.path_rd}...")
            try:
                with ZipFile(self.path_rd, "r") as zip_file:
                    zip_file.extractall(self.extract_to)
                print("Extraction completed successfully.")
            except BadZipFile:
                raise RuntimeError("Dataset ZIP is corrupted.")
        else:
            print("Data already exists on disk.")

    def organize_flat_structure(self):
        """
        Step 2: Organize flat images into class-named subfolders.
        """
        print("Organizing images into class-specific folders...")

        for split in ["train", "valid", "test"]:
            split_path = os.path.join(self.extract_to, split)

            if not os.path.exists(split_path):
                continue

            for filename in os.listdir(split_path):
                file_path = os.path.join(split_path, filename)

                # --- Skip directories ---
                if os.path.isdir(file_path):
                    continue

                # --- Move file into correct class folder ---
                for class_name in self.class_names:
                    if filename.lower().startswith(class_name.lower()):
                        target_dir = os.path.join(split_path, class_name)
                        os.makedirs(target_dir, exist_ok=True)
                        shutil.move(file_path, os.path.join(target_dir, filename))
                        break

    def create_datasets(self):
        """
        Step 3: Create tf.data.Dataset objects from organized directories.
        """

        train_path = os.path.join(self.extract_to, "train")
        val_path = os.path.join(self.extract_to, "valid")
        test_path = os.path.join(self.extract_to, "test")

        # --- Load datasets ---
        train_data = tf.keras.utils.image_dataset_from_directory(
            train_path, image_size=self.img_size, batch_size=self.batch_size,
            label_mode="categorical"
        )

        val_data = tf.keras.utils.image_dataset_from_directory(
            val_path, image_size=self.img_size, batch_size=self.batch_size,
            label_mode="categorical",
            validation_split=0.5, 
            subset="training",    
            seed=123
        )

        test_data = tf.keras.utils.image_dataset_from_directory(
            val_path, image_size=self.img_size, batch_size=self.batch_size,
            label_mode="categorical",
            validation_split=0.5, 
            subset="validation", 
            seed=123
        )

        # --- Preprocessing layers ---
        rescaler = tf.keras.layers.Rescaling(1.0 / 255)
        augmentation = tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal_and_vertical"),
            tf.keras.layers.RandomRotation(0.2)
        ])

        # --- Apply augmentation to training data ---
        train_data = train_data.map(
            lambda x, y: (augmentation(x, training=True), y),
            num_parallel_calls=tf.data.AUTOTUNE
        )

        # --- Apply rescaling to all datasets ---
        train_data = train_data.map(
            lambda x, y: (rescaler(x), y),
            num_parallel_calls=tf.data.AUTOTUNE
        )
        val_data = val_data.map(
            lambda x, y: (rescaler(x), y),
            num_parallel_calls=tf.data.AUTOTUNE
        )
        test_data = test_data.map(
            lambda x, y: (rescaler(x), y),
            num_parallel_calls=tf.data.AUTOTUNE
        )

        # --- Prefetch for performance --- 
        return (
            train_data.prefetch(tf.data.AUTOTUNE),
            val_data.prefetch(tf.data.AUTOTUNE),
            test_data.prefetch(tf.data.AUTOTUNE)
        )
        
if __name__ == "__main__":
    pipeline = DataPipeline("data/dataset.zip", "data/steel_data")
    train_data, val_data, test_data = pipeline.running_engine()