In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import GradScaler, autocast

from pathlib import Path
from PIL import Image
import numpy as np
import json
from tqdm import tqdm
import copy
from types import SimpleNamespace

# pycocotools가 필요합니다. pip install pycocotools
from pycocotools import mask as mask_utils

# --- 프로젝트 경로 추가 및 모듈 임포트 ---
import sys
sys.path.append('.')

from models.blade_model_v2 import BladeModelV2
from utils.criterion import SetCriterion
from utils.hungarian_matcher import HungarianMatcher

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

# --- 최종 설정 (Configuration) ---
class Config:
    DATA_ROOT = Path('C:/EngineBladeAI/EngineInspectionAI_MS/data/final_dataset_augmented')
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    # --- 학습 하이퍼파라미터 (수정) ---
    BATCH_SIZE = 4
    EPOCHS = 50
    LR = 2e-5  # <-- [수정] 1e-4는 너무 높았으므로, 2e-5 (0.00002)로 낮춥니다.
    WEIGHT_DECAY = 1e-4
    NUM_WORKERS = 0
    LR_DROP_STEP = 20
    GRADIENT_CLIP_VAL = 1.0 # <-- [추가] Gradient Clipping 값 설정

    MODEL = SimpleNamespace(
        # --- [수정] 딕셔너리를 SimpleNamespace로 변경 ---
        BACKBONE=SimpleNamespace(NAME='ConvNeXt-Tiny'),
        FPN=SimpleNamespace(OUT_CHANNELS=256),
        HEAD_B=SimpleNamespace(
            FEAT_CHANNELS=256,
            OUT_CHANNELS=256,
            NUM_CLASSES=3,
            QUERIES_PER_CLASS=100,
            DEC_LAYERS=6
        )
    )
    LOSS = SimpleNamespace(
        CLASS_WEIGHTS=[1.5, 1.0, 1.3], # Crack, Nick, Tear
        EOS_COEF=0.1
    )

config = Config()
print(f"\n--- Configuration Initialized ---")
print(f"Data Path: {config.DATA_ROOT}")
print(f"Device: {config.DEVICE}")
print(f"Initial Learning Rate: {config.LR}") # <-- 확인용 print 추가

In [None]:
class FinalBladeDataset(Dataset):
    """
    최종 통합된 데이터셋(final_dataset)을 위한 Dataset 클래스.
    """
    def __init__(self, root, split='train', transform=None):
        self.root = Path(root)
        self.split = split
        self.images_dir = self.root / self.split / 'images'
        
        json_path = self.root / self.split / 'annotations.json'
        with open(json_path, 'r') as f:
            self.data = json.load(f)
            
        self.images_info = self.data['images']
        self.annotations_map = {}
        for ann in self.data['annotations']:
            img_id = ann['image_id']
            if img_id not in self.annotations_map:
                self.annotations_map[img_id] = []
            self.annotations_map[img_id].append(ann)
            
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((640, 640)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.images_info)

    def __getitem__(self, idx):
        img_info = self.images_info[idx]
        img_id = img_info['id']
        img_path = self.images_dir / img_info['file_name']
        image = Image.open(img_path).convert('RGB')
        
        original_w, original_h = image.size
        
        target = {}
        blade_mask = np.zeros((original_h, original_w), dtype=np.uint8)
        damage_masks_np = []
        damage_labels = []
        multilabel_vector = torch.zeros(3, dtype=torch.float32)

        annotations = self.annotations_map.get(img_id, [])
        for ann in annotations:
            # --- [핵심 수정] ---
            # segmentation 데이터가 유효한지 확인하는 방어 코드 추가
            seg = ann.get('segmentation')
            if not seg or not isinstance(seg, list) or not seg[0] or len(seg[0]) < 6:
                # 유효하지 않은 polygon (최소 3개의 점 필요)이면 건너뛰기
                continue
                
            cat_id = ann['category_id']
            
            try:
                rle = mask_utils.frPyObjects([seg[0]], original_h, original_w)
                mask = mask_utils.decode(rle)
            except Exception as e:
                print(f"Warning: Failed to decode segmentation for ann_id {ann.get('id')}. Error: {e}")
                continue

            if mask.ndim == 3:
                mask = np.max(mask, axis=2)

            if cat_id == 1:
                blade_mask = np.maximum(blade_mask, mask)
            else:
                damage_masks_np.append(mask)
                damage_labels.append(cat_id - 2)
                multilabel_vector[cat_id - 2] = 1.0

        image = self.transform(image)
        
        target['blade_mask'] = torch.from_numpy(blade_mask).long()
        target['labels'] = torch.tensor(damage_labels, dtype=torch.int64)
        target['multilabel'] = multilabel_vector
        
        if damage_masks_np:
            damage_masks_tensor = torch.from_numpy(np.stack(damage_masks_np)).float()
            target['masks'] = damage_masks_tensor
        else:
            target['masks'] = torch.zeros((0, original_h, original_w), dtype=torch.float32)

        return image, target

