In [92]:
from torchvision.io.image import read_image
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models import ResNet50_Weights
import torch
from PIL import Image
from torchvision import transforms

In [134]:
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fasterrcnn_resnet50_fpn(weights=weights, num_classes=len(weights.meta["categories"]), weights_backbone=ResNet50_Weights.DEFAULT).to(device)
model.eval()
preprocess = weights.transforms()

In [138]:
img = Image.open("test.jpeg").convert('RGB')
img_tensor_i = transforms.Compose([
    transforms.PILToTensor()
])(img)
img_tensor_f = transforms.ToTensor()(img)
prep_img = preprocess(img_tensor_f).unsqueeze_(0)
prediction = model(prep_img)[0]
threshold = 0.75
prediction['boxes'] = prediction['boxes'][prediction['scores'] > threshold]
prediction['labels'] = prediction['labels'][prediction['scores'] > threshold]
prediction['scores'] = prediction['scores'][prediction['scores'] > threshold]
prediction

{'boxes': tensor([[692.9408, 286.2021, 935.4335, 815.8149],
         [313.8521, 456.7144, 469.8397, 810.5626],
         [534.9616, 447.1707, 604.5193, 561.5443],
         [646.1714, 398.4750, 714.6689, 717.5959],
         [583.5797, 393.5256, 660.8668, 729.6469],
         [686.7581, 728.0333, 742.6044, 823.1209],
         [313.1302, 494.9341, 391.5665, 631.1080]], grad_fn=<IndexBackward0>),
 'labels': tensor([ 1,  1, 31,  1,  1, 31,  1]),
 'scores': tensor([0.9987, 0.9859, 0.9841, 0.9770, 0.9753, 0.9714, 0.8051],
        grad_fn=<IndexBackward0>)}

In [143]:
labels = ["{}: {:.2f}%".format(weights.meta["categories"][prediction["labels"][i]], round(float(prediction['scores'][i]) * 100, 0)) for i in range(len(prediction["labels"]))]
box = draw_bounding_boxes(image = img_tensor_i, boxes=prediction["boxes"],
                          labels=labels,
                          colors=(0, 255, 42),
                          width=2, 
                          font_size=17,
                          font='Arial')

im = to_pil_image(box.detach())

['person - [692.9407958984375, 286.20208740234375, 935.4334716796875, 815.81494140625] - 100.00%',
 'person - [313.85205078125, 456.7144470214844, 469.8396911621094, 810.5626220703125] - 99.00%',
 'handbag - [534.9616088867188, 447.1706848144531, 604.519287109375, 561.5443115234375] - 98.00%',
 'person - [646.17138671875, 398.4749755859375, 714.6688842773438, 717.595947265625] - 98.00%',
 'person - [583.5797119140625, 393.52557373046875, 660.8667602539062, 729.6469116210938] - 98.00%',
 'handbag - [686.758056640625, 728.0333251953125, 742.6043701171875, 823.1209106445312] - 97.00%',
 'person - [313.13018798828125, 494.93414306640625, 391.56646728515625, 631.1080322265625] - 81.00%']