In [2]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoImageProcessor, AutoModelForImageClassification
import numpy as np
from time import time

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

# ============================================
# 1. Load Models
# ============================================
print("Loading models...")

# Google's ViT model
vit_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit_model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224").to(device)
vit_model.eval()

# Facebook's DeiT model
deit_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
deit_model = AutoModelForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224").to(device)
deit_model.eval()

print("Models loaded successfully!\n")

Using device: cpu

Loading models...


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


Models loaded successfully!



In [4]:


# ============================================
# 2. Helper Functions
# ============================================

def load_image_from_url(url):
    """Load image from URL"""
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert('RGB')
    return img

def load_image_from_file(file_path):
    """Load image from local file"""
    img = Image.open(file_path).convert('RGB')
    return img

def predict_with_model(model, processor, image, model_name, device):
    """Run inference with a model"""
    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        start_time = time()
        outputs = model(**inputs)
        inference_time = time() - start_time

    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    predicted_label = model.config.id2label[predicted_class_idx]
    confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_idx].item()

    return predicted_label, confidence, inference_time

# ============================================
# 3. Test with Sample Images
# ============================================

# Example image URLs (ImageNet-style images)
test_images = {
    "cat": "https://images.unsplash.com/photo-1574158622682-e40e69881006?w=224",
   # "dog": "https://images.unsplash.com/photo-1633722715463-d30628519b1e?w=224"
    #"bird": "https://images.unsplash.com/photo-1444464666175-1642158e3c45?w=224"
}

print("=" * 70)
print("IMAGE CLASSIFICATION RESULTS")
print("=" * 70)

IMAGE CLASSIFICATION RESULTS


In [5]:


for image_name, image_url in test_images.items():
    print(f"\n{'='*70}")
    print(f"Image: {image_name.upper()}")
    print(f"URL: {image_url}")
    print(f"{'='*70}")

    try:
        # Load image
        image = load_image_from_url(image_url)
        image.thumbnail((224, 224))

        # ViT predictions
        print("\n--- Google Vision Transformer (ViT) ---")
        vit_label, vit_conf, vit_time = predict_with_model(
            vit_model, vit_processor, image, "ViT", device
        )
        print(f"Prediction: {vit_label}")
        print(f"Confidence: {vit_conf:.4f}")
        print(f"Inference time: {vit_time*1000:.2f}ms")

        # DeiT predictions
        print("\n--- Facebook Data-efficient Image Transformer (DeiT) ---")
        deit_label, deit_conf, deit_time = predict_with_model(
            deit_model, deit_processor, image, "DeiT", device
        )
        print(f"Prediction: {deit_label}")
        print(f"Confidence: {deit_conf:.4f}")
        print(f"Inference time: {deit_time*1000:.2f}ms")

        # Comparison
        print("\n--- Comparison ---")
        print(f"Agree: {vit_label == deit_label}")
        print(f"Average confidence: {(vit_conf + deit_conf)/2:.4f}")

    except Exception as e:
        print(f"Error processing {image_name}: {str(e)}")

# ============================================
# 4. Test with Local Image (Optional)
# ============================================
print("\n" + "="*70)
print("TO USE LOCAL IMAGE:")
print("="*70)
print("""
# Uncomment and modify the path:
# local_image = load_image_from_file("path/to/your/image.jpg")
# vit_label, vit_conf, _ = predict_with_model(vit_model, vit_processor, local_image, "ViT", device)
# deit_label, deit_conf, _ = predict_with_model(deit_model, deit_processor, local_image, "DeiT", device)
# print(f"ViT: {vit_label} ({vit_conf:.4f})")
# print(f"DeiT: {deit_label} ({deit_conf:.4f})")
""")


Image: CAT
URL: https://images.unsplash.com/photo-1574158622682-e40e69881006?w=224

--- Google Vision Transformer (ViT) ---
Prediction: Egyptian cat
Confidence: 0.7972
Inference time: 3129.53ms

--- Facebook Data-efficient Image Transformer (DeiT) ---
Prediction: Egyptian cat
Confidence: 0.7124
Inference time: 1369.83ms

--- Comparison ---
Agree: True
Average confidence: 0.7548

TO USE LOCAL IMAGE:

# Uncomment and modify the path:
# local_image = load_image_from_file("path/to/your/image.jpg")
# vit_label, vit_conf, _ = predict_with_model(vit_model, vit_processor, local_image, "ViT", device)
# deit_label, deit_conf, _ = predict_with_model(deit_model, deit_processor, local_image, "DeiT", device)
# print(f"ViT: {vit_label} ({vit_conf:.4f})")
# print(f"DeiT: {deit_label} ({deit_conf:.4f})")



In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [17]:

local_images={"dog": load_image_from_file("/content/drive/My Drive/Colab Notebooks/GeneratedImages/dog.jpg") ,
        "bird": load_image_from_file("/content/drive/My Drive/Colab Notebooks/GeneratedImages/bird.jpg") ,
              "Ammar":load_image_from_file("/content/drive/My Drive/Colab Notebooks/GeneratedImages/ammarphotolinkedin.jpg"),
               "lady":load_image_from_file("/content/drive/My Drive/Colab Notebooks/GeneratedImages/lady.jpg")}


for image_name, local_image in local_images.items():
    print(f"\n{'='*70}")
    print(f"Image: {image_name.upper()}")
    print(f"{'='*70}")
    vit_label, vit_conf, _ = predict_with_model(vit_model, vit_processor, local_image, "ViT", device)
    deit_label, deit_conf, _ = predict_with_model(deit_model, deit_processor, local_image, "DeiT", device)
    print(f"ViT: {vit_label} ({vit_conf:.4f})")
    print(f"DeiT: {deit_label} ({deit_conf:.4f})")


Image: DOG
ViT: golden retriever (0.9650)
DeiT: golden retriever (0.9962)

Image: BIRD
ViT: bee eater (0.4672)
DeiT: bee eater (0.9791)

Image: AMMAR
ViT: suit, suit of clothes (0.5560)
DeiT: suit, suit of clothes (0.9494)

Image: LADY
ViT: jersey, T-shirt, tee shirt (0.0775)
DeiT: pajama, pyjama, pj's, jammies (0.2738)
