In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import cv2
import numpy as np
import tensorflow as tf
import time
import os

# Define class labels
CLASSES_BASKET = ['ball', 'net']
CLASSES_GOAL = ['goal']

def initialize_interpreter(model_path):
    '''Initialize TensorFlow Lite interpreter
    '''
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    return interpreter

def preprocess_frame(frame, input_size, is_grayscale):
    ''' Preprocess the input frame to feed to the TFLite model.
        Handles both RGB and grayscale inputs.
    '''
    img = tf.convert_to_tensor(frame)

    # if is_grayscale:
    #     # Convert to grayscale using TensorFlow
    #     img = tf.image.rgb_to_grayscale(img)

    resized_img = tf.image.resize(img, input_size)
    resized_img = resized_img[tf.newaxis, :]  # Add batch dimension
    resized_img = tf.cast(resized_img, dtype=tf.uint8)
    return resized_img

def detect_objects(interpreter, frame, threshold):
    ''' Returns a list of detection results for the frame
    '''
    signature_fn = interpreter.get_signature_runner()
    output = signature_fn(images=frame)

    # Extract outputs as arrays
    count = int(output['output_0'][0])
    scores = output['output_1'][0]
    classes = output['output_2'][0]
    boxes = output['output_3'][0]

    # Filter valid results using NumPy
    valid_indices = scores >= threshold
    valid_scores = scores[valid_indices]
    valid_classes = classes[valid_indices]
    valid_boxes = boxes[valid_indices]

    # Construct results
    results = [
        {
            'bounding_box': valid_boxes[i],
            'class_id': int(valid_classes[i]),
            'score': valid_scores[i]
        }
        for i in range(len(valid_scores))
    ]
    return results

def get_basket_roi(input_video_path, interpreter, threshold, num_frames):
    ''' Detect the basket's ROI from the initial frames of the video
    '''
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        raise ValueError(f"Error opening video file: {input_video_path}")

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    _, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']

    roi = None
    for fr in range(num_frames):
        ret, frame = cap.read()
        if not ret:
            break
        preprocessed_frame = preprocess_frame(frame, (input_height, input_width))
        detections = detect_objects(interpreter, preprocessed_frame, threshold=threshold)

        for obj in detections:
            if CLASSES_BASKET[int(obj['class_id'])] == 'net':
                ymin, xmin, ymax, xmax = obj['bounding_box']
                xmin = int(xmin * frame_width)
                xmax = int(xmax * frame_width)
                ymin = int(ymin * frame_height)
                ymax = int(ymax * frame_height)

                width = xmax - xmin
                height = ymax - ymin

                # Expand the bounding box
                xmin = max(0, int(xmin - 0.8 * width))  # Increase leftward expansion by 0.8 times
                xmax = min(frame_width, int(xmax + 0.8 * width))  # Increase rightward expansion by 0.8 times
                ymin = max(0, int(ymin - 1.8 * height))  # Increase upward expansion 1.8 times
                ymax = min(frame_height, ymax)  # Keep the bottom fixed

                roi = (xmin, ymin, xmax, ymax)
                break
        if roi:
            print(f"Basket (net) detected in the frame {fr+1}")
            break

    cap.release()
    if not roi:
        raise ValueError("Basket (net) not detected in the initial frames.")
    return roi


def crop_and_save_video(input_video_path, output_video_path, roi):
    ''' Crop frames from the video based on the ROI and save as a new video
    '''
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        raise ValueError(f"Error opening video file: {input_video_path}")

    frame_rate = cap.get(cv2.CAP_PROP_FPS)
    xmin, ymin, xmax, ymax = roi
    cropped_width = xmax - xmin
    cropped_height = ymax - ymin

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for saving video
    out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (cropped_width, cropped_height))

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        cropped_frame = frame[ymin:ymax, xmin:xmax]
        out.write(cropped_frame)

    cap.release()
    out.release()
    print(f"Cropped video saved to {output_video_path}")

