In [None]:
!pip install torch torchvision torchaudio
import torch
import torchvision
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# Load a model pre-trained on COCO
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# Replace the classifier with a new one, that has num_classes which is user-defined
num_classes = 2  # 1 class (object) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Function to load and process the image
def process_image(image_path):
    # Load the image
    image = Image.open(image_path)

    # Define a transform to convert the image to tensor
    transform = transforms.Compose([transforms.ToTensor()])
    image_tensor = transform(image)

    # Put the model in evaluation mode
    model.eval()

    # Perform inference
    with torch.no_grad():
        prediction = model([image_tensor])

    return image, prediction

# Get the image path from the user
image_path = input("Enter the path to your image: ")

# Process the image
image, prediction = process_image(image_path)

# Visualize the results
def visualize_results(image, prediction):
  # Get the bounding boxes, labels and scores
  boxes = prediction[0]['boxes']
  labels = prediction[0]['labels']
  scores = prediction[0]['scores']

  # Plot the image
  plt.imshow(image)

  # Add the bounding boxes to the plot
  for i in range(len(boxes)):
    if scores[i] > 0.5: # Set a threshold for the score
      xmin, ymin, xmax, ymax = boxes[i]
      plt.plot([xmin, xmax, xmax, xmin, xmin], [ymin, ymin, ymax, ymax, ymin], linewidth=3)
      plt.text(xmin, ymin, f'Class: {labels[i]}, Score: {scores[i]:.2f}', bbox=dict(facecolor='white'))

  plt.show()

visualize_results(image, prediction)



Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:05<00:00, 32.1MB/s]
