In [10]:
import json
import yaml
from ultralytics import YOLO
import matplotlib.pyplot as plt
import cv2
import os
import torch
from ipywidgets import Text, Button, Output, VBox
from IPython.display import display

# Função para carregar os nomes das classes a partir do arquivo YAML
def load_class_names(yaml_path):
    with open(yaml_path, 'r') as file:
        data = yaml.safe_load(file)
    return data['names']

# Função para gerar nomes de quadrantes
def generate_quadrant_names(rows, cols):
    quadrant_names = []
    for r in range(rows):
        for c in range(cols):
            quadrant_names.append(chr(65 + r) + str(c + 1))
    return quadrant_names

# Função para determinar os quadrantes que um ponto (x, y) atravessa dado os limites da chapa
def get_quadrants_and_sizes(x1, y1, x2, y2, cx1, cy1, cx2, cy2, rows, cols):
    quadrant_width = (cx2 - cx1) // cols
    quadrant_height = (cy2 - cy1) // rows
    quadrants = {}

    for x in range(x1, x2 + 1):
        for y in range(y1, y2 + 1):
            col = (x - cx1) // quadrant_width
            row = (y - cy1) // quadrant_height
            col = min(col, cols - 1)  # Ensure col index is within bounds
            row = min(row, rows - 1)  # Ensure row index is within bounds
            quadrant = chr(65 + row) + str(col + 1)
            if quadrant not in quadrants:
                quadrants[quadrant] = 0
            quadrants[quadrant] += 1
    
    # Normalize sizes by the size of the defect
    total_size = (x2 - x1 + 1) * (y2 - y1 + 1)
    for quadrant in quadrants:
        quadrants[quadrant] /= total_size
    
    return quadrants

