Notebook servant à faire l'augmentation de données du jeu de données de détection et à créer le jeu de données de classification si souhaité.

# Import

In [4]:
import cv2
from matplotlib import pyplot as plt
%matplotlib inline
import albumentations as A

import json
import yaml

from glob import glob
from re import sub
from random import randrange
import os
from PIL import Image
import math
from tqdm import tqdm
import shutil

INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.0 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations


# Augmentation des données

## Initialisation des chemins utilisés

In [1]:
detection_only = False # Si True créer un jeu de données "detections only" avec pour seul classe "roll"

detection_class = True # Si True créer un jeu de données "detection and classification"

classification = False # Si True créer un jeu de données de classification (change detection_class = True)

name = 'front1' # Les dossiers finaux seront detection_dataset_<name>, detection_only_dataset_<name> et classification_dataset_<name>

'''
num_aug: int
Le nombre d'images crée par la première transformation pour chaque image source
Attention, il y a une étape qui fait l'image miroir de toutes les images
Au final on a (nb_img_src * (num_aug + 1)) * 2 images dans le dossier de sortie
'''
num_aug = 6

'''
src_dir: str
Dossier source avec sous dossiers "images" et "labels" et fichier notes.json
'''
src_dir = r".\dataset\frontview1"

'''
det_cls_res_path: str
Dossier de sortie

det_cls_res_path
|- train
|   |- images
|   |   |- img1.jpg
|   |   |- img2.jpg
|   |   |- ...
|   |- labels
|       |- img1.txt
|       |- img2.txt
|       |- ...
|- val
|   |- images
|   |   |- img1.jpg
|   |   |- img2.jpg
|   |   |- ...
|   |- labels
|       |- img1.txt
|       |- img2.txt
|       |- ...
|- test
|   |- ...
|- data.yaml
'''
det_cls_res_path = "detection_dataset_" + name

'''
det_only_res_path: str
Dossier de sortie

det_cls_res_path
|- train
|   |- images
|   |   |- img1.jpg
|   |   |- img2.jpg
|   |   |- ...
|   |- labels
|       |- img1.txt
|       |- img2.txt
|       |- ...
|- val
|   |- images
|   |   |- img1.jpg
|   |   |- img2.jpg
|   |   |- ...
|   |- labels
|       |- img1.txt
|       |- img2.txt
|       |- ...
|- test
|   |- ...
|- data.yaml
'''
det_only_res_path = "detection_only_dataset_" + name

'''
cls_res_path: str
Dossier de sortie avec l'oganisation suivante:

cls_res_path
|- data
    |- train
    |   |- class 1
    |   |   |- img 1
    |   |   |- img 2
    |   |   |- ...
    |   |- class 2
    |   |   |- ...
    |   |...
    |- val
    |   |- class 1
    |   |   |-...
    |   |- ...
    |- test
        |- ...
'''
cls_res_path = "classification_dataset_" + name

if classification:
    detection_class = True

## Fonctions de visualisation

In [2]:
BOX_COLOR = (255, 0, 0) # Red
TEXT_COLOR = (255, 255, 255) # White


