In [None]:
import os
import pathlib
import matplotlib
import matplotlib.pyplot as plt
import io
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from six.moves.urllib.request import urlopen
import tensorflow as tf
import tensorflow_hub as hub

tf.get_logger().setLevel('ERROR')

In [None]:
def load_image_into_numpy_array(path):
    """Load an image from file into a numpy array.
    Puts image into numpy array to feed into tensorflow graph.
    Note that by convention we put it into a numpy array with shape(height, width, channels), where channels=3 for RGB
    Args:
        Path: the file path to the image
    Returns:
        uint8 numpy array with shape(img_height, img_width,3)"""
    image = None
    if path.startswith('http'):
        response = urlopen(path)
        image_data = response.read()
        image_data = BytesIO(image_data)
        image = Image.open(BytesIO(image_data))
    else:
        image_data = tf.io.gfile.GFile(path, 'rb').read()
        image = Image.open(BytesIO(image_data))

    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (1, im_height, im_width, 3)).astype(np.uint8)

ALL_MODELS = {
    'CenterNet HourGlass104 512x512': 'https://tfhub.dev/tensorflow/centernet/hourglass_512x513/1',
    'CenterNet HourGlass104 Keypoints 512x512': 'https://tfhub.dev/tensorflow/centernet/hourglass_512x513/1',
    'CenterNet HourGlass104 1024x1024': 'https://tfhub.dev/tensorflow/centernet/hourglass_1024x1024/1',
    'CenterNet HourGlass104 Keypoints 1024x1024': 'https://tfhub.dev/tensorflow/centernet/hourglass_1024x1024_kpts/1',
    'CenterNet Resnet50 V1 FPN 512x512': 'https://tfhub.dev/tensorflow/centernet/resnet50v1_fpn_512x512/1',
    'CenterNet Resnet50 V1 FPN keypoints 512x512': 'https://tfhub.dev/tensorflow/centernet/resnet101v1_fpn_512x512_kpts/1',
    'CenterNet Resnet101 v1 FPN 512x512': 'https://tfhub.dev/tensorflow/centernet/resnet101v1_fpn_512x512/1',
    'CenterNet Resnet50 V2 512x512': 'https://tfhub.dev/tensorflow/centernet/resnet50v2_512x512/1',
    'CenterNet Resnet101 V2 keypoints 512x512': 'https://tfhub.dev/tensorflow/centernet/resnet101v2_512x512_kpts/1',
    'EfficientDet D0 512x512': 'https://tfhub.dev/tensorflow/efficientdet/d0/1',
    'EfficientDet D1 640x640': 'https://tfhub.dev/tensorflow/efficientdet/d1/1',
    'EfficientDet D2 768x768': 'https://tfhub.dev/tensorflow/efficientdet/d2/1',
    'EfficientDet D3 896x896': 'https://tfhub.dev/tensorflow/efficientdet/d3/1',
    'EfficientDet D4 1024x1024': 'https://tfhub.dev/tensorflow/efficientdet/d4/1',
    'EfficientDet D5 1280x1280': 'https://tfhub.dev/tensorflow/efficientdet/d5/1',
    'EfficientDet D6 1280x1280': 'https://tfhub.dev/tensorflow/efficientdet/d6/1',
    'EfficientDet D7 1536x1536': 'https://tfhub.dev/tensorflow/efficientdet/d7/1',
    'SSD MobileNet v2 320x320': 'https://tfhub.dev/tensorflow/ssd_mobilenet_v2/2',
    'SSD MobileNet V1 FPN 640x640': 'https://tfhub.dev/tensorflow/ssd_mobilenet_v1/fpn_640x640/1',
    'SSD MobileNet V2 FPNLite 320x320': 'https://tfhub.dev/tensorflow/ssd_mobilenet_v2/fpnlite_320x320/1',
    'SSD MobileNet V2 FPNLite 640x640': 'https://tfhub.dev/tensorflow/ssd_mobilenet_v2/fpnlite_640x640/1',
    'SSD ResNet50 v1 FPN 640x640 (RetinaNet50)': 'https://tfhub.dev/tensorflow/retinanet/resnet50_v1_fpn_640x640/1',
    'SSD ResNet50 V1 FPN 1024x1024 (RetinaNet50)': 'https://tfhub.dev/tensorflow/retinanet/resnet50_v1_fpn_1024x1024/1',
    'SSD ResNet101 V1 FPN 640x640 (RetinaNet101)': 'https://tfhub.dev/tensorflow/retinanet/resnet101_v1_fpn_640x640/1',
    'SSD ResNet101 V1 FPN 1024x1024 (RetinaNet101)': 'https://tfhub.dev/tensorflow/retinanet/resnet101_v1_fpn_1024x1024/1',
    'SSD ResNet152 V1 FPN 640x640 (RetinaNet152)': 'https://tfhub.dev/tensorflow/retinanet/resnet152_v1_fpn_640x640/1',
    'SSD ResNet152 V1 FPN 1024x1024 (RetinaNet152)': 'https://tfhub.dev/tensorflow/retinanet/resnet152_v1_fpn_1024x1024/1',
    'Faster R-CNN ResNet50 V1 640x640': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet50_v1_640x640/1',
    'Faster R-CNN ResNet50 V1 1024x1024': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet50_v1_1024x1024/1',
    'Faster R-CNN ResNet50 V1 800x1333': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet50_v1_800x1333/1',
    'Faster R-CNN ResNet101 V1 640x640': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet101_v1_640x640/1',
    'Faster R-CNN ResNet101 V1 1024x1024': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet101_v1_1024x1024/1',
    'Faster R-CNN ResNet101 V1 800x1333': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet101_v1_800x1333/1',
    'Faster R-CNN ResNet152 V1 640x640': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet152_v1_640x640/1',
    'Faster R-CNN ResNet152 V1 1024x1024': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet152_v1_1024x1024/1',
    'Faster R-CNN ResNet152 V1 800x1333': 'https://tfhub.dev/tensorflow/faster_rcnn/resnet152_v1_800x1333/1',
    'Faster R-CNN Inception ResNet V2 640x640': 'https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1',
    'Faster R-CNN Inception ResNet V2 1024x1024': 'https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_1024x1024/1',
    'Mask R-CNN Inception ResNet V2 1024x1024': 'https://tfhub.dev/tensorflow/mask_rcnn/inception_resnet_v2_1024x1024/1'
}

