In [1]:
%pwd

'/home/ubuntu/experiments/object-counter'

In [3]:
#Entrypoint webapp.py

from io import BytesIO
from flask import request, jsonify
from counter import config

count_action = config.get_count_action()

def object_detection():
    uploaded_file = request.files['file']
    threshold = float(request.form.get('threshold', 0.5))
    image = BytesIO()
    uploaded_file.save(image)
    count_response = count_action.execute(image, threshold)
    return jsonify(count_response)

In [None]:
# config file - config.py
import os

from counter.adapters.count_repo import CountMongoDBRepo
from counter.adapters.object_detector import TFSObjectDetector
from counter.domain.actions import CountDetectedObjects

def get_count_action() -> CountDetectedObjects:
    env = os.environ.get('ENV', 'dev')
    count_action_fn = f"{env}_count_action"
    return globals()[count_action_fn]()

def prod_count_action() -> CountDetectedObjects:
    tfs_host = os.environ.get('TFS_HOST', 'localhost')
    tfs_port = os.environ.get('TFS_PORT', 8501)
    mongo_host = os.environ.get('MONGO_HOST', 'localhost')
    mongo_port = os.environ.get('MONGO_PORT', 27017)
    mongo_db = os.environ.get('MONGO_DB', 'prod_counter')
    return CountDetectedObjects(TFSObjectDetector(tfs_host, tfs_port, 'rfcn'),
                                CountMongoDBRepo(host=mongo_host, port=mongo_port, database=mongo_db))

In [None]:
#From counter.adapter.object_detector.py

import json
from typing import List, BinaryIO

import numpy as np
import requests
from PIL import Image

from counter.domain.models import Prediction, Box
from counter.domain.ports import ObjectDetector

class TFSObjectDetector(ObjectDetector):
    def __init__(self, host, port, model):
        self.url = f"http://{host}:{port}/v1/models/{model}:predict"
        self.classes_dict = self.__build_classes_dict()

    def predict(self, image: BinaryIO) -> List[Prediction]:
        np_image = self.__to_np_array(image)
        predict_request = '{"instances" : %s}' % np.expand_dims(np_image, 0).tolist()
        response = requests.post(self.url, data=predict_request)
        predictions = response.json()['predictions'][0]
        return self.__raw_predictions_to_domain(predictions)

    @staticmethod
    def __build_classes_dict():
        with open('counter/adapters/mscoco_label_map.json') as json_file:
            labels = json.load(json_file)
            return {label['id']: label['display_name'] for label in labels}

    @staticmethod
    def __to_np_array(image: BinaryIO):
        image_ = Image.open(image)
        (im_width, im_height) = image_.size
        return np.array(image_.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)

    def __raw_predictions_to_domain(self, raw_predictions: dict) -> List[Prediction]:
        num_detections = int(raw_predictions.get('num_detections'))
        predictions = []
        for i in range(0, num_detections):
            detection_box = raw_predictions['detection_boxes'][i]
            box = Box(xmin=detection_box[1], ymin=detection_box[0], xmax=detection_box[3], ymax=detection_box[2])
            detection_score = raw_predictions['detection_scores'][i]
            detection_class = raw_predictions['detection_classes'][i]
            class_name = self.classes_dict[detection_class]
            predictions.append(Prediction(class_name=class_name, score=detection_score, box=box))
        return predictions


In [None]:
from typing import List

from pymongo import MongoClient

from counter.domain.models import ObjectCount
from counter.domain.ports import ObjectCountRepo

class CountMongoDBRepo(ObjectCountRepo):

    def __init__(self, host, port, database):
        self.__host = host
        self.__port = port
        self.__database = database

    def __get_counter_col(self):
        client = MongoClient(self.__host, self.__port)
        db = client[self.__database]
        counter_col = db.counter
        return counter_col

    def read_values(self, object_classes: List[str] = None) -> List[ObjectCount]:
        counter_col = self.__get_counter_col()
        query = {"object_class": {"$in": object_classes}} if object_classes else None
        counters = counter_col.find(query)
        object_counts = []
        for counter in counters:
            object_counts.append(ObjectCount(counter['object_class'], counter['count']))
        return object_counts

    def update_values(self, new_values: List[ObjectCount]):
        counter_col = self.__get_counter_col()
        for value in new_values:
            counter_col.update_one({'object_class': value.object_class}, {'$inc': {'count': value.count}}, upsert=True)


In [None]:
#From counter.domain.actions.py

from PIL import Image

from counter.debug import draw
from counter.domain.models import CountResponse
from counter.domain.ports import ObjectDetector, ObjectCountRepo
from counter.domain.predictions import over_threshold, count

class CountDetectedObjects:
    def __init__(self, object_detector: ObjectDetector, object_count_repo: ObjectCountRepo):
        self.__object_detector = object_detector
        self.__object_count_repo = object_count_repo

    def execute(self, image, threshold) -> CountResponse:
        predictions = self.__find_valid_predictions(image, threshold)
        object_counts = count(predictions)
        self.__object_count_repo.update_values(object_counts)
        total_objects = self.__object_count_repo.read_values()
        return CountResponse(current_objects=object_counts, total_objects=total_objects)

    def __find_valid_predictions(self, image, threshold):
        predictions = self.__object_detector.predict(image)
        self.__debug_image(image, predictions, "all_predictions.jpg")
        valid_predictions = list(over_threshold(predictions, threshold=threshold))
        self.__debug_image(image, valid_predictions, f"valid_predictions_with_threshold_{threshold}.jpg")
        return valid_predictions

    @staticmethod
    def __debug_image(image, predictions, image_name):
        if __debug__ and image is not None:
            image = Image.open(image)
            draw(predictions, image, image_name)