In [43]:
import tensorflow as tf
import numpy as np
import time
import os
import glob
from tensorflow.keras.preprocessing.image import load_img, img_to_array

# ---- CONSTANTS ----
IMG_HEIGHT = 128
IMG_WIDTH = 128
MAX_BOXES = 5  # Number of bounding boxes per image
NUM_CLASSES = 2  # Assuming binary classification (0 or 1)
EPOCHS = 10
BATCH_SIZE = 8
STEPS_PER_EPOCH = 100
VAL_STEPS = 20

# ---- DATA LOADING ----
train_images_dir = "path/to/train/images"
train_labels_dir = "path/to/train/labels"

val_images_dir = "path/to/val/images"
val_labels_dir = "path/to/val/labels"

train_image_paths = sorted(glob.glob(os.path.join(train_images_dir, "*.jpg")))
val_image_paths = sorted(glob.glob(os.path.join(val_images_dir, "*.jpg")))

train_label_paths = sorted(glob.glob(os.path.join(train_labels_dir, "*.txt")))
val_label_paths = sorted(glob.glob(os.path.join(val_labels_dir, "*.txt")))

print(f"Images found for training: {len(train_image_paths)}")
print(f"Images found for validation: {len(val_image_paths)}")
print(f"Label files found: {len(train_label_paths)}")

def load_labels(label_path):
    labels = np.zeros((MAX_BOXES, 5))  # Format: [class, x, y, w, h]
    try:
        with open(label_path, "r") as f:
            lines = f.readlines()
            for i, line in enumerate(lines[:MAX_BOXES]):
                labels[i] = np.array([float(x) for x in line.strip().split()])
    except:
        pass  # If file missing, return zeros
    return labels

def data_generator(image_paths, label_paths, batch_size):
    while True:
        for i in range(0, len(image_paths), batch_size):
            batch_images = []
            batch_labels = []
            for j in range(batch_size):
                if i + j >= len(image_paths):
                    break
                img = load_img(image_paths[i + j], target_size=(IMG_HEIGHT, IMG_WIDTH))
                img = img_to_array(img) / 255.0  # Normalize to [0,1]
                label = load_labels(label_paths[i + j])
                
                batch_images.append(img)
                batch_labels.append(label)
            
            yield np.array(batch_images), np.array(batch_labels)

train_generator = data_generator(train_image_paths, train_label_paths, BATCH_SIZE)
val_generator = data_generator(val_image_paths, val_label_paths, BATCH_SIZE)

# ---- CUSTOM LOSS FUNCTION ----
def custom_loss(y_true, y_pred):
    tf.print("DEBUG: y_true shape:", tf.shape(y_true))
    tf.print("DEBUG: y_pred shape:", tf.shape(y_pred))

    y_true = tf.reshape(y_true, [-1, MAX_BOXES, 5])
    y_pred = tf.reshape(y_pred, [-1, MAX_BOXES, 5])

    # Extracting class labels and bounding box predictions
    class_labels = tf.cast(y_true[..., 0], tf.int32)  # Class labels (int)
    logits = y_pred[..., 0]  # Predicted class scores

    # Ensure class_labels and logits have the same shape
    class_labels = tf.reshape(class_labels, [-1])
    logits = tf.reshape(logits, [-1, NUM_CLASSES])  # Make logits match

    bbox_loss = tf.reduce_mean(tf.keras.losses.MAE(y_true[..., 1:], y_pred[..., 1:]))
    class_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(class_labels, logits, from_logits=True))

    return bbox_loss + class_loss

# ---- MODEL DEFINITION ----
model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
    tf.keras.layers.Conv2D(32, (3, 3), activation="relu"),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(MAX_BOXES * 5, activation="linear"),
    tf.keras.layers.Reshape((MAX_BOXES, 5))
])

# ---- COMPILE MODEL ----
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=custom_loss,
    metrics=["accuracy"]
)

# ---- TRAINING ----
start_time = time.time()

history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VAL_STEPS
)

end_time = time.time()
print(f"Total training time: {end_time - start_time:.2f} seconds")


Images found for training: 0
Images found for validation: 0
Label files found: 0


KeyboardInterrupt: 