# Загрузка библиотек

In [1]:
import pandas as pd
import os
import numpy as np
from ultralytics import YOLO
import random
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from IPython.display import clear_output

In [2]:
RANDOM_STATE = 42

seed = RANDOM_STATE
np.random.seed(seed)
random.seed(seed)
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)

# Подготовка датасета для валидации

In [3]:
df_test = pd.read_csv('csvs/test.csv', sep=';', index_col=[0]).reset_index(drop=True)
df_test = df_test.loc[
    (df_test['terminal']==1) &
    (df_test['file_name'].apply(lambda x: 'augm' not in x)) &
    (df_test['file_name'].apply(lambda x: 'ЗНО' in x))
]
df_test_damaged = df_test.loc[
    df_test['terminal_damaged']==1
].sample(100)
df_test_undamaged = df_test.loc[
    df_test['terminal_undamaged']==1
].sample(100)
df_test = pd.concat(
    (df_test_damaged, df_test_undamaged)
)

# Формирование папки с прогнозами

In [4]:
model = YOLO('runs/detect/train3/weights/best.pt')

In [10]:
val_folder = Path('val_photos_threshold40')
if not os.path.exists(val_folder):
    os.makedirs(val_folder)
for image_data in tqdm(df_test.iterrows()):
    # image_data = df_test.iloc[0]
    image_path = image_data[1]['file_name']
    damaged_flag = image_data[1]['terminal_damaged']
    results = model.predict(image_path)
    fig, ax = plt.subplots(1, 2, figsize=(19, 10))
    ax[0].imshow(Image.open(f'{image_path}'))
    ax[1].imshow(Image.open(image_path))
    for i, box in enumerate(results[0].boxes):
        x1, y1, x2, y2 = box.xyxy[0].tolist()
        cls = int(box.cls)
        conf = float(box.conf)
        if conf<0.4:
            continue
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2, edgecolor='red')
        ax[1].add_patch(rect)
        ax[1].text(x1, y1 - 10, f'{results[0].names[int(cls)]} {(conf):.2f}', color='white', fontsize=9,
                bbox=dict(facecolor='red', alpha=0.5))
    damage_status = 'Terminal is not damaged'
    if damaged_flag:
        damage_status = 'Terminal is damaged'
    ax[0].set_title(f"Original Image\n{damage_status}")
    ax[0].set_axis_off()
    ax[1].set_title('Damage Detection')
    ax[1].set_axis_off()
    file_name = val_folder / image_path.split('\\')[-1]
    plt.savefig(file_name)
    plt.close(fig)
clear_output()