# Função para calcular a pontuação da chapa
def calculate_score(defects, rows, cols):
    score = 100
    center_quadrant = chr(65 + rows//2) + str(cols//2 + 1)
    quadrant_scores = {name: 0 for name in generate_quadrant_names(rows, cols)}
    processed_defects = set()

    for defect in defects:
        defect_id = (defect['class'], defect['coordinates']['x1'], defect['coordinates']['y1'], defect['coordinates']['x2'], defect['coordinates']['y2'])
        if defect_id not in processed_defects:
            penalty = penalties[defect['class']]
            for quadrant, size in defect['quadrants'].items():
                quadrant_penalty = penalty * size
                if quadrant == center_quadrant:
                    quadrant_penalty *= 2  # Penalidade dobrada para o quadrante central
                defect['penalty'] = quadrant_penalty
                score -= quadrant_penalty
                quadrant_scores[quadrant] += quadrant_penalty
            processed_defects.add(defect_id)
    
    for quadrant in quadrant_scores:
        quadrant_scores[quadrant] = max(100 - quadrant_scores[quadrant], 0)  # A pontuação não pode ser menor que 0
    
    return max(score, 0), quadrant_scores  # A pontuação não pode ser menor que 0

# Função para adicionar texto com fundo opaco
def put_text_with_background(image, text, position, font, font_scale, font_color, thickness, bg_color):
    (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
    x, y = position
    cv2.rectangle(image, (x, y - text_height - 5), (x + text_width, y + 5), bg_color, -1)
    cv2.putText(image, text, (x, y), font, font_scale, font_color, thickness)

# Carregar os nomes das classes a partir do arquivo YAML
yaml_path = 'datasets/granito-3nbcw/data.yaml'
class_names = load_class_names(yaml_path)

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

# Parâmetros fixos
model_path = 'runs/detect/train7/weights/best.pt'
image_dir = 'datasets/granito-3nbcw/train/images/'
output_dir = 'runs/detections/images'
json_filename = 'runs/detections/detection_results.json'
rows = 3
cols = 3
margin_percentage = 0.01
min_confidence = 0.25
crop_margin_percentage = 0.01
penalties = {'veio': 1, 'furo': 2, 'fanado': 3}

# Interface gráfica
image_input = Text(description='Imagem:', placeholder='Digite o nome do arquivo de imagem')
process_button = Button(description='Processar', icon='check', tooltip='Clique para processar a imagem')
output = Output()

def process_image(change):
    with output:
        output.clear_output()
        image_file = image_input.value
        image_path = os.path.join(image_dir, image_file)
        
        if not os.path.exists(image_path):
            print(f"Arquivo {image_path} não encontrado.")
            return
        
        # Fazer a detecção na imagem especificada
        model = YOLO(model_path)
        model.to(device)
        results = model(image_path)
        detections = results[0].boxes
        image = cv2.imread(image_path)
        height, width, _ = image.shape
        
        chapa_box = None
        for box in detections:
            if int(box.cls[0]) == 0:
                chapa_box = box
                break
        
        if chapa_box:
            cx1, cy1, cx2, cy2 = map(int, chapa_box.xyxy[0])
            margin_w = int((cx2 - cx1) * margin_percentage)
            margin_h = int((cy2 - cy1) * margin_percentage)
            cx1 = max(cx1 - margin_w, 0)
            cy1 = max(cy1 - margin_h, 0)
            cx2 = min(cx2 + margin_w, width)
            cy2 = min(cy2 + margin_h, height)
            crop_margin_w = int((cx2 - cx1) * crop_margin_percentage)
            crop_margin_h = int((cy2 - cy1) * crop_margin_percentage)
            crop_cx1 = max(cx1 - crop_margin_w, 0)
            crop_cy1 = max(cy1 - crop_margin_h, 0)
            crop_cx2 = min(cx2 + crop_margin_w, width)
            crop_cy2 = min(cy2 + crop_margin_h, height)
            
            results_summary = []
            errors_summary = []
            disregarded_summary = []

            for box in detections:
                cls = int(box.cls[0].item())
                class_name = class_names[cls]
                conf = float(box.conf[0].item())
                x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())

                if cls != 0:
                    if conf >= min_confidence:
                        if cx1 <= x1 <= cx2 and cy1 <= y1 <= cy2 and cx1 <= x2 <= cx2 and cy1 <= y2 <= cy2:
                            quadrants_and_sizes = get_quadrants_and_sizes(x1, y1, x2, y2, cx1, cy1, cx2, cy2, rows, cols)
                            for quadrant, size in quadrants_and_sizes.items():
                                results_summary.append({
                                    "quadrant": quadrant,
                                    "class": class_name,
                                    "confidence": conf,
                                    "penalty": size * penalties[class_name],
                                    "coordinates": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
                                    "quadrants": quadrants_and_sizes
                                })
                        else:
                            errors_summary.append({
                                "class": class_name,
                                "confidence": conf,
                                "coordinates": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
                            })
                    else:
                        disregarded_summary.append({
                            "class": class_name,
                            "confidence": conf,
                            "coordinates": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
                        })

            score, quadrant_scores = calculate_score(results_summary, rows, cols)
            results_summary.sort(key=lambda x: x["quadrant"])

            print("\nParâmetros Utilizados:")
            print(f"{'Imagem:':<20} {image_dir}")
            print(f"{'Nome do Arquivo:':<20} {image_file}")
            print(f"{'Linhas:':<20} {rows}")
            print(f"{'Colunas:':<20} {cols}")
            print(f"{'Margem:':<20} {margin_percentage * 100}%")
            print(f"{'Confiança mínima:':<20} {min_confidence}")
            print(f"{'Margem de corte:':<20} {crop_margin_percentage * 100}%\n")

            print(f"Pontuação da chapa: {score:.2f}%\n")

            print("Resultados da detecção:")
            if results_summary:
                print(f"{'Quadrante':<10} {'Classe':<10} {'Confiança':<10} {'Penalidade':<10} {'Coordenadas':<30}")
                print("="*90)
                for result in results_summary:
                    penalty = result['penalty'] if 'penalty' in result else 0
                    print(f"{result['quadrant']:<10} {result['class']:<10}      {result['confidence']:.2f}        {penalty:.2f} ({result['coordinates']['x1']}, {result['coordinates']['y1']}), ({result['coordinates']['x2']}, {result['coordinates']['y2']})")
            else:
                print("Nenhum defeito foi detectado na chapa.\n")

            if errors_summary:
                print("\nErros fora da chapa:")
                print(f"{'Classe':<10} {'Confiança':<10} {'Coordenadas':<20}")
                print("="*70)
                for error in errors_summary:
                    print(f"{error['class']:<10} {error['confidence']:.2f}       ({error['coordinates']['x1']}, {error['coordinates']['y1']}), ({error['coordinates']['x2']}, {error['coordinates']['y2']})")
            else:
                print("Nenhum erro fora da chapa foi detectado.\n")

            if disregarded_summary:
                print("\nDefeitos desconsiderados por confiança mínima:")
                print(f"{'Classe':<10} {'Confiança':<10} {'Coordenadas':<20}")
                print("="*70)
                for disregarded in disregarded_summary:
                    print(f"{disregarded['class']:<10} {disregarded['confidence']:.2f}       ({disregarded['coordinates']['x1']}, {disregarded['coordinates']['y1']}), ({disregarded['coordinates']['x2']}, {disregarded['coordinates']['y2']})")
            else:
                print("Nenhum defeito foi desconsiderado por confiança mínima.\n")

            output_data = {
                image_file: {
                    "results": results_summary,
                    "score": f"{score:.2f}%",
                    "output_image": os.path.join(output_dir, 'result_' + os.path.basename(image_file)),
                    "cropped_image": os.path.join(output_dir, 'cropped_' + os.path.basename(image_file))
                }
            }

            if os.path.exists(json_filename):
                with open(json_filename, 'r') as json_file:
                    existing_data = json.load(json_file)
                existing_data[image_file] = output_data[image_file]
                output_data = existing_data

            with open(json_filename, 'w') as json_file:
                json.dump(output_data, json_file, indent=4)

            print(f"Resultados salvos em '{json_filename}'\n")

            result_image = results[0].plot()
            output_image_file = os.path.join(output_dir, 'result_' + os.path.basename(image_file))
            cv2.imwrite(output_image_file, result_image)
            print(f"Imagem salva em: {output_image_file}\n")

            plt.imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
            plt.axis('off')
            plt.show()

            cropped_image = image[crop_cy1:crop_cy2, crop_cx1:crop_cx2]

            quadrant_width = (cx2 - cx1) // cols
            quadrant_height = (cy2 - cy1) // rows
            for i in range(1, cols):
                cv2.line(cropped_image, (cx1 + i * quadrant_width - crop_cx1, cy1 - crop_cy1), (cx1 + i * quadrant_width - crop_cx1, cy2 - crop_cy1), (0, 255, 0), 2)
            for i in range(1, rows):
                cv2.line(cropped_image, (cx1 - crop_cx1, cy1 + i * quadrant_height - crop_cy1), (cx2 - crop_cx1, cy1 + i * quadrant_height - crop_cy1), (0, 255, 0), 2)

            quadrant_names = generate_quadrant_names(rows, cols)
            for i, name in enumerate(quadrant_names):
                qx = cx1 + (i % cols) * quadrant_width + quadrant_width // 2 - crop_cx1
                qy = cy1 + (i // cols) * quadrant_height + quadrant_height // 2 - crop_cy1
                put_text_with_background(cropped_image, name, (qx, qy - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2, (0, 0, 0))
                put_text_with_background(cropped_image, f"{quadrant_scores[name]:.2f}%", (qx, qy + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2, (0, 0, 0))

            cropped_image_file = os.path.join(output_dir, 'cropped_' + os.path.basename(image_file))
            cv2.imwrite(cropped_image_file, cropped_image)
            print(f"Imagem salva em: {cropped_image_file}\n")

            plt.imshow(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
            plt.axis('off')
            plt.show()
        else:
            print("Nenhuma chapa (classe 0) foi detectada na imagem.")

process_button.on_click(process_image)

display(VBox([image_input, process_button, output]))


VBox(children=(Text(value='', description='Imagem:', placeholder='Digite o nome do arquivo de imagem'), Button…