In [None]:
import os
import cv2
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, applications, callbacks
from tensorflow.keras.regularizers import l2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.saving import load_model
import time # Added for time measurement

# --- DATA PATHS (Already specified by you) ---
new_images_dir = r"C:\Users\andrey\.cache\kagglehub\datasets\sadhliroomyprime\cattle-weight-detection-model-dataset-12k\versions\3\www.acmeai.tech Dataset - BMGF-LivestockWeight-CV\Vector\B3\Side\data\images"
new_annotations_file = r"C:\Users\andrey\.cache\kagglehub\datasets\sadhliroomyprime\cattle-weight-detection-model-dataset-12k\versions\3\www.acmeai.tech Dataset - BMGF-LivestockWeight-CV\Vector\B3\Side\data\COCO_Side.json"
# ---------------------------------------------

# --- CONFIGURATION ---
DATA_LIMIT = 500 # Limit on the number of annotations to process
TARGET_SIZE = (224, 224)
NUM_KEYPOINTS = 9
AUGMENT_DATA = True # Whether to perform data augmentation
EPOCHS = 100 # Initial number of epochs
BATCH_SIZE = 16
# -------------------

print("--- Script execution started ---")
print(f"TensorFlow Version: {tf.__version__}")
print(f"Available GPUs: {tf.config.list_physical_devices('GPU')}")

