In [None]:
import os
import tensorflow as tf

def predict_logits(model_path, test_dir):
    """
    Takes a model path and test directory, returns dictionary of filename:logit pairs.
    
    Args:
        model_path (str): Path to the .h5 model file
        test_dir (str): Directory containing test images
        
    Returns:
        dict: Dictionary with filenames as keys and logits as values
    """
    # Load model
    model = tf.keras.models.load_model(model_path)
    
    # Get all image files recursively
    image_files = []
    for root, _, files in os.walk(test_dir):
        for file in files:
            if file.lower().endswith('.jpg'):
                image_files.append(os.path.join(root, file))
    
    # Create dataset
    def process_image(path):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, (224, 224)) 
        img = tf.cast(img, tf.float32) / 255.0 # <<<<<<<< test if this is needed!!!!!!!!!
        return img
    
    # Create dataset from image files
    dataset = tf.data.Dataset.from_tensor_slices(image_files)
    dataset = dataset.map(process_image)
    dataset = dataset.batch(32)
    
    # Get predictions
    predictions = model.predict(dataset, verbose=0)
    
    # Create dictionary of filename:logit pairs
    results = {os.path.basename(fname): float(logit) 
              for fname, logit in zip(image_files, predictions.flatten())}
    
    return results

# Example usage:
# model_path = "path/to/your/model.h5"
# test_dir = "path/to/test/data"
# results = predict_logits(model_path, test_dir)