<a href="https://colab.research.google.com/github/Marcottero/AI_project/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
mount = '/content/drive'
from google.colab import drive
drive.mount(mount)

Mounted at /content/drive


In [2]:
!pip install ultralytics
!pip install easyocr
!pip install pyyaml

Collecting ultralytics
  Downloading ultralytics-8.3.146-py3-none-any.whl.metadata (37 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.14-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8.0->ultralytics)
  Downloading n

In [4]:
import torch
import torch.nn as nn
import cv2
import numpy as np
from ultralytics import YOLO
import easyocr
import yaml
import os
from pathlib import Path
import json
from typing import Dict, List, Tuple, Optional
import matplotlib.pyplot as plt
from PIL import Image
import albumentations as A
from torch.utils.data import Dataset, DataLoader
import logging

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


  check_for_updates()


In [5]:
# Configurazione logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TechnicalDrawingAnalyzer:
    """
    Modello multi-modale per l'analisi automatica di disegni tecnici 2D.
    Combina YOLO per object detection e OCR per la lettura del testo.
    """

    def __init__(self, model_path: str = None, device: str = 'auto'):
        """
        Inizializza il modello analyzer.

        Args:
            model_path: Path al modello YOLO personalizzato (se None, usa YOLOv8n pre-trained)
            device: Device da utilizzare ('auto', 'cpu', 'cuda')
        """
        self.device = self._setup_device(device)
        self.model = self._load_model(model_path)
        self.ocr_reader = easyocr.Reader(['it', 'en'])

        # Classi per object detection
        self.classes = {
            0: 'missing_weld',
            1: 'weld_error',
            2: 'weld_ok',
            3: 'valid_name',
            4: 'des_name',
            5: 'mat_cod',
            6: 'part_cod'
        }

        # Soglie di confidenza
        self.confidence_threshold = 0.25
        self.iou_threshold = 0.45

    def _setup_device(self, device: str) -> torch.device:
        """Setup del device di computazione."""
        if device == 'auto':
            return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        return torch.device(device)

    def _load_model(self, model_path: str) -> YOLO:
        """Carica il modello YOLO."""
        if model_path and os.path.exists(model_path):
            logger.info(f"Caricamento modello personalizzato da: {model_path}")
            return YOLO(model_path)
        else:
            logger.info("Caricamento YOLOv8n pre-trained per transfer learning")
            return YOLO('yolov8n.pt')