# Loading annotations from JSON
print(f"Loading annotations from: {new_annotations_file}")
try:
    with open(new_annotations_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    print("Annotation file successfully loaded and parsed.")
except FileNotFoundError:
    print(f"!!! ERROR !!!: Could not find the annotation file: {new_annotations_file}")
    exit()
except json.JSONDecodeError as e:
    print(f"!!! ERROR !!!: Could not parse the JSON file: {new_annotations_file}. Error: {e}")
    exit()
except Exception as e:
    print(f"!!! ERROR !!!: An unexpected error occurred while reading the annotation file: {e}")
    exit()

# Checking the data structure
print("Checking JSON structure...")
if not all(key in data for key in ['annotations', 'images', 'categories']):
    print("!!! ERROR !!!: The JSON file has an incorrect structure. Expected keys 'annotations', 'images', 'categories'.")
    keys_found = list(data.keys())
    print(f"Keys found: {keys_found}")
    exit()
print("JSON structure is correct.")

annotations = data['annotations']
images_data = data['images'] # Renamed to avoid conflict with the module
categories = data['categories']

print(f"Found {len(annotations)} annotations and {len(images_data)} image records in the JSON.")

# Function to load and augment images and keypoints
def load_data(annotations, images_data, images_dir, target_size=TARGET_SIZE, num_keypoints=NUM_KEYPOINTS, augment=AUGMENT_DATA, limit=DATA_LIMIT):
    print(f"\n--- Starting data loading (Limit: {limit} annotations) ---")
    start_time = time.time()
    images_list = []
    keypoints_list = []
    image_info_list = []
    processed_count = 0
    skipped_annotations = 0
    skipped_images = 0

    # Create a dictionary for quick access to image information by ID
    image_id_map = {img['id']: img for img in images_data}
    print(f"Created a map for {len(image_id_map)} images.")

    total_annotations_to_process = min(len(annotations), limit) if limit is not None else len(annotations)

    for i, annotation in enumerate(annotations):
        # --- CHANGE: Applying the limit ---
        if limit is not None and processed_count >= limit:
            print(f"\nReached the limit of {limit} successfully processed annotations. Stopping data loading.")
            break
        # -----------------------------------

        if (i + 1) % 50 == 0: # Logging progress
            print(f"  Processing annotation {i+1}/{len(annotations)} (Target: {processed_count}/{limit})...")

        image_id = annotation.get('image_id') # Safer way to get ID
        if image_id is None:
            print(f"  Warning: Annotation {annotation.get('id', 'N/A')} is missing 'image_id'. Skipping.")
            skipped_annotations += 1
            continue

        # Check if image information exists
        if image_id not in image_id_map:
            print(f"  Warning: Skipping annotation {annotation.get('id', 'N/A')} because no image with ID {image_id} was found.")
            skipped_annotations += 1
            continue

        image_info = image_id_map[image_id]
        image_name = image_info.get('file_name')
        if not image_name:
            print(f"  Warning: Image info for ID {image_id} is missing 'file_name'. Skipping annotation {annotation.get('id', 'N/A')}.")
            skipped_annotations += 1
            continue

        # Get original image dimensions
        orig_width = image_info.get('width')
        orig_height = image_info.get('height')
        if not orig_width or not orig_height:
            print(f"  Warning: Image info for ID {image_id} ('{image_name}') is missing dimensions ('width' or 'height'). Skipping.")
            skipped_annotations += 1
            continue

        # Load the image
        image_path = os.path.join(images_dir, image_name)
        if not os.path.exists(image_path):
            print(f"  Warning: Image file not found: {image_path}. Skipping annotation {annotation.get('id', 'N/A')}.")
            skipped_annotations += 1
            skipped_images += 1 # Count as a skipped image
            continue

        try:
            image = cv2.imread(image_path)
            if image is None:
                print(f"  Warning: Failed to load image (cv2.imread returned None): {image_path}. Skipping.")
                skipped_annotations += 1
                skipped_images += 1
                continue
        except Exception as e:
            print(f"  ERROR reading image {image_path}: {e}. Skipping.")
            skipped_annotations += 1
            skipped_images += 1
            continue

        # Resize the image
        try:
            image_resized = cv2.resize(image, target_size)
        except Exception as e:
            print(f"  ERROR resizing image {image_path}: {e}. Skipping.")
            skipped_annotations += 1
            continue

        # Scale keypoint coordinates
        scale_x = target_size[0] / orig_width
        scale_y = target_size[1] / orig_height

        keypoints_list_raw = annotation.get('keypoints')
        if keypoints_list_raw is None:
            print(f"  Warning: Annotation {annotation.get('id', 'N/A')} is missing the 'keypoints' key. Skipping.")
            skipped_annotations += 1
            continue

        # --- Keypoint Processing ---
        expected_kpt_len = num_keypoints * 3
        if len(keypoints_list_raw) != expected_kpt_len:
            print(f"  Warning: Incorrect number of values ({len(keypoints_list_raw)}) in 'keypoints' for annotation {annotation['id']} (expected {expected_kpt_len}). Skipping.")
            skipped_annotations += 1
            continue

        keypoints_scaled = []
        valid_keypoints = True
        for ki in range(0, expected_kpt_len, 3):
            try:
                x_orig = float(keypoints_list_raw[ki])
                y_orig = float(keypoints_list_raw[ki + 1])
                # visibility = keypoints_list_raw[ki + 2] # Ignored
            except (ValueError, TypeError) as e:
                print(f"  Warning: Invalid coordinate data ({keypoints_list_raw[ki]}, {keypoints_list_raw[ki+1]}) in annotation {annotation['id']}: {e}. Skipping annotation.")
                valid_keypoints = False
                break # Stop processing keypoints for this annotation

            if x_orig == 0 and y_orig == 0:
                keypoints_scaled.append([0.0, 0.0])
            else:
                x = x_orig * scale_x
                y = y_orig * scale_y
                keypoints_scaled.append([x, y])

        if not valid_keypoints:
            skipped_annotations += 1
            continue # Move to the next annotation

        # Check if we got the correct number of points
        if len(keypoints_scaled) != num_keypoints:
            # This check is now less likely due to the previous array length check
            print(f"  Logic Error: After processing, got {len(keypoints_scaled)} points instead of {num_keypoints} for annotation {annotation['id']}. Skipping.")
            skipped_annotations += 1
            continue
        # -----------------------------

        # --- If everything is okay, add the data and perform augmentation ---
        current_image_normalized = image_resized / 255.0
        current_keypoints_flat = np.array(keypoints_scaled).flatten()
        current_info = {
            'file_name': image_name,
            'orig_width': orig_width,
            'orig_height': orig_height,
            'scale_x': scale_x,
            'scale_y': scale_y,
            'annotation_id': annotation.get('id', 'N/A') # Add annotation ID for debugging
        }

        images_list.append(current_image_normalized)
        keypoints_list.append(current_keypoints_flat)
        image_info_list.append(current_info)

        # Augmentation
        if augment:
            # 1. Horizontal flip
            try:
                flipped_image = cv2.flip(image_resized, 1)
                flipped_keypoints = []
                for kp in keypoints_scaled:
                    if kp[0] != 0 or kp[1] != 0:
                        flipped_keypoints.append([target_size[0] - kp[0], kp[1]])
                    else:
                        flipped_keypoints.append([0.0, 0.0])
                images_list.append(flipped_image / 255.0)
                keypoints_list.append(np.array(flipped_keypoints).flatten())
                image_info_list.append({**current_info, 'file_name': f"flipped_{image_name}"}) # Copy info
            except Exception as e:
                print(f"  Error during augmentation (flip) for {image_name}: {e}")

            # 2. Brightness adjustment
            try:
                bright_image = np.clip(image_resized * np.random.uniform(0.8, 1.2), 0, 255).astype(np.uint8)
                images_list.append(bright_image / 255.0)
                keypoints_list.append(current_keypoints_flat) # Points don't change
                image_info_list.append({**current_info, 'file_name': f"bright_{image_name}"})
            except Exception as e:
                print(f"  Error during augmentation (bright) for {image_name}: {e}")

            # 3. Rotation
            try:
                angle = np.random.uniform(-15, 15)
                center = (target_size[0] // 2, target_size[1] // 2)
                M = cv2.getRotationMatrix2D(center, angle, 1.0)
                rotated_image = cv2.warpAffine(image_resized, M, target_size)

                rotated_keypoints = []
                for kp in keypoints_scaled:
                    if kp[0] != 0 or kp[1] != 0:
                        point = np.array([[kp[0]], [kp[1]], [1]])
                        rotated_point = M @ point
                        rotated_keypoints.append([rotated_point[0, 0], rotated_point[1, 0]])
                    else:
                        rotated_keypoints.append([0.0, 0.0])
                images_list.append(rotated_image / 255.0)
                keypoints_list.append(np.array(rotated_keypoints).flatten())
                image_info_list.append({**current_info, 'file_name': f"rotated_{image_name}"})
            except Exception as e:
                print(f"  Error during augmentation (rotate) for {image_name}: {e}")

        processed_count += 1 # Increment the counter ONLY AFTER successfully processing ONE annotation

    # --- End of loading loop ---
    end_time = time.time()
    print(f"\n--- Data loading finished in {end_time - start_time:.2f} sec ---")
    print(f"Successfully processed annotations: {processed_count}")
    print(f"Skipped annotations due to errors/filters: {skipped_annotations}")
    print(f"Skipped images due to loading/path errors: {skipped_images}")
    print(f"Total examples (with augmentation): {len(images_list)}")

    if not images_list:
        print("!!! ERROR !!!: Image list is empty after data loading!")
        return None, None, None # Return None to indicate an error

    return np.array(images_list), np.array(keypoints_list), image_info_list

# --- LOADING DATA WITH A LIMIT ---
X, y, image_info_list = load_data(annotations, images_data, new_images_dir)

# --- POST-LOADING CHECK ---
if X is None or X.shape[0] == 0:
    print("\n!!! CRITICAL ERROR !!!: Failed to load data. Further execution is not possible.")
    exit()

print("\nShape of loaded images (X):", X.shape)
print("Shape of loaded keypoints (y):", y.shape)
print(f"Number of examples for training (with augmentation): {X.shape[0]}")

# --- DATA SPLITTING ---
print("\nSplitting data into training and validation sets...")
try:
    X_train, X_val, y_train, y_val, info_train, info_val = train_test_split(
        X, y, image_info_list, test_size=0.2, random_state=42
    )
    print(f"Training set: {X_train.shape[0]} examples")
    print(f"Validation set: {X_val.shape[0]} examples")
except ValueError as e:
    print(f"\n!!! ERROR during data splitting: {e}")
    print("Perhaps too little data was loaded to split. Check DATA_LIMIT.")
    exit()

# --- MODEL CREATION ---
print("\nCreating the model...")
def create_improved_model(img_height, img_width, num_keypoints):
    print(f"  Creating EfficientNetB3 base model (input: {img_height}x{img_width}x3)")
    base_model = applications.EfficientNetB3(
        input_shape=(img_height, img_width, 3),
        include_top=False,
        weights='imagenet'
    )
    print("  Base model created.")

    trainable_layers = 30
    print(f"  Freezing base model layers, except for the last {trainable_layers} (and BN)")
    for layer in base_model.layers[:-trainable_layers]:
        layer.trainable = False
    for layer in base_model.layers[-trainable_layers:]:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = True
    print("  Layers frozen/unfrozen.")

    inputs = layers.Input(shape=(img_height, img_width, 3))
    x = inputs # Assuming normalization in load_data

    print("  Building the custom model 'head'...")
    x = base_model(x, training=False) # Important: training=False for BN in frozen layers

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(1024, activation='relu', kernel_regularizer=l2(0.001))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(512, activation='relu', kernel_regularizer=l2(0.001))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)

    outputs = layers.Dense(num_keypoints * 2, activation='linear', name='keypoints_output')(x)
    print(f"  Output layer created: {num_keypoints * 2} neurons.")

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    print("Model successfully created.")
    return model

model = create_improved_model(TARGET_SIZE[0], TARGET_SIZE[1], NUM_KEYPOINTS)

# --- OPTIMIZER AND LOSS FUNCTION ---
initial_learning_rate = 0.001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=max(1, (X_train.shape[0] // BATCH_SIZE) * 5), # Adapt decay_steps
    decay_rate=0.9, staircase=True
)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)


# --- MODEL COMPILATION ---
print("\nCompiling the model...")
model.compile(
    optimizer=optimizer,
    loss='mse',  # or another built-in loss like 'mae'
    metrics=['mae']
)
print("Model compiled.")
model.summary()

# --- CALLBACKS ---
print("\nSetting up callbacks...")
early_stopping = callbacks.EarlyStopping(
    monitor='val_loss', patience=15, verbose=1, restore_best_weights=True
)
# reduce_lr = callbacks.ReduceLROnPlateau(
#     monitor='val_loss', factor=0.2, patience=5, verbose=1, min_lr=1e-6
# )
checkpoint_path = 'best_keypoints_model_9pts_limited.keras' # Change the name for the limited version
checkpoint = callbacks.ModelCheckpoint(
    checkpoint_path, monitor='val_loss', save_best_only=True, verbose=1
)
print(f"Callbacks configured. The best model will be saved to '{checkpoint_path}'")

# --- MODEL TRAINING ---
def train_model_func(X_train, y_train, X_val, y_val, model, epochs=EPOCHS, batch_size=BATCH_SIZE):
    print(f"\n--- Starting model training ({epochs} epochs, batch_size={batch_size}) ---")
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=[early_stopping, checkpoint],
        verbose=1
    )
    print("--- Training finished ---")
    return history

