From 83469f95cf9e07cb92fa7336d744b26400096d90 Mon Sep 17 00:00:00 2001 From: kdt523 Date: Fri, 17 Oct 2025 01:04:53 +0530 Subject: [PATCH 1/3] Add Vision Transformer demo for image classification (Fixes #13372) --- DIRECTORY.md | 1 + computer_vision/vision_transformer_demo.py | 219 +++++++++++++++++++++ pyproject.toml | 3 + 3 files changed, 223 insertions(+) create mode 100644 computer_vision/vision_transformer_demo.py diff --git a/DIRECTORY.md b/DIRECTORY.md index 0f9859577493..d7b0b074889c 100644 --- a/DIRECTORY.md +++ b/DIRECTORY.md @@ -139,6 +139,7 @@ * [Mean Threshold](computer_vision/mean_threshold.py) * [Mosaic Augmentation](computer_vision/mosaic_augmentation.py) * [Pooling Functions](computer_vision/pooling_functions.py) + * [Vision Transformer Demo](computer_vision/vision_transformer_demo.py) ## Conversions * [Astronomical Length Scale Conversion](conversions/astronomical_length_scale_conversion.py) diff --git a/computer_vision/vision_transformer_demo.py b/computer_vision/vision_transformer_demo.py new file mode 100644 index 000000000000..7da6d0f639f0 --- /dev/null +++ b/computer_vision/vision_transformer_demo.py @@ -0,0 +1,219 @@ +""" +Vision Transformer (ViT) Image Classification Demo + +This module demonstrates how to use a pre-trained Vision Transformer (ViT) model +from Hugging Face for image classification tasks. + +Vision Transformers apply the transformer architecture (originally designed for NLP) +to computer vision by splitting images into patches and processing them with +self-attention mechanisms. + +Requirements: + - torch + - transformers + - Pillow (PIL) + - requests + +Resources: + - Paper: https://arxiv.org/abs/2010.11929 + - Hugging Face: https://huggingface.co/docs/transformers/model_doc/vit + +Example Usage: + from computer_vision.vision_transformer_demo import classify_image + + # Classify an image from URL + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + result = classify_image(url) + print(f"Predicted: {result['label']} (confidence: {result['score']:.2%})") + + # Classify a local image + result = classify_image("path/to/image.jpg", top_k=3) + for pred in result['top_k_predictions']: + print(f"{pred['label']}: {pred['score']:.2%}") +""" + +from __future__ import annotations + +import sys +from io import BytesIO +from pathlib import Path +from typing import Any + +try: + import requests + import torch + from PIL import Image + from transformers import ViTForImageClassification, ViTImageProcessor +except ImportError as e: + print(f"Error: Missing required dependency: {e.name}") + print("Install dependencies: pip install torch transformers pillow requests") + sys.exit(1) + + +def load_image(image_source: str | Path, timeout: int = 10) -> Image.Image: + """ + Load an image from a URL or local file path. + + Args: + image_source: URL string or Path object to the image + timeout: Network timeout in seconds (default: 10) + + Returns: + PIL Image object + + Raises: + TimeoutError: If request times out + ConnectionError: If URL is unreachable + FileNotFoundError: If local file doesn't exist + IOError: If image cannot be opened + + Examples: + >>> # Test with non-existent file + >>> try: + ... load_image("nonexistent_file.jpg") + ... except FileNotFoundError: + ... print("File not found") + File not found + """ + if isinstance(image_source, (str, Path)) and str(image_source).startswith( + ("http://", "https://") + ): + try: + response = requests.get(str(image_source), timeout=timeout) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") + except requests.exceptions.Timeout: + msg = ( + f"Request timed out after {timeout} seconds. " + "Try increasing the timeout parameter." + ) + raise TimeoutError(msg) + except requests.exceptions.RequestException as e: + msg = f"Failed to download image from URL: {e}" + raise ConnectionError(msg) from e + else: + # Load from local file + file_path = Path(image_source) + if not file_path.exists(): + msg = f"Image file not found: {file_path}" + raise FileNotFoundError(msg) + return Image.open(file_path).convert("RGB") + + +def classify_image( + image_source: str | Path, + model_name: str = "google/vit-base-patch16-224", + top_k: int = 1, +) -> dict[str, Any]: + """ + Classify an image using a Vision Transformer model. + + Args: + image_source: URL or local path to the image + model_name: Hugging Face model identifier (default: google/vit-base-patch16-224) + top_k: Number of top predictions to return (default: 1) + + Returns: + Dictionary containing: + - label: Predicted class label + - score: Confidence score (0-1) + - top_k_predictions: List of top-k predictions (if top_k > 1) + + Raises: + ValueError: If top_k is less than 1 + FileNotFoundError: If image file doesn't exist + ConnectionError: If unable to download from URL + + Examples: + >>> # Test parameter validation + >>> try: + ... classify_image("test.jpg", top_k=0) + ... except ValueError as e: + ... print("Invalid top_k") + Invalid top_k + """ + if top_k < 1: + raise ValueError("top_k must be at least 1") + # Load image + image = load_image(image_source) + + # Load pre-trained model and processor + # Using context manager pattern for better resource management + processor = ViTImageProcessor.from_pretrained(model_name) + model = ViTForImageClassification.from_pretrained(model_name) + + # Preprocess image + inputs = processor(images=image, return_tensors="pt") + + # Perform inference + with torch.no_grad(): # Disable gradient calculation for inference + outputs = model(**inputs) + logits = outputs.logits + + # Get predictions + probabilities = torch.nn.functional.softmax(logits, dim=-1) + top_k_probs, top_k_indices = torch.topk(probabilities, k=top_k, dim=-1) + + # Format results + predictions = [] + for prob, idx in zip(top_k_probs[0], top_k_indices[0]): + predictions.append( + {"label": model.config.id2label[idx.item()], "score": prob.item()} + ) + + result = { + "label": predictions[0]["label"], + "score": predictions[0]["score"], + "top_k_predictions": predictions if top_k > 1 else None, + } + + return result + + +def main() -> None: + """ + Main function demonstrating Vision Transformer usage. + + Downloads a sample image and performs classification. + """ + print("Vision Transformer (ViT) Image Classification Demo") + print("=" * 60) + + # Sample image URL (two cats on a couch from COCO dataset) + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + print(f"\nLoading image from: {image_url}") + + try: + # Get top-3 predictions + result = classify_image(image_url, top_k=3) + + print(f"\n{'Prediction Results':^60}") + print("-" * 60) + print(f"Top Prediction: {result['label']}") + print(f"Confidence: {result['score']:.2%}") + + if result["top_k_predictions"]: + print(f"\n{'Top 3 Predictions':^60}") + print("-" * 60) + for i, pred in enumerate(result["top_k_predictions"], 1): + print(f"{i}. {pred['label']:<40} {pred['score']:>6.2%}") + + # Example with local image (commented out) + print("\n" + "=" * 60) + print("To classify a local image, use:") + print(' result = classify_image("path/to/your/image.jpg")') + print(" print(f\"Predicted: {result['label']}\")") + + except TimeoutError as e: + print(f"\nError: {e}") + print("Please check your internet connection and try again.") + except ConnectionError as e: + print(f"\nError: {e}") + except Exception as e: + print(f"\nUnexpected error: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 60ba0d3b65d9..dd02f0286f90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,12 +20,15 @@ dependencies = [ "opencv-python>=4.10.0.84", "pandas>=2.2.3", "pillow>=11.3", + "requests>=2.31.0", "rich>=13.9.4", "scikit-learn>=1.5.2", "scipy>=1.16.2", "sphinx-pyproject>=0.3", "statsmodels>=0.14.4", "sympy>=1.13.3", + "torch>=2.0.0", + "transformers>=4.30.0", "tweepy>=4.14", "typing-extensions>=4.12.2", "xgboost>=2.1.3", From 8701c927d33ccc5f6909f771fcb253d8a833f0e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Oct 2025 19:50:35 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dd02f0286f90..016569101fb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,15 +20,15 @@ dependencies = [ "opencv-python>=4.10.0.84", "pandas>=2.2.3", "pillow>=11.3", - "requests>=2.31.0", + "requests>=2.31", "rich>=13.9.4", "scikit-learn>=1.5.2", "scipy>=1.16.2", "sphinx-pyproject>=0.3", "statsmodels>=0.14.4", "sympy>=1.13.3", - "torch>=2.0.0", - "transformers>=4.30.0", + "torch>=2", + "transformers>=4.30", "tweepy>=4.14", "typing-extensions>=4.12.2", "xgboost>=2.1.3", From 9b7385c8bf89d89e194460b1fc76aa874239fa89 Mon Sep 17 00:00:00 2001 From: kdt523 Date: Fri, 17 Oct 2025 13:11:51 +0530 Subject: [PATCH 3/3] Add doctest to main() function per reviewer request --- computer_vision/vision_transformer_demo.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/computer_vision/vision_transformer_demo.py b/computer_vision/vision_transformer_demo.py index 3370bea6226f..cb34b21bc220 100644 --- a/computer_vision/vision_transformer_demo.py +++ b/computer_vision/vision_transformer_demo.py @@ -176,6 +176,14 @@ def main() -> None: Main function demonstrating Vision Transformer usage. Downloads a sample image and performs classification. + + Examples: + >>> # Verify main is callable + >>> callable(main) + True + >>> # Verify main returns None + >>> main() is None # doctest: +SKIP + True """ print("Vision Transformer (ViT) Image Classification Demo") print("=" * 60)