In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from torchvision.io import read_image
from typing import List, Tuple
import os
import glob
import json
from ultralytics import YOLO
from torchvision.io import read_image
from torchvision.transforms.functional import resize
import shutil
from tqdm.auto import trange, tqdm
from pathlib import Path



dataset_path = "SoccerNet/jersey-2023/"
train_path = os.path.join(dataset_path, 'train')
test_path = os.path.join(dataset_path, 'test')
train_images = glob.glob(os.path.join(train_path, 'images/**/*.jpg'), recursive=True)
test_images = glob.glob(os.path.join(test_path, 'images/**/*.jpg'), recursive=True)

model = YOLO('runs/detect/BOX2/weights/best.pt')
to_gray = v2.Grayscale(num_output_channels=3)
destination_folder = Path(os.path.join(train_path, 'images/-1/'))
destination_folder.mkdir(parents=True, exist_ok=True)

In [None]:
# CLEAN DATASET FROM FALSE IMAGES/LABELS ASSOCIATIONS 



train_gt = json.load(open(os.path.join(train_path, 'train_gt.json')))
train_targets = []

counter = 0
for id in tqdm(os.listdir(os.path.join(train_path, 'images/')), desc="Analyzing players"):
    if train_gt[id] != -1:
        player_folder = Path(os.path.join(destination_folder, f'{id}/'))
        player_folder.mkdir(parents=True, exist_ok=True)
        for image in tqdm(os.listdir(os.path.join(train_path, 'images/', id)), desc=f"Analyzing images for player id {id}"):
            image_path = os.path.join(train_path, 'images/', id, image)
            img= read_image(image_path)
            img = resize(img, (640, 640)) / 255.0
            img = to_gray(img)
            pred = model.predict(img.unsqueeze(0), verbose=False)
            if pred[0].boxes:
                try:
                    train_targets.append(train_gt[id])
                except:
                    print(f"Skipping {id}")        
            else:
                train_targets.append(-1)
                image_filename = os.path.basename(image_path)
                #destination_path = os.path.join(destination_folder, image_filename)
                shutil.move(image_path, player_folder)
                counter += 1
    else:
        print(f"Player id: {id} label is already -1 ...")
        try:
            int(id)
            train_targets.extend([train_gt[id]] * len(os.listdir(os.path.join(train_path, 'images/', id))))
        except Exception:
            print(f"Skipping {id}")

    print(f"Moved : {counter} images")


In [None]:
# REVERT DATASET



for id in tqdm(os.listdir(os.path.join(train_path, 'images/-1/')), desc="Reverting original dataset: "):
    player_folder = Path(os.path.join(destination_folder, f'{id}/'))
    for image in os.listdir(os.path.join(train_path, 'images/-1/', id)):
        image_filename = os.path.basename(image_path)
        revert_path = os.path.join(train_path, 'images/', id)
        shutil.move(image_path, revert_path)
        counter += 1

print(f"Restored: {counter} images")