In [1]:
import rospy
from sensor_msgs.msg import Image
import sys
from cv_bridge import CvBridge

import numpy as np
import pandas as pd

import glob
import cv2
 
import torch
import torchvision
from torchvision import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import TensorDataset
import os

import time

In [2]:
## User configure---------------

class_name = "car"
weight = "model.pt"
test_path = "test"
output_path = "result"
dataset_class=[class_name]
colors = ((0,0,0),(255,0,0))

## -------------------------------

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 

data_class=dataset_class
data_class.insert(0, class_name)
classes = tuple(data_class)

model=torch.load(weight)

model.to(device)

model.eval()

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256)
          (relu): ReLU(inplace=True)
          (downsample)

In [4]:
## Publish setting
pub = rospy.Publisher("output_image", Image, queue_size=1)

In [5]:
def process_image(msg):
    global pub
    try:
        bridge = CvBridge()
        img_bgr = bridge.imgmsg_to_cv2(msg, "bgr8")
        
        img_bgr = cv2.resize(img_bgr,(512*img_bgr.shape[:2][1]//img_bgr.shape[:2][0],512))
        img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        
        image_tensor = torchvision.transforms.functional.to_tensor(img)
    
        with torch.no_grad():
            prediction = model([image_tensor.to(device)])
    
        for i,box in enumerate(prediction[0]['boxes']):
            score = prediction[0]['scores'][i].cpu().numpy()
            if score > 0.5:
                score = round(float(score),2)
                cat = prediction[0]['labels'][i].cpu().numpy()
                cat = 1
                txt = '{} {}'.format(classes[int(cat)], str(score))
                font = cv2.FONT_HERSHEY_SIMPLEX
                cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
                c = colors[int(cat)]
                box=box.cpu().numpy().astype('int')
                cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), c , 2)
                cv2.rectangle(img,(box[0], box[1] - cat_size[1] - 2),(box[0] + cat_size[0], box[1] - 2), c, -1)
                cv2.putText(img, txt, (box[0], box[1] - 2), font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)
                output_img = bridge.cv2_to_imgmsg(img,"rgb8")
                pub.publish(output_img)
    except Exception as err:
        print (err)
                
                

In [None]:
def start_node():
    rospy.init_node('faster_rcnn')
    rospy.Subscriber("camera/color/image_raw", Image, process_image)
    rospy.spin()

In [None]:
try:
    start_node()
except rospy.ROSInterruptException as err:
    print(err)