In [2]:
import serial
import numpy as np
from PIL import Image
import os
import time
from pynput import keyboard

ser = serial.Serial('COM5', 115200, timeout=1)  # Ensure baud rate matches Pico

image_counter = 0

# Create the raw_images folder if it doesn't exist
if not os.path.exists('raw_images'):
    os.makedirs('raw_images')

def on_press(key):
    global image_counter
    if key == keyboard.Key.space:
        print("Space bar pressed. Capturing image...")
        ser.reset_input_buffer()  # Clear input buffer
        ser.reset_output_buffer()  # Clear output buffer
        ser.write(b'C')  # Send capture command to Pico

        # Read any debug messages from Pico (if applicable)
        # read_arduino_debug_messages()

        if find_header():
            print("Header found. Reading image data...")
            image_data = read_image_data()
            if image_data:
                # Convert the byte data to a NumPy array
                image_array = np.frombuffer(image_data, dtype=np.uint8)

                # Unpack bits to get the actual image pixels
                bits = np.unpackbits(image_array)

                # Ensure the array has the correct size
                if bits.size != 96 * 96:
                    print(f"Unexpected number of bits: {bits.size}")
                else:
                    # Reshape bits into a 96x96 array
                    image_array = bits.reshape((96, 96))

                    # Convert bits to 0 and 255 for display purposes
                    image_array = image_array * 255

                    # Create an image from the array
                    img = Image.fromarray(image_array.astype('uint8'), 'L')  # 'L' mode for grayscale

                    # Save the image
                    image_filename = f'raw_images/image_{image_counter}.png'
                    img.save(image_filename)
                    print(f"Image saved as {image_filename}")
                    image_counter += 1
            else:
                print("Failed to read image data.")
        else:
            print("Failed to find image header.")
    elif key == keyboard.Key.enter:
        print("Enter key pressed. Exiting program.")
        # Stop listener
        return False

def read_arduino_debug_messages():
    # Read any available debug messages from the Pico until 'END' is received
    start_time = time.time()
    timeout = 5  # seconds
    while True:
        if ser.in_waiting:
            line = ser.readline().decode('utf-8', errors='replace').strip()
            if line:
                print(f"Pico says: {line}")
                if line == 'END':
                    break
        elif (time.time() - start_time) > timeout:
            print("Timeout while waiting for Pico debug messages.")
            break
        else:
            time.sleep(0.01)

def find_header():
    max_attempts = 5000  # Adjust as needed
    attempts = 0
    while attempts < max_attempts:
        byte = ser.read(1)
        if not byte:
            attempts += 1
            continue
        if byte[0] == 0x55:
            next_byte = ser.read(1)
            if next_byte and next_byte[0] == 0xAA:
                return True
        attempts += 1
    return False

def read_image_data():
    image_size = (96 * 96) // 8  # 1152 bytes for 1-bit per pixel image
    image_data = bytearray()
    timeout = 5  # seconds
    start_time = time.time()
    while len(image_data) < image_size and (time.time() - start_time) < timeout:
        if ser.in_waiting:
            data = ser.read(ser.in_waiting)
            image_data.extend(data)
            print(f"Received {len(data)} bytes, total {len(image_data)}/{image_size}")
        else:
            time.sleep(0.01)  # Small delay to prevent CPU overuse
    if len(image_data) == image_size:
        return image_data
    else:
        print(f"Incomplete image data received. Received {len(image_data)} bytes.")
        return None

print("Press space bar to capture an image. Press enter to exit.")

try:
    with keyboard.Listener(on_press=on_press) as listener:
        listener.join()
except KeyboardInterrupt:
    print("Keyboard interrupt received. Exiting program.")
finally:
    ser.close()


Press space bar to capture an image. Press enter to exit.
Space bar pressed. Capturing image...
