Skip to content

Licht-T/tf-centernet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tf-centernet

PyPI version Upload Python Package

CenterNet implementation with Tensorflow 2.

Install

pip instal tf-centernet

Example

Object detection

import numpy as np
import PIL.Image
import centernet

# Default: num_classes=80
obj = centernet.ObjectDetection(num_classes=80)

# Default: weights_path=None
# num_classes=80 and weights_path=None: Pre-trained COCO model will be loaded.
# Otherwise: User-defined weight file will be loaded.
obj.load_weights(weights_path=None)

img = np.array(PIL.Image.open('./data/sf.jpg'))[..., ::-1]

# The image with predicted bounding-boxes is created if `debug=True`
boxes, classes, scores = obj.predict(img, debug=True)

output_obj

Pose estimation

import numpy as np
import PIL.Image
import centernet

# Default: num_joints=17
pe = centernet.PoseEstimation(num_joints=17)

# Default: weights_path=None
# num_joints=17 and weights_path=None: Pre-trained COCO model will be loaded.
# Otherwise: User-defined weight file will be loaded.
pe.load_weights(weights_path=None)

# Adjust this for the better prediction
pe.score_threshold = 0.1

img = np.array(PIL.Image.open('./data/chi.jpg'))[..., ::-1]

# The image with predicted keypoints is created if `debug=True`
boxes, keypoints, scores = pe.predict(img, debug=True)

output_pose

TODO

  • Object detection
  • Pre-trained model for object detection with Hourglass-104
  • Pose estimation
  • Pre-trained model for pose estimation with Hourglass-104
  • DLA-34 backbone and pre-trained models
  • Training function and Loss definition
  • Training data augmentation