In [1]:
import numpy as np
import json 
import os   
import random    
import cv2   
import time     
import tensorflow as tf
import falcon   
from falcon_multipart.middleware import MultipartMiddleware   
from PIL import Image   
from darkflow.net.build import TFNet 

from keras.models import model_from_json   
from keras.preprocessing import image  
from keras.preprocessing.image import img_to_array, load_img  
from keras.preprocessing.image import ImageDataGenerator 
from keras.optimizers import SGD     

from keras.applications.vgg16 import VGG16 
from keras.applications.vgg16 import preprocess_input   
from keras.applications.vgg16 import decode_predictions   

from waitress import serve 
import shutil 
import glob


from keras import backend as K     
config = tf.ConfigProto()   
config.gpu_options.allow_growth = True  
sess = tf.Session(config=config)   
K.set_session(sess)

In [2]:
# CORSを許可　
class CORSMiddleware:
    def process_request(self, req, res):
        res.set_header('Access-Control-Allow-Origin', '*')

In [3]:
class PostImage(object):   
    def __init__(self):
        # darkflowを初期化
        options = {
            "model" : "cfg/yolov2-signboard0731.cfg",        
            "load" : "bin/signboard0731/yolov2-signboard0731_10000.weights",
            "threshold" : 0.1,           
            "gpu" : 0.9          
        }
        self.tfnet = TFNet(options)     
        
        # VGG16を初期化
        self.vgg16=VGG16_tabelog()  

        
    def on_post(self, req, res):
        data = req.get_param('file').file.read()
        
        arr = np.asarray(bytearray(data), dtype=np.uint8)  
        img = cv2.imdecode(arr, 1) 
        predict = self.tfnet.return_predict(img)   
        print(predict)
        result = []
            

        for item in predict:
            result.append(self.vgg16.predict(img,item)) 
            
        result_json = json.dumps(result, cls=MyEncoder) 
        print(result_json)
        res.status = falcon.HTTP_200  
        res.body = result_json   
    
    # GETは使わない
    def on_get(self, req, res):
        res.body = '{"message": "画像をPOSTしてください．"}'  
        res.status = falcon.HTTP_200

In [6]:
class VGG16_tabelog: 
    def __init__(self):
        self.batch_size = 32   
        self.file_name = "../signboard_classifier/vgg16_tabelogimg_cv2out_150"  
        self.label = sorted(glob.glob("../signboard_classifier/tabelogimg_cv2out_150/train/*"))
        for i in range(len(self.label)):    
            self.label[i]=self.label[i].split("/")[-1]  
            
        self.zero_list=[0]*len(self.label) 
        self.node = dict(zip(self.label,self.zero_list))  
        self.url = self.label 
        rep_url_list=[] 
        for i in self.url:
            rep_url = i.replace("_","/") 
            rep_url = "https://tabelog.com/"+rep_url 
            rep_url_list.append(rep_url)
        self.url=rep_url_list 
        self.url = dict(zip(self.label,self.url))
        
            
        
        # モデルの読み込む
        json_string = open(self.file_name + ".json").read()
        global model
        model = model_from_json(json_string)  
        model.load_weights(self.file_name + ".h5")  
        model.compile(            
            optimizer = SGD(lr = 0.0001, momentum = 0.9),
            loss = "categorical_crossentropy",
            metrics = ["accuracy"]
        )
        global graph
        graph = tf.get_default_graph()
        
    def predict(self, src, _item): 
        item = _item
        
        tlx = item['topleft']['x']
        tly = item['topleft']['y']
        brx = item['bottomright']['x']
        bry = item['bottomright']['y']
        conf = item['confidence']

        dst = src[tly:bry, tlx:brx]
        temp_img = cv2.resize(dst, (224, 224))
        temp_img = Image.fromarray(temp_img[:, :, ::-1].copy())

        x = img_to_array(temp_img)
        x = x.astype("float32")/255.0
        x = x.reshape((1, 224, 224, 3))

        with graph.as_default():
            img_pred = model.predict(x)
            name = self.label[np.argmax(img_pred)]
            print(name)
            item["name"] = name
            item["node"] = self.node[name]
            item["url"] = self.url[name]
        
        return item

In [4]:
class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):  
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(MyEncoder, self).default(obj)

In [2]:
app = falcon.API(middleware=[CORSMiddleware(), MultipartMiddleware()])  
app.add_route('/gunicorn', PostImage())

In [5]:
serve(app, listen='0.0.0.0:18000')  

