In [1]:
%%capture
%pip install tensorflow numpy matplotlib pandas scikit-learn seaborn

In [2]:
import tensorflow as tf
import pathlib
from typing import Tuple

2024-05-31 12:06:06.280889: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


**Create Dataset Class to stream images from directory**

In [None]:
# Define Dataset Class
import tensorflow as tf
import pathlib
from typing import Tuple

class MatrixDataset:
    def __init__(self, data_dir: str, batch_size: int, img_size: Tuple[int, int] = (64, 64)):
        """
        Initializes the MatrixDataset.

        Args:
            data_dir (str): Directory where PNG images are stored.
            img_size (Tuple[int, int]): Desired image size (height, width).
            batch_size (int): Batch size for the dataset.
        """
        self.data_dir = pathlib.Path(data_dir)
        self.img_size = img_size
        self.batch_size = batch_size
    
    def _load_and_preprocess_image(self, file_path: str) -> tf.Tensor:
        """
        Loads and processes a PNG image.

        Args:
            file_path (str): Path to the PNG image.

        Returns:
            tf.Tensor: Preprocessed image tensor.
        """
        try:
            img = tf.io.read_file(file_path)
            img = tf.image.decode_png(img, channels=4)
            img = tf.image.resize(img, self.img_size)
            img = tf.cast(img, tf.float32) / 255.0  # Normalize to [0, 1]
            return img
        except Exception as e:
            raise RuntimeError(f"Error loading and preprocessing image {file_path}: {e}")
    
    def _get_label(self, file_path: str) -> str:
        """
        Extracts the label from the file path. Assumes that images are stored in subdirectories named after their labels.

        Args:
            file_path (str): Path to the PNG image.

        Returns:
            str: Label of the image in hexadecimal format.
        """
        try:
            return file_path.parent.name
        except Exception as e:
            raise RuntimeError(f"Error extracting label from {file_path}: {e}")
    
    def _hex_to_int(self, hex_str: str) -> int:
        """
        Converts a hexadecimal string to an integer.

        Args:
            hex_str (str): Hexadecimal string.

        Returns:
            int: Integer representation of the hexadecimal string.
        """
        try:
            return int(hex_str, 16)
        except ValueError as e:
            raise RuntimeError(f"Error converting hex string to int: {hex_str}: {e}")
    
    
    def _load_and_preprocess_from_path_label(self, file_path: str) -> Tuple[tf.Tensor, tf.Tensor]:
        """
        Loads and preprocesses an image and its label.

        Args:
            file_path (str): Path to the image file.

        Returns:
            Tuple[tf.Tensor, tf.Tensor]: Preprocessed image and label tensor.
        """
        label = self._get_label(file_path)
        label = self._hex_to_int(label)
        label = tf.convert_to_tensor(label, dtype=tf.int32)
        image = self._load_and_preprocess_image(file_path)
        return image, label
        
    def create_dataset(self) -> tf.data.Dataset:
        """
        Creates a tf.data.Dataset from the PNG images in the data directory.

        Returns:
            tf.data.Dataset: Dataset object containing processed images and labels.
        """
        try:
            list_ds = tf.data.Dataset.list_files(str(self.data_dir / '*/*.png'))
            # Apply the _load_and_preprocess_from_path_label method to each file path in the dataset 
            labeled_ds = list_ds.map(lambda x: self._load_and_preprocess_from_path_label(x), num_parallel_calls=tf.data.experimental.AUTOTUNE) # Optimise number of parallel calls based on available CPU
            # Shuffle, batch, and prefetch the dataset for optimal performance (loading and training simultaneously)
            dataset = labeled_ds.shuffle(buffer_size=100).batch(self.batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
            return dataset
        except Exception as e:
            raise RuntimeError(f"Error creating dataset: {e}")

# Usage example:
# dataset = MatrixDataset(data_dir='path/to/pngs', img_size=(64, 64), batch_size=32)
# train_dataset = dataset.create_dataset()

*Load Datasets*