In [None]:


import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt
import uuid
import os


model = tf.keras.models.load_model('handwritten_digits.keras')


os.makedirs("/tmp/predictions", exist_ok=True)


def predict_and_show(image):
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    resized = cv2.resize(gray, (28, 28))
    inverted = cv2.bitwise_not(resized)
    normalized = inverted / 255.0
    input_img = np.expand_dims(normalized, axis=0)

    prediction = model.predict(input_img)
    predicted_class = int(np.argmax(prediction))
    confidence = float(np.max(prediction))

    
    fig, ax = plt.subplots()
    ax.imshow(inverted, cmap=plt.cm.binary)
    ax.axis('off')
    ax.set_title(f"Prediction: {predicted_class}")
    img_path = f"/tmp/predictions/pred_{uuid.uuid4().hex}.png"
    fig.savefig(img_path, bbox_inches='tight')
    plt.close(fig)

    return img_path, f"Predicted: {predicted_class} (Confidence: {confidence * 100:.2f}%)", gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), ""


def feedback_correct():
    return "Thanks for the confirmation ✅", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ""


def feedback_incorrect():
    return "Please enter the correct digit:", gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ""


def submit_correction(correct_digit):
    
    print(f"User entered correction: {correct_digit}")
    return f"Correction received: {correct_digit}", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ""


with gr.Blocks() as demo:
    gr.Markdown("## ✏️ Handwritten Digit Recognition App")
    gr.Markdown("Upload a digit image. The model predicts and visualizes the digit. You can correct it if wrong.")

    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="numpy", label="Upload Digit Image")
            predict_btn = gr.Button("Predict")
        with gr.Column():
            output_image = gr.Image(label="Matplotlib Prediction Image")
            output_text = gr.Textbox(label="Prediction", interactive=False)

    with gr.Row():
        correct_btn = gr.Button("Correct", visible=False)
        incorrect_btn = gr.Button("Incorrect", visible=False)

    feedback_output = gr.Textbox(label="Feedback", interactive=False)
    correct_digit_input = gr.Textbox(label="Enter Correct Digit", visible=False)
    submit_btn = gr.Button("Submit Correction", visible=False)

    
    predict_btn.click(
        fn=predict_and_show,
        inputs=[img_input],
        outputs=[
            output_image,
            output_text,
            correct_btn,
            incorrect_btn,
            submit_btn,
            correct_digit_input
        ]
    )

    correct_btn.click(
        fn=feedback_correct,
        inputs=[],
        outputs=[
            feedback_output,
            correct_digit_input,
            submit_btn,
            incorrect_btn,
            correct_digit_input
        ]
    )

    incorrect_btn.click(
        fn=feedback_incorrect,
        inputs=[],
        outputs=[
            feedback_output,
            correct_digit_input,
            submit_btn,
            correct_btn,
            correct_digit_input
        ]
    )

    submit_btn.click(
        fn=submit_correction,
        inputs=[correct_digit_input],
        outputs=[
            feedback_output,
            correct_digit_input,
            submit_btn,
            correct_btn,
            correct_digit_input
        ]
    )

demo.launch(share=True)
