# YOLO Object Tracking Test 

This notebook explores the possiblity of making use of the YOLO algorithm to perform tracking of the mouse moving within a maze. 

YOLO is a machine learning model designed to perform real-time object localization and detection and is widely used for a variety of computer vision based tasks. It's well known for its impressive speed and accuracy, and ease of use. 

In this experiment, we use YOLO v8.2 to perform image localization. This is the process of identifying a set of two points on the image, known as the bounding box, which surround an object within the image. 

In [2]:
import numpy as np 
import cv2 
from tqdm import trange
import matplotlib.pyplot as plt 
import ultralytics 
import os 
from typing import Optional
from tqdm import tqdm 
import torch

VIDEO_PATH = "/Users/henrywilliams/Downloads/videos/maze_test.mp4" 

yolo = ultralytics.YOLO('yolov8n.pt')

# Create a Dataset

The following code cell takes 250 random frames from the video capture, and saves them to the `./frames/` directory. These frames are later annotated and used to fine-tune the YOLO model. To change where the frames are saved, modify the `OUTPUT_DIR` variable, and in order to change the number of samples captured, modify the `TOTAL_SAMPLES` variable. 

In [3]:
TOTAL_SAMPLES = 250
OUTPUT_DIR = './frames'

def get_nth_frame(cap: cv2.VideoCapture, n: int, n_frames: Optional[int] = None) -> np.ndarray: 
    """
    Retrieve the nth frame from a video capture object.

    This function extracts the nth frame from a given video capture object `cap`. 
    If `n_frames` is not provided, the function will determine the total number of frames 
    in the video. The frame counter is reset to its last value before the nth frame was 
    retrieved

    Parameters:
    cap (cv2.VideoCapture): The video capture object from which to extract the frame.
    n (int): The frame number to retrieve.
    n_frames (Optional[int], optional): The total number of frames in the video. 
                                        If not provided, it will be determined automatically.

    Returns:
    np.ndarray: The nth frame as an array.

    Raises:
    IndexError: If `n` exceeds the total number of frames available in the video.
    RuntimeError: If the frame could not be read from the video capture object.

    Example usage:
    >>> cap = cv2.VideoCapture('video.mp4')
    >>> frame = get_nth_frame(cap, 10)
    >>> cv2.imshow('10th Frame', frame)
    >>> cv2.waitKey(0)
    >>> cv2.destroyAllWindows()
    """
    n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) if n_frames is None else n_frames
    current_frame = cap.get(cv2.CAP_PROP_POS_FRAMES)

    if n > n_frames: 
        raise IndexError(f"Attempted to get {n}th frame when only {n_frames} frames exist")
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, n)
    
    ret, frame = cap.read()
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)

    if not ret: 
        raise RuntimeError(f"Failed to read {n}th frame")
    
    return frame 

cap = cv2.VideoCapture(VIDEO_PATH)
n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)

if not os.path.exists(OUTPUT_DIR) :
    os.mkdir(OUTPUT_DIR)
else: 
    raise KeyboardInterrupt("Un-annotated dataset already exists, skipping")

for i in trange(TOTAL_SAMPLES):
    n = np.random.randint(0, n_frames)
    frame = get_nth_frame(cap, n, n_frames=n_frames)
    cv2.imwrite(f"./frames/frame-{n}.jpg", frame)

KeyboardInterrupt: Un-annotated dataset already exists, skipping

## Annotation 

Please visit [roboflow](https://universe.roboflow.com/mice-maze/mice-maze) to view and download the annotated dataset. 

## Fine-tuning the model

In [5]:
# Select the best backend to train the model on 
# If no such backend is available, such as a CUDA-enabled 
# GPU, found in modern Nvidia cards, or MPS found in 
# Apple devices with an Apple Silicon CPU (M1, M2, etc), 
# make use of the CPU 
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_built():
    device = torch.device('mps')
else: 
    print("No backend availale, defaulting to CPU")
    print("This might take a while, please ensure you installed the correct version of pytorch for your hardware")
    print("Please see `https://pytorch.org/get-started/locally/`")
    device = torch.device('cpu')

results = yolo.train(data='mice-maze.yolov8/data.yaml', device=device)