def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    images = torch.stack(images, dim=0)
    return images, targets

print("✅ Dataset class and collate_fn are defined.")

In [None]:
print("--- Creating DataLoaders ---")
train_dataset = FinalBladeDataset(root=config.DATA_ROOT, split='train')
val_dataset = FinalBladeDataset(root=config.DATA_ROOT, split='valid')

train_loader = DataLoader(
    train_dataset, batch_size=config.BATCH_SIZE, shuffle=True,
    num_workers=config.NUM_WORKERS, collate_fn=collate_fn
)
val_loader = DataLoader(
    val_dataset, batch_size=config.BATCH_SIZE, shuffle=False,
    num_workers=config.NUM_WORKERS, collate_fn=collate_fn
)
print(f"✅ DataLoaders created!")
print(f"   Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

In [None]:
# ===== 셀 4: 모델, 손실함수, 옵티마이저 초기화 (수정) =====

print("--- Initializing Model, Criterion, Optimizer ---")
model = BladeModelV2(config).to(config.DEVICE)

matcher = HungarianMatcher(cost_class=2.0, cost_mask=5.0, cost_dice=5.0)
# 수정된 weight_dict (손상 탐지의 중요도를 크게 높임)
weight_dict = {'loss_ce': 5.0, 'loss_mask': 10.0, 'loss_dice': 10.0}

criterion = SetCriterion(
    num_classes=config.MODEL.HEAD_B.NUM_CLASSES, matcher=matcher, weight_dict=weight_dict,
    eos_coef=config.LOSS.EOS_COEF, losses=['labels', 'masks'],
    class_weights=config.LOSS.CLASS_WEIGHTS
).to(config.DEVICE)

optimizer = AdamW(model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
scaler = GradScaler()

# --- [수정] 학습률 스케줄러 설정 ---
lr_scheduler = StepLR(optimizer, step_size=config.LR_DROP_STEP)

print("✅ Initialization complete.")

In [None]:
# torchmetrics에서 필요한 모든 평가 지표 클래스를 임포트합니다.
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.classification import MulticlassJaccardIndex # Blade IoU용


# train_epoch 함수는 그대로 둡니다.
def train_epoch(model, criterion, dataloader, optimizer, device, epoch):
    model.train()
    criterion.train()
    total_loss = 0
    blade_loss_weight = 1.0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.EPOCHS} [Train]")
    for images, targets in pbar:
        images = images.to(device)
        targets_gpu = [{k: v.to(device) for k, v in t.items()} for t in targets]
        with autocast():
            outputs = model(images)
            loss_dict = criterion(outputs, targets_gpu)
            damage_loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
            blade_logits = outputs['blade_logits']
            gt_blade_masks = torch.stack([t['blade_mask'] for t in targets_gpu]).unsqueeze(1).float()
            blade_logits_resized = F.interpolate(blade_logits, size=gt_blade_masks.shape[-2:], mode="bilinear", align_corners=False)
            loss_blade = F.binary_cross_entropy_with_logits(blade_logits_resized, gt_blade_masks)
            weighted_loss = damage_loss + (loss_blade * blade_loss_weight)
        optimizer.zero_grad()
        scaler.scale(weighted_loss).backward()
        
        # --- [추가] Gradient Clipping ---
        # scaler가 unscale을 한 후에 clipping을 적용해야 함
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.GRADIENT_CLIP_VAL)
        
        scaler.step(optimizer)
        scaler.update()

        total_loss += weighted_loss.item()
        pbar.set_postfix({'loss': f'{weighted_loss.item():.4f}', 'L_dmg': f'{damage_loss.item():.2f}', 'L_bld': f'{loss_blade.item():.2f}'})
        
    return total_loss / len(dataloader)


