In [9]:
import os
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import config
import importlib


In [12]:

#reload config whenever there are changes to avoid stupid cache issues
importlib.reload(config)

# Parameters from training
img_width, img_height = config.img_width, config.img_height
class_names = config.class_names

TF_MODEL_FILE_PATH = config.TF_MODEL_FILE_PATH
BATCH_SIZE = 12500 

I was foolish enough to try and run this multi threaded initially... until I discovered TF Lite is multithreaded itself so we can just adjust the BATCH param and make sure the architecture of the model allows it. It's best to build a tf lite model that serves batch processing straight away I found

In [13]:


# load the interpreter
def load_interpreter():
    interpreter = tf.lite.Interpreter(model_path=TF_MODEL_FILE_PATH)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    return interpreter, input_details, output_details

interpreter, input_details, output_details = load_interpreter()


def open_imgs(img_paths):
    batch_imgs = []
    for img_path in img_paths:
        img = tf.keras.utils.load_img(img_path, target_size=(img_height, img_width))
        img_array = tf.keras.utils.img_to_array(img)
        batch_imgs.append(img_array)
    return np.array(batch_imgs)


def predict_and_classify_images(batch_imgs):
    batch_imgs = batch_imgs.astype(np.float32)  # Ensure the batch is of type float32
    interpreter.resize_tensor_input(input_details[0]['index'], [len(batch_imgs), img_height, img_width, 3])
    interpreter.allocate_tensors()  # Reallocate tensors after resizing
    interpreter.set_tensor(input_details[0]['index'], batch_imgs)
    interpreter.invoke()
    predictions = interpreter.get_tensor(output_details[0]['index'])
    results = []
    for i in range(predictions.shape[0]):
        score = tf.nn.softmax(predictions[i])
        results.append((class_names[np.argmax(score)], 100 * np.max(score)))
    return results

# dirs
input_directory = './data/02-source-map-tiles'
log_file_path = config.log_file_path

# gather all file paths
all_files = [os.path.join(root, file) for root, _, files in os.walk(input_directory) for file in files if os.path.isfile(os.path.join(root, file))]
total_files = len(all_files)
print(f"Collected all file paths, total tiles: {total_files}")

# Process and log images
with open(log_file_path, 'w') as log_file:
    with tqdm(total=total_files, desc="Processing images") as pbar:
        for i in range(0, total_files, BATCH_SIZE):
            batch_paths = all_files[i:i + BATCH_SIZE]
            batch_imgs = open_imgs(batch_paths)
            predictions = predict_and_classify_images(batch_imgs)
            for path, (pred_class, confidence) in zip(batch_paths, predictions):
                log_file.write(f"{path},{pred_class},{confidence:.2f}\n")
            pbar.update(len(batch_paths))

print("Processing and logging completed.")


Collected all file paths, total tiles: 1593368


Processing images: 100%|██████████| 1593368/1593368 [2:34:34<00:00, 171.80it/s] 

Processing and logging completed.



