In [None]:
import os
import glob
import random
import math
from pathlib import Path
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, backend as K
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

print('TensorFlow:', tf.__version__)
tf.random.set_seed(42)
np.random.seed(42)
random.seed(42)

DATA_ROOT = ''
CLASS_NAMES = None

# Image/Training Params
IMG_SIZE = (150, 150)
IMG_SHAPE = IMG_SIZE + (3,)
NUM_CLASSES = 2


# Data Loader

In [None]:

def _list_images_in_class_dir(class_dir):
    exts = ('*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tif', '*.tiff')
    files = []
    for e in exts:
        files.extend(glob.glob(os.path.join(class_dir, e)))
    return files

def _scan_classes(root, class_names=None):
    if class_names is None:
        classes = [d.name for d in Path(root).iterdir() if d.is_dir()]
        classes.sort()
    else:
        classes = class_names
    return classes

def _load_paths_labels(root, class_names):
    X, y = [], []
    for idx, cname in enumerate(class_names):
        cdir = os.path.join(root, cname)
        files = _list_images_in_class_dir(cdir)
        X.extend(files)
        y.extend([idx]*len(files))
    return np.array(X), np.array(y)

def _read_image(path, target_size):
    img = tf.keras.utils.load_img(path, target_size=target_size)
    img = tf.keras.utils.img_to_array(img)
    return img

def load_br35h(DATA_ROOT, class_names=None, img_size=(150,150), val_split=0.2, test_split=0.1):
    data_root = Path(DATA_ROOT)
    if not data_root.exists():
        raise FileNotFoundError(f"DATA_ROOT not found: {DATA_ROOT}")

    split_dirs = ['train', 'val', 'test']
    if all((data_root / d).exists() for d in split_dirs):
        tr_classes = _scan_classes(data_root / 'train', class_names)
        X_train, y_train = _load_paths_labels(data_root / 'train', tr_classes)
        X_val, y_val = _load_paths_labels(data_root / 'val', tr_classes)
        X_test, y_test = _load_paths_labels(data_root / 'test', tr_classes)
        return (X_train, y_train), (X_val, y_val), (X_test, y_test), tr_classes

    classes = _scan_classes(DATA_ROOT, class_names)
    X, y = _load_paths_labels(DATA_ROOT, classes)
    
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=(600 + 300), 
                                                        stratify=y, random_state=42, shuffle=True)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=300, 
                                                    stratify=y_temp, random_state=42, shuffle=True)
    return (X_train, y_train), (X_val, y_val), (X_test, y_test), classes

def make_tf_dataset(X_paths, y, batch_size=32, img_size=(150,150), shuffle=True, augment=False):
    AUTOTUNE = tf.data.AUTOTUNE
    
    def _load_and_preprocess(path, label):
        img = tf.numpy_function(lambda p: _read_image(p.decode('utf-8'), img_size), [path], tf.float32)
        img.set_shape(img_size + (3,))
        img = img / 255.0
        if augment:
            img = tf.image.random_flip_left_right(img)
            img = tf.image.random_flip_up_down(img)
        return img, tf.one_hot(label, depth=NUM_CLASSES)
    
    ds = tf.data.Dataset.from_tensor_slices((X_paths.astype('U'), y.astype(np.int32)))
    if shuffle:
        ds = ds.shuffle(buffer_size=min(1000, len(X_paths)), reshuffle_each_iteration=True)
    ds = ds.map(_load_and_preprocess, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds

(X_train, y_train), (X_val, y_val), (X_test, y_test), CLASS_NAMES_INFER = load_br35h(DATA_ROOT, CLASS_NAMES, IMG_SIZE)
print('Classes:', CLASS_NAMES_INFER)
print('Train/Val/Test sizes:', len(X_train), len(X_val), len(X_test))

train_ds = make_tf_dataset(X_train, y_train, BATCH_SIZE, IMG_SIZE, shuffle=True, augment=True)
val_ds   = make_tf_dataset(X_val,   y_val,   BATCH_SIZE, IMG_SIZE, shuffle=False)
test_ds  = make_tf_dataset(X_test,  y_test,  BATCH_SIZE, IMG_SIZE, shuffle=False)
