In [7]:
import os
import random
import math
from torchvision.io.image import decode_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
import logging

def batched(array, size):
    array_iter = iter(array) 
    while True: 
        b = []
        try: [b.append(next(array_iter)) for _ in range(size)]
        except: StopIteration
        if b: yield b 
        else: break

def annotate_and_save(img, prediction, save_file, categories):
    labels = [categories[i] for i in prediction["labels"]]
    box = draw_bounding_boxes(
        image=img, 
        boxes=prediction['boxes'], 
        labels=labels, 
        colors="red", 
        width=4, 
        font="Helvetica.ttf" if os.path.exists("Helvetica.ttf") else None, 
        font_size=30
    )
    im = to_pil_image(box.detach())
    im.save(save_file)

def detect_annotate_save(images, save_files, preprocess, model, categories): 
    processed_images = [preprocess(image) for image in images]
    predictions = model(processed_images) 
    for image, prediction, fname in zip(images, predictions, save_files): 
        annotate_and_save(image, prediction, fname, categories)

def detection2D(files, output_dir='output', batch_size=5, ):
    logger = logging.getLogger("detection2D")
    if not os.path.isdir(output_dir): 
        logger.warning(f"Output directory '{output_dir}' was not found. Creating directory '{output_dir}'")
        os.mkdir(output_dir)
    fnames = [os.path.basename(f) for f in files]
    save_files = [os.path.join(output_dir, p) for p in fnames]
    imgs = [(decode_image(f), sf) for f, sf in zip(files, save_files)]
    batches = batched(imgs, batch_size) 
    batches = (tuple(zip(*b)) for b in batches)
    logger.info("Loading model")
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
    model.eval()
    preprocess = weights.transforms()
    logger.info("Starting detection")
    for batch in tqdm(batches, total=math.ceil(len(imgs)/batch_size)):
        images, save_names = batch
        detect_annotate_save(images, save_names, preprocess, model, weights.meta["categories"])
    
    

In [4]:
files = [os.path.join("pics", f) for f in os.listdir("pics")] 
logging.basicConfig(level=logging.INFO)
detection2D(files,output_dir="output2",batch_size=2) 

INFO:detection2D:Loading model
INFO:detection2D:Starting detection
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:10<00:00, 10.45s/it]
