In [3]:
import os
import numpy as np
import tensorflow as tf
from tqdm import tqdm

# Parameters from training
img_width, img_height = 15, 15
class_names = ['Class-0-water', 'Class-1', 'Class-2', 'Class-3', 'Class-4', 'Class-5']
TF_MODEL_FILE_PATH = 'initial-model-hi-res-15px-batched-turbo-4o.tflite'
BATCH_SIZE = 12500  # Adjust the batch size as needed

# Load the interpreter
def load_interpreter():
    interpreter = tf.lite.Interpreter(model_path=TF_MODEL_FILE_PATH)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    return interpreter, input_details, output_details

interpreter, input_details, output_details = load_interpreter()

# Function to open and preprocess images
def open_imgs(img_paths):
    batch_imgs = []
    for img_path in img_paths:
        img = tf.keras.utils.load_img(img_path, target_size=(img_height, img_width))
        img_array = tf.keras.utils.img_to_array(img)
        batch_imgs.append(img_array)
    return np.array(batch_imgs)

# Function to predict and classify a batch of images
def predict_and_classify_images(batch_imgs):
    batch_imgs = batch_imgs.astype(np.float32)  # Ensure the batch is of type float32
    interpreter.resize_tensor_input(input_details[0]['index'], [len(batch_imgs), img_height, img_width, 3])
    interpreter.allocate_tensors()  # Reallocate tensors after resizing
    interpreter.set_tensor(input_details[0]['index'], batch_imgs)
    interpreter.invoke()
    predictions = interpreter.get_tensor(output_details[0]['index'])
    results = []
    for i in range(predictions.shape[0]):
        score = tf.nn.softmax(predictions[i])
        results.append((class_names[np.argmax(score)], 100 * np.max(score)))
    return results

# Directories
input_directory = './data/02-map-chopped-tiles-uncategorised-in-15px'
log_file_path = './classification_log-hi-res-Large-Tiles-batched-new-train-4o.txt'

# Gather all file paths
all_files = [os.path.join(root, file) for root, _, files in os.walk(input_directory) for file in files if os.path.isfile(os.path.join(root, file))]
total_files = len(all_files)
print(f"Collected all file paths, total: {total_files}")

# Process and log images
with open(log_file_path, 'w') as log_file:
    with tqdm(total=total_files, desc="Processing images") as pbar:
        for i in range(0, total_files, BATCH_SIZE):
            batch_paths = all_files[i:i + BATCH_SIZE]
            batch_imgs = open_imgs(batch_paths)
            predictions = predict_and_classify_images(batch_imgs)
            for path, (pred_class, confidence) in zip(batch_paths, predictions):
                log_file.write(f"{path},{pred_class},{confidence:.2f}\n")
            pbar.update(len(batch_paths))

print("Processing and logging completed.")


Collected all file paths, total: 255640


Processing images: 100%|██████████| 255640/255640 [03:25<00:00, 1241.14it/s]

Processing and logging completed.