IMAGES_FOR_TEST = {
    'Beach': 'models/research/object_detection/test_images/image2.jpg',
    'Dogs': 'models/research/object_detection/test_images/image1.jpg',
    'Naxos Taverna': 'https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg',
    'Beatles': 'https://upload.wikimedia.org/wikipedia/commons/1/1b/The_Coleoptera_of_the_British_islands_%28Plate_125%29_%288592917784%29.jpg',
    'Phones': 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0d/Google_I_%26_O_2015_phone.jpg/640px-Google_I_%26_O_2015_phone.jpg',
    'Birds': 'https://upload.wikimedia.org/wikipedia/commons/0/09/The_smaller_British_birds_%288053893633%29.jpg'
}

COCO17_HUMAN_POSE_KEYPOINTS = [
    (0, 1),   # Nose to Left eye
    (0, 2),   # Nose to Right eye
    (1, 3),   # Left eye to Left ear
    (2, 4),   # Right eye to Right ear
    (0, 5),   # Nose to Left shoulder
    (0, 6),   # Nose to Right shoulder
    (5, 7),   # Left shoulder to Left elbow
    (7, 9),   # Left elbow to Left wrist
    (6, 8),   # Right shoulder to Right elbow
    (8, 10),  # Right elbow to Right wrist
    (5, 6),   # Left shoulder to Right shoulder
    (5, 11),  # Left shoulder to Left hip
    (6, 12),  # Right shoulder to Right hip
    (11, 12), # Left hip to Right hip
    (11, 13), # Left hip to Left knee
    (13, 15), # Left knee to Left ankle
    (12, 14), # Right hip to Right knee
    (14, 16)  # Right knee to Right ankle
]



