In [11]:
import numpy as np
from PIL import Image
import tensorflow as tf

# Load TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path="mobilenet_distilled.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
print("Expected input shape:", input_shape)

# Load image
image = Image.open("dfsjh.jpeg").convert("RGB")
image = image.resize((256, 256))

image_np = np.array(image).astype(np.float32) / 255.0
image_np = (image_np - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
image_np = image_np.astype(np.float32)  # <-- Important cast here

image_np = np.expand_dims(image_np, axis=0)  # (1, 256, 256, 3)

if input_shape[1] == 3:
    image_np = np.transpose(image_np, (0, 3, 1, 2))
elif input_shape[-1] == 3:
    pass
else:
    raise ValueError(f"Unexpected input shape format: {input_shape}")

print("Input tensor shape:", image_np.shape)
print("Input tensor dtype:", image_np.dtype)

interpreter.set_tensor(input_details[0]['index'], image_np)

interpreter.invoke()

output = interpreter.get_tensor(output_details[0]['index'])
pred = np.argmax(output)

print("Predicted class index:", pred)


Expected input shape: [  1   3 256 256]
Input tensor shape: (1, 3, 256, 256)
Input tensor dtype: float32
Predicted class index: 0
