In [1]:
# !pip install transformers torch ipywidgets pillow

# I think cybershuttle probably has this already, but it doesn't hurt to have

In [2]:
# Imports
import torch
from PIL import Image
from ipywidgets import FileUpload, Button, VBox, Output, Label
from IPython.display import display
import io
import os
import time
from transformers import AutoModel, AutoFeatureExtractor, AutoImageProcessor, ResNetForImageClassification
from PIL import Image
import tracemalloc
import subprocess
import gc
import platform

In [3]:
model_name = "microsoft/resnet-50"

# Download model and feature extractor (if needed)
save_directory = "./my_local_model"
if not os.path.exists(save_directory):
    model = ResNetForImageClassification.from_pretrained(model_name)  # Use the correct class
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

    # Save locally
    model.save_pretrained(save_directory)
    feature_extractor.save_pretrained(save_directory)
    print("Model and feature extractor saved successfully!")
else:
    # Load from saved directory
    model = ResNetForImageClassification.from_pretrained(save_directory)
    feature_extractor = AutoFeatureExtractor.from_pretrained(save_directory)
    print("Loaded model from local directory")

# Move model to appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

Loaded model from local directory




In [4]:
# # Load model and processor (only once for optimization purposes)
# processor = AutoImageProcessor.from_pretrained("./my_local_model")
# model = ResNetForImageClassification.from_pretrained("./my_local_model")

# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = model.to(device)

In [5]:
def classify_image(image, feature_extractor, model):
    if not isinstance(image, Image.Image):
        image = Image.open(image)

    inputs = feature_extractor(image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        logits = model(**inputs).logits

    predicted_label = logits.argmax(-1).item()
    return model.config.id2label[predicted_label]

In [6]:
import time
import tracemalloc
from PIL import Image
import io
from ipywidgets import FileUpload, Button, Label, Output, VBox
from IPython.display import display

# UI Setup (multiple files accepted)
uploader = FileUpload(accept='image/*', multiple=True)
classify_button = Button(description="🚀 Classify Image", button_style="success")
clear_output_button = Button(description="🧹 Clear Results", button_style="warning")
status_label = Label("Upload an image and click the button!")
output = Output()

def on_classify_clicked(b):
    output.clear_output()
    status_label.value = "🧠 Thinking..."

    if uploader.value:
        with output:
            for file_info in uploader.value:  # Iterate over uploaded files
                uploaded_filename = file_info["name"]
                uploaded_content = file_info["content"]
                image = Image.open(io.BytesIO(uploaded_content))

                # Start timing and memory tracking for this image
                tracemalloc.start()
                start_time = time.time()

                prediction = classify_image(image, feature_extractor, model)

                # End timing and memory tracking
                end_time = time.time()
                current, peak = tracemalloc.get_traced_memory()
                tracemalloc.stop()

                # Display image and results
                display(image)
                print(f"🧠 Predicted class for {uploaded_filename}: {prediction}")
                print(f"⏱️ Inference time: {end_time - start_time:.4f} seconds")
                print(f"📈 Peak memory usage: {peak / 1024 / 1024:.4f} MB\n")

        status_label.value = "✅ Done!"
    else:
        with output:
            print("❗ Please upload at least one image.")
        status_label.value = "⚠️ No images uploaded."

def on_clear_output_clicked(b):
    uploader._counter = 0
    output.clear_output()
    status_label.value = "Upload an image and click the button!"
    with output:
        print("Results cleared")

# Connect buttons to functions
classify_button.on_click(on_classify_clicked)
clear_output_button.on_click(on_clear_output_clicked)

# Launch app
display(VBox([status_label, uploader, classify_button, clear_output_button, output]))


VBox(children=(Label(value='Upload an image and click the button!'), FileUpload(value=(), accept='image/*', de…