In [2]:
! pip install requests tensorflow_hub tensorflow_hub websockets pillow

Collecting websockets
  Downloading websockets-13.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading websockets-13.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (157 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m157.3/157.3 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: websockets
Successfully installed websockets-13.0.1


In [None]:
import argparse
import os
import tarfile
import json
import tempfile

import requests
import tensorflow as tf
import time
from PIL import Image
import requests
import tensorflow_hub as hub
import websockets.sync.client as websocket
from tqdm import tqdm




def run_detector(img, detector):
    img_tensor = tf.convert_to_tensor(img)
    img_tensor = tf.image.convert_image_dtype(img_tensor, tf.float32)[tf.newaxis, ...]
    result = detector(img_tensor)
    result = {key: value.numpy() for key, value in result.items()}
    return result


def parse(result):
    boxes = result["detection_boxes"]
    classes = result["detection_class_entities"]
    scores = result["detection_scores"]

    objects = []
    for i in range(len(boxes)):
        if scores[i] < 0.5:
            continue

        objects.append(
            {
                "box": boxes[i].tolist(),
                "class": classes[i].decode("utf-8"),
                "score": float(scores[i]),
            }
        )

    return objects


def process(ws: websocket.ClientConnection, url: str, detector: str):
    def send(data):
        ws.send(json.dumps(data))

    send(["process", None])
    item = ws.recv()

    with tempfile.SpooledTemporaryFile() as file:
        response = requests.get(f'{url if "http" in url else f"http://{url}"}/file/keyframes/{item}')
        file.write(response.content)

        image = Image.open(file).convert("RGB")
        start = time.time()
        result = run_detector(image, detector)
        print(f"Processed {item} in {(time.time() - start)}s")
        objects = parse(result)

        send(["finish", objects])

def download_model(model):
    response = requests.get(f"https://tfhub.dev/{model}?tf-hub-format=compressed", stream=True)
    total_size = int(response.headers.get("content-length", 0))
    chunk_size = 1024 * 1024

    with open("model.tar.gz", "wb") as file:
        for data in tqdm(response.iter_content(chunk_size=chunk_size),
                         total=total_size // chunk_size,
                         unit="MB",
                         unit_scale=True,
                         ascii=True,
                         desc="Downloading model"):
            file.write(data)

    with tarfile.open("model.tar.gz", "r:gz") as file:
        file.extractall("model")
    os.remove("model.tar.gz")
    print(f"Downloaded model {model}")

    return


def main():
    # parser = argparse.ArgumentParser(description='AIC Swarm Manager')
    # parser.add_argument('--url',
    #                     type=str,
    #                     default='localhost:8000',
    #                     help='The URL to connect to')
    # parser.add_argument("--model",
    #                     type=str,
    #                     default="google/faster_rcnn/openimages_v4/inception_resnet_v2/1",
    #                     help="The model to use")

    # args = parser.parse_args()
    url = "35.187.254.228:8000"
    model = "google/faster_rcnn/openimages_v4/inception_resnet_v2/1"
    ws_url = f"ws://{url}/session"

    if not url or not model:
        raise ValueError("URL and model must be provided")

    if not os.path.exists("model"):
        download_model(model)

    model = hub.load("model")
    detector = model.signatures["default"]
    print(f"Loaded model {model}")

    processed = 0
    with websocket.connect(ws_url) as ws:
        print("Connected to server")
        while True:
            try:
                process(ws, url, detector)
                processed += 1
            except Exception as error:
                print(f"Processed {processed} items, good job :D")
                print("Find error, raising....")
                raise error


if __name__ == "__main__":
    main()

Downloading model: 230MB [00:32, 7.13MB/s]                         


Downloaded model google/faster_rcnn/openimages_v4/inception_resnet_v2/1
Loaded model <tensorflow.python.trackable.autotrackable.AutoTrackable object at 0x784bd0994430>
Connected to server
Processed L12/L12_V005/034.jpg in 54.605309009552s
Processed L12/L12_V005/054.jpg in 1.2488927841186523s
Processed L12/L12_V005/055.jpg in 1.2894234657287598s
Processed L12/L12_V005/058.jpg in 1.4598002433776855s
Processed L12/L12_V005/059.jpg in 1.5131864547729492s
Processed L12/L12_V005/061.jpg in 1.2795250415802002s
Processed L12/L12_V005/062.jpg in 1.278784990310669s
Processed L12/L12_V005/063.jpg in 1.2688732147216797s
Processed L12/L12_V005/064.jpg in 1.267535924911499s
Processed L12/L12_V005/066.jpg in 1.261793613433838s
Processed L12/L12_V005/069.jpg in 1.2511029243469238s
Processed L12/L12_V005/070.jpg in 1.2734308242797852s
Processed L12/L12_V005/071.jpg in 1.2715458869934082s
Processed L12/L12_V005/072.jpg in 1.4838438034057617s
Processed L12/L12_V005/073.jpg in 1.2645843029022217s
Processe