In [None]:
#We first need to install the older gradio version to meet compability with other modules
!pip install gradio==5.35.0

In [None]:
import gradio as gr
import torch
import timm
import cv2
import numpy as np
import pandas as pd
from torchvision import transforms
import os
import shutil

# Configurance
CONFIG = {
    "img_size": 224,
    "num_classes": 3,
    "model_name": "swin_tiny_patch4_window7_224",
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}
print("Using device:", CONFIG["device"])
print("GPU available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))

IDX2LABEL = {0: "Healthy", 1: "Moderate", 2: "Severe"}

test_tf = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Load trained-Model
model = timm.create_model(CONFIG['model_name'], pretrained=False, num_classes=CONFIG['num_classes'])
model.load_state_dict(torch.load("/kaggle/input/test-best-model/best_sick_f1_model(1).pth", map_location=CONFIG['device']))
model.to(CONFIG['device'])
model.eval()

# Single Detecting Function
def predict_fn(img):
    img = np.array(img)
    img = cv2.resize(img, (CONFIG['img_size'], CONFIG['img_size']))
    tensor = test_tf(img).unsqueeze(0).to(CONFIG['device'])
    with torch.no_grad():
        outputs = model(tensor)
        probs = torch.softmax(outputs, dim=1)[0].cpu().numpy()
        _, pred = torch.max(outputs, 1)

    pred_label = IDX2LABEL[pred.item()]
    prob_data = pd.DataFrame({
        "Class": [IDX2LABEL[i] for i in range(CONFIG['num_classes'])],
        "Probability": [float(probs[i]) for i in range(CONFIG['num_classes'])]
    })

    return pred_label, prob_data, prob_data

# Batch-Detecting Function
def batch_predict_fn(file_obj):
    results = []
    if file_obj is None:
        return pd.DataFrame(columns=["Image", "Prediction", "Healthy", "Moderate", "Severe"])

    if file_obj.name.endswith(".zip"):
        import zipfile, tempfile
        tmpdir = tempfile.mkdtemp()
        with zipfile.ZipFile(file_obj.name, "r") as zip_ref:
            zip_ref.extractall(tmpdir)
        img_paths = [os.path.join(tmpdir, f) for f in os.listdir(tmpdir)]
    else:
        img_paths = [file_obj.name]

    for path in img_paths:
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_resized = cv2.resize(img, (CONFIG['img_size'], CONFIG['img_size']))
        tensor = test_tf(img_resized).unsqueeze(0).to(CONFIG['device'])
        with torch.no_grad():
            outputs = model(tensor)
            probs = torch.softmax(outputs, dim=1)[0].cpu().numpy()
            _, pred = torch.max(outputs, 1)
        pred_label = IDX2LABEL[pred.item()]
        results.append([os.path.basename(path), pred_label, probs[0], probs[1], probs[2]])

    df = pd.DataFrame(results, columns=["Image", "Prediction", "Healthy", "Moderate", "Severe"])
    return df

with gr.Blocks(theme="soft") as demo:
    gr.Markdown("""
    # ü©∫ Diabetic Retinopathy Severity Classification
    ---
    This demo uses a **Swin Transformer** model to classify retinal fundus images into three severity levels of diabetic retinopathy.  
    - **Classes**: Healthy / Moderate / Severe  
    - **Dataset**: APTOS 2019 (Kaggle)  
    - **Model**: Swin-Tiny Patch4 Window7 224  

    üåü *Early detection of diabetic retinopathy is crucial for preventing vision loss.*  
    """)
    
    #Provide Example for Single Detecting
    with gr.Tab("üîç Single Prediction"):
        with gr.Row():
            with gr.Column(scale=1):
                img_input = gr.Image(type="pil", label="üì§ Upload Retinal Image")
                submit_btn = gr.Button("Run Prediction")
            with gr.Column(scale=1):
                prediction = gr.Label(label="Final Prediction")
                prob_plot = gr.BarPlot(x="Class", y="Probability", label="Prediction Confidence")
                prob_table = gr.Dataframe(headers=["Class","Probability"], label="Numerical Probabilities")

        submit_btn.click(fn=predict_fn, inputs=img_input, outputs=[prediction, prob_plot, prob_table])

        #
        gr.Markdown("### üéØ Try with Example Images")
        gr.Examples(
            examples=[
                "/kaggle/input/demonstration/class 2.png",
                "/kaggle/input/demonstration/class 4.png"
            ],
            inputs=img_input,
            outputs=[prediction, prob_plot, prob_table],
            fn=predict_fn,
            label="Click an example image to run prediction"
        )

    with gr.Tab("üìÇ Batch Prediction"):
        batch_input = gr.File(file_types=[".zip", ".jpg", ".png"], label="Upload multiple images or zip")
        batch_btn = gr.Button("Run Batch Prediction")
        batch_output = gr.Dataframe(headers=["Image", "Prediction", "Healthy", "Moderate", "Severe"])
        batch_btn.click(fn=batch_predict_fn, inputs=batch_input, outputs=batch_output)

        # Provide Example for Batching-Detecting
        gr.Markdown("### üì¶ Try with Example Batch")
        shutil.make_archive("/kaggle/working/batch_test", 'zip', "/kaggle/input/demonstration/test_batch")
        gr.Examples(
            examples=[
                "/kaggle/working/batch_test.zip"
            ],
            inputs=batch_input,
            outputs=batch_output,
            fn=batch_predict_fn,
            label="Click to run batch example"
)
    with gr.Tab("‚ÑπÔ∏è Model Info"):
        gr.Markdown("""
        ## üìñ Model Information
        - **Architecture**: Swin Transformer Tiny  
        - **Training Data**: APTOS 2019 (Kaggle)  
        - **Performance**: Average Sick F1-score = 0.8277  
        - **Metrics**: Precision, Recall, F1-score available  
        """)

demo.launch(share=True, allowed_paths=["/kaggle/input/demonstration"])