# Train Player Detection Model

In this notebook, we'll begin training the YOLOv8 model for player detection. The model will be trained on a custom dataset of soccer players, which includes images with bounding boxes around the players.

## Import necessary libraries

In [2]:
# Check GPU availability
!nvidia-smi

# Install dependencies
!pip3 install -q ultralytics roboflow

import os
from roboflow import Roboflow
from IPython.display import Image
import torch

# Set up directories
HOME = os.getcwd()
!mkdir -p {HOME}/datasets
!mkdir -p {HOME}/models

zsh:1: command not found: nvidia-smi
You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.[0m


## Download pre-trained model

We're going to be downloading Roboflow's pre-trained YOLOv8 soccer player detection model, which is trained on a custom dataset of soccer players. This model will be used as a starting point for our training.

We will be finetuning this model on hockey players, so it will be useful to have a model that is already trained on a similar task. 

In [3]:
# Download the soccer player detection model
!gdown -O "{HOME}/models/football-player-detection.pt" "https://drive.google.com/uc?id=17PXFNlx-jI7VjVo_vQnB1sONjRyvoB-q"

# Verify the model downloaded correctly
!ls -la {HOME}/models/

Downloading...
From (original): https://drive.google.com/uc?id=17PXFNlx-jI7VjVo_vQnB1sONjRyvoB-q
From (redirected): https://drive.google.com/uc?id=17PXFNlx-jI7VjVo_vQnB1sONjRyvoB-q&confirm=t&uuid=926fc169-d56f-4ad3-8f88-b41aaf44e066
To: /Users/jetjadeja/Projects/work/sieve/hockey-vision-analytics/notebooks/models/football-player-detection.pt
100%|████████████████████████████████████████| 137M/137M [00:02<00:00, 51.0MB/s]
total 295936
drwxr-xr-x@ 3 jetjadeja  staff         96 Jun  3 14:49 [1m[36m.[m[m
drwxr-xr-x@ 5 jetjadeja  staff        160 Jun  3 14:48 [1m[36m..[m[m
-rw-r--r--@ 1 jetjadeja  staff  136802409 Jul 25  2024 football-player-detection.pt


## Download datasets. 

In [4]:
import os
from dotenv import load_dotenv
from roboflow import Roboflow

# Load environment variables from .env file
load_dotenv()

# Get API key from environment variables
ROBOFLOW_API_KEY = os.getenv('ROBOFLOW_API_KEY')

if not ROBOFLOW_API_KEY:
    raise ValueError("ROBOFLOW_API_KEY not found in .env file. Please add your API key to the .env file.")

print("✅ Successfully loaded ROBOFLOW_API_KEY from .env file")

# Initialize Roboflow with API key
rf = Roboflow(api_key=ROBOFLOW_API_KEY)

# Download the hockey players dataset to the desired location
dataset_location = f"{HOME}/datasets/players"
project = rf.workspace("sportcontract").project("hockey-fwm0b")
dataset = project.version(1).download("yolov8", location=dataset_location)

# Check the dataset structure
!cat {dataset_location}/data.yaml

✅ Successfully loaded ROBOFLOW_API_KEY from .env file
loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in /Users/jetjadeja/Projects/work/sieve/hockey-vision-analytics/notebooks/datasets/players to yolov8:: 100%|██████████| 83190/83190 [00:01<00:00, 51859.72it/s]





Extracting Dataset Version Zip to /Users/jetjadeja/Projects/work/sieve/hockey-vision-analytics/notebooks/datasets/players in yolov8:: 100%|██████████| 3178/3178 [00:00<00:00, 8325.63it/s]


names:
- Coach
- Players
- centerSpot
- faceOffSpot
- goal
- goalPost
- goalie
- human
- player
- player_cheering
- puck
- redSpot
- referee
- scoreboard
nc: 14
roboflow:
  license: CC BY 4.0
  project: hockey-fwm0b
  url: https://universe.roboflow.com/sportcontract/hockey-fwm0b/dataset/1
  version: 1
  workspace: sportcontract
test: ../test/images
train: ../train/images
val: ../valid/images


### Inspect dataset

Let's now quickly inspect the dataset to ensure it has the correct format and contains the necessary annotations. The dataset should be in YOLO format, which consists of images and corresponding text files with bounding box annotations.

In [5]:
# Check how many images we have
!echo "Training images:"
!ls {dataset.location}/train/images | wc -l
!echo "Validation images:"
!ls {dataset.location}/valid/images | wc -l

# Look at the classes
import yaml
with open(f'{dataset.location}/data.yaml', 'r') as file:
    data_config = yaml.safe_load(file)
print("Classes:", data_config['names'])

Training images:
    1391
Validation images:
     125
Classes: ['Coach', 'Players', 'centerSpot', 'faceOffSpot', 'goal', 'goalPost', 'goalie', 'human', 'player', 'player_cheering', 'puck', 'redSpot', 'referee', 'scoreboard']


## Data Setup

Now, let's modify our dataset for Player-Only Detection, since we only want to train on players, we need to filter the dataset:

In [7]:
# Read the original data.yaml
with open(f'{dataset.location}/data.yaml', 'r') as file:
    data_config = yaml.safe_load(file)

print("Original classes:", data_config['names'])

# Create a new data.yaml focusing only on player-related classes
player_classes = {
    0: 'player',      # General players
    1: 'goalie',      # Goalies are also players
    2: 'Players',     # Capital P players
    3: 'human'        # Humans on ice (likely players)
}

# Map old class indices to new ones
class_mapping = {
    'player': 0,
    'goalie': 1,
    'Players': 0,    # Map to general player
    'human': 0,      # Map to general player
    'Coach': -1,     # Ignore coaches
    'player_cheering': -1,  # Ignore
}

# Create new data.yaml for player detection only
new_data_config = {
    'path': dataset.location,
    'train': 'train/images',
    'val': 'valid/images',
    'names': {
        0: 'player',
        1: 'goalie'
    },
    'nc': 2  # Only 2 classes now
}

# Save the new configuration
with open(f'{dataset.location}/data_players_only.yaml', 'w') as file:
    yaml.dump(new_data_config, file)

print("New player-only classes:", new_data_config['names'])

Original classes: ['Coach', 'Players', 'centerSpot', 'faceOffSpot', 'goal', 'goalPost', 'goalie', 'human', 'player', 'player_cheering', 'puck', 'redSpot', 'referee', 'scoreboard']
New player-only classes: {0: 'player', 1: 'goalie'}


## Remap Annotations 

Now, we need to remap our annotations to consolidate our player classes.

In [8]:
import glob

def remap_labels(label_path, class_mapping):
    """Remap multiple classes to consolidated player classes"""
    with open(label_path, 'r') as f:
        lines = f.readlines()
    
    new_lines = []
    for line in lines:
        parts = line.strip().split()
        if len(parts) > 0:
            old_class = int(parts[0])
            # Map to new classes
            if old_class in class_mapping:
                new_class = class_mapping[old_class]
                parts[0] = str(new_class)
                new_lines.append(' '.join(parts) + '\n')
    
    return new_lines

# Define mapping from old classes to new classes
# Based on your data.yaml:
class_mapping = {
    0: 3,   # Coach -> coach (3)
    1: 0,   # Players -> player (0)
    6: 1,   # goalie -> goalie (1)
    7: 0,   # human -> player (0)
    8: 0,   # player -> player (0)
    9: 0,   # player_cheering -> player (0)
    12: 2   # referee -> referee (2)
    # Ignoring non-player classes: centerSpot, faceOffSpot, goal, goalPost, puck, redSpot
}

# Process all label files
for split in ['train', 'valid', 'test']:
    label_dir = f'{dataset.location}/{split}/labels'
    if os.path.exists(label_dir):
        label_files = glob.glob(f'{label_dir}/*.txt')
        print(f"Processing {len(label_files)} files in {split}...")
        
        for label_file in label_files:
            new_lines = remap_labels(label_file, class_mapping)
            with open(label_file, 'w') as f:
                f.writelines(new_lines)

print("Label remapping complete!")

Processing 1391 files in train...
Processing 125 files in valid...
Processing 67 files in test...
Label remapping complete!


## Verify our dataset

In [9]:
# Check class distribution after remapping
import numpy as np

def count_classes(split):
    label_dir = f'{dataset.location}/{split}/labels'
    class_counts = {0: 0, 1: 0, 2: 0, 3: 0}
    
    if os.path.exists(label_dir):
        for label_file in glob.glob(f'{label_dir}/*.txt'):
            with open(label_file, 'r') as f:
                for line in f:
                    if line.strip():
                        class_id = int(line.split()[0])
                        if class_id in class_counts:
                            class_counts[class_id] += 1
    return class_counts

# Count for each split
for split in ['train', 'valid']:
    counts = count_classes(split)
    print(f"\n{split.upper()} set class distribution:")
    print(f"  Players: {counts[0]}")
    print(f"  Goalies: {counts[1]}")
    print(f"  Referees: {counts[2]}")
    print(f"  Coaches: {counts[3]}")
    print(f"  Total: {sum(counts.values())}")


TRAIN set class distribution:
  Players: 7656
  Goalies: 712
  Referees: 896
  Coaches: 28
  Total: 9292

VALID set class distribution:
  Players: 1014
  Goalies: 99
  Referees: 101
  Coaches: 7
  Total: 1221


## Finetune the model

Now, let's finetune the pre-trained YOLOv8 model on our custom dataset of hockey players. We'll use the Roboflow library to handle the training process.

In [None]:
# Start training with transfer learning from soccer model
# Train with the remapped dataset
# AGGRESSIVE HOCKEY RETRAINING
!yolo task=detect \
      mode=train \
      model={HOME}/models/football-player-detection.pt \
      data={dataset.location}/data_players_only.yaml \
      epochs=300 \
      imgsz=640 \
      batch=32 \
      patience=50 \
      freeze=0 \
      project={HOME}/models/player_detection \
      name=hockey_aggressive_v1 \
      exist_ok=True \
      amp=True \
      cache=True \
      lr0=0.01 \
      lrf=0.001 \
      momentum=0.937 \
      weight_decay=0.0005 \
      warmup_epochs=5 \
      warmup_momentum=0.8 \
      workers=8 \
      close_mosaic=30 \
      hsv_h=0.015 \
      hsv_s=0.7 \
      hsv_v=0.4 \
      degrees=15 \
      translate=0.2 \
      scale=0.5 \
      flipud=0.0 \
      fliplr=0.5 \
      mosaic=1.0 \
      mixup=0.15 \
      copy_paste=0.1 \
      cos_lr=True

Ultralytics 8.3.148 🚀 Python-3.9.6 torch-2.7.0 CPU (Apple M4 Pro)
[34m[1mengine/trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=32, bgr=0.0, box=7.5, cache=True, cfg=None, classes=None, close_mosaic=30, cls=0.5, conf=None, copy_paste=0.1, copy_paste_mode=flip, cos_lr=True, cutmix=0.0, data=/Users/jetjadeja/Projects/work/sieve/hockey-vision-analytics/notebooks/datasets/players/data_players_only.yaml, degrees=15, deterministic=True, device=cpu, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=300, erasing=0.4, exist_ok=True, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=0, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=640, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.01, lrf=0.001, mask_ratio=4, max_det=300, mixup=0.15, mode=train, model=/Users/jetjadeja/Projects/work/sieve/hockey-vision-analytics/notebooks/models/football-player-detection.pt, momentum=0.937, mosaic=1.0, multi_sc

In [None]:
# # During training, monitor progress
# from IPython.display import Image, display
# import time

# # Check if training has started
# time.sleep(30)  # Wait for first epoch

# # Display training progress
# if os.path.exists(f'{HOME}/hockey_training/hockey_players_sportcontract/results.png'):
#     display(Image(f'{HOME}/hockey_training/hockey_players_sportcontract/results.png', width=800))

# # After training completes, validate
# !yolo task=detect \
#       mode=val \
#       model={HOME}/hockey_training/hockey_players_sportcontract/weights/best.pt \
#       data={dataset.location}/data_players_only.yaml \
#       imgsz=640 \
#       batch=32

KeyboardInterrupt: 

## Validate our model

In [None]:
# Run validation on the best model
!yolo task=detect \
      mode=val \
      model={HOME}/hockey_training/hockey_players_v1/weights/best.pt \
      data={dataset.location}/data.yaml \
      imgsz=640

# Test on a sample image
!yolo task=detect \
      mode=predict \
      model={HOME}/hockey_training/hockey_players_v1/weights/best.pt \
      source={dataset.location}/valid/images \
      save=True \
      conf=0.25