# Uncomment to train:
# history = train_model_func(X_train, y_train, X_val, y_val, model)
# Instead, for testing without a long training session, you can try to load a pre-trained model,
# or just skip training if `checkpoint_path` already exists.
# For this example, we'll leave the training commented out. If the model file doesn't exist, the script will notify you.
history = None # Initialize so that plot_history doesn't cause an error if training is skipped

if os.path.exists(checkpoint_path):
    print(f"Found an existing model file '{checkpoint_path}'. Training will be skipped unless uncommented above.")
else:
    print(f"Model file '{checkpoint_path}' not found. If needed, uncomment the training block.")
    # If you want to train when the model is not found:
    # print("Starting training because the model was not found...")
    # history = train_model_func(X_train, y_train, X_val, y_val, model)


# --- TRAINING VISUALIZATION ---
def plot_history(history):
    if not history or not history.history:
        print("No history data to visualize (perhaps training was skipped or the model was loaded).")
        return

    print("\nVisualizing the training process...")
    plt.figure(figsize=(12, 5))

    try:
        # Loss plot
        plt.subplot(1, 2, 1)
        plt.plot(history.history['loss'], label='Loss (Training)')
        plt.plot(history.history['val_loss'], label='Loss (Validation)')
        plt.title('Loss Function Dynamics')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        # MAE metric plot
        plt.subplot(1, 2, 2)
        plt.plot(history.history['mae'], label='MAE (Training)')
        plt.plot(history.history['val_mae'], label='MAE (Validation)')
        plt.title('Mean Absolute Error Dynamics')
        plt.xlabel('Epoch')
        plt.ylabel('MAE')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.show()
        print("Training plots displayed.")
    except KeyError as e:
        print(f"Error during visualization: missing key in training history: {e}")
    except Exception as e:
        print(f"Error while plotting graphs: {e}")

