In [43]:
import torch
from torchvision.models import resnet50
from torchvision.transforms import transforms
from PIL import Image

In [42]:
# Load the pre-trained ResNet model
model = resnet50(pretrained=True)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [45]:
def preprocess_image(img_path):
    img = Image.open(img_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img_tensor = transform(img).unsqueeze(0)  # Add a batch dimension
    return img_tensor

In [50]:
def predict_class(img_path, model):
    # Preprocess the input image
    img_tensor = preprocess_image(img_path)
    with torch.no_grad():
        predictions = model(img_tensor)
    _, predicted_class = torch.max(predictions, 1)
    return predicted_class.item()

In [53]:
def get_class_label(class_index):
    with open('imagenet_classes.txt', 'r') as f:
        labels = f.readlines()
    class_labels = [label.strip() for label in labels]
    return class_labels[class_index]

In [55]:
img_path = '/Users/' #write the path to your image
predicted_class = predict_class(img_path, model)
class_label = get_class_label(predicted_class)

print(f'Predicted Class Index: {predicted_class}')
print(f'Predicted Class Label: {class_label}')

Predicted Class Index: 283
Predicted Class Label: 283: 'Persian cat',
