In [6]:
import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report


In [7]:
# Load the pre-trained ResNet50 model
model = models.resnet50(pretrained=True)
model.eval()  # Set the model to evaluation mode


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 149MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [8]:
# Define transformations to resize and normalize the image
preprocess = transforms.Compose([
    transforms.Resize(256),       # Resize the image to 256x256 pixels
    transforms.CenterCrop(224),   # Crop it to 224x224 pixels, the input size expected by ResNet
    transforms.ToTensor(),        # Convert the image to a PyTorch tensor
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean values for normalization
        std=[0.229, 0.224, 0.225]    # ImageNet std deviation values for normalization
    )
])


In [9]:
# Download images directly into Colab with updated URLs
!wget -O cat1.jpg https://images.unsplash.com/photo-1574158622682-e40e69881006
!wget -O cat2.jpg https://images.unsplash.com/photo-1543852786-1cf6624b9987?crop=entropy&cs=tinysrgb&fit=max&fm=jpg&ixid=MnwzNjUyOXwwfDF8c2VhcmNofDR8fGNhdHxlbnwwfHx8fDE2MzUyNTUwNjM&ixlib=rb-1.2.1&q=80&w=400
!wget -O dog1.jpg https://images.unsplash.com/photo-1517849845537-4d257902454a
!wget -O dog2.jpg https://images.unsplash.com/photo-1525253086316-d0c936c814f8
!wget -O car1.jpg https://images.unsplash.com/photo-1493238792000-8113da705763
!wget -O car2.jpg https://images.unsplash.com/photo-1502877338535-766e1452684a
!wget -O flower1.jpg https://images.unsplash.com/photo-1501004318641-b39e6451bec6
!wget -O flower2.jpg https://images.unsplash.com/photo-1516979187457-637abb4f9353
!!wget -O bird1.jpg https://images.unsplash.com/photo-1557683316-973673baf926
!wget -O bird2.jpg https://images.unsplash.com/photo-1557683316-973673baf926


--2024-10-31 22:55:38--  https://images.unsplash.com/photo-1574158622682-e40e69881006
Resolving images.unsplash.com (images.unsplash.com)... 151.101.2.208, 151.101.66.208, 151.101.130.208, ...
Connecting to images.unsplash.com (images.unsplash.com)|151.101.2.208|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 900980 (880K) [image/jpeg]
Saving to: ‘cat1.jpg’


2024-10-31 22:55:38 (24.5 MB/s) - ‘cat1.jpg’ saved [900980/900980]

--2024-10-31 22:55:38--  https://images.unsplash.com/photo-1543852786-1cf6624b9987?crop=entropy
Resolving images.unsplash.com (images.unsplash.com)... 151.101.2.208, 151.101.66.208, 151.101.130.208, ...
Connecting to images.unsplash.com (images.unsplash.com)|151.101.2.208|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5517104 (5.3M) [image/jpeg]
Saving to: ‘cat2.jpg’


2024-10-31 22:55:39 (83.5 MB/s) - ‘cat2.jpg’ saved [5517104/5517104]

--2024-10-31 22:55:39--  https://images.unsplash.com/photo-1517849845537-4d

In [10]:
def load_image(image_path):
    img = Image.open(image_path)          # Open the image
    img = preprocess(img).unsqueeze(0)    # Apply preprocessing and add batch dimension
    return img


In [11]:
def predict_image_class(image_path):
    img_tensor = load_image(image_path)        # Load and preprocess image
    with torch.no_grad():                      # Disable gradient calculations
        output = model(img_tensor)             # Forward pass
    probabilities = torch.nn.functional.softmax(output[0], dim=0)  # Convert to probabilities

    # Get the top 5 predicted classes
    _, top5_indices = torch.topk(probabilities, 5)
    return top5_indices, probabilities[top5_indices]


In [12]:
import json
import requests

url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
labels = requests.get(url).json()


In [13]:
def print_predictions(image_path):
    top5_indices, top5_probs = predict_image_class(image_path)
    for idx, prob in zip(top5_indices, top5_probs):
        print(f"{labels[idx]}: {prob:.4f}")


In [14]:
print_predictions("dog1.jpg")

pug: 0.9675
Griffon Bruxellois: 0.0262
French Bulldog: 0.0020
Boston Terrier: 0.0004
Bullmastiff: 0.0004


In [15]:
print_predictions("dog2.jpg")


Border Collie: 0.5926
Japanese Chin: 0.1174
collie: 0.0945
Papillon: 0.0664
English Setter: 0.0228


In [16]:
print_predictions("cat1.jpg")

tabby cat: 0.6762
tiger cat: 0.2017
Egyptian Mau: 0.1200
carton: 0.0004
lynx: 0.0003


In [17]:
print_predictions("cat2.jpg")


Egyptian Mau: 0.5895
bow tie: 0.2444
tabby cat: 0.0858
tiger cat: 0.0364
poke bonnet: 0.0099


In [18]:
print_predictions("car1.jpg")


taxicab: 0.9131
parking meter: 0.0183
station wagon: 0.0152
car wheel: 0.0150
sports car: 0.0058


In [20]:
print_predictions("car2.jpg")


sports car: 0.6234
car wheel: 0.1727
convertible: 0.1326
station wagon: 0.0499
grille: 0.0154


In [21]:
print_predictions("flower1.jpg")

pot: 0.2644
vase: 0.1658
cup: 0.0435
plectrum: 0.0366
barrette: 0.0235


In [23]:
print_predictions("flower2.jpg")

pill bottle: 0.2361
match: 0.1679
eraser: 0.1258
carton: 0.1002
accordion: 0.0691


In [25]:
print_predictions("bird1.jpg")

wing: 0.0546
lighthouse: 0.0268
space shuttle: 0.0229
water bottle: 0.0197
seashore: 0.0196


In [26]:
print_predictions("bird2.jpg")

wing: 0.0546
lighthouse: 0.0268
space shuttle: 0.0229
water bottle: 0.0197
seashore: 0.0196