In [6]:
#画像のpath
tabelog_img_input_path=""   
tabelog_img_output_path=""
store_list = sorted(glob.glob(tabelog_img_input_path+"/*"))

In [4]:
#看板領域の抽出

for store in range(len(store_list)):
    img_list = sorted(glob.glob(store_list[store]+"/*"))
    for img in range(len(img_list)):
        print(img_list[img])
        imgcv=cv2.imread(img_list[img])
        if type(imgcv)==np.ndarray:
            result=tfnet.return_predict(imgcv)
            #print(result)
            for i in range(len(result)):
                if 0.8<result[i]['confidence']:
                    tlx=result[i]['topleft']['x']
                    tly=result[i]['topleft']['y']
                    brx=result[i]['bottomright']['x']
                    bry=result[i]['bottomright']['y']
            
                    imgcv_out=imgcv[tly:bry,tlx:brx]
                    if not os.path.exists(tabelog_img_output_path+store_list[store].split("/")[-1]+"/"):
                        os.makedirs(tabelog_img_output_path+store_list[store].split("/")[-1]+"/")
                
                    cv2.imwrite(tabelog_img_output_path+store_list[store].split("/")[-1]+"/"+str(i)+img_list[img].split("/")[-1],imgcv_out)

In [11]:
#一定数以下のデータの店舗は捨てる   
import shutil
output_store_list = sorted(glob.glob(tabelog_img_output_path+"/*")) 
for output_store in range(len(output_store_list)):
    if 20>len(sorted(glob.glob(output_store_list[output_store]+"/*"))):
        shutil.rmtree(output_store_list[output_store])   

In [3]:
#データセットの分割

import glob
output_store_list=sorted(glob.glob(tabelog_img_output_path+"/*")) 

for store in range(len(output_store_list)):
    img_list = glob.glob(output_store_list[store]+"/*.jpg")
    val_num=int(len(img_list)*0.2)   
    test_num=int(len(img_list)*0.2)   
    print(val_num)
        
    validation_list=random.sample(img_list,val_num)  
    for val_img in range(len(validation_list)):
        if not os.path.exists(tabelog_img_output_path+"/validation/"+output_store_list[store].split("/")[-1]):
            os.makedirs(tabelog_img_output_path+"/validation/"+output_store_list[store].split("/")[-1])
        shutil.move(validation_list[val_img],tabelog_img_output_path+"/validation/"+output_store_list[store].split("/")[-1])
            
    img_list =glob.glob(output_store_list[store]+"/*.jpg") 
                                              
    test_list=random.sample(img_list,test_num) 
    for test_img in range(len(test_list)):
        if not os.path.exists(tabelog_img_output_path+"/test/"+output_store_list[store].split("/")[-1]):
            os.makedirs(tabelog_img_output_path+"/test/"+output_store_list[store].split("/")[-1])
        shutil.move(test_list[test_img],tabelog_img_output_path+"/test/"+output_store_list[store].split("/")[-1])
                     
    train_list=glob.glob(output_store_list[store]+"/*.jpg") 
    for train_img in range(len(train_list)):
        if not os.path.exists(tabelog_img_output_path+"/train/"+output_store_list[store].split("/")[-1]):
            os.makedirs(tabelog_img_output_path+"/train/"+output_store_list[store].split("/")[-1])
        shutil.move(train_list[train_img],tabelog_img_output_path+"/train/"+output_store_list[store].split("/")[-1])  

In [33]:
#信頼度80%以上が20枚以下を削除
import shutil
output_store_list = sorted(glob.glob(tabelog_img_output_path+"/*")) 
for output_store in range(len(output_store_list)):
    if 20>len(sorted(glob.glob(output_store_list[output_store]+"/*"))):
        shutil.rmtree(output_store_list[output_store])   

In [None]:
#YOLOの設定と学習
options = {
            "model" : "cfg/yolov2-signboard0731.cfg",
            "load" : "bin/signboard0731/yolov2-signboard0731_10000.weights",
            "threshold" : 0.1,
            "gpu" : 0.9
            }
tfnet = TFNet(options) 

In [None]:
img_path=""
imgcv=cv2.imread(img_path)
if type(imgcv)==np.ndarray:
    print("true")
else:
    print("false")  
print(type(imgcv))

result=tfnet.return_predict(imgcv)
print(result)

In [None]:
for i in range(len(result)):
    tlx=result[i]['topleft']['x']
    tly=result[i]['topleft']['y']
    brx=result[i]['bottomright']['x']
    bry=result[i]['bottomright']['y']
    
    imgcv_out=imgcv[tly:bry,tlx:brx]