if history: # Call only if training occurred
    plot_history(history)

# --- FUNCTIONS FOR PREDICTION AND VISUALIZATION (remain unchanged) ---
def rescale_keypoints(scaled_keypoints_flat, orig_width, orig_height, target_size=TARGET_SIZE):
    scale_x = orig_width / target_size[0]
    scale_y = orig_height / target_size[1]
    scaled_keypoints = scaled_keypoints_flat.reshape(-1, 2)
    rescaled_keypoints = []
    for kp in scaled_keypoints:
        x = kp[0] * scale_x
        y = kp[1] * scale_y
        rescaled_keypoints.append([x, y])
    return np.array(rescaled_keypoints)

def visualize_keypoints(image_path, predicted_keypoints_original, title="Predicted Keypoints"):
    print(f"\nVisualizing result for: {image_path}")
    print(f"  Predicted keypoints (original scale): {predicted_keypoints_original}")

    image = cv2.imread(image_path)
    if image is None:
        print(f"  Error: Failed to load image for visualization: {image_path}")
        return

    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    orig_height, orig_width = image.shape[:2]

    plt.figure(figsize=(10, 10))
    plt.imshow(image_rgb)

    # Copy image_rgb for drawing so as not to change the original shown by imshow
    image_to_draw_on = image_rgb.copy()

    for i, point in enumerate(predicted_keypoints_original):
        x, y = int(point[0]), int(point[1])
        # Draw only if the point is within the image boundaries
        if 0 <= x < orig_width and 0 <= y < orig_height:
            color_map = plt.get_cmap("tab10") # Use a color palette
            # Convert color from RGBA (0-1) to RGB for plt
            color_for_plt = color_map(i / NUM_KEYPOINTS) # Normalize index for the palette

            # Draw on the copy of the image
            # We use cv2.circle, but on image_to_draw_on
            # Convert plt color (R,G,B,A) to (B,G,R) for cv2, but here we work with RGB
            cv_color = tuple(int(c * 255) for c in color_for_plt[:3]) # Take only RGB, ignore alpha

            cv2.circle(image_to_draw_on, (x, y), radius=max(5, int(min(orig_width, orig_height)*0.01)), color=cv_color, thickness=-1) # Filled circle
            cv2.circle(image_to_draw_on, (x, y), radius=max(7, int(min(orig_width, orig_height)*0.012)), color=(255,255,255), thickness=2) # White outline

    plt.imshow(image_to_draw_on) # Show the image with drawn points
    plt.title(title)
    plt.axis('off')

    save_path = './predicted_keypoints_output_limited.png'
    try:
        # Save the image with drawn points, converting from RGB to BGR for cv2.imwrite
        image_to_save_bgr = cv2.cvtColor(image_to_draw_on, cv2.COLOR_RGB2BGR)
        cv2.imwrite(save_path, image_to_save_bgr)
        print(f"  Image with keypoints saved to: {save_path}")
    except Exception as e:
        print(f"  Failed to save the image: {e}")
    plt.show()


