In [10]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# Constants
IMG_WIDTH = 128
IMG_HEIGHT = 128

class CropRowDetector:
    def __init__(self, model_path):
        self.model = tf.keras.models.load_model(model_path, compile=False)

    def preprocess_image(self, image):
        img = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))
        img = np.expand_dims(img, axis=0)  # Add batch dimension
        return img

    def split_image(self, image, size=(IMG_WIDTH, IMG_HEIGHT)):
        h, w, _ = image.shape
        sub_images = []
        for y in range(0, h, size[1]):
            for x in range(0, w, size[0]):
                sub_img = image[y:y+size[1], x:x+size[0]]
                if sub_img.shape[0] == size[1] and sub_img.shape[1] == size[0]:
                    sub_images.append(sub_img)
        return sub_images

    def reconstruct_image(self, predictions, original_shape, size=(IMG_WIDTH, IMG_HEIGHT)):
        reconstructed = np.zeros((original_shape[0], original_shape[1], 3), dtype=np.uint8)  # 3 channels for RGB
        index = 0
        
        for y in range(0, original_shape[0], size[1]):
            for x in range(0, original_shape[1], size[0]):
                if index < len(predictions):
                    mask = predictions[index].squeeze()  # Remove any singleton dimensions
                    mask = np.stack([mask] * 3, axis=-1)  # Convert (128, 128) to (128, 128, 3)
                    reconstructed[y:y+size[1], x:x+size[0]] = mask
                    index += 1
        return reconstructed

    def process_full_image(self, image_path):
        img = cv2.imread(image_path)
        if img is None:
            print(f"Error: Unable to read image at {image_path}")
            return None
        
        original_shape = img.shape
        sub_images = self.split_image(img)

        # Predict for each sub-image
        predictions = []
        for sub_img in sub_images:
            input_image = self.preprocess_image(sub_img)
            prediction = self.model.predict(input_image)[0]
            predicted_mask = (prediction > 0.5).astype(np.uint8) * 255  # Make sure it's a binary mask
            predictions.append(predicted_mask)

        # Reconstruct the predicted image
        reconstructed_image = self.reconstruct_image(predictions, original_shape)

        return img, reconstructed_image

def main(image_path, model_path):
    detector = CropRowDetector(model_path)
    original_image, reconstructed_image = detector.process_full_image(image_path)
    
    if original_image is not None and reconstructed_image is not None:
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(cv2.cvtColor(reconstructed_image, cv2.COLOR_BGR2RGB))  # Show in color
        plt.title("Reconstructed Output")
        plt.axis('off')

        plt.show()
    else:
        print("Image processing failed.")

if __name__ == "__main__":
    # Specify the image path and the model path
    image_path = "C:/Users/lenovo/Documents/GitHub/Mechatronics/Paper Submission/Aerial Image2.png"
    model_path = "C:/Users/lenovo/Documents/GitHub/Mechatronics/Codes/crop_row_detection_model.h5"
    main(image_path, model_path)




ValueError: could not broadcast input array from shape (128,128,3) into shape (128,4,3)