Visualization tools

# To visualize the imagges with the proper detected boxes, keypoints and segmentation, we will use the TensorFlow Object Detection API. To install it we will clone the repo

In [None]:
!git clone --depth 1 https://github.com/tensorflow/models

Cloning into 'models'...
remote: Enumerating objects: 4327, done.[K
remote: Counting objects: 100% (4327/4327), done.[K
remote: Compressing objects: 100% (3346/3346), done.[K
remote: Total 4327 (delta 1208), reused 2046 (delta 909), pack-reused 0 (from 0)[K
Receiving objects: 100% (4327/4327), 53.65 MiB | 24.30 MiB/s, done.
Resolving deltas: 100% (1208/1208), done.


Installing the Object Detection API

In [None]:
# 1. Clean up existing installations
!pip uninstall -y tensorflow protobuf
!rm -rf models  # Remove previous clone if exists

# 2. Install EXACT versions that work together
!pip install tensorflow==2.12.0
!pip install protobuf==3.20.3

# 3. Clone TF Models fresh (critical!)
!git clone --depth 1 https://github.com/tensorflow/models

# 4. COMPILE PROTOBUFS CORRECTLY (most important step)
%cd models/research/
!protoc object_detection/protos/*.proto --python_out=.

# 5. Install Object Detection API
!cp object_detection/packages/tf2/setup.py .
!pip install .

# 6. Add to Python path
import sys
sys.path.append('/content/models/research')
sys.path.append('/content/models/research/slim')

# 7. RESTART RUNTIME NOW (Colab: Runtime > Restart runtime)
# Then run ONLY the imports below after restart

Found existing installation: tensorflow 2.18.0
Uninstalling tensorflow-2.18.0:
  Successfully uninstalled tensorflow-2.18.0
Found existing installation: protobuf 5.29.4
Uninstalling protobuf-5.29.4:
  Successfully uninstalled protobuf-5.29.4
Collecting tensorflow==2.12.0
  Downloading tensorflow-2.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Collecting gast<=0.4.0,>=0.2.1 (from tensorflow==2.12.0)
  Downloading gast-0.4.0-py3-none-any.whl.metadata (1.1 kB)
Collecting keras<2.13,>=2.12.0 (from tensorflow==2.12.0)
  Downloading keras-2.12.0-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting numpy<1.24,>=1.22 (from tensorflow==2.12.0)
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 (from tensorflow==2.12.0)
  Downloading protobuf-4.25.7-cp37-abi3-manylinux2014_x86_64.whl.metadata (541 bytes)
Collecting tenso

Collecting protobuf==3.20.3
  Downloading protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Downloading protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 4.25.7
    Uninstalling protobuf-4.25.7:
      Successfully uninstalled protobuf-4.25.7
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-text 2.18.1 requires tensorflow<2.19,>=2.18.0, but you have tensorflow 2.12.0 which is incompatible.
tensorflow-decision-forests 1.11.0 requires tensorflow==2.18.0, but you have tensorflow 2.12.0 which is incompatible.
bigframes 2.1.0 requires numpy>=1.24.0, but you have numpy 1.23.5 which is incompatible.
orbax-checkpoi

Cloning into 'models'...
remote: Enumerating objects: 4327, done.[K
remote: Counting objects: 100% (4327/4327), done.[K
remote: Compressing objects: 100% (3345/3345), done.[K
remote: Total 4327 (delta 1208), reused 2051 (delta 910), pack-reused 0 (from 0)[K
Receiving objects: 100% (4327/4327), 53.65 MiB | 18.27 MiB/s, done.
Resolving deltas: 100% (1208/1208), done.
/content/models/research
Processing /content/models/research
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting avro-python3 (from object_detection==0.1)
  Downloading avro-python3-1.10.2.tar.gz (38 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting apache-beam (from object_detection==0.1)
  Downloading apache_beam-2.64.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.4 kB)
Collecting contextlib2 (from object_detection==0.1)
  Downloading contextlib2-21.6.0-py2.py3-none-any.whl.metadata (4.1 kB)
Collecting lvis (from object_detection==0.1)
  Downloading lvis-0.5.3-py3

Now we can import the dependencies we will need later

In [None]:
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_utils
from object_detection.utils import ops as utils_ops

print("Imports successful! 🎉")

TypeError: Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

Load label map data(for plotting)

 Label maps correspond index numbers to category names, so that when our convolution network predict 5, We know that this correspond to airplane. Here we use internal utility function but anything that returns a dictionary mapping integers to appropriate string labels would be fine We are going, for simplicity, to load from the repository that we loaded the object detectio API code.

In [None]:
import os
from object_detection.utils import label_map_util

# 1. First ensure the label file exists
if not os.path.exists(PATH_TO_LABELS):
    !wget https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt -P ./models/research/object_detection/data/

# 2. Use the correct function
PATH_TO_LABELS = './models/research/object_detection/data/mscoco_label_map.pbtxt'
try:
    # Try newer version first
    category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
except AttributeError:
    # Fallback to older version
    category_index = label_map_util.get_label_map_dict(PATH_TO_LABELS, use_display_name=True)

print("Successfully loaded categories:", len(category_index))

TypeError: Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

In [None]:
# Define the ALL_MODELS dictionary first
ALL_MODELS = {
    'CenterNet HourGlass104 512x512': 'https://tfhub.dev/tensorflow/centernet/hourglass_512x513/1',
    'CenterNet HourGlass104 Keypoints 512x512': 'https://tfhub.dev/tensorflow/centernet/hourglass_512x513_kpts/1',
    'EfficientDet D0 512x512': 'https://tfhub.dev/tensorflow/efficientdet/d0/1',
    'EfficientDet D1 640x640': 'https://tfhub.dev/tensorflow/efficientdet/d1/1',
    'EfficientDet D2 768x768': 'https://tfhub.dev/tensorflow/efficientdet/d2/1',
    'EfficientDet D3 896x896': 'https://tfhub.dev/tensorflow/efficientdet/d3/1',
    'EfficientDet D4 1024x1024': 'https://tfhub.dev/tensorflow/efficientdet/d4/1',
    'EfficientDet D5 1280x1280': 'https://tfhub.dev/tensorflow/efficientdet/d5/1',
    'EfficientDet D6 1280x1280': 'https://tfhub.dev/tensorflow/efficientdet/d6/1',
    'EfficientDet D7 1536x1536': 'https://tfhub.dev/tensorflow/efficientdet/d7/1',
    'SSD MobileNet v2 320x320': 'https://tfhub.dev/tensorflow/ssd_mobilenet_v2/2'
}

# Now you can use it
model_display_name = 'EfficientDet D7 1536x1536'
model_handle = ALL_MODELS[model_display_name]

print('Selected model: ' + model_display_name)
print('Model handle at TensorFlow Hub: {}'.format(model_handle))

Selected model: EfficientDet D7 1536x1536
Model handle at TensorFlow Hub: https://tfhub.dev/tensorflow/efficientdet/d7/1


Build a detection model and load pre-trained model weights

Here we will choose which object Detection model we will use. Select the architecture and it will be loaded automatically. If you want to change the model to try other architectures later just change the next cell and execute following ones.

Tip: if you want to read details about the selected model, you can follow the link( model hadle) and read additional documentation on TF Hub. After you elect a model, we will print the handle to make it eadsier.

In [None]:
# Model selection
model_display_name = 'EfficientDet D7 1536x1536'
model_handle = ALL_MODELS[model_display_name]
print('Selected model: ' + model_display_name)
print('Model handle at TensorFlow Hub: {}'.format(model_handle))

Selected model: EfficientDet D7 1536x1536
Model handle at TensorFlow Hub: https://tfhub.dev/tensorflow/efficientdet/d7/1


Loading the Selected model from TensorFlow Hub

Here we just need the model handle that was selected and use the Tensorflow Hub library to load it to memory.

In [None]:
print('Loading model...')
hub_model = hub.load(model_handle)
print('Model loaded!')

Loading model...




Model loaded!


Loading an image

Let's try the modle on a simple image. To help with this, we provide a list of test images.

Here are some simple things to try out if you are curious:

Try running inference on your own images, just upload them to colab and load the same way it's done in the cell below
Modify some of the input images and see if detection still works. Soe simple things to try out here include flipping the image horizontally, or converting to grayscale( Note that we still expect the imput image to have 3 channels).
Be careful: When using images with an alpha channel, the model expect 3 channels images and the alpha will count as a 4th.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
import os
import urllib.request

def load_wikimedia_image(url):
    """Load an image from Wikimedia Commons with proper headers"""
    try:
        headers = {
            'User-Agent': 'Python-Requests/2.31.0 (CommonsImageLoader/1.0; +https://your-site.org)'
        }
        response = requests.get(url, headers=headers, timeout=10)
        response.raise_for_status()
        return Image.open(BytesIO(response.content))
    except Exception as e:
        print(f"Error loading Wikimedia image: {e}")
        return None

def load_local_image(path):
    """Load an image from local file system"""
    try:
        with open(path, 'rb') as f:
            return Image.open(BytesIO(f.read()))
    except Exception as e:
        print(f"Error loading local image: {e}")
        return None

def load_image_into_numpy_array(path):
    """Universal image loader that handles all cases"""
    if path.startswith('https://upload.wikimedia.org'):
        image = load_wikimedia_image(path)
    elif path.startswith('http'):
        try:
            response = requests.get(path, timeout=10)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content))
        except Exception as e:
            print(f"Error loading image: {e}")
            return None
    else:
        # Local file path
        image = load_local_image(path)

    if image:
        return np.array(image).reshape(1, *image.size[::-1], 3)
    return None

# Test images - mixed local and Wikimedia
IMAGES_FOR_TEST = {
    'Beach' : 'models/research/object_detection/test_images/image2.jpg',
    'Dogs' : 'models/research/object_detection/test_images/image1.jpg',
    # By Heiko Gorski, source: https://commons.wikimedia.org/wiki/File:Naxos_Taverna.jpg
    'Naxos Taverna' : 'https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg',
    # Source: https://commons.wikimedia.org/wiki/File:The_Coleoptera_of_the_British_islands_(Plate_125)_(8592917784).jpg
    'Beatles' : 'https://upload.wikimedia.org/wikipedia/commons/1/1b/The_Coleoptera_of_the_British_islands_%28Plate_125%29_%288592917784%29.jpg',
    # By Americo Toledano, Source: https://commons.wikimedia.org/wiki/File:Biblioteca_Maim%C3%B3nides,_Campus_Universitario_de_Rabanales_007.jpg
    'Phones' : 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0d/Google_I_%26_O_2015_phone.jpg/640px-Google_I_%26_O_2015_phone.jpg',
    # Source: https://commons.wikimedia.org/wiki/File:The_Coleoptera_of_the_British_islands_(Plate_125)_(8592917784).jpg
    'Birds' : 'https://upload.wikimedia.org/wikipedia/commons/0/09/The_smaller_British_birds_%288053893633%29.jpg'
}

# Image processing
selected_image = 'Beach'#@param['Beach','Dogs','Naxos Taverna','Betles','Phones','Birds']
flip_image_horizontally = False #@param{type:"boolean"}
convert_image_to_grayscale = False#@param{type: "boolean"}

if selected_image not in IMAGES_FOR_TEST:
    print(f"Error: '{selected_image}' not available")
else:
    image_path = IMAGES_FOR_TEST[selected_image]

    # Check if local file exists
    if not image_path.startswith('http') and not os.path.exists(image_path):
        print(f"Error: Local file '{image_path}' not found")
        print("Please ensure you've cloned the TensorFlow models repository and the files exist")
        print("You can clone it with:")
        print("!git clone --depth 1 https://github.com/tensorflow/models")
    else:
        print(f"Loading {selected_image} from {'local path' if not image_path.startswith('http') else 'URL'}...")
        image_np = load_image_into_numpy_array(image_path)

        if image_np is not None:
            if flip_image_horizontally:
                image_np[0] = np.fliplr(image_np[0]).copy()

            if convert_image_to_grayscale:
                image_np[0] = np.tile(np.mean(image_np[0], 2, keepdims=True), (1, 1, 3)).astype(np.uint8)

            plt.figure(figsize=(24, 32))
            plt.imshow(image_np[0])
            plt.title(f"Image: {selected_image}", pad=20)
            plt.axis('off')
            plt.show()
        else:
            # Fallback blank image
            blank_image = np.zeros((1, 512, 512, 3), dtype=np.uint8)
            plt.figure(figsize=(24, 32))
            plt.imshow(blank_image[0])
            plt.title("Failed to load image", pad=20)
            plt.axis('off')
            plt.show()

Error: Local file 'models/research/object_detection/test_images/image2.jpg' not found
Please ensure you've cloned the TensorFlow models repository and the files exist
You can clone it with:
!git clone --depth 1 https://github.com/tensorflow/models


In [None]:
# Running inference
results = hub_model(image_np)
result = {key: value.numpy() for key, value in results.items()}
print(result.keys())

dict_keys(['detection_anchor_indices', 'detection_boxes', 'detection_classes', 'detection_multiclass_scores', 'detection_scores', 'num_detections', 'raw_detection_boxes', 'raw_detection_scores'])


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from object_detection.utils import visualization_utils as vis_utils

# 1. Define COCO keypoint connections (missing in your code)
COCO17_HUMAN_POSE_KEYPOINTS = [
    (0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6),
    (5, 7), (7, 9), (6, 8), (8, 10), (5, 6), (5, 11),
    (6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)
]

# 2. Ensure you have these required variables:
#    - image_np: Your loaded image array
#    - result: Detection results from your model
#    - category_index: Loaded label map

# 3. Visualization code with error handling
try:
    label_id_offset = 0
    image_np_with_detections = image_np.copy()

    # Use keypoints if available
    keypoints, keypoint_scores = None, None
    if 'detection_keypoints' in result:
        keypoints = result['detection_keypoints'][0]
        keypoint_scores = result['detection_keypoint_scores'][0]

    # Visualization
    vis_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections[0],
        result['detection_boxes'][0],
        (result['detection_classes'][0] + label_id_offset).astype(int),
        result['detection_scores'][0],
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=200,
        min_score_thresh=.30,
        agnostic_mode=False,
        keypoints=keypoints,
        keypoint_scores=keypoint_scores,
        keypoint_edges=COCO17_HUMAN_POSE_KEYPOINTS)

    # Display
    plt.figure(figsize=(24, 32))
    plt.imshow(image_np_with_detections[0])
    plt.axis('off')  # Cleaner display without axes
    plt.show()

except Exception as e:
    print(f"Visualization error: {e}")
    # Fallback: Show original image if detection fails
    plt.figure(figsize=(24, 32))
    plt.imshow(image_np[0])
    plt.title("Detection failed - showing original image")
    plt.axis('off')
    plt.show()

OPTIONAL

In [None]:
image_np_with_mask = image_np.copy()
if 'detection_masks' in results:
  detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
      detection_masks_detection_boxes,
      image_np.shape[1],
      image_np.shape[2])
  detection_masks_reframed = tf.cast(detection_masks_reframed>0.5,tf.uint8)
  result['detection_masks_reframed'] = detection_masks_reframed.numpy()
  vis_utils.visualize_boxes_and_labels_on_image_array(
      image_np_with_mask[0],
      result['detection_boxes'][0],
      (result['detection_classes'][0] + label_id_offset).astype(int),
      result['detection_scores'][0],
      category_index,
      use_normalized_coordinates=True,
      max_boxes_to_draw=200,
      min_score_thresh=.30,
      agnostic_mode=False,
      instance_masks=result.get('detection_masks_reframed', None),
      line_thickness=8)
  plt.figure(figsize=(24, 32))
  plt.imshow(image_np_with_mask[0])
  plt.show()
