In [None]:
import tensorflow as tf
from sklearn.model_selection import train_test_split
import os
from tensorflow.keras.applications.vgg16 import preprocess_input
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import matplotlib.patches as patches

from yolo_loss import YoloLoss
from yolo_model import yolo_v1_model
from yolo_preprocessing import *

2025-11-09 17:22:19.411545: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762705339.431909    6903 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762705339.438225    6903 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1762705339.453231    6903 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762705339.453256    6903 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762705339.453258    6903 computation_placer.cc:177] computation placer alr

## Creating a model

In [None]:
img_size = (448, 448)
S, B, C = 7, 2, 20

model = yolo_v1_model(img_size=(img_size[0], img_size[1], 3), S=S, B=B, C=C)

## Preprocessing

In [None]:
PATH_DATASET_TRAIN = "/VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/Annotations/"

files = [os.path.join(PATH_DATASET_TRAIN, f) for f in sorted(os.listdir(PATH_DATASET_TRAIN)) if f.endswith(".xml")]
train_files, valid_files = train_test_split(files, test_size=0.2, random_state=42, shuffle=True)

# Removing classes head, hand, foot like in orginal paper
class_names = ["person", "bird", "cat", "cow", "dog", "horse", "sheep", "aeroplane", "bicycle", "boat", "bus", "car", "motorbike", "train", "bottle", "chair", "diningtable", "pottedplant", "sofa", "tvmonitor", "head", "foot", "hand"]
remove_those_vals = ['head', 'hand', 'foot']

batch_size = 32
train_gen = DataGenerator(train_files, img_size, class_names, remove_those_vals, S=S, B=B, C=C, batch_size=batch_size, shuffle=True, augment=True)
valid_gen = DataGenerator(valid_files, img_size, class_names, remove_those_vals, S=S, B=B, C=C, shuffle=True, augment=False)

## Model training

In [None]:
checkPointCallback = tf.keras.callbacks.ModelCheckpoint(
    './weights/weights.keras',
    monitor="val_loss",
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode="auto",
    save_freq="epoch",
    initial_value_threshold=None,
)

ReduceLROnPlateau = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.001,
    patience=10,
    verbose=0,
    mode='auto',
    min_delta=0.0001,
    cooldown=0,
    min_lr=1e-7)

In [None]:
loss_fn = YoloLoss(S=S, B=B, C=C, lambda_coord=5.0, lambda_noobj=0.5)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

model.compile(loss=loss_fn, optimizer=optimizer)

In [None]:
epochs = 150

history = model.fit(train_gen,
                    validation_data=valid_gen, 
                    epochs=epochs,
                    batch_size=1,
                    shuffle=True,
                    callbacks=[ReduceLROnPlateau, checkPointCallback])

## Inference

In [None]:
n_images_to_plot = 10
PATH_DATASET_TEST = "/VOCtest_06-Nov-2007/VOCdevkit/VOC2007/Annotations/"

X_test, y_boxes_test, y_labels_test = load_dataset(PATH_DATASET_TEST, remove_those_vals, image_target_size=img_size, n_img=n_images_to_plot)
X_test, y_boxes_test, y_labels_test = shuffle(X_test, y_boxes_test, y_labels_test)

X_test_preprocess = preprocess_input(np.array(X_test))
y_test_labels_onehot = to_one_hot(y_labels_test, class_names)

y_boxes_test_normalized = xywh_pixels_to_normalized_centers_per_image(y_boxes_test, img_size)

for i in range(len(y_boxes_test_normalized)):
    y_boxes_test_normalized[i] = np.array(y_boxes_test_normalized[i]).reshape(-1, 4)

# For validate encode & decode functions
y_test_encoded = encode_batch(y_boxes_test_normalized, y_test_labels_onehot, S=S, B=B, C=C)
y_test_decoded = decode_batch(y_test_encoded, S=S, B=B, img_size=img_size)

In [None]:
model.evaluate(X_test_preprocess, y_test_encoded)

In [None]:
y_pred = model(X_test[:n_images_to_plot], training=False).numpy()

y_pred_decoded = decode_batch(y_pred, S=S, B=B, img_size=img_size, conf_thresh=0.5)
print("Detections per image:", [len(r) for r in y_pred_decoded][:n_images_to_plot])

Non Maximum Suppresion can be implemented here.

In [None]:
def plot_img_and_box(X, y_boxes, y_pred_boxes=False, n_images=1, figsize=(10, 10)):

    fig, ax = plt.subplots(n_images, 2, figsize=(figsize[0], figsize[1]  * n_images))

    if n_images == 1:
        ax = np.array([ax])

    for i in range(n_images):
        ax[i, 0].imshow(X[i])
        ax[i, 1].imshow(X[i])

        # True boxes and labels
        for j in range(len(y_boxes[i])):

            # Boxes
            box_1 = np.array(y_boxes[i][j])
            x, y, w, h = box_1[0], box_1[1], box_1[2], box_1[3]
            rect =  patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='blue', facecolor='none', label='true')    
            ax[i, 0].add_patch(rect)

            # Labels and confs
            conf, label = box_1[4], box_1[5]
            label_text = class_names[int(label)]
            ax[i, 0].text(x, y -5, f"{label_text} ({conf:.2f})", color='black', fontsize=9, backgroundcolor='white')
            ax[i, 0].axis('off')

        # Pred boxes and labels
        for j in range(len(y_pred_boxes[i])):
            # Boxes
            box_2 = np.array(y_pred_boxes[i][j])
            x, y, w, h = box_2[0], box_2[1], box_2[2], box_2[3]
            rect =  patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='blue', facecolor='none', label='true')    
            ax[i, 1].add_patch(rect)

            # Labels and confs
            conf, label = box_2[4], box_2[5]
            label_text = class_names[int(label)]
            ax[i, 1].text(x, y -5, f"{label_text} ({conf:.2f})", color='black', fontsize=9, backgroundcolor='white')
            ax[i, 1].axis('off')

    plt.tight_layout(pad=0.1)
    plt.subplots_adjust(wspace=0.02, hspace=0.02, top=0.59, bottom=0.01)
    plt.show()

plot_img_and_box(X_test, y_test_decoded, y_pred_decoded, n_images=n_images_to_plot, figsize=(6, 6))