def validate(model, criterion, dataloader, device):
    model.eval()
    criterion.eval()
    
    # --- [최종 수정] Detection 전용 평가 지표 객체 초기화 ---
    num_damage_classes = config.MODEL.HEAD_B.NUM_CLASSES
    # 1. Blade IoU (이전과 동일)
    blade_iou_metric = MulticlassJaccardIndex(num_classes=2).to(device)
    # 2. Damage mAP (탐지 문제의 표준 평가 지표)
    map_metric = MeanAveragePrecision(iou_type="segm")

    val_losses = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="[Valid]")
        for images, targets in pbar:
            images = images.to(device)
            targets_gpu = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            with autocast():
                outputs = model(images)
                # --- 손실 계산 (기존과 동일) ---
                loss_dict = criterion(outputs, targets_gpu)
                damage_loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
                blade_logits = outputs['blade_logits']
                gt_blade_masks = torch.stack([t['blade_mask'] for t in targets_gpu]).unsqueeze(1).float()
                blade_logits_resized = F.interpolate(blade_logits, size=gt_blade_masks.shape[-2:], mode="bilinear", align_corners=False)
                loss_blade = F.binary_cross_entropy_with_logits(blade_logits_resized, gt_blade_masks)
                weighted_loss = damage_loss + (loss_blade * 1.0)
            val_losses.append(weighted_loss.item())
            
            # --- [최종 수정] 평가 지표 업데이트 ---
            # 1. Blade IoU
            pred_blade_masks = (torch.sigmoid(blade_logits_resized) > 0.5).int().squeeze(1)
            blade_iou_metric.update(pred_blade_masks, gt_blade_masks.squeeze(1).int())

            # 2. mAP 계산을 위한 데이터 형식 변환
            pred_logits = outputs['pred_logits'].cpu()
            pred_masks = outputs['pred_masks'].cpu()
            
            preds_for_map = []
            for i in range(len(targets)):
                scores, labels = F.softmax(pred_logits[i], dim=-1).max(-1)
                masks_bool = (torch.sigmoid(pred_masks[i]) > 0.5)
                
                preds_for_map.append(dict(
                    masks=masks_bool, scores=scores, labels=labels,
                ))

            targets_for_map = []
            for t in targets:
                targets_for_map.append(dict(
                    masks=(t['masks'] > 0.5), labels=t['labels'],
                ))
            
            map_metric.update(preds_for_map, targets_for_map)

    # --- [최종 수정] 모든 지표 계산 및 집계 ---
    blade_iou = blade_iou_metric.compute().item()
    map_results = map_metric.compute()
    
    # mAP 결과에서 Precision, Recall 추출
    precision = map_results['map_50'].item() # mAP@50은 Precision-Recall 곡선의 면적
    recall = map_results['mar_100'].item() # 100개 예측 시 평균 재현율
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)

    metrics = {
        'loss': np.mean(val_losses),
        'blade_iou': blade_iou,
        'mAP': map_results['map'].item(),
        'precision': precision, # 근사치
        'recall': recall,
        'f1_score': f1
    }
    
    return metrics

In [None]:
# ===== 셀 6의 메인 학습 루프를 아래 코드로 교체 =====

print("\n--- 🚀 Starting Final Training 🚀 ---")
best_val_loss = float('inf')

for epoch in range(config.EPOCHS):
    train_loss = train_epoch(model, criterion, train_loader, optimizer, config.DEVICE, epoch)
    val_metrics = validate(model, criterion, val_loader, config.DEVICE)
    
    val_loss = val_metrics['loss']
    
    # --- [최종 수정] 새로운 지표들 출력 ---
    print(f"\nEpoch {epoch+1}/{config.EPOCHS} -> Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"  [Blade] IoU: {val_metrics['blade_iou']:.4f}")
    print(f"  [Damage] mAP: {val_metrics['mAP']:.4f} | Precision: {val_metrics['precision']:.4f} | "
          f"Recall: {val_metrics['recall']:.4f} | F1: {val_metrics['f1_score']:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'blade_damage_best_model.pth')
        print(f"✨ New best model saved with validation loss: {best_val_loss:.4f}")

print("\n--- 🎉 Training Complete ---")