<a href="https://colab.research.google.com/github/YoonSa8/ML_Notebooks/blob/master/Computer_Vision_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
#-- this notebook performes the following
# 1- Player, referee and ball detection using Yolov8
# 2- team classification using cloth colors HSV , kmeans
# 3- jersey number recognition using transformers
# 4- player tracking using ByTrack
# 5- annotated video output with all metadata shown

In [16]:
#installing dependancies
!pip install ultralytics mmocr opencv-python-headless numpy matplotlib scikit-learn
!pip install -U onnxruntime



In [17]:
#try another model since mmcv builder takes too much time
!pip install transformers



In [18]:
from ultralytics import YOLO
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from collections import defaultdict
import os

In [19]:
#configration
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
detector = YOLO("yolov8x.pt")
TEAM_COLORS= {
    "referee": "mint",
    "teamA": "red",
    "teamB": "white"
}
input_path = "/content/ball_tiled_output.mp4"
output_path = "Output.mp4"
FRAME_SKIP = 2

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [55]:
def extract_frames(video_path, skip=1):
  cap = cv2.VideoCapture(video_path)
  frames = []
  idx =0
  while True:
    ret, frame = cap.read()
    if not ret: break
    if idx%skip ==0:
      frames.append(frame)
    idx+=1
  cap.release()
  return frames

In [56]:
def detection_yolo(frames):
  all_detections = []
  for i, frame in enumerate(frames):
    results = detector(frame, verbose=False)[0]
    boxes= []
    for box in results.boxes.data.cpu().numpy():
      x1,y1,x2,y2, conf, cls_id = box
      label = detector.names[int(cls_id)]
      if label in ["person", "sports ball"]:
        boxes.append((int(x1), int(y1),int(x2),int(y2), label, conf))
    all_detections.append(boxes)
  return all_detections

In [57]:
def extract_team_color(image, k=3):
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  image = cv2.resize(image, (30,30))
  pixels = image.reshape(-1,3)
  kmeans= KMeans(n_clusters=k).fit(pixels)
  counts = np.bincount(kmeans.labels_)
  color = kmeans.cluster_centers_[np.argmax(counts)]
  return color

In [58]:
def classify_team(color):
  r,g,b = color
  if g>180 and b>180 : return "referee"
  elif r>150 and g<100: return "teamA"
  elif r>180 and g>180: return "teamB"
  return "unclear"

In [59]:
def preprocess_befor_ocr(img):
  if not isinstance(img, np.ndarray):
        img =np.array(img)
  img = cv2.resize(img,(img.shape[1]*3,img.shape[0]*3))
  img = cv2.GaussianBlur(img,(3,3),0)
  img = cv2.convertScaleAbs(img, alpha=1.8, beta=0)
  return img

In [60]:
def read_jersey_num(img):
    if isinstance(img, np.ndarray):
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        img = preprocess_befor_ocr(img)
    pixel_values = processor(images=img, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return generated_text.strip()

In [61]:
def annotate_output(frames, detections):
  h,w, _ = frames[0].shape
  fourcc= cv2.VideoWriter_fourcc(*'mp4v')
  out = cv2.VideoWriter(output_path,fourcc, 15,(w,h))
  for frame, boxes in zip(frames, detections):
    count =1
    for box in boxes:
      x1,y1,x2,y2,label,conf= box
      crop= frame[y1:y2, x1:x2]
      if crop.size == 0:
                continue
      dom_color = extract_team_color(crop)
      team = classify_team(dom_color)
      jersey = read_jersey_num(crop)
      color = (0,0,255) if team == 'teamA' else (255,255,255) if team =="teamB" else (255,192,203)
      cv2.rectangle(frame, (x1,y1), (x2,y2), color, 2)
      label_txt = f'{team} #{jersey}' if label == 'person' else label
      cv2.putText(frame, label_txt, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    count+=1
    print(f'frame number: {count}')
    out.write(frame)
  out.release()


In [None]:
frames = extract_frames(input_path, skip=FRAME_SKIP)
detections = detection_yolo(frames)
print(f"Total frames: {len(frames)}, total detections: {len(detections)}")
annotate_output(frames, detections)

Total frames: 372, total detections: 372
