In [None]:
import cv2
import numpy as np
import RPi.GPIO as GPIO
import time
import tflite_runtime.interpreter as tflite

# -------------------------------
# GPIO Setup
# -------------------------------
GPIO.setmode(GPIO.BCM)
SCAN_BUTTON = 17
GPIO.setup(SCAN_BUTTON, GPIO.IN, pull_up_down=GPIO.PUD_UP)

# -------------------------------
# ORB + Matcher Setup
# -------------------------------
orb = cv2.ORB_create(800)  # reduced features for speed
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
MAX_FRAME_WIDTH = 600  # downsize for speed

def preprocess(frame, width=MAX_FRAME_WIDTH):
    h, w = frame.shape[:2]
    if w > width:
        new_h = int(h * width / w)
        frame = cv2.resize(frame, (width, new_h), interpolation=cv2.INTER_AREA)
    return frame

def linear_blend(target, warped, mask_warped):
    overlap = (mask_warped.astype(np.uint8) & (cv2.cvtColor(target, cv2.COLOR_BGR2GRAY)>0).astype(np.uint8))
    if overlap.sum() == 0:
        target[mask_warped==1] = warped[mask_warped==1]
        return target
    dist = cv2.distanceTransform((overlap==0).astype(np.uint8), cv2.DIST_L2, 5)
    maxd = dist.max() if dist.max()>0 else 1.0
    alpha = np.clip(dist / maxd, 0.0, 1.0)[...,None]
    mask = mask_warped.astype(bool)
    target[mask] = (warped[mask].astype(np.float32) * (1-alpha[mask]) + target[mask].astype(np.float32) * alpha[mask]).astype(np.uint8)
    return target

def stitch_frames(frame_sequence):
    first = preprocess(frame_sequence[0])
    H, W = first.shape[:2]
    big_canvas = np.zeros((H*4, W*3, 3), dtype=np.uint8)
    big_mask = np.zeros((H*4, W*3), dtype=np.uint8)

    y0 = 10
    x0 = (big_canvas.shape[1] - W)//2
    big_canvas[y0:y0+H, x0:x0+W] = first
    big_mask[y0:y0+H, x0:x0+W] = 255

    prev_kp, prev_des = orb.detectAndCompute(cv2.cvtColor(first, cv2.COLOR_BGR2GRAY), None)

    for raw in frame_sequence[1:]:
        frame = preprocess(raw)
        gframe = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        kp, des = orb.detectAndCompute(gframe, None)
        if des is None or prev_des is None:
            continue
        matches = bf.match(des, prev_des)
        matches = sorted(matches, key=lambda x: x.distance)[:100]

        if len(matches) < 8:
            continue

        src_pts = np.float32([ kp[m.queryIdx].pt for m in matches ]).reshape(-1,1,2)
        dst_pts = np.float32([ prev_kp[m.trainIdx].pt for m in matches ]).reshape(-1,1,2)
        Hmat, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)

        if Hmat is None:
            continue

        warped = cv2.warpPerspective(frame, Hmat, (big_canvas.shape[1], big_canvas.shape[0]))
        warped_gray = cv2.cvtColor(warped, cv2.COLOR_BGR2GRAY)
        warped_mask = (warped_gray > 10).astype(np.uint8)

        big_canvas = linear_blend(big_canvas, warped, warped_mask)
        big_mask = np.clip(big_mask + warped_mask*255, 0, 255)

        prev_kp, prev_des = kp, des

    ys, xs = np.where(big_mask>0)
    if ys.size == 0:
        return big_canvas
    miny, maxy = ys.min(), ys.max()
    minx, maxx = xs.min(), xs.max()
    return big_canvas[miny:maxy+1, minx:maxx+1]

# -------------------------------
# TFLite Setup
# -------------------------------
interpreter = tflite.Interpreter(model_path="mango_leaf_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
class_names = ["Healthy", "Anthracnose", "OtherDisease"]  # replace with your classes

def classify_leaf(img):
    h, w = input_details[0]['shape'][1], input_details[0]['shape'][2]
    img_resized = cv2.resize(img, (w, h))
    img_resized = img_resized.astype(np.float32) / 255.0
    img_resized = np.expand_dims(img_resized, axis=0)
    interpreter.set_tensor(input_details[0]['index'], img_resized)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])[0]
    class_idx = np.argmax(output_data)
    confidence = output_data[class_idx]
    return class_idx, confidence

# Optional severity (example)
def estimate_severity(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    _, mask = cv2.threshold(gray, 120, 255, cv2.THRESH_BINARY)
    diseased_pixels = np.sum(mask > 0)
    total_pixels = mask.size
    return (diseased_pixels / total_pixels) * 100

# -------------------------------
# Scan Function
# -------------------------------
def scan_leaf():
    print("📸 Scanning started...")
    cap = cv2.VideoCapture(0)
    frames = []
    frame_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if frame_count % 5 == 0:
            frames.append(frame)
            print(f"Captured frame {len(frames)}")
        frame_count += 1

        if GPIO.input(SCAN_BUTTON) == GPIO.HIGH:  # button released
            break
        cv2.waitKey(1)

    cap.release()
    cv2.destroyAllWindows()

    if len(frames) > 1:
        stitched = stitch_frames(frames)
        cv2.imwrite("stitched_leaf.png", stitched)
        print("🌿 Stitched leaf saved as stitched_leaf.png")

        # Classification
        class_idx, confidence = classify_leaf(stitched)
        print(f"Disease: {class_names[class_idx]}, Confidence: {confidence*100:.2f}%")

        # Severity
        severity = estimate_severity(stitched)
        print(f"Severity: {severity:.2f}%")
    else:
        print("⚠️ Not enough frames for stitching.")

# -------------------------------
# Main Loop
# -------------------------------
try:
    while True:
        if GPIO.input(SCAN_BUTTON) == GPIO.LOW:  # button pressed
            scan_leaf()
        time.sleep(0.1)

except KeyboardInterrupt:
    GPIO.cleanup()
    print("\n👋 Exiting")
