# インポート

In [2]:
import numpy as np
import os, random
from utils.split_dataset import split_dataset

from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

In [2]:
dir_input = 'inputs'
dir_input_sample = os.path.join(dir_input, "sample")
dir_input_dataset = os.path.join(dir_input, "dataset")
dir_processing = 'processing'
dir_output = 'outputs'

In [3]:
# sample
use_custom_dataset = False
file_sample = os.path.join(dir_input_sample, 'sample.csv') # 要変換
file_output_tflite = 'efficientdet-lite-salad.tflite'
file_output_labels = 'salad-labels.txt'

# データ準備

In [4]:
if not use_custom_dataset:
    train_data, validation_data, test_data = object_detector.DataLoader.from_csv(file_sample)

else:
    label_map = {1: 'shiro', 2: 'others'}
    train_dir, val_dir, test_dir = split_dataset(dir_images, dir_annotations, val_split=0.2, test_split=0.2, out_path=dir_processing)
    train_data = object_detector.DataLoader.from_pascal_voc(
        os.path.join(train_dir, 'images'),
        os.path.join(train_dir, 'annotations'), label_map=label_map)
    validation_data = object_detector.DataLoader.from_pascal_voc(
        os.path.join(val_dir, 'images'),
        os.path.join(val_dir, 'annotations'), label_map=label_map)
    test_data = object_detector.DataLoader.from_pascal_voc(
        os.path.join(test_dir, 'images'),
        os.path.join(test_dir, 'annotations'), label_map=label_map)


INFO:tensorflow:Cache will be stored in C:\Users\saito\AppData\Local\Temp\tmp8x2rhyed with prefix filename train_41e10587cc007cd10ef3e87cd861a652. Cache_prefix is C:\Users\saito\AppData\Local\Temp\tmp8x2rhyed\train_41e10587cc007cd10ef3e87cd861a652
INFO:tensorflow:Cache will be stored in C:\Users\saito\AppData\Local\Temp\tmpm4hg79kh with prefix filename val_41e10587cc007cd10ef3e87cd861a652. Cache_prefix is C:\Users\saito\AppData\Local\Temp\tmpm4hg79kh\val_41e10587cc007cd10ef3e87cd861a652
INFO:tensorflow:Cache will be stored in C:\Users\saito\AppData\Local\Temp\tmppsvsib53 with prefix filename test_41e10587cc007cd10ef3e87cd861a652. Cache_prefix is C:\Users\saito\AppData\Local\Temp\tmppsvsib53\test_41e10587cc007cd10ef3e87cd861a652
INFO:tensorflow:On image 0
INFO:tensorflow:On image 100
INFO:tensorflow:On image 0
INFO:tensorflow:On image 0


In [5]:
print(f'train count: {len(train_data)}')
print(f'validation count: {len(validation_data)}')
print(f'test count: {len(test_data)}')

train count: 175
validation count: 25
test count: 25


# モデル選択  
0～4が選べる

In [6]:
spec = model_spec.get('efficientdet_lite4')

# 学習

In [None]:
model = object_detector.create(train_data=train_data, 
                               model_spec=spec, 
                               validation_data=validation_data, 
                               epochs=1, 
                               batch_size=10, 
                               train_whole_model=True)

# 検証（数値）

In [None]:
model.evaluate(test_data)

# 検証（画像）

In [6]:
list_images = []
if use_custom_dataset:
    images_path = os.path.join(test_dir, "images")
    filenames = os.listdir(images_path)
    for file in filenames:
        if file.endswith('.jpg'):
            list_images.append(file)
else:
    filenames = os.listdir(dir_input_sample)
    for file in filenames:
        if file.endswith('.jpg'):
            list_images.append(file)

# 変換

In [10]:
model.export(export_dir=dir_output, tflite_filename=file_output_tflite, label_filename=file_output_labels,
             export_format=[ExportFormat.TFLITE, ExportFormat.LABEL])