# Crop a dataset with YOLOv5

After training a custom YOLOv5 model in MothDetectionYOLOv5.ipynb, it is time to put it into practice. In this notebook, a dataset is fed to the cropping model. The images of the crops around the objects (moths) are saved as its result.

## Setup

In [1]:
import os
import time
import torch
from PIL import Image, ImageOps


In [2]:
dataset_dir = '/data/mothRecognition/data/meetnet_230911_231019'  # Location of the to-be-cropped dataset
save_dir = '/data/croppedDatasetNew/'  # Location to save the crops to
yolo_weights_path = '/data/mothDetection/yolov5/runs/train/exp3/weights/best.pt'  # Location of the custom YOLOv5 model


In [3]:
# Load model
yolo = torch.hub.load('ultralytics/yolov5', 'custom', path=yolo_weights_path)


Using cache found in /home/farfalla/.cache/torch/hub/ultralytics_yolov5_master
YOLOv5 🚀 2023-10-12 Python-3.8.18 torch-2.0.1+cu117 CUDA:0 (NVIDIA GeForce RTX 4090, 24209MiB)

Fusing layers... 
Model summary: 157 layers, 7012822 parameters, 0 gradients, 15.8 GFLOPs
Adding AutoShape... 


## Functions

In [4]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]
        
def squareAndPadBox(box):
    # box = list(box)
    # Make the crop square by elongating the shortest side
    w = box[2] - box[0]
    h = box[3] - box[1]

    pad_for_square = (max(w, h) - min(w, h)) / 2
    shortest_side = 'w' if w < h else 'h'

    if shortest_side == 'w':
        box[0] = box[0] - pad_for_square
        box[2] = box[2] + pad_for_square
    if shortest_side == 'h':
        box[1] = box[1] - pad_for_square
        box[3] = box[3] + pad_for_square

    # Pad with 10% on all sides
    pad = 0.1 * (box[2] - box[0])
    box[0] = box[0] - pad
    box[1] = box[1] - pad
    box[2] = box[2] + pad
    box[3] = box[3] + pad
    return box

def squareYoloCrop(images,
                   predictions,
                   filenames,
                   save_dir,
                   size=240,
                   center_cropped=0,
                   multi_crop=False,
                   min_box_score=0.5):
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    for i, image in enumerate(images):
        prediction = predictions[i]
        boxes = prediction[:, :4].cpu().tolist() # x1, y1, x2, y2
        scores = prediction[:, 4]
        box_crops = []
        boxes_valid = len(boxes) >= 1 and max(scores) >= min_box_score
        crop_scores = []
        
        if not boxes_valid:  # Do a centre crop instead
            center_cropped += 1
            w = image.width
            h = image.height
            
            if w > h:
                left = int(w/2-h/2)
                upper = 0
                right = left + h
                lower = h
            if h > w:
                left = 0
                upper = int(h/2-w/2)
                right = w
                lower = upper + w
            if w == h:
                left = 0
                upper = 0
                right = w
                lower = h
            box = (left, upper, right, lower)
            box_crops.append(box)
            crop_scores.append('centre')

        if boxes_valid:
            if multi_crop:
                for score_i, score in enumerate(scores):
                    if score >= min_box_score:
                        box = boxes[score_i]
                        box_crops.append(box)
                        crop_scores.append(score)
            else:
                best_crop_index = torch.argmax(scores).tolist()  # Index with highest score
                box = boxes[best_crop_index]
                box_crops.append(box)
                crop_scores.append(scores[best_crop_index])
            box_crops = list(map(squareAndPadBox, box_crops))
        
        for box_i, box in enumerate(box_crops):
            crop = image.crop(box)
            crop = crop.resize([size, size])

            filename_and_extension = str.split(str.split(filenames[i], "/")[-1], ".")
            if len(filename_and_extension) == 2:
                if boxes_valid:
                    filename = f"{crop_scores[box_i]:.2f}_{filename_and_extension[0]}_{box_i}.{filename_and_extension[1]}"
                else:
                    filename = f"{crop_scores[box_i]}_{filename_and_extension[0]}_{box_i}.{filename_and_extension[1]}"
            else:
                print(f"Failed saving crop: Filename and extension include 0 or more than 1 period ('.'): {filename_and_extension}")
                continue

            try:
                crop.save(save_dir + filename)
            except ValueError:
                print(f"ValueError for {filename}, not saved.")
            except:
                print("Something undefined went wrong.")
            
    return center_cropped


## Crop

In [6]:
batch_size = 512
faulty_batches = []
nr_of_processed_images = 0
center_cropped = 0

start = time.time()

for i, (sub_dir, dirs, files) in enumerate(os.walk(dataset_dir)):
    if i == 0 and len(files) == 0:
        nr_of_dirs = len(dirs)
    if len(files) > 0:
        sub_dir_lowest = str.split(sub_dir, "/")[-1]
        paths = [sub_dir + "/" + file for file in files]
        batches = chunks(paths, batch_size)
        for j, batch in enumerate(batches):
            nr_of_processed_images += len(batch)
            print(f'Dir {i} out of {nr_of_dirs}/batch {j} (batch size = {batch_size})/image {nr_of_processed_images}/center cropped = {center_cropped}', end='\r')
            image_list = []
            for filename in batch:
                im = Image.open(filename)
                try:
                    ImageOps.exif_transpose(im, in_place=True)
                except:
                    print("Problem with reading EXIF data.")
                image_list.append(im)
            try:
                results = yolo(image_list).pred
            except:
                faulty_batches.append({'batch_size': batch_size,
                                       'nr_of_processed_images': nr_of_processed_images,
                                       'dir (i)': i,
                                       'batch (j)': j})
                print(f"\nError in this batch, try later (batch {j} out of dir {i} with batch_size={batch_size}) (OOM?)")
                continue
            center_cropped = squareYoloCrop(images=image_list,
                                            predictions=results,
                                            filenames=batch,
                                            save_dir=save_dir + sub_dir_lowest + "/",
                                            center_cropped=center_cropped,
                                            multi_crop=True,
                                            min_box_score=25.0000)
end = time.time()

print("\n")
print('Time taken:', end - start, 's')
print('faulty batches:', faulty_batches)


Dir 254 out of 254/batch 0 (batch size = 512)/image 4338 out of 4338/center cropped = 112

214.71734714508057
faulty batches: []


If something has gone wrong in one of the batches (that is, when `faulty batches` is not empty), you can process those batches one by one by copying the loop above and only processing the faulty batch. For this, set `batch_size = 1`. 