In [1]:
import kafka
from kafka import KafkaProducer, KafkaConsumer
from kafka.errors import kafka_errors
import traceback
import json
import sys
import socket
import cv2
import os
import base64
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display
import numpy as np
from bigdl.serving.schema import *
import time

In [2]:
host_name = socket.gethostname()
port = "9092"

In [3]:
class OutputQueue:
    def __init__(self, host=None, port=None, group_id='group-1',
                 auto_offset_reset='earliest', **kwargs):
        host = host if host else "localhost"
        port = port if port else "9092"
        self.topic_name = kwargs.get("topic_name") if kwargs.get("topic_name") else "cluster-serving_serving_stream"
        
        for key in ["host", "port", "topic_name"]:
            if key in kwargs:
                kwargs.pop(key)
        # create a kafka consumer    
        self.db = KafkaConsumer(self.topic_name, bootstrap_servers=host+":"+port, 
                                group_id=group_id, auto_offset_reset=auto_offset_reset, **kwargs)
        
    def dequeue(self):
        records = self.db.poll(timeout_ms = 500)
        self.db.commit()
        decoded = {}
        for tp, messages in records.items():
                for message in messages:
                    res_id = message.key.decode()
                    print(res_id)
                    res_value = message.value.decode()
                    decoded[res_id] = self.get_ndarray_from_b64(res_value)
        return decoded
    
    def get_ndarray_from_b64(self, b64str):
        b = base64.b64decode(b64str)
        a = pa.BufferReader(b)
        c = a.read_buffer()
        myreader = pa.ipc.open_stream(c)
        r = [i for i in myreader]
        assert len(r) > 0
        if len(r) == 1:
            return self.get_ndarray_from_record_batch(r[0])
        else:
            l = []
            for ele in r:
                l.append(self.get_ndarray_from_record_batch(ele))
            return l

    def get_ndarray_from_record_batch(self, record_batch):
        data = record_batch[0].to_numpy()
        shape_list = record_batch[1].to_pylist()
        shape = [i for i in shape_list if i]
        ndarray = data.reshape(shape)
        return ndarray
    
    def close(self):
        self.db.close()

In [4]:
def get_result(timeout=5):
    start_time = time.time()
    end_time = start_time
    output_api = OutputQueue(host=host_name, port=port)
    results_queue = {}
    total_num = 0
    while True:
        result = output_api.dequeue()
        total_num += len(result)
        if not result:
            end_time = time.time()
            if (end_time - start_time > timeout):
                if len(results_queue) == 0:
                    break
                else:
                    #show_images(results_queue) #显示结果
                    results_queue = {}
        else:
            start_time = time.time()
            end_time = start_time
            #result = process_result(result) #后处理
            print("获取结果:", result)
            results_queue.update(result)
            if len(results_queue) > 6:       #如果已有超过6个结果就会显示
                #show_images(results_queue)  #显示结果
                results_queue = {}
                
    output_api.close()
    print("Task completed successfully, total num is : {}".format(total_num))

In [5]:
get_result()

Task completed successfully, total num is : 0