def process_goal_detection(input_video_path, output_video_path, model_path, threshold):
    ''' Process the cropped video to detect goals
    '''
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        raise ValueError(f"Error opening video file: {input_video_path}")

    print("Processing goal detection...")

    frame_rate = cap.get(cv2.CAP_PROP_FPS)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    interpreter = initialize_interpreter(model_path)
    input_details = interpreter.get_input_details()
    input_shape = input_details[0]['shape'][1:3]

    # Create VideoWriter object to save the output video with goal detection
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for saving video
    out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height))

    frame_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Preprocess frame
        preprocessed_frame = preprocess_frame(frame, input_shape)

        # Inference on the preprocessed frame
        detections = detect_objects(interpreter, preprocessed_frame, threshold)

        # Annotate the frame
        annotated_frame = frame.copy()  # Create a copy of the frame for annotation
        for obj in detections:
            box = obj['bounding_box']
            x_min = int(box[1] * frame_width)
            y_min = int(box[0] * frame_height)
            x_max = int(box[3] * frame_width)
            y_max = int(box[2] * frame_height)

            # Draw bounding box and label
            cv2.rectangle(annotated_frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
            label = f"Score: {obj['score']:.2f}"
            cv2.putText(annotated_frame, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        out.write(annotated_frame)  # Write annotated frame to output video
        frame_count += 1
        if frame_count % 500 == 0:
          print(f"Processed frame {frame_count}")

    cap.release()
    out.release()
    print(f"Output video with frames {frame_count} saved to {output_video_path}")


In [None]:
# Main execution
cwd = os.getcwd()
MODEL_PATH = '/content/drive/MyDrive/Ballogy/Models'
BASKET_MODEL_NAME = 'ballogy_1_eff0.tflite'
GOAL_MODEL_NAME = 'mixed_model_3.tflite'
DETECTION_THRESHOLD_BASKET = 0.2   # Basket detection threshold
DETECTION_THRESHOLD_GOAL = 0.70    # Goal detection threshold

# Input and output video paths
INPUT_VIDEO_PATH = "/content/drive/MyDrive/Ballogy_Videos/right/right_v1.mp4"
CROPPED_VIDEO_PATH = "/content/drive/MyDrive/Ballogy/output/mm3_0.70/right_v1/cropped.mp4"
OUTPUT_VIDEO_PATH = "/content/drive/MyDrive/Ballogy/output/mm3_0.70/right_v1/goal_detection.mp4"

# Basket detection model path
basket_model_path = f'{MODEL_PATH}/{BASKET_MODEL_NAME}'
basket_interpreter = initialize_interpreter(basket_model_path)

# Step 1: Detect basket and crop video
roi = get_basket_roi(INPUT_VIDEO_PATH, basket_interpreter, threshold=DETECTION_THRESHOLD_BASKET, num_frames=300)
crop_and_save_video(INPUT_VIDEO_PATH, CROPPED_VIDEO_PATH, roi)

# Goal detection model path
goal_model_path = f'{MODEL_PATH}/{GOAL_MODEL_NAME}'

# Step 2: Detect goals in cropped video
process_goal_detection(CROPPED_VIDEO_PATH, OUTPUT_VIDEO_PATH, goal_model_path, threshold=DETECTION_THRESHOLD_GOAL)

Basket (net) detected in the frame 142
Cropped video saved to /content/drive/MyDrive/Ballogy/output/mm3_0.70/right_v1/cropped.mp4
Processing goal detection...
Processed frame 500
Processed frame 1000
Processed frame 1500
Processed frame 2000
Processed frame 2500
Processed frame 3000
Processed frame 3500
Processed frame 4000
Processed frame 4500
Processed frame 5000
Processed frame 5500
Processed frame 6000
Processed frame 6500
Output video with frames 6784 saved to /content/drive/MyDrive/Ballogy/output/mm3_0.70/right_v1/goal_detection.mp4