def predict_keypoints(image_path, model, target_size=TARGET_SIZE):
    print(f"\nPrediction for image: {image_path}")
    image = cv2.imread(image_path)
    if image is None:
        print(f"  Error: Failed to load image for prediction: {image_path}")
        return None

    orig_height, orig_width = image.shape[:2]
    print(f"  Original size: {orig_width}x{orig_height}")

    try:
        image_resized = cv2.resize(image, target_size)
        image_input = image_resized / 255.0
        image_input = np.expand_dims(image_input, axis=0)
        print(f"  Input data shape for model: {image_input.shape}")
    except Exception as e:
        print(f"  Error preparing image for prediction: {e}")
        return None

    try:
        start_pred_time = time.time()
        keypoints_pred_scaled_flat = model.predict(image_input)
        end_pred_time = time.time()
        print(f"  Prediction executed in {end_pred_time - start_pred_time:.4f} sec.")
        print(f"  Model output shape: {keypoints_pred_scaled_flat.shape}")
    except Exception as e:
        print(f"  Error during model.predict execution: {e}")
        return None

    try:
        keypoints_original = rescale_keypoints(keypoints_pred_scaled_flat[0], orig_width, orig_height, target_size)
        print(f"  Predicted keypoints shape (original scale): {keypoints_original.shape}")
        # print(f"  Predicted keypoints: {keypoints_original}") # Commented out, as it's printed in visualize_keypoints
        return keypoints_original
    except Exception as e:
        print(f"  Error during keypoint rescaling: {e}")
        return None


