In [None]:
# !pip install transformers torch ipywidgets pillow

# I think cybershuttle probably has this already, but it doesn't hurt to have

: 

In [None]:
# Imports
import torch
from transformers import AutoImageProcessor, ResNetForImageClassification
from PIL import Image
from ipywidgets import FileUpload, Button, VBox, Output, Label
from IPython.display import display
import io
import os

: 

In [None]:
# Load model and processor (only once for optimization purposes)
processor = AutoImageProcessor.from_pretrained("./my_local_model")
model = ResNetForImageClassification.from_pretrained("./my_local_model")

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

: 

In [None]:
def classify_image(image, processor, model):
    if not isinstance(image, Image.Image):
        image = Image.open(image)

    inputs = processor(image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        logits = model(**inputs).logits

    predicted_label = logits.argmax(-1).item()
    return model.config.id2label[predicted_label]

: 

In [None]:
# UI Setup (multiple files accepted)
uploader = FileUpload(accept='image/*', multiple=True)
classify_button = Button(description="🚀 Classify Image", button_style="success")
status_label = Label("Upload an image and click the button!")
output = Output()

# Button logic with "Thinking..." status
def on_classify_clicked(b):
    output.clear_output()
    status_label.value = "🧠 Thinking..."
    
    # Loop through all images uploaded
    if uploader.value: 
        with output:
            for uploaded_filename in uploader.value:
                uploaded_content = uploader.value[uploaded_filename]['content']
                image = Image.open(io.BytesIO(uploaded_content))

                prediction = classify_image(image, processor, model)

                display(image)
                print(f"🧠 Predicted class for {uploaded_filename}: {prediction}")
        
        status_label.value = "✅ Done!"
    else:
        with output:
            print("❗ Please upload at least one image.")
        status_label.value = "⚠️ No images uploaded."

classify_button.on_click(on_classify_clicked)


: 

In [None]:
# Launch app
display(VBox([status_label, uploader, classify_button, output]))

: 