def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=10):
    """Ajoute le boite encadrante à l'image

    Args:
        img (np.ndarray): Image source
        bbox (list: int(x,y,w,h)): Coordonnées de la boite encadrante
        class_name (str): Nom de la classe
        color (tuple: int, optional): Couleur de la boite. Defaults to BOX_COLOR.
        thickness (int, optional): Epaisseur de la boite. Defaults to 10.

    Returns:
        np.ndarray: Image avec la boite encadrante
    """
    x_c, y_c, w, h = bbox
    height, width, _ = img.shape
    x_min, x_max, y_min, y_max = int((x_c - w/2)*width), int((x_c + w/2)*width), int((y_c - h/2)*height), int((y_c + h/2)*height)

    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)

    ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
    cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1)
    cv2.putText(
        img,
        text=class_name,
        org=(x_min, y_min - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35,
        color=TEXT_COLOR,
        lineType=cv2.LINE_AA,
    )
    return img


def visualize(image, bboxes, category_ids, category_id_to_name, ax = None):
    """Affiche l'image avec ses boites encadrantes

    Args:
        image (np.ndarray): Image source
        bboxes (list: list: int): Liste des boites encadrantes
        category_ids (list: int): Liste des indice de classe des boites
        category_id_to_name (list: str): Liste de conversion indice/nom de classe
        ax (plt.axes, optional): Position dans la figure matplotlib. Defaults to None.
    """
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
        class_name = category_id_to_name[category_id]
        img = visualize_bbox(img, bbox, class_name)
    if ax != None:
        ax.imshow(img)
    else:
        plt.figure(figsize=(12, 12))
        #plt.axis('off')
        plt.imshow(img)

## Augmentation

Penser à modifier le chemin d'accès au données source si changements

L'augmentation est faite en 2 étapes, une première avec des augmentation non systématiques (c'est à dire qu'elles ont des probabilité de ne pas ce produire) et une deuxième étape avec des augmentation systématiques (miroir entre autre)

### Mise en place du dossier

In [3]:
# création des dossier et sous dossiers
splits = []
for split in ['test', 'val', 'train']:
    if split in os.listdir(src_dir):
        splits .append(split)

if detection_class:
    try:
        os.mkdir(det_cls_res_path)
    except OSError as error:  
        print(error)
    for split in splits:
        try:
            os.mkdir(os.path.join(det_cls_res_path, split))
        except OSError as error:  
            print(error)
        try:
            os.mkdir(os.path.join(det_cls_res_path, split, 'images'))
        except OSError as error:  
            print(error)
        try:
            os.mkdir(os.path.join(det_cls_res_path, split, 'labels'))
        except OSError as error:  
            print(error)

if detection_only:
    try:
        os.mkdir(det_only_res_path)
    except OSError as error:  
        print(error)
    for split in splits:
        try:
            os.mkdir(os.path.join(det_only_res_path, split))
        except OSError as error:  
            print(error)
        try:
            os.mkdir(os.path.join(det_only_res_path, split, 'images'))
        except OSError as error:  
            print(error)
        try:
            os.mkdir(os.path.join(det_only_res_path, split, 'labels'))
        except OSError as error:  
            print(error)


NameError: name 'os' is not defined

### Copie des données val et test

In [42]:
splits = []
for split in ['test', 'val']:
    if split in os.listdir(src_dir):
        splits .append(split)
        
if detection_class:
    for split in splits:
        for kind in ['images', 'labels']:
            src_path = os.path.join(src_dir, split, kind)
            targ_path = os.path.join(det_cls_res_path, split, kind)
            for file in tqdm(os.listdir(src_path), desc = 'file copied'):
                shutil.copy2(os.path.join(src_path, file), targ_path)

if detection_only:
    for split in splits:
        img_src_path = os.path.join(src_dir, split, "images")
        lab_src_path = os.path.join(src_dir, split, "labels")
        images = glob('*.jpg', dir_fd=img_src_path)
        img_only_res_path = os.path.join(det_only_res_path, split, "images")
        lab_only_res_path = os.path.join(det_only_res_path, split, "labels")
        for img_name in tqdm(images, desc = 'images copied'):
            '''
            Lecture des fichiers sources avec récupération des boîtes englobantes
            '''
            shutil.copy2(os.path.join(img_src_path, img_name), img_only_res_path)
            label_name = sub("jpg$", "txt", img_name)
            with open(os.path.join(lab_src_path, label_name), "r") as label_file:
                new_label_file = open(os.path.join(lab_only_res_path, label_name), 'w')
                for line in label_file:
                    split_line = line.split(' ')
                    new_label_file.write('0')
                    for i in range(1,5):
                        new_label_file.write(' ' + split_line[i])
                new_label_file.close()

file copied: 100%|██████████| 4/4 [00:00<00:00, 66.52it/s]
file copied: 100%|██████████| 4/4 [00:00<00:00, 212.16it/s]


### Création du fichier .yaml pour YOLO

In [8]:
if detection_class:
    data = {}
    try:
        json_f = open(os.path.join(src_dir,'train', "notes.json"))
        json_data = json.load(json_f)

        yaml_f = open(os.path.join(det_cls_res_path, "data.yaml"), 'w')
        names = {}

        # cat2name sert à lier l'indice de la classe à son nom
        k = 0
        cat2name = [0] * len(json_data['categories'])
        for category in json_data['categories']:
            names[category['id']] = category['name']
            cat2name[k] = category['name']
            k += 1
        
        # ajout des champs de données
        data['names'] = names
        data['nc'] = len(json_data['categories'])
        data['train'] ="./train/images"
        data['val'] = "./val/images"
        data['test'] = "./test/images"

        # écriture dans le fichier yaml
        yaml.dump(data, yaml_f, default_flow_style=False, allow_unicode=True)

        # fermeture des fichiers
        json_f.close()
        yaml_f.close()

    except IOError as error:
        print(error)

if detection_only:
    data = {}
    try:
        json_f = open(os.path.join(src_dir,'train', "notes.json"))
        json_data = json.load(json_f)

        yaml_f = open(os.path.join(det_only_res_path, "data.yaml"), 'w')
        names = {}


        names['0'] = 'Roll'
        
        # ajout des champs de données
        data['names'] = names
        data['nc'] = 1
        data['train'] ="./train/images"
        data['val'] = "./val/images"
        data['test'] = "./test/images"

        # écriture dans le fichier yaml
        yaml.dump(data, yaml_f, default_flow_style=False, allow_unicode=True)

        # fermeture des fichiers
        json_f.close()
        yaml_f.close()

    except IOError as error:
        print(error)

[Errno 2] No such file or directory: 'detection_dataset_front1\\data.yaml'


### Création des données augmentées

In [44]:
# dossier cibles
img_src_path = os.path.join(src_dir, 'train', 'images')
lab_src_path = os.path.join(src_dir, 'train', 'labels')
img_cls_res_path = os.path.join(det_cls_res_path, 'train', 'images')
lab_cls_res_path = os.path.join(det_cls_res_path, 'train', 'labels')
images = glob('*.jpg', dir_fd=img_src_path)

img_only_res_path = os.path.join(det_only_res_path, 'train', 'images')
lab_only_res_path = os.path.join(det_only_res_path, 'train', 'labels')

In [45]:
# transformations
transform1 = A.Compose(
    [
        # Pixels
        A.RandomBrightnessContrast(p=0.2),
        A.RandomGamma(p=0.2),
        A.ISONoise(p=0.2),
        A.GaussNoise(p=0.2),
        A.CLAHE(p=0.2), # add contrast
        A.RandomSunFlare(src_radius = 100, num_flare_circles_upper= 10, p=0.2), # attention au rayon, si trop grand peut entièrement caché une boite
        A.RandomSunFlare(src_radius = 100, num_flare_circles_upper= 10, p=0.2),
        A.RandomSunFlare(src_radius = 100, num_flare_circles_upper= 10, p=0.2), # plusieur pour avoir different angles (ils se forment en ligne)
        
        # Spatial
        A.BBoxSafeRandomCrop(p=0.2),
        A.Rotate(limit=(-10, 10), p=0.3), # voir quel angles sont raisonnables
        # A.PixelDropout(dropout_prob=0.01 ,p=0.5),
    ], bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids']), #dans BbowParams on peut ajouter des paramètres de tailles... utile pour les boites qui deviendraient trop petites
)

transform2 = A.Compose(
    [
        A.HorizontalFlip(p=1),
    ], bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids']), #dans BbowParams on peut ajouter des paramètres de tailles... utile pour les boites qui deviendraient trop petites
)

assert(detection_class | detection_only)

# augmentation des données
for img_name in tqdm(images, desc = 'images processed'):
    '''
    Lecture des fichiers sources avec récupération des boîtes englobantes
    '''
    image = cv2.imread(os.path.join(img_src_path, img_name)) # lire l'image
    bboxes =[]
    category_ids = []
    label_name = sub("jpg$", "txt", img_name)
    with open(os.path.join(lab_src_path, label_name), "r") as label_file:
        for line in label_file:
            split_line = line.split(' ')
            bboxes.append([float(split_line[1]), float(split_line[2]), float(split_line[3]), float(split_line[4])])
            category_ids.append(int(split_line[0]))
    '''
    Copier l'image et les labels sources dans le dossier final
    '''
    if detection_class:
        shutil.copy2(os.path.join(img_src_path, img_name), img_cls_res_path)
        shutil.copy2(os.path.join(lab_src_path, label_name), lab_cls_res_path)
    if detection_only:
        shutil.copy2(os.path.join(img_src_path, img_name), img_only_res_path)
        new_label_file = open(os.path.join(lab_only_res_path, label_name), 'w')
        for bbox in bboxes:
            new_label_file.write('0')
            for i in range(4):
                new_label_file.write(' ' + str(bbox[i]))
        new_label_file.close()
    '''
    Création des images augmentées:
        - transformation
        - miroir de l'originale
        - miroir de la transformée
    '''
    for ind in range(num_aug):
        transformed = transform1(image=image, bboxes=bboxes, category_ids=category_ids)
        transformed_mirror = transform2(image=transformed['image'], bboxes=transformed['bboxes'], category_ids=transformed['category_ids'])

        # enregistrer nouvelles images et labels
        # image transformée
        if detection_class:
            cv2.imwrite(os.path.join(img_cls_res_path, "transformed_" + str(ind) + "_" + img_name), transformed['image'])
            new_label_file = open(os.path.join(lab_cls_res_path, "transformed_" + str(ind) + "_" + label_name), 'w')
            for line in range(len(transformed['bboxes'])):
                new_label_file.write(str(transformed['category_ids'][line]))
                for i in range(4):
                    new_label_file.write(" " + str(transformed['bboxes'][line][i]))
                new_label_file.write('\n')
            new_label_file.close()
        if detection_only:
            cv2.imwrite(os.path.join(img_only_res_path, "transformed_" + str(ind) + "_" + img_name), transformed['image'])
            new_label_file = open(os.path.join(lab_only_res_path, "transformed_" + str(ind) + "_" + label_name), 'w')
            for line in range(len(transformed['bboxes'])):
                new_label_file.write('0')
                for i in range(4):
                    new_label_file.write(" " + str(transformed['bboxes'][line][i]))
                new_label_file.write('\n')
            new_label_file.close()

        # image transformée miroir
        if detection_class:
            cv2.imwrite(os.path.join(img_cls_res_path, "transformed_miroir_" + str(ind) + "_" + img_name), transformed_mirror['image'])
            new_label_file = open(os.path.join(lab_cls_res_path, "transformed_miroir_" + str(ind) + "_" + label_name), 'w')
            for line in range(len(transformed_mirror['bboxes'])):
                new_label_file.write(str(transformed_mirror['category_ids'][line]))
                for i in range(4):
                    new_label_file.write(" " + str(transformed_mirror['bboxes'][line][i]))
                new_label_file.write('\n')
            new_label_file.close()
        if detection_only:
            cv2.imwrite(os.path.join(img_only_res_path, "transformed_miroir_" + str(ind) + "_" + img_name), transformed_mirror['image'])
            new_label_file = open(os.path.join(lab_only_res_path, "transformed_miroir_" + str(ind) + "_" + label_name), 'w')
            for line in range(len(transformed_mirror['bboxes'])):
                new_label_file.write('0')
                for i in range(4):
                    new_label_file.write(" " + str(transformed_mirror['bboxes'][line][i]))
                new_label_file.write('\n')
            new_label_file.close()

    # enregistrer image miroir
    mirror = transform2(image=image, bboxes=bboxes, category_ids=category_ids)
    if detection_class:
        cv2.imwrite(os.path.join(img_cls_res_path, "miroir_" + img_name), mirror['image'])
        new_label_file = open(os.path.join(lab_cls_res_path,"miroir_" + label_name), 'w')
        for line in range(len(mirror['bboxes'])):
            new_label_file.write(str(mirror['category_ids'][line]))
            for i in range(4):
                new_label_file.write(" " + str(mirror['bboxes'][line][i]))
            new_label_file.write('\n')
        new_label_file.close()
    if detection_only:
        cv2.imwrite(os.path.join(img_only_res_path, "miroir_" + img_name), mirror['image'])
        new_label_file = open(os.path.join(lab_only_res_path, "miroir_" + label_name), 'w')
        for line in range(len(mirror['bboxes'])):
            new_label_file.write('0')
            for i in range(4):
                new_label_file.write(" " + str(mirror['bboxes'][line][i]))
            new_label_file.write('\n')
        new_label_file.close()






images processed: 100%|██████████| 21/21 [00:10<00:00,  2.03it/s]


### Check data augmentation results

In [46]:
"""
images = glob('*.jpg', dir_fd=img_cls_res_path)
num_images = len(images)

num_grids = math.ceil(num_images / 9)
for grid in range(num_grids):
    fig, axs = plt.subplots(3, 3, figsize=(20, 20))  # Create a 3x3 grid of subplots
    fig, axs2 = plt.subplots(3, 3, figsize=(15, 15))  # Create a 3x3 grid of subplots
    grid_image_paths = images[grid * 9 : (grid + 1) * 9]

    for ax, img_name in zip(axs.flatten(), grid_image_paths):
        image = cv2.imread(img_cls_res_path + '/' + img_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")

    for ax, img_name in zip(axs2.flatten(), grid_image_paths):
        image = cv2.imread(img_cls_res_path + '/' + img_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        bboxes =[]
        category_ids =[]
        label_name = sub("jpg$", "txt", img_name)
        with open(lab_cls_res_path + '/' + label_name, "r") as label_file:
            for line in label_file:
                split_line = line.split(' ')
                bboxes.append([float(split_line[1]), float(split_line[2]), float(split_line[3]), float(split_line[4])])
                category_ids.append(int(split_line[0]))
        visualize(image, bboxes, category_ids, cat2name, ax)
        ax.axis("off")

    plt.tight_layout()
    plt.show()"""

'\nimages = glob(\'*.jpg\', dir_fd=img_cls_res_path)\nnum_images = len(images)\n\nnum_grids = math.ceil(num_images / 9)\nfor grid in range(num_grids):\n    fig, axs = plt.subplots(3, 3, figsize=(20, 20))  # Create a 3x3 grid of subplots\n    fig, axs2 = plt.subplots(3, 3, figsize=(15, 15))  # Create a 3x3 grid of subplots\n    grid_image_paths = images[grid * 9 : (grid + 1) * 9]\n\n    for ax, img_name in zip(axs.flatten(), grid_image_paths):\n        image = cv2.imread(img_cls_res_path + \'/\' + img_name)\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        ax.imshow(image)\n        ax.axis("off")\n\n    for ax, img_name in zip(axs2.flatten(), grid_image_paths):\n        image = cv2.imread(img_cls_res_path + \'/\' + img_name)\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        bboxes =[]\n        category_ids =[]\n        label_name = sub("jpg$", "txt", img_name)\n        with open(lab_cls_res_path + \'/\' + label_name, "r") as label_file:\n            for lin

### Training data health check

In [47]:

labels_src = glob('*.txt', dir_fd=lab_src_path)
labels_augment = glob('*.txt', dir_fd=lab_cls_res_path)
img_src = glob('*.jpg', dir_fd=img_src_path)
img_augment = glob('*.jpg', dir_fd=img_cls_res_path)

# info source
print("Il y a " + str(len(labels_src)) + " images dans le dossier source:")

label_count = [0] * len(cat2name)

for fname in labels_src:
    with open(lab_src_path + '/' + fname, "r") as label_file:
        for line in label_file:
            split_line = line.split(' ')
            label_count[int(split_line[0])] += 1

for i in range(len(cat2name)):
    print(cat2name[i], " a ", label_count[i], " instances")

# info augmenté
print("------------------------------------------\nIl y a " + str(len(labels_augment)) + " images dans le dossier augmenté:")

label_count = [0] * len(cat2name)

for fname in labels_augment:
    with open(lab_cls_res_path + '/' + fname, "r") as label_file:
        for line in label_file:
            split_line = line.split(' ')
            label_count[int(split_line[0])] += 1

for i in range(len(cat2name)):
    print(cat2name[i], " a ", label_count[i], " instances")

# info dimension moyenne images
print("------------------------------------------\nLes dimension moyennes du jeu source sont:")
mean_w = 0
mean_h = 0

for fname in img_src:
    with Image.open(img_src_path + '/' + fname) as img:
        w, h = img.size
        mean_w += w
        mean_h += h
mean_w /= len(img_src)
mean_h /= len(img_src)

print("mean width = ", mean_w)
print("mean heigh = ", mean_h)

print("------------------------------------------\nLes dimension moyennes du jeu augmenté sont:")
mean_w = 0
mean_h = 0

for fname in img_augment:
    with Image.open(img_cls_res_path + '/' + fname) as img:
        w, h = img.size
        mean_w += w
        mean_h += h
mean_w /= len(img_augment)
mean_h /= len(img_augment)

print("mean width = ", mean_w)
print("mean heigh = ", mean_h)

Il y a 21 images dans le dossier source:


NameError: name 'cat2name' is not defined

# Génération du Dataset "classification"

<span style="color: red"> Attention </span>: Si il y a un problème lors de l'entraînement, vérifier:
- si il y a au <span style="color: red"> minimum </span> 2 splits ("train" et "val" ou "train" et "test")
- si les split créés on au <span style="color: red"> minimum </span> une image par classe
- le champs data du model YOLO pointe vers le dossier contenant les splits ("train", ...)

Le fichier <span style="color: blue"> data.yaml </span> n'est ici qu'<span style="color: blue"> informatif </span> et n'est pas utilisé par YOLO!

In [23]:
# don't go beyond here with Run All
assert classification

AssertionError: 

## Initialisation des chemins utilisés

In [None]:
'''
src_dir: str
Dossier source avec sous dossiers "images" et "labels" et fichier notes.json
'''
src_dir = det_cls_res_path




## Cropping 

In [None]:
# dossier dataset
try:
    os.mkdir(cls_res_path)
except OSError as error:  
    print(error)

# cropping par split
for split in ["train", "val", "test"]:
    img_src_path = src_dir + "/" + split + "/images"
    lab_src_path = src_dir + "/" + split + "/labels"
    images = glob('*.jpg', dir_fd=img_src_path)
    if len(images)==0:
        continue
    # création des dossiers du splt
    try:
        os.mkdir(cls_res_path + "/" + split)
    except OSError as error:  
        print(error)
    for classname in cat2name:
        try:
            os.mkdir(cls_res_path + '/' + split + '/' + classname)
        except OSError as error:  
            print(error)

    img_res_path = cls_res_path + '/' + split

    for img_name in tqdm(images, desc = split):
        img_src = cv2.imread(img_src_path + '/' + img_name)
        img_h, img_w, nb_chan = img_src.shape
        bboxes =[]
        category_ids =[]
        label_name = sub("jpg$", "txt", img_name)
        # récupération des informations
        with open(lab_src_path + '/' + label_name, "r") as label_file:
            for line in label_file:
                split_line = line.split(' ')
                bboxes.append([int(float(split_line[1])*img_w), int(float(split_line[2])*img_h), int(float(split_line[3])*img_w), int(float(split_line[4])*img_h)])
                category_ids.append(int(split_line[0]))
        # cropping et enregistrement des images
        for k, instance in enumerate(zip(bboxes, category_ids)):
            bbox, cat_id = instance
            '''
            le repère de cv2 est 
            0/0---X--->
            |
            |
            Y
            |
            |
            v
            '''
            xmin = bbox[0] - bbox[2]//2
            xmax = bbox[0] + bbox[2]//2
            ymin = bbox[1] - bbox[3]//2
            ymax = bbox[1] + bbox[3]//2
            crop_img = img_src[ymin:ymax, xmin:xmax] # [y, x]
            cat_name = cat2name[cat_id]
            crop_img_name = img_res_path + '/' + cat_name + '/cropped_image' + str(k) + '_' + img_name
            cv2.imwrite(crop_img_name, crop_img)

# yaml file
# ajout des champs de données
data = {}
data['path'] = cls_res_path
data['names'] = cat2name
data['nc'] = len(cat2name)
data['train'] ="train"
data['val'] = "val"
data['test'] = "test"

try:
    yaml_f = open(cls_res_path + "/data.yaml", 'w')

    # écriture dans le fichier yaml
    yaml.dump(data, yaml_f, default_flow_style=False, allow_unicode=True)

    # fermeture des fichiers
    yaml_f.close()

except IOError as error:
    print(error)


In [None]:
print("Dimension moyenne des images de classification:")

for split in ["train", "val", "test"]:
    print("------------------------------------------\n" + split + ':')
    mean_w = 0
    mean_h = 0
    max_w, min_w, max_h, min_h = 0, math.inf , 0, math.inf
    nb_img = 0

    split_path = cls_res_path + '/' + split

    for label in tqdm(cat2name, desc = split):
        fnames = glob('*.jpg', dir_fd=split_path + '/' + label)
        nb_img += len(fnames)
        for fname in tqdm(fnames, desc = 'image in '+label):
            with Image.open(split_path + '/' + label + '/' + fname) as img:
                w, h = img.size
                mean_w += w
                mean_h += h
                if w > max_w:
                    max_w = w
                if w < min_w:
                    min_w = w
                if h > max_h:
                    max_h = h
                if h < min_h:
                    min_h = h
    if nb_img != 0:            
        mean_w /= nb_img
        mean_h /= nb_img

        print("\nmean width = ", mean_w)
        print("mean heigh = ", mean_h)
        print("max width: ", max_w)
        print("min width: ", min_w)
        print("max heigh: ", max_h)
        print("min heigh: ", min_h)
        print("number of images: ", nb_img)
    else:
        print("No image in " + split + " split!")

Dimension moyenne des images de classification:
------------------------------------------
train:


image in 3G: 0it [00:00, ?it/s]:00<?, ?it/s]
image in 4G: 0it [00:00, ?it/s]
image in B4: 0it [00:00, ?it/s]
image in B8: 0it [00:00, ?it/s]
train: 100%|██████████| 4/4 [00:00<00:00, 177.49it/s]


No image in train split!
------------------------------------------
val:


image in 3G: 0it [00:00, ?it/s]0<?, ?it/s]
image in 4G: 0it [00:00, ?it/s]
image in B4: 0it [00:00, ?it/s]
image in B8: 0it [00:00, ?it/s]
val: 100%|██████████| 4/4 [00:00<00:00, 188.75it/s]


No image in val split!
------------------------------------------
test:


image in 3G: 0it [00:00, ?it/s]00<?, ?it/s]
image in 4G: 0it [00:00, ?it/s]
image in B4: 0it [00:00, ?it/s]
image in B8: 0it [00:00, ?it/s]
test: 100%|██████████| 4/4 [00:00<00:00, 170.15it/s]

No image in test split!