class DrawingDataset(Dataset):
    """
    Dataset personalizzato per i disegni tecnici annotati.
    """

    def __init__(self, image_dir: str, label_dir: str, img_size: int = 640, augment: bool = True):
        """
        Args:
            image_dir: Directory contenente le immagini
            label_dir: Directory contenente le annotazioni YOLO
            img_size: Dimensione target delle immagini
            augment: Se applicare data augmentation
        """
        self.image_dir = Path(image_dir)
        self.label_dir = Path(label_dir)
        self.img_size = img_size

        # Lista dei file immagine
        self.image_files = list(self.image_dir.glob('*.jpg')) + \
                          list(self.image_dir.glob('*.png')) + \
                          list(self.image_dir.glob('*.jpeg'))

        # Data augmentation specifica per disegni tecnici
        self.augment_transform = A.Compose([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            A.Blur(blur_limit=3, p=0.3),
            A.RandomRotate90(p=0.2),  # Solo rotazioni di 90° per disegni tecnici
            A.HorizontalFlip(p=0.3),
        ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

        self.apply_augment = augment

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Carica immagine
        img_path = self.image_files[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Carica annotazioni
        label_path = self.label_dir / f"{img_path.stem}.txt"
        boxes, class_labels = self._load_annotations(label_path)

        # Applica augmentation se richiesta
        if self.apply_augment and boxes:
            augmented = self.augment_transform(image=image, bboxes=boxes, class_labels=class_labels)
            image = augmented['image']
            boxes = augmented['bboxes']
            class_labels = augmented['class_labels']

        return {
            'image': image,
            'boxes': boxes,
            'labels': class_labels,
            'image_path': str(img_path)
        }

    def _load_annotations(self, label_path: Path) -> Tuple[List, List]:
        """Carica le annotazioni YOLO da file."""
        boxes = []
        class_labels = []

        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        class_id = int(parts[0])
                        x_center, y_center, width, height = map(float, parts[1:5])
                        boxes.append([x_center, y_center, width, height])
                        class_labels.append(class_id)

        return boxes, class_labels

class ModelTrainer:
    """
    Classe per il training del modello con transfer learning.
    """

    def __init__(self, data_yaml_path: str):
        """
        Args:
            data_yaml_path: Path al file YAML con configurazione dataset
        """
        self.data_yaml = data_yaml_path
        self.model = None

    def create_data_yaml(self, train_dir: str, val_dir: str, output_path: str):
        """
        Crea il file YAML di configurazione per YOLO.

        Args:
            train_dir: Directory training images
            val_dir: Directory validation images
            output_path: Path output del file YAML
        """
        data_config = {
            'train': train_dir,
            'val': val_dir,
            'nc': 7,  # Numero di classi
            'names': ['missing_weld', 'weld_error', 'weld_ok', 'valid_name',
                     'des_name', 'mat_cod', 'part_cod']
        }

        with open(output_path, 'w') as f:
            yaml.dump(data_config, f, default_flow_style=False)

        logger.info(f"File YAML creato: {output_path}")

    def train_model(self, epochs: int = 100, batch_size: int = 16, img_size: int = 640):
        """
        Esegue il training con transfer learning.

        Args:
            epochs: Numero di epoche
            batch_size: Dimensione batch
            img_size: Dimensione immagini
        """
        # Carica modello pre-trained YOLOv8
        self.model = YOLO('yolov8n.pt')

        # Training con transfer learning
        logger.info("Inizio training con transfer learning...")
        results = self.model.train(
            data=self.data_yaml,
            epochs=epochs,
            batch=batch_size,
            imgsz=img_size,
            device=0 if torch.cuda.is_available() else 'cpu',
            patience=20,  # Early stopping
            save_period=10,  # Salva checkpoint ogni 10 epoche
            project='technical_drawings',
            name='yolo_custom',
            exist_ok=True,
            # Parametri specifici per transfer learning
            lr0=0.001,  # Learning rate ridotto per fine-tuning
            warmup_epochs=3,
            box=7.5,  # Box loss gain
            cls=0.5,   # Classification loss gain
            dfl=1.5    # DFL loss gain
        )

        logger.info("Training completato!")
        return results

class ErrorDetector:
    """
    Classe per la rilevazione degli errori specifici nei disegni tecnici.
    """

    def __init__(self, analyzer: TechnicalDrawingAnalyzer):
        self.analyzer = analyzer

    def analyze_drawing(self, image_path: str) -> Dict:
        """
        Analizza un disegno tecnico e rileva tutti gli errori.

        Args:
            image_path: Path dell'immagine da analizzare

        Returns:
            Dict con tutti gli errori rilevati
        """
        # Carica immagine
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Impossibile caricare l'immagine: {image_path}")

        # Object detection
        results = self.analyzer.model(image, conf=self.analyzer.confidence_threshold)

        # Estrai detections
        detections = self._parse_detections(results[0])

        # Analizza errori specifici
        errors = {
            'weld_errors': self._analyze_weld_errors(detections, image),
            'cartouche_errors': self._analyze_cartouche_errors(detections, image),
            'bom_errors': self._analyze_bom_errors(detections, image),
            'summary': {}
        }

        # Riassunto errori
        errors['summary'] = self._create_error_summary(errors)

        return errors

    def _parse_detections(self, results) -> Dict:
        """Parsing dei risultati di detection."""
        detections = {class_name: [] for class_name in self.analyzer.classes.values()}

        if results.boxes is not None:
            for box in results.boxes:
                # Estrai coordinate e classe
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                confidence = box.conf[0].cpu().numpy()
                class_id = int(box.cls[0].cpu().numpy())

                class_name = self.analyzer.classes.get(class_id, 'unknown')

                detections[class_name].append({
                    'bbox': [int(x1), int(y1), int(x2), int(y2)],
                    'confidence': float(confidence),
                    'class_id': class_id
                })

        return detections

    def _analyze_weld_errors(self, detections: Dict, image: np.ndarray) -> Dict:
        """Analizza errori nei simboli di saldatura."""
        weld_errors = {
            'missing_welds': len(detections['missing_weld']),
            'incorrect_welds': len(detections['weld_error']),
            'correct_welds': len(detections['weld_ok']),
            'details': []
        }

        # Dettagli errori saldature mancanti
        for detection in detections['missing_weld']:
            weld_errors['details'].append({
                'type': 'missing_weld',
                'location': detection['bbox'],
                'confidence': detection['confidence']
            })

        # Dettagli errori saldature incorrette
        for detection in detections['weld_error']:
            weld_errors['details'].append({
                'type': 'weld_error',
                'location': detection['bbox'],
                'confidence': detection['confidence']
            })

        return weld_errors

    def _analyze_cartouche_errors(self, detections: Dict, image: np.ndarray) -> Dict:
        """Analizza errori nel cartiglio."""
        cartouche_errors = {
            'validator_designer_same': False,
            'missing_validator': False,
            'details': []
        }

        # Estrai testo dai campi rilevati
        validator_text = self._extract_text_from_detections(
            detections['valid_name'], image
        )
        designer_text = self._extract_text_from_detections(
            detections['des_name'], image
        )

        # Controlla se validatore manca
        if not validator_text:
            cartouche_errors['missing_validator'] = True
            cartouche_errors['details'].append({
                'type': 'missing_validator',
                'description': 'Campo validatore vuoto o non rilevato'
            })

        # Controlla se validatore e disegnatore sono uguali
        if validator_text and designer_text:
            if any(v_name.strip().lower() == d_name.strip().lower()
                   for v_name in validator_text for d_name in designer_text):
                cartouche_errors['validator_designer_same'] = True
                cartouche_errors['details'].append({
                    'type': 'validator_designer_same',
                    'validator': validator_text,
                    'designer': designer_text
                })

        return cartouche_errors

    def _analyze_bom_errors(self, detections: Dict, image: np.ndarray) -> Dict:
        """Analizza errori nella BOM."""
        bom_errors = {
            'material_part_mismatch': False,
            'details': []
        }

        # Estrai codici materiale e parte
        material_codes = self._extract_text_from_detections(
            detections['mat_cod'], image
        )
        part_codes = self._extract_text_from_detections(
            detections['part_cod'], image
        )

        # Verifica corrispondenza codici (logica semplificata)
        if material_codes and part_codes:
            # Esempio di logica: i primi 3 caratteri dovrebbero corrispondere
            for mat_code in material_codes:
                for part_code in part_codes:
                    if len(mat_code) >= 3 and len(part_code) >= 3:
                        if mat_code[:3] != part_code[:3]:
                            bom_errors['material_part_mismatch'] = True
                            bom_errors['details'].append({
                                'type': 'code_mismatch',
                                'material_code': mat_code,
                                'part_code': part_code
                            })

        return bom_errors

    def _extract_text_from_detections(self, detections: List, image: np.ndarray) -> List[str]:
        """Estrae testo dalle aree rilevate usando OCR."""
        texts = []

        for detection in detections:
            bbox = detection['bbox']
            x1, y1, x2, y2 = bbox

            # Estrai ROI
            roi = image[y1:y2, x1:x2]

            if roi.size > 0:
                # OCR sulla ROI
                try:
                    ocr_results = self.analyzer.ocr_reader.readtext(roi)
                    for (_, text, confidence) in ocr_results:
                        if confidence > 0.5:  # Soglia confidenza OCR
                            texts.append(text)
                except Exception as e:
                    logger.warning(f"Errore OCR: {e}")

        return texts

    def _create_error_summary(self, errors: Dict) -> Dict:
        """Crea riassunto degli errori trovati."""
        summary = {
            'total_errors': 0,
            'error_types': [],
            'severity': 'low'
        }

        # Conta errori saldature
        weld_error_count = (errors['weld_errors']['missing_welds'] +
                           errors['weld_errors']['incorrect_welds'])
        summary['total_errors'] += weld_error_count

        if weld_error_count > 0:
            summary['error_types'].append('weld_errors')

        # Conta errori cartiglio
        cartouche_error_count = sum([
            errors['cartouche_errors']['validator_designer_same'],
            errors['cartouche_errors']['missing_validator']
        ])
        summary['total_errors'] += cartouche_error_count

        if cartouche_error_count > 0:
            summary['error_types'].append('cartouche_errors')

        # Conta errori BOM
        if errors['bom_errors']['material_part_mismatch']:
            summary['total_errors'] += 1
            summary['error_types'].append('bom_errors')

        # Determina severità
        if summary['total_errors'] == 0:
            summary['severity'] = 'none'
        elif summary['total_errors'] <= 2:
            summary['severity'] = 'low'
        elif summary['total_errors'] <= 5:
            summary['severity'] = 'medium'
        else:
            summary['severity'] = 'high'

        return summary

class VisualizationTool:
    """
    Strumenti per visualizzare risultati e annotazioni.
    """

    @staticmethod
    def visualize_detections(image_path: str, detections: Dict, output_path: str = None):
        """
        Visualizza le detection sull'immagine.

        Args:
            image_path: Path immagine originale
            detections: Risultati detection
            output_path: Path output (se None, mostra a schermo)
        """
        image = cv2.imread(image_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Colori per diverse classi
        colors = {
            'missing_weld': (255, 0, 0),    # Rosso
            'weld_error': (255, 165, 0),    # Arancione
            'weld_ok': (0, 255, 0),         # Verde
            'valid_name': (0, 0, 255),      # Blu
            'des_name': (255, 0, 255),      # Magenta
            'mat_cod': (0, 255, 255),       # Ciano
            'part_cod': (255, 255, 0)       # Giallo
        }

        plt.figure(figsize=(15, 10))
        plt.imshow(image_rgb)

        # Disegna bounding boxes
        for class_name, class_detections in detections.items():
            color = colors.get(class_name, (128, 128, 128))

            for detection in class_detections:
                bbox = detection['bbox']
                x1, y1, x2, y2 = bbox

                # Disegna rettangolo
                rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
                                   fill=False, color=np.array(color)/255,
                                   linewidth=2)
                plt.gca().add_patch(rect)

                # Aggiungi label
                plt.text(x1, y1-5, f"{class_name}: {detection['confidence']:.2f}",
                        color=np.array(color)/255, fontsize=8,
                        bbox=dict(boxstyle="round,pad=0.3",
                                facecolor='white', alpha=0.7))

        plt.axis('off')
        plt.title('Rilevazioni nel Disegno Tecnico')

        if output_path:
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            logger.info(f"Visualizzazione salvata: {output_path}")
        else:
            plt.show()

# Esempio di utilizzo completo
def main():
    """Esempio di utilizzo del sistema completo."""

    # 1. Setup directory
    base_dir = Path("technical_drawings_dataset")
    train_dir = base_dir / "images" / "train"
    val_dir = base_dir / "images" / "val"
    labels_train_dir = base_dir / "labels" / "train"
    labels_val_dir = base_dir / "labels" / "val"

    # Crea directory se non esistono
    for dir_path in [train_dir, val_dir, labels_train_dir, labels_val_dir]:
        dir_path.mkdir(parents=True, exist_ok=True)

    # 2. Preparazione training
    trainer = ModelTrainer("data.yaml")

    # Crea file YAML configurazione
    trainer.create_data_yaml(
        train_dir=str(train_dir),
        val_dir=str(val_dir),
        output_path="data.yaml"
    )

    # 3. Training del modello (se hai i dati)
    # results = trainer.train_model(epochs=100, batch_size=8)

    # 4. Carica modello per inferenza
    analyzer = TechnicalDrawingAnalyzer()
    error_detector = ErrorDetector(analyzer)

    # 5. Analizza un disegno
    # image_path = "example_drawing.jpg"
    # if os.path.exists(image_path):
    #     errors = error_detector.analyze_drawing(image_path)
    #
    #     print("=== RISULTATI ANALISI ===")
    #     print(f"Errori totali: {errors['summary']['total_errors']}")
    #     print(f"Severità: {errors['summary']['severity']}")
    #     print(f"Tipi di errore: {errors['summary']['error_types']}")
    #
    #     # Visualizza risultati
    #     detections = error_detector._parse_detections(
    #         analyzer.model(cv2.imread(image_path), conf=0.25)[0]
    #     )
    #     VisualizationTool.visualize_detections(
    #         image_path, detections, "output_visualization.png"
    #     )

    logger.info("Sistema configurato e pronto per l'uso!")

    return analyzer, error_detector

if __name__ == "__main__":
    analyzer, error_detector = main()

Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt to 'yolov8n.pt'...


100%|██████████| 6.25M/6.25M [00:00<00:00, 102MB/s]


Progress: |██████████████████████████████████████████████████| 100.0% Complete



Progress: |--------------------------------------------------| 0.0% CompleteProgress: |--------------------------------------------------| 0.1% CompleteProgress: |--------------------------------------------------| 0.1% CompleteProgress: |--------------------------------------------------| 0.2% CompleteProgress: |--------------------------------------------------| 0.2% CompleteProgress: |--------------------------------------------------| 0.3% CompleteProgress: |--------------------------------------------------| 0.3% CompleteProgress: |--------------------------------------------------| 0.4% CompleteProgress: |--------------------------------------------------| 0.5% CompleteProgress: |--------------------------------------------------| 0.5% CompleteProgress: |--------------------------------------------------| 0.6% CompleteProgress: |--------------------------------------------------| 0.6% CompleteProgress: |--------------------------------------------------| 0.7% Complet