# --- MODEL USAGE ---
print("\n--- Loading and testing the best model ---")
best_model = None
if os.path.exists(checkpoint_path):
    try:
        print(f"Loading model from: {checkpoint_path}")
        best_model = load_model(checkpoint_path)
        print("Model successfully loaded.")

        # --- CHANGE THE PATH to the test image ---
        # Specify the path to ONE of your NEW images for testing
        # Important: this image MUST be among those in images_dir
        # Example, REPLACE WITH A REAL FILE!
        # test_image_filename = "9_s_181_F.jpg" # Example
        # Let's try to take the first image from the validation set, if it exists
        test_image_path = None
        if 'info_val' in locals() and info_val and len(info_val) > 0:
            test_image_filename = info_val[0]['file_name']
            test_image_path = os.path.join(new_images_dir, test_image_filename)
            print(f"Selected test image from validation set: {test_image_filename}")
        else: # If the validation set does not exist or is empty, use a fallback
            default_test_image = "9_s_181_F.jpg" # REPLACE IF THIS FILE DOESN'T EXIST
            test_image_path = os.path.join(new_images_dir, default_test_image)
            print(f"Validation set is empty or unavailable. Using default test image: {default_test_image}")
            print(f"!!! ATTENTION: Make sure the file '{default_test_image}' exists in '{new_images_dir}' !!!")

        # ---------------------------------------------
        print(f"Path to the test image: {test_image_path}")

        if test_image_path and os.path.exists(test_image_path):
            predicted_keypoints = predict_keypoints(test_image_path, best_model)
            if predicted_keypoints is not None:
                visualize_keypoints(test_image_path, predicted_keypoints)
        else:
            print(f"!!! ERROR: Test image NOT FOUND at path: {test_image_path}")
            print("!!! Check if the file exists and the filename is correct, or if the validation set is not empty.")

    except Exception as e:
        print(f"!!! ERROR while loading or using the model: {e}")
        print(f"!!! Make sure the file '{checkpoint_path}' exists, is not corrupted, and that training ran at least partially.")
else:
    print(f"WARNING: The best model file '{checkpoint_path}' was not found. Testing is not possible.")


