In [None]:
import torch
import os
from matplotlib import pyplot as plt
import numpy as np
import cv2

1. Load the custom model 

In [None]:
# Load a custom YOLOv5 model from the 'ultralytics/yolov5' repository with the 'custom' architecture.
# The 'path' parameter specifies the path to the custom model weights file ('last.pt').
# The 'force_reload' parameter forces reloading the model weights, ensuring that the latest version is loaded.
model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/runs/train/exp15/weights/last.pt', force_reload=True)
# Change the path to the custom model that you have trained


2. Test with an image

In [None]:
# Construct the file path for an image using the os.path.join() function
# The image is located in the 'data/images' directory.
img = os.path.join('data', 'images', 'awake') # Change the awake to the name of you file that you want to check


In [None]:
results = model(img)

In [None]:
results.print()

In [None]:
# Import matplotlib and configure it to display plots inline in Jupyter Notebook or Jupyter Lab
%matplotlib inline

# Display the rendered image using matplotlib
# The image is obtained from the 'results' object and rendered using np.squeeze() to remove any unnecessary dimensions
# Finally, plt.imshow() displays the image
plt.imshow(np.squeeze(results.render()))

# Show the plot
plt.show()


3. Test the model with realtime video

In [None]:
# Open a connection to the default camera (index 0)
cap = cv2.VideoCapture(0)

# Continuously loop until the camera is opened
while cap.isOpened():
    # Read a frame from the video capture
    ret, frame = cap.read()
    
    # Make object detections on the current frame using the custom YOLOv5 model
    results = model(frame)
    
    # Display the rendered image with detected objects using OpenCV
    cv2.imshow('YOLO', np.squeeze(results.render()))
    
    # Check if the 'q' key is pressed to quit the video feed
    if cv2.waitKey(10) & 0xFF == ord('q'):
        break

# Release the video capture device
cap.release()

# Close all OpenCV windows
cv2.destroyAllWindows()