# --- MODEL EVALUATION (PCK) ---
# Updated calculate_pck function (info_val removed)
def calculate_pck(model, X_val, y_val, threshold_factor=0.2, target_size=TARGET_SIZE, num_keypoints=NUM_KEYPOINTS):
    if X_val is None or X_val.shape[0] == 0:
        print("Cannot calculate PCK: validation set is empty.")
        return 0.0

    correct_keypoints = 0
    total_visible_keypoints = 0

    # Changed: The message about PCK calculation will be made by the calling function (plot_pck_curve) or once
    # print(f"\n--- Calculating PCK@{threshold_factor} on the validation set ({X_val.shape[0]} examples)... ---")
    try:
        predictions_scaled_flat = model.predict(X_val, verbose=0) # verbose=0 to reduce output in the loop
    except Exception as e:
        print(f"  Error during prediction on the validation set: {e}")
        return 0.0

    normalizer = np.sqrt(target_size[0]**2 + target_size[1]**2) * threshold_factor

    for i in range(len(X_val)):
        try:
            pred_keypoints_scaled = predictions_scaled_flat[i].reshape(num_keypoints, 2)
            gt_keypoints_scaled = y_val[i].reshape(num_keypoints, 2)
        except Exception as e:
            print(f"  Error on reshape during PCK calculation for example {i}: {e}")
            continue # Skip this example

        for j in range(num_keypoints):
            # A keypoint is considered visible if its ground truth coordinates are not (0,0)
            # (this is an assumption used during data loading for missing points)
            if gt_keypoints_scaled[j][0] != 0 or gt_keypoints_scaled[j][1] != 0:
                total_visible_keypoints += 1
                try:
                    error = np.sqrt(np.sum((pred_keypoints_scaled[j] - gt_keypoints_scaled[j])**2))
                    if error < normalizer:
                        correct_keypoints += 1
                except Exception as e:
                    print(f"  Error calculating error/pck for point {j} of example {i}: {e}")

    if total_visible_keypoints == 0:
        # print("WARNING: No visible keypoints found in the validation set to calculate PCK.")
        return 0.0 # Return 0.0 if there are no visible points

    pck = correct_keypoints / total_visible_keypoints
    # Changed: The output will be made by the calling function
    # print(f"  Number of visible keypoints in the validation set: {total_visible_keypoints}")
    # print(f"  Number of correctly detected keypoints: {correct_keypoints}")
    # print(f"  PCK@{threshold_factor}: {pck:.4f}")
    return pck

# New function to visualize the PCK curve
def plot_pck_curve(model, X_val, y_val, target_size=TARGET_SIZE, num_keypoints=NUM_KEYPOINTS):
    if X_val is None or X_val.shape[0] == 0:
        print("Cannot plot PCK curve: validation set is empty.")
        return

    thresholds = np.linspace(0.01, 0.5, num=20) # Range of thresholds
    pck_values = []

    print(f"\n--- Calculating PCK for various thresholds to plot the curve (on {X_val.shape[0]} examples)... ---")
    for thresh in thresholds:
        print(f"  Calculating PCK for threshold: {thresh:.3f}")
        current_pck = calculate_pck(model, X_val, y_val, threshold_factor=thresh, target_size=target_size, num_keypoints=num_keypoints)
        pck_values.append(current_pck)
        print(f"  PCK@{thresh:.3f}: {current_pck:.4f}")

    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, pck_values, marker='o', linestyle='-')
    plt.title('PCK Curve (Percentage of Correct Keypoints)')
    plt.xlabel(f'Threshold (fraction of the diagonal of the target size {target_size})')
    plt.ylabel('PCK@threshold')
    plt.grid(True)
    plt.ylim(0, 1.05) # PCK is in the range [0, 1]
    # Configure the number of ticks on the X-axis for better readability
    num_xticks = 10
    xtick_indices = np.linspace(0, len(thresholds) - 1, num_xticks, dtype=int)
    plt.xticks(thresholds[xtick_indices], [f"{thresholds[i]:.2f}" for i in xtick_indices], rotation=45)

    plt.tight_layout()
    plt.show()
    print("PCK curve displayed.")


if best_model is not None and X_val is not None and X_val.shape[0] > 0:
    print("\nEvaluating the loaded model (single threshold)...")
    # Updated call to calculate_pck
    default_threshold = 0.2
    pck_score = calculate_pck(best_model, X_val, y_val, threshold_factor=default_threshold)
    print(f"--- PCK Result for threshold {default_threshold} ---")
    print(f"  PCK@{default_threshold}: {pck_score:.4f} (Calculated on {X_val.shape[0]} validation examples)")

    # Call the new function to visualize the PCK curve
    plot_pck_curve(best_model, X_val, y_val)

else:
    if best_model is None:
        print("\nCannot evaluate model and plot PCK curve: model not loaded.")
    if X_val is None or X_val.shape[0] == 0:
        print("\nCannot evaluate model and plot PCK curve: validation set is empty or was not created.")


# --- FINE-TUNING - optional ---
# (The fine_tune_model function code would go here, as in the previous response)
# ... (fine_tune_model function code) ...

# --- COMPLETION ---
print("\n--- Script execution finished ---")