In [3]:
# ===============================================================================
# 🚀 완전한 Recycle Segmentation 파이프라인
# ===============================================================================

# PyTorch 관련 임포트
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.amp import autocast, GradScaler
from torch.cuda.amp import autocast, GradScaler

# Transformers 관련 임포트
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation

# 이미지 처리 관련 임포트
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from PIL import Image as PILImage

# 시스템 및 유틸리티 임포트
import os
import re
import json
import random
import shutil
import zipfile
import math
from glob import glob

# 진행상황 및 데이터 처리 관련 임포트
from tqdm import tqdm
from collections import Counter
import albumentations as A
from itertools import cycle

# ===============================================================================
# 📋 STEP 0: 데이터 정리 및 기본 설정
# ===============================================================================

print("🗂️ 데이터 정리 시작...")

# 데이터 정리: train 폴더에서 datasets 폴더로 이동
base_dir = "C:/Users/USER/Desktop/Reco_Notebook"
os.makedirs(os.path.join(base_dir, "datasets"), exist_ok=True)
train_dir = os.path.join(base_dir, "train")
train_imgs = glob(os.path.join(train_dir, "*.jpg"))
train_masks = glob(os.path.join(train_dir, "*.png"))

print(f"📁 train 폴더에서 찾은 이미지: {len(train_imgs)}개")
print(f"📁 train 폴더에서 찾은 마스크: {len(train_masks)}개")

# 이미지 복사
for train_img in train_imgs:
    image_dir = os.path.join(base_dir, "datasets", "images", os.path.basename(train_img))
    os.makedirs(os.path.dirname(image_dir), exist_ok=True)
    shutil.copy(train_img, image_dir)

# 마스크 복사
for train_mask in train_masks:
    mask_dir = os.path.join(base_dir, "datasets", "masks", os.path.basename(train_mask))
    os.makedirs(os.path.dirname(mask_dir), exist_ok=True)
    shutil.copy(train_mask, mask_dir)

print(f"✅ 데이터 정리 완료!")
print(f"   📁 이미지: {len(train_imgs)}개 → datasets/images/")
print(f"   📁 마스크: {len(train_masks)}개 → datasets/masks/")

# 클래스 정의: 7개 클래스 (background 포함)
class_names = [
    "background", "can", "glass",
    "paper", "plastic", "styrofoam", "vinyl"
]

# 클래스명 ↔ ID 매핑
label2id = {name: i for i, name in enumerate(class_names)}
id2label = {i: name for name, i in label2id.items()}
num_classes = len(class_names)

# 시각화용 색상 (배경은 투명 처리)
class_colors_bright = [
    None,             # background - 투명 (원본 이미지 배경 보임)
    (0, 255, 255),    # can - 밝은 청록색 (Cyan)
    (255, 255, 0),    # glass - 밝은 노란색
    (128, 255, 0),    # paper - 연두색
    (255, 0, 0),      # plastic - 밝은 빨간색
    (0, 128, 255),    # styrofoam - 밝은 파란색
    (255, 0, 128)     # vinyl - 밝은 분홍색
]

# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔧 사용 디바이스: {device}")

# 데이터 경로 설정
image_dir = os.path.join(base_dir, "datasets", "images")
mask_dir = os.path.join(base_dir, "datasets", "masks")

# ===============================================================================
# 🔍 STEP 1: 이미지-마스크 파일 매칭 및 데이터셋
# ===============================================================================

def preprocess_datasets():
    """datasets 폴더의 images(.jpg)와 masks(.png)를 매칭하고 클래스별 픽셀 수 계산"""
    print("\n" + "=" * 60)
    print("📊 STEP 1: datasets 폴더 데이터 매칭 시작")

    def get_base_name(filename):
        """파일명에서 숫자 기반 키 생성 (숫자만 추출)"""
        import re
        numbers = re.findall(r'\d+', filename)
        return '_'.join(numbers) if numbers else filename

    def find_matching_files():
        # 숫자 기반 매칭 방법 (930개 모두 성공!)
        image_list = [f for f in os.listdir(image_dir) if f.lower().endswith('.jpg')]
        mask_list = [f for f in os.listdir(mask_dir) if f.lower().endswith('.png')]
        
        print(f"📁 JPG 이미지 파일 개수: {len(image_list)}개")
        print(f"📁 PNG 마스크 파일 개수: {len(mask_list)}개")
        
        def get_numeric_key(filename):
            """파일명에서 숫자만 추출해서 키 생성"""
            import re
            numbers = re.findall(r'\d+', filename)
            return '_'.join(numbers) if numbers else filename
        
        # 숫자 기반으로 매칭
        image_dict = {get_numeric_key(f): f for f in image_list}
        mask_dict = {get_numeric_key(f): f for f in mask_list}
        
        matched_pairs = []
        for base_key in image_dict:
            if base_key in mask_dict:
                matched_pairs.append({
                    'base_name': base_key, 
                    'image_file': image_dict[base_key], 
                    'mask_file': mask_dict[base_key]
                })
        
        # 매칭 결과 확인
        image_only = set(image_dict.keys()) - set(mask_dict.keys())
        mask_only = set(mask_dict.keys()) - set(image_dict.keys())
        
        if image_only:
            print(f"⚠️ 매칭되지 않은 이미지: {len(image_only)}개")
            if len(image_only) <= 3:
                for key in list(image_only):
                    print(f"   - {image_dict[key]}")
        
        if mask_only:
            print(f"⚠️ 매칭되지 않은 마스크: {len(mask_only)}개")
            if len(mask_only) <= 3:
                for key in list(mask_only):
                    print(f"   - {mask_dict[key]}")
        
        print(f"✅ 숫자 기반 매칭 성공: {len(matched_pairs)}개")
        return matched_pairs

    final_data_list = []
    pixel_counter = Counter()
    matched_pairs = find_matching_files()

    print(f"🔍 매칭된 이미지-마스크 쌍: {len(matched_pairs)}개")

    for pair in tqdm(matched_pairs, desc="데이터 매칭"):
        img_path = os.path.join(image_dir, pair['image_file'])
        mask_path = os.path.join(mask_dir, pair['mask_file'])

        if not os.path.exists(img_path) or not os.path.exists(mask_path):
            continue

        try:
            mask_array = np.array(Image.open(mask_path))
            unique_classes = np.unique(mask_array)
            for class_id in unique_classes:
                if class_id in label2id.values():
                    pixel_counter[class_id] += (mask_array == class_id).sum()

            final_data_list.append({
                "image": img_path,
                "label": mask_path,
                "class_ids": unique_classes.tolist(),
                "base_name": pair['base_name']
            })

        except Exception as e:
            print(f"⚠️ 마스크 로딩 오류: {pair['base_name']} - {str(e)}")
            continue

    print(f"✅ 최종 데이터 개수: {len(final_data_list)}개")
    return final_data_list, pixel_counter

def create_basic_transforms(input_size=512):
    return A.Compose([A.Resize(input_size, input_size)])

class ImprovedSegDataset(Dataset):
    def __init__(self, items, processor, input_size=512):
        self.items = items
        self.processor = processor
        self.input_size = input_size
        self.transform = create_basic_transforms(input_size)
        self.max_class_id = max(label2id.values())
        self.valid_items = items
        print(f"📊 전체 데이터셋 크기: {len(self.valid_items)}개")

    def __len__(self):
        return len(self.valid_items)

    def __getitem__(self, idx):
        if idx >= len(self.valid_items):
            idx = idx % len(self.valid_items)

        rec = self.valid_items[idx]

        try:
            image = cv2.imread(rec['image'])
            if image is None:
                raise ValueError(f"이미지 로드 실패: {rec['image']}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            mask = cv2.imread(rec['label'], cv2.IMREAD_GRAYSCALE)
            if mask is None:
                raise ValueError(f"마스크 로드 실패: {rec['label']}")

            if image.shape[:2] != mask.shape[:2]:
                h, w = min(image.shape[0], mask.shape[0]), min(image.shape[1], mask.shape[1])
                image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
                mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)

            mask = np.clip(mask, 0, self.max_class_id)

            if self.transform:
                try:
                    transformed = self.transform(image=image, mask=mask)
                    image, mask = transformed['image'], transformed['mask']
                except Exception as e:
                    print(f"⚠️ Transform 오류: {e}")

            try:
                proc = self.processor(images=image, return_tensors="pt")
                pixel_values = proc["pixel_values"].squeeze(0)
            except Exception as e:
                image_tensor = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32) / 255.0
                pixel_values = image_tensor

            labels = torch.tensor(mask, dtype=torch.long)

            return {
                "pixel_values": pixel_values,
                "labels": labels,
                "filename": os.path.basename(rec['image']),
                "original_image_path": rec['image']
            }

        except Exception as e:
            print(f"⚠️ 데이터 로딩 오류: {rec['image']} - {str(e)}")
            if idx == 0:
                return self._get_dummy_data()
            else:
                return self.__getitem__(0)

    def _get_dummy_data(self):
        dummy_image = torch.zeros(3, self.input_size, self.input_size, dtype=torch.float32)
        dummy_mask = torch.zeros(self.input_size, self.input_size, dtype=torch.long)
        return {
            "pixel_values": dummy_image,
            "labels": dummy_mask,
            "filename": "dummy.jpg",
            "original_image_path": "dummy_path"
        }

def create_clean_dataset(items, processor, input_size=512):
    print(f"🔍 필터링 없는 데이터셋 생성 중...")
    print(f"📊 입력 데이터 개수: {len(items)}개")
    dataset = ImprovedSegDataset(items, processor, input_size)
    print(f"✅ 최종 데이터셋 크기: {len(dataset)}개")
    return dataset

# ===============================================================================
# 🚀 STEP 2: Loss 함수 및 학습 시스템
# ===============================================================================

class CombinedBoundaryLoss(nn.Module):
    def __init__(self, class_weights=None, dice_weight=0.5, ce_weight=0.3, boundary_weight=0.2, smooth=1e-7):
        super().__init__()
        self.class_weights = class_weights
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.boundary_weight = boundary_weight
        self.smooth = smooth

        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        self.register_buffer('sobel_x', sobel_x)
        self.register_buffer('sobel_y', sobel_y)

    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights)

        probs = F.softmax(logits, dim=1)
        dice_losses = []
        num_classes = logits.shape[1]
        for cls in range(1, num_classes):
            t_cls = (targets == cls).float()
            p_cls = probs[:, cls]
            inter = (p_cls * t_cls).sum(dim=[1,2])
            union = p_cls.sum(dim=[1,2]) + t_cls.sum(dim=[1,2])
            dice_score = ((2 * inter + self.smooth) / (union + self.smooth))
            dice_losses.append(1 - dice_score)

        if dice_losses:
            dice_loss = torch.stack(dice_losses, dim=1).mean()
        else:
            dice_loss = torch.tensor(0.0, device=logits.device)

        # Boundary Loss
        pred_mask = torch.argmax(probs, dim=1, keepdim=True).float()
        gt_mask = targets.unsqueeze(1).float()

        gx_pred = F.conv2d(pred_mask, self.sobel_x, padding=1)
        gy_pred = F.conv2d(pred_mask, self.sobel_y, padding=1)
        edge_pred = torch.sqrt(gx_pred ** 2 + gy_pred ** 2 + 1e-8)

        gx_gt = F.conv2d(gt_mask, self.sobel_x, padding=1)
        gy_gt = F.conv2d(gt_mask, self.sobel_y, padding=1)
        edge_gt = torch.sqrt(gx_gt ** 2 + gy_gt ** 2 + 1e-8)

        edge_mask = (edge_gt > 0.1).float()
        if edge_mask.sum() > 0:
            boundary_loss = F.l1_loss(edge_pred * edge_mask, edge_gt * edge_mask)
        else:
            boundary_loss = torch.tensor(0.0, device=logits.device)

        total_loss = (
            self.ce_weight * ce_loss +
            self.dice_weight * dice_loss +
            self.boundary_weight * boundary_loss
        )
        return total_loss

def calculate_improved_weights(data_list, device='cuda'):
    """클래스 개수 맞춘 개선된 가중치 계산"""
    pixel_counter = Counter()

    for rec in data_list:
        try:
            mask = np.array(Image.open(rec["label"]).convert("L"))
            unique, counts = np.unique(mask, return_counts=True)
            for cls_id, count in zip(unique, counts):
                if 0 <= cls_id <= 6:
                    pixel_counter[cls_id] += count
        except:
            continue

    fg_pixels = {k: v for k, v in pixel_counter.items() if k > 0}
    total_fg = sum(fg_pixels.values())

    weights = np.ones(7)
    weights[0] = 0.05

    for cls_id in range(1, 7):
        if cls_id in fg_pixels:
            freq = fg_pixels[cls_id] / total_fg
            weights[cls_id] = np.sqrt(1.0 / (freq + 1e-6))

    weights[1:] = weights[1:] / weights[1:].sum() * 6

    print(f"\n📊 클래스 가중치:")
    for i, w in enumerate(weights):
        class_name = id2label.get(i, f"class_{i}")
        print(f"  {class_name}: {w:.3f}")

    return torch.tensor(weights, dtype=torch.float32, device=device)

def remove_tiny_noise_with_confidence(pred_mask, single_probs, min_area_ratio=0.003, conf_thresh=0.3, adaptive_thresh=True):
    """노이즈 제거 함수"""
    H, W = pred_mask.shape
    total_pixels = H * W
    filtered_mask = np.zeros_like(pred_mask)

    for cls_id in np.unique(pred_mask):
        if cls_id == 0:
            continue

        class_mask = (pred_mask == cls_id).astype(np.uint8)
        num_labels, labels = cv2.connectedComponents(class_mask)

        if adaptive_thresh:
            class_confidences = single_probs[cls_id][class_mask == 1]
            if len(class_confidences) > 0:
                adaptive_conf_thresh = max(conf_thresh, np.percentile(class_confidences, 75))
            else:
                adaptive_conf_thresh = conf_thresh
        else:
            adaptive_conf_thresh = conf_thresh

        for label_id in range(1, num_labels):
            component_mask = (labels == label_id)
            area = component_mask.sum()
            area_ratio = area / total_pixels

            if area_ratio >= min_area_ratio:
                filtered_mask[component_mask] = cls_id
            else:
                comp_confidences = single_probs[cls_id][component_mask]
                max_conf = comp_confidences.max() if comp_confidences.size > 0 else 0.0

                if max_conf >= adaptive_conf_thresh:
                    filtered_mask[component_mask] = cls_id

    return filtered_mask

def refine_mask_morphology(mask, kernel_size=5):
    """형태학적 연산"""
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    opened = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
    closed = cv2.morphologyEx(opened, cv2.MORPH_CLOSE, kernel)
    return closed

def gentle_predict(batch, model, input_size=512, num_classes=10, confidence_threshold=None, use_multiscale=False, use_tta=False):
    """예측 함수"""
    imgs = batch["pixel_values"].to(device)

    with torch.no_grad():
        if use_tta:
            tta_preds = []
            
            # 원본
            outputs = model(pixel_values=imgs)
            logits = outputs.logits
            if logits.shape[-2:] != (input_size, input_size):
                logits = F.interpolate(logits, size=(input_size, input_size), mode="bilinear", align_corners=False)
            tta_preds.append(F.softmax(logits, dim=1))

            # 좌우 반전
            imgs_h_flipped = torch.flip(imgs, dims=[3])
            outputs = model(pixel_values=imgs_h_flipped)
            logits = outputs.logits
            if logits.shape[-2:] != (input_size, input_size):
                logits = F.interpolate(logits, size=(input_size, input_size), mode="bilinear", align_corners=False)
            logits_h_flipped_back = torch.flip(logits, dims=[3])
            tta_preds.append(F.softmax(logits_h_flipped_back, dim=1))

            probs = torch.stack(tta_preds).mean(dim=0)
        else:
            outputs = model(pixel_values=imgs)
            logits = outputs.logits
            if logits.shape[-2:] != (input_size, input_size):
                logits = F.interpolate(logits, size=(input_size, input_size), mode="bilinear", align_corners=False)
            probs = F.softmax(logits, dim=1)

        filtered_pred_list = []
        for i in range(probs.shape[0]):
            single_probs = probs[i].cpu().numpy()
            pred_mask = np.argmax(single_probs, axis=0)

            filtered1 = remove_tiny_noise_with_confidence(
                pred_mask, single_probs, min_area_ratio=0.003, conf_thresh=0.3, adaptive_thresh=True
            )
            filtered2 = refine_mask_morphology(filtered1, kernel_size=5)
            
            filtered_pred_list.append(torch.tensor(filtered2, device=device))

        pred = torch.stack(filtered_pred_list)

    return probs, pred

def improved_training(model, train_loader, val_loader, processor, class_weights_tensor,
                     max_epochs=400, patience=40, device='cuda',
                     use_enhanced_loss=False, use_advanced_scheduler=False):
    """학습 함수"""
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    model = model.to(device)
    scaler = GradScaler()
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4, eps=1e-8)

    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=12,
                                  threshold=0.005, min_lr=1e-6, verbose=True)

    criterion = CombinedBoundaryLoss(
        class_weights=class_weights_tensor,
        dice_weight=0.5,
        ce_weight=0.3,
        boundary_weight=0.2
    ).to(device)

    history = {'train_loss': [], 'val_loss': [], 'dice_scores': [], 'iou_scores': [], 'learning_rates': []}
    best_dice = 0.0
    early_stop_counter = 0
    best_model_path = os.path.join(base_dir, "results", "best_model")

    for epoch in range(1, max_epochs + 1):
        # Training
        model.train()
        train_losses = []
        for batch in train_loader:
            imgs = batch["pixel_values"].to(device)
            masks = batch["labels"].to(device)
            masks = torch.clamp(masks, 0, 9)
            masks = F.interpolate(
                masks.unsqueeze(1).float(),
                size=(512, 512),
                mode='nearest'
            ).squeeze(1).long()

            optimizer.zero_grad()

            with autocast():
                outputs = model(pixel_values=imgs)
                logits = outputs.logits
                logits = F.interpolate(logits, size=(512, 512), mode='bilinear')
                loss = criterion(logits, masks)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss.item())

        avg_train_loss = np.mean(train_losses)
        current_lr = optimizer.param_groups[0]['lr']

        # Validation
        val_dice, val_iou = evaluate_model_fairly(model, val_loader)

        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch in val_loader:
                imgs = batch["pixel_values"].to(device)
                masks = batch["labels"].to(device)
                masks = torch.clamp(masks, 0, 9)
                masks = F.interpolate(
                    masks.unsqueeze(1).float(),
                    size=(512, 512),
                    mode='nearest'
                ).squeeze(1).long()

                with autocast():
                    outputs = model(pixel_values=imgs)
                    logits = outputs.logits
                    logits = F.interpolate(logits, size=(512, 512), mode='bilinear')
                    loss = criterion(logits, masks)
                val_losses.append(loss.item())

        avg_val_loss = np.mean(val_losses) if val_losses else float('inf')

        scheduler.step(val_dice)

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['dice_scores'].append(val_dice)
        history['iou_scores'].append(val_iou)
        history['learning_rates'].append(current_lr)

        print(f"Epoch {epoch}/{max_epochs}  ▶  Train Loss: {avg_train_loss:.4f}  |  Val Loss: {avg_val_loss:.4f}  |  Dice: {val_dice:.4f}  |  IoU: {val_iou:.4f}  |  LR: {current_lr:.2e}")

        if val_dice > best_dice:
            best_dice = val_dice
            try:
                os.makedirs(best_model_path, exist_ok=True)
                model.eval()
                model.save_pretrained(best_model_path, safe_serialization=False)
                processor.save_pretrained(best_model_path)
                print(f"✅ New Best! Dice {val_dice:.4f} 저장 완료")
            except Exception as e:
                print(f"❌ 모델 저장 실패: {e}")
            early_stop_counter = 0
        else:
            early_stop_counter += 1

        if early_stop_counter >= patience:
            print(f"\n🛑 Early stopping at epoch {epoch}")
            break

    return history, best_model_path

def evaluate_model_fairly(model, val_loader, use_strict=False):
    """모델 평가"""
    model.eval()
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for batch in val_loader:
            _, preds = gentle_predict(batch, model, 512, 10)

            targets = batch["labels"].to(device)
            targets = F.interpolate(
                targets.unsqueeze(1).float(),
                size=(512, 512),
                mode="nearest"
            ).squeeze(1).long()

            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    if all_preds and all_targets:
        pred_flat = np.concatenate([p.flatten() for p in all_preds])
        target_flat = np.concatenate([t.flatten() for t in all_targets])
        dice, iou, _ = calculate_advanced_metrics(pred_flat, target_flat, 10)
        return dice, iou

    return 0.0, 0.0

def calculate_advanced_metrics(pred, target, num_classes):
    """메트릭 계산"""
    pred = np.array(pred).flatten()
    target = np.array(target).flatten()

    valid_mask = target != 0
    pred_valid = pred[valid_mask]
    target_valid = target[valid_mask]

    if len(target_valid) == 0:
        return 0.0, 0.0, np.zeros(num_classes)

    dice_scores = []
    iou_scores = []
    precision_scores = np.zeros(num_classes)

    for cls in range(1, num_classes):
        pred_cls = (pred_valid == cls)
        target_cls = (target_valid == cls)

        if target_cls.sum() == 0:
            continue

        intersection = (pred_cls & target_cls).sum()
        union = (pred_cls | target_cls).sum()

        if union > 0:
            iou = intersection / union
            dice = (2 * intersection) / (pred_cls.sum() + target_cls.sum())
            iou_scores.append(iou)
            dice_scores.append(dice)

        if pred_cls.sum() > 0:
            precision_scores[cls] = intersection / pred_cls.sum()

    return (np.mean(dice_scores) if dice_scores else 0.0,
            np.mean(iou_scores) if iou_scores else 0.0,
            precision_scores)

# ===============================================================================
# 🎨 STEP 3: 시각화 함수들
# ===============================================================================

def mask_to_color_rgb(mask):
    """마스크를 RGB 컬러 이미지로 변환"""
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    for class_id in range(1, len(class_names)):
        if class_id < len(class_colors_bright) and class_colors_bright[class_id] is not None:
            class_mask = (mask == class_id)
            if class_mask.any():
                color = class_colors_bright[class_id]
                color_mask[class_mask] = color
    
    return color_mask

def add_readable_center_labels(image, mask, class_names, label_scale=0.6):
    """각 세그먼트의 중심에 읽기 쉬운 라벨 추가"""
    labeled_image = image.copy()
    
    for class_id in range(1, len(class_names)):
        class_mask = (mask == class_id)
        if class_mask.any():
            coords = np.where(class_mask)
            center_y = int(np.mean(coords[0]))
            center_x = int(np.mean(coords[1]))
            
            label_text = class_names[class_id]
            
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = label_scale
            thickness = 2
            
            (text_w, text_h), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
            
            box_x1 = max(0, center_x - text_w // 2 - 5)
            box_y1 = max(0, center_y - text_h // 2 - 5)
            box_x2 = min(image.shape[1], center_x + text_w // 2 + 5)
            box_y2 = min(image.shape[0], center_y + text_h // 2 + 5)
            
            cv2.rectangle(labeled_image, (box_x1, box_y1), (box_x2, box_y2), (0, 0, 0), -1)
            
            text_x = center_x - text_w // 2
            text_y = center_y + text_h // 2
            cv2.putText(labeled_image, label_text, (text_x, text_y), font, font_scale, (255, 255, 255), thickness)
    
    return labeled_image

def visualize_with_transparent_background(image, mask, alpha=0.7):
    """배경은 투명하게, 객체들만 색칠해서 오버레이"""
    result = image.copy()
    
    for class_id in range(1, len(class_names)):
        if class_id < len(class_colors_bright) and class_colors_bright[class_id] is not None:
            class_mask = (mask == class_id)
            if class_mask.any():
                color = class_colors_bright[class_id]
                color_overlay = np.full_like(image[class_mask], color)
                result[class_mask] = cv2.addWeighted(
                    image[class_mask], 1-alpha,
                    color_overlay, alpha,
                    0
                )
    
    return result

def create_pure_mask_visualization(mask):
    """순수 마스크 시각화 (배경 완전 투명)"""
    h, w = mask.shape
    rgba_image = np.zeros((h, w, 4), dtype=np.uint8)
    
    for class_id in range(1, len(class_names)):
        if class_id < len(class_colors_bright) and class_colors_bright[class_id] is not None:
            class_mask = (mask == class_id)
            if class_mask.any():
                color = class_colors_bright[class_id]
                rgba_image[class_mask] = [color[0], color[1], color[2], 255]
    
    return rgba_image

def save_prediction_comparison_readable(image_tensor, true_mask, pred_mask, save_path, original_image_path):
    """완벽한 예측 비교 저장"""
    if original_image_path and os.path.exists(original_image_path):
        orig_img = PILImage.open(original_image_path).convert('RGB')
        image_np = np.array(orig_img)

        if image_np.shape[:2] != (512, 512):
            image_np = cv2.resize(image_np, (512, 512))

    else:
        if torch.is_tensor(image_tensor):
            image_np = image_tensor.permute(1, 2, 0).cpu().numpy()

            if image_np.min() >= -3 and image_np.max() <= 3:
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                image_np = image_np * std + mean
                image_np = np.clip(image_np, 0, 1)

            if image_np.max() <= 1:
                image_np = (image_np * 255).astype(np.uint8)
            else:
                image_np = image_np.astype(np.uint8)
        else:
            image_np = image_tensor

        if len(image_np.shape) == 3 and image_np.shape[2] == 3:
            image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
            image_np = np.ascontiguousarray(image_np, dtype=np.uint8)

    target_h, target_w = image_np.shape[:2]
    if true_mask.shape != (target_h, target_w):
        true_mask = cv2.resize(true_mask.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST)
    if pred_mask.shape != (target_h, target_w):
        pred_mask = cv2.resize(pred_mask.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST)

    true_color_rgb = mask_to_color_rgb(true_mask)
    pred_color_rgb = mask_to_color_rgb(pred_mask)
    
    true_with_labels = add_readable_center_labels(true_color_rgb.copy(), true_mask, class_names, label_scale=0.7)
    pred_with_labels = add_readable_center_labels(pred_color_rgb.copy(), pred_mask, class_names, label_scale=0.7)

    transparent_overlay = visualize_with_transparent_background(image_np, pred_mask, alpha=0.6)
    overlay_with_labels = add_readable_center_labels(transparent_overlay.copy(), pred_mask, class_names, label_scale=0.6)

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    axes[0].imshow(image_np)
    axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
    axes[0].axis('off')

    axes[1].imshow(true_with_labels)
    axes[1].set_title("Ground Truth", fontsize=12, fontweight='bold')
    axes[1].axis('off')

    axes[2].imshow(pred_with_labels)
    axes[2].set_title("Prediction", fontsize=12, fontweight='bold')
    axes[2].axis('off')

    axes[3].imshow(overlay_with_labels)
    axes[3].set_title("Transparent Overlay", fontsize=12, fontweight='bold')
    axes[3].axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

# ===============================================================================
# 🗂️ STEP 4: 결과 저장 시스템
# ===============================================================================

def clean_save_all_results(model, val_loader, processor, history, final_dice, final_iou, 
                          best_model_path, save_base_dir=None):
    """깔끔한 결과 저장"""
    if save_base_dir is None:
        save_base_dir = os.path.join(base_dir, "results")
    
    print("🗂️ 깔끔한 결과 저장 시작...")
    
    viz_dir = os.path.join(save_base_dir, "visualizations")
    perf_dir = os.path.join(save_base_dir, "performance")
    
    if os.path.exists(viz_dir):
        shutil.rmtree(viz_dir)
    if os.path.exists(perf_dir):
        shutil.rmtree(perf_dir)
    
    os.makedirs(save_base_dir, exist_ok=True)
    
    # Best Model 확인
    model_save_dir = os.path.join(save_base_dir, "best_model")
    if os.path.exists(model_save_dir):
        files = os.listdir(model_save_dir)
        if files:
            print(f"✅ 1. Best Model 확인됨: {model_save_dir}")
        else:
            print(f"⚠️ 1. Best Model 폴더는 있지만 비어있음")
    else:
        print(f"❌ 1. Best Model 폴더 없음: {model_save_dir}")
        os.makedirs(model_save_dir, exist_ok=True)
        model.save_pretrained(model_save_dir, safe_serialization=False)
        processor.save_pretrained(model_save_dir)
        print(f"✅ 1. Current Model 저장: {model_save_dir}")
    
    # 모든 예측 결과 시각화 저장
    os.makedirs(viz_dir, exist_ok=True)
    print("🎨 2. 모든 예측 결과 시각화 저장 중...")
    total_saved = save_all_prediction_visualizations(model, val_loader, viz_dir)
    print(f"✅ 2. 시각화 완료: {total_saved}개 이미지 → {viz_dir}")
    
    # 성능 결과 저장
    os.makedirs(perf_dir, exist_ok=True)
    
    save_training_graphs(history, os.path.join(perf_dir, "training_history.png"))
    save_performance_report(history, final_dice, final_iou, os.path.join(perf_dir, "performance_report.txt"))
    
    print(f"✅ 3. 성능 결과 저장: {perf_dir}")
    
    print(f"\n🎉 깔끔한 저장 완료!")
    print(f"📁 저장 위치: {save_base_dir}")
    print(f"📊 구조:")
    print(f"  ├── best_model/        (학습된 모델)")
    print(f"  ├── visualizations/    ({total_saved}개 예측 시각화)")
    print(f"  └── performance/       (성능 그래프 + 리포트)")
    
    return save_base_dir

def save_all_prediction_visualizations(model, val_loader, save_dir):
    """모든 validation 데이터의 예측 결과를 4패널로 시각화하여 저장"""
    model.eval()
    total_saved = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc="시각화 저장")):
            imgs = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            filenames = batch["filename"]
            
            _, preds = gentle_predict(batch, model, 512, len(class_names))
            
            batch_size = imgs.shape[0]
            for i in range(batch_size):
                img_path = batch["original_image_path"][i]
                img_np = np.array(Image.open(img_path).convert("RGB"))
                
                gt_mask = labels[i].cpu().numpy().astype(np.uint8)
                pred_mask = preds[i].cpu().numpy().astype(np.uint8)
                
                filename = f"prediction_{batch_idx:03d}_{i:02d}_{os.path.splitext(filenames[i])[0]}.png"
                save_path = os.path.join(save_dir, filename)
                
                create_4panel_visualization(img_np, gt_mask, pred_mask, save_path)
                total_saved += 1
    
    return total_saved

def create_4panel_visualization(img_np, gt_mask, pred_mask, save_path):
    """4패널 시각화: 원본 + GT + 예측 + 오버레이"""
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(img_np)
    axes[0].set_title("Original Image", fontsize=14, fontweight='bold')
    axes[0].axis('off')

    gt_color = mask_to_color_rgb(gt_mask)
    gt_with_labels = add_readable_center_labels(gt_color.copy(), gt_mask, class_names, label_scale=0.7)
    axes[1].imshow(gt_with_labels)
    axes[1].set_title("Ground Truth", fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    pred_color = mask_to_color_rgb(pred_mask)
    pred_with_labels = add_readable_center_labels(pred_color.copy(), pred_mask, class_names, label_scale=0.7)
    axes[2].imshow(pred_with_labels)
    axes[2].set_title("Prediction", fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    overlay = create_overlay(img_np, pred_mask, alpha=0.4)
    overlay_with_labels = add_readable_center_labels(overlay, pred_mask, class_names, label_scale=0.6)
    axes[3].imshow(overlay_with_labels)
    axes[3].set_title("Overlay", fontsize=14, fontweight='bold')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
    plt.close()

def create_overlay(img_np, pred_mask, alpha=0.4):
    """원본 이미지 + 예측 마스크 오버레이"""
    overlay = img_np.copy().astype(np.float32)
    color_mask = mask_to_color_rgb(pred_mask).astype(np.float32)
    
    mask_area = (pred_mask > 0)
    overlay[mask_area] = (
        overlay[mask_area] * (1 - alpha) + 
        color_mask[mask_area] * alpha
    )
    
    return overlay.astype(np.uint8)

def save_training_graphs(history, save_path):
    """학습 히스토리 그래프 저장"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    axes[0, 0].plot(history['train_loss'], label='Train Loss', color='blue', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Val Loss', color='red', linewidth=2)
    axes[0, 0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].plot(history['dice_scores'], label='Dice Score', color='green', linewidth=2)
    axes[0, 1].set_title('Dice Score', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Dice Score')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    axes[1, 0].plot(history['iou_scores'], label='IoU Score', color='orange', linewidth=2)
    axes[1, 0].set_title('IoU Score', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('IoU Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    if 'learning_rates' in history:
        axes[1, 1].plot(history['learning_rates'], label='Learning Rate', color='purple', linewidth=2)
        axes[1, 1].set_title('Learning Rate', fontsize=14, fontweight='bold')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        axes[1, 1].set_yscale('log')
    else:
        axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches='tight', facecolor='white')
    plt.close()

def save_performance_report(history, final_dice, final_iou, save_path):
    """성능 리포트 텍스트 파일 저장"""
    with open(save_path, 'w', encoding='utf-8') as f:
        f.write("🎯 Recycle Segmentation 성능 리포트\n")
        f.write("=" * 60 + "\n\n")
        
        f.write("📊 학습 결과:\n")
        f.write(f"  • 총 에포크: {len(history['train_loss'])}\n")
        f.write(f"  • 최종 Train Loss: {history['train_loss'][-1]:.4f}\n")
        f.write(f"  • 최종 Val Loss: {history['val_loss'][-1]:.4f}\n")
        f.write(f"  • 최고 Dice Score: {max(history['dice_scores']):.4f}\n")
        f.write(f"  • 최고 IoU Score: {max(history['iou_scores']):.4f}\n\n")
        
        f.write("🎯 최종 성능:\n")
        f.write(f"  • Dice Score: {final_dice:.4f}\n")
        f.write(f"  • IoU Score: {final_iou:.4f}\n\n")
        
        f.write("📈 성능 평가:\n")
        if final_dice > 0.85:
            f.write("  ✅ 우수한 성능! 실제 배포 가능한 수준\n")
        elif final_dice > 0.7:
            f.write("  🟢 좋은 성능! 실용적으로 사용 가능\n")
        elif final_dice > 0.5:
            f.write("  🟡 보통 성능. 추가 개선으로 향상 가능\n")
        else:
            f.write("  🔴 성능 부족. 추가 튜닝 필요\n")

# ===============================================================================
# 🚀 STEP 5: 메인 파이프라인 실행
# ===============================================================================

def calculate_class_weights_from_pixel_counter(pixel_counter):
    """픽셀 카운터로부터 클래스 가중치 계산"""
    print("\n📊 클래스 가중치 계산 중...")
    
    total_pixels = sum(pixel_counter.values())
    
    class_weights = []
    for class_id in range(len(class_names)):
        if class_id in pixel_counter and pixel_counter[class_id] > 0:
            weight = total_pixels / (len(class_names) * pixel_counter[class_id])
        else:
            weight = 1.0
        class_weights.append(weight)
    
    max_weight = max(class_weights)
    if max_weight > 10:
        class_weights = [w / max_weight * 10 for w in class_weights]
    
    class_weights_tensor = torch.FloatTensor(class_weights)
    
    print("🎯 클래스별 가중치:")
    for i, (class_name, weight) in enumerate(zip(class_names, class_weights)):
        pixel_count = pixel_counter.get(i, 0)
        print(f"   {class_name}: {weight:.3f} (픽셀 수: {pixel_count:,})")
    
    return class_weights_tensor

def run_clean_pipeline():
    """완전한 파이프라인 실행"""
    print("🚀 깔끔한 Recycle Segmentation 파이프라인 시작!")
    print("="*80)

    try:
        # STEP 1: 데이터 전처리
        print("\n📊 STEP 1: 데이터 전처리 시작")
        final_data_list, pixel_counter = preprocess_datasets()

        if len(final_data_list) == 0:
            print("❌ 처리할 데이터가 없습니다!")
            return

        print(f"✅ 전처리 완료: {len(final_data_list)}개 데이터")

        # STEP 2: 클래스 가중치 계산
        print("\n🔧 STEP 2: 클래스 가중치 계산")
        
        class_weights_tensor = calculate_class_weights_from_pixel_counter(pixel_counter).to(device)
        print(f"✅ 클래스 가중치 계산 완료")

        # 모델 로딩
        print("\n🤖 모델 및 프로세서 로딩 중...")
        processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small", use_fast=True)
        model = AutoModelForSemanticSegmentation.from_pretrained(
            "apple/deeplabv3-mobilevit-small",
            num_labels=len(class_names),
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True
        ).to(device)
        print("✅ 모델 로딩 완료")

        # Train/Val 분할
        print("\n📊 Train/Val 데이터 분할 중...")
        random.seed(42)
        random.shuffle(final_data_list)
        
        split_idx = int(len(final_data_list) * 0.8)
        train_list = final_data_list[:split_idx]
        val_list = final_data_list[split_idx:]

        print(f"📊 데이터 분할 완료: Train {len(train_list)}개, Val {len(val_list)}개")

        # 데이터셋 생성
        train_ds = create_clean_dataset(train_list, processor, input_size=512)
        val_ds = create_clean_dataset(val_list, processor, input_size=512)
        
        train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)

        print(f"✅ DataLoader 생성 완료: Train batches {len(train_loader)}, Val batches {len(val_loader)}")

        # STEP 3: 모델 학습
        print("\n🚀 STEP 3: 모델 학습 시작")
        history, best_model_path = improved_training(
            model, train_loader, val_loader, processor, class_weights_tensor,
            max_epochs=200, patience=30, device=device
        )

        # Best 모델 로딩
        if best_model_path and os.path.exists(best_model_path):
            print(f"\n📥 Best 모델 로딩 중: {best_model_path}")
            try:
                model = AutoModelForSemanticSegmentation.from_pretrained(
                    best_model_path, local_files_only=True
                ).to(device)
                processor = AutoImageProcessor.from_pretrained(best_model_path)
                print("✅ Best 모델 로딩 완료!")
            except Exception as e:
                print(f"⚠️ Best 모델 로딩 실패: {e}")

        # 최종 성능 평가
        print("\n📊 최종 성능 평가 중...")
        model.eval()
        all_preds, all_targets = [], []
        with torch.no_grad():
            for batch in val_loader:
                _, pred = gentle_predict(batch, model, 512, len(class_names))
                all_preds.append(pred.cpu().numpy())
                all_targets.append(batch["labels"].cpu().numpy())
        
        pred_flat = np.concatenate([p.flatten() for p in all_preds])
        target_flat = np.concatenate([t.flatten() for t in all_targets])
        final_dice, final_iou, _ = calculate_advanced_metrics(pred_flat, target_flat, len(class_names))

        # STEP 4: 결과 저장
        print("\n🗂️ STEP 4: 깔끔한 결과 저장")
        save_dir = clean_save_all_results(
            model, val_loader, processor, history, 
            final_dice, final_iou, best_model_path
        )

        print(f"\n🎉 완료! 모든 결과가 저장되었습니다:")
        print(f"📁 {save_dir}")
        print(f"📊 최종 성능: Dice {final_dice:.4f}, IoU {final_iou:.4f}")

        return save_dir

    except Exception as e:
        print(f"❌ 파이프라인 실행 중 오류: {e}")
        import traceback
        traceback.print_exc()
        return None

# 실행
if __name__ == "__main__":
    run_clean_pipeline()

🗂️ 데이터 정리 시작...
📁 train 폴더에서 찾은 이미지: 930개
📁 train 폴더에서 찾은 마스크: 930개
✅ 데이터 정리 완료!
   📁 이미지: 930개 → datasets/images/
   📁 마스크: 930개 → datasets/masks/
🔧 사용 디바이스: cuda
🚀 깔끔한 Recycle Segmentation 파이프라인 시작!

📊 STEP 1: 데이터 전처리 시작

📊 STEP 1: datasets 폴더 데이터 매칭 시작
📁 JPG 이미지 파일 개수: 930개
📁 PNG 마스크 파일 개수: 930개
✅ 숫자 기반 매칭 성공: 930개
🔍 매칭된 이미지-마스크 쌍: 930개


데이터 매칭: 100%|██████████| 930/930 [00:02<00:00, 423.23it/s]


✅ 최종 데이터 개수: 930개
✅ 전처리 완료: 930개 데이터

🔧 STEP 2: 클래스 가중치 계산

📊 클래스 가중치 계산 중...
🎯 클래스별 가중치:
   background: 0.195 (픽셀 수: 178,471,959)
   can: 5.586 (픽셀 수: 6,234,788)
   glass: 5.579 (픽셀 수: 6,242,213)
   paper: 4.862 (픽셀 수: 7,163,569)
   plastic: 2.850 (픽셀 수: 12,220,847)
   styrofoam: 3.296 (픽셀 수: 10,565,234)
   vinyl: 1.521 (픽셀 수: 22,895,310)
✅ 클래스 가중치 계산 완료

🤖 모델 및 프로세서 로딩 중...


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
  _torch_pytree._register_pytree_node(
  return torch.load(checkpoint_file, map_location=map_location)
Some weights of MobileViTForSemanticSegmentation were not initialized from the model checkpoint at apple/deeplabv3-mobilevit-small and are newly initialized because the shapes did not match:
- segmentation_head.classifier.convolution.weight: found shape torch.Size([21, 256, 1, 1]) in the checkpoint and torch.Size([7, 256, 1, 1]) in the model instantiated
- segmentation_head.classifier.convolution.bias: found shape torch.Size([21]) in the checkpoint and torch.Size([7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ 모델 로딩 완료

📊 Train/Val 데이터 분할 중...
📊 데이터 분할 완료: Train 744개, Val 186개
🔍 필터링 없는 데이터셋 생성 중...
📊 입력 데이터 개수: 744개
📊 전체 데이터셋 크기: 744개
✅ 최종 데이터셋 크기: 744개
🔍 필터링 없는 데이터셋 생성 중...
📊 입력 데이터 개수: 186개
📊 전체 데이터셋 크기: 186개
✅ 최종 데이터셋 크기: 186개
✅ DataLoader 생성 완료: Train batches 46, Val batches 12

🚀 STEP 3: 모델 학습 시작


  scaler = GradScaler()
  with autocast():
  with autocast():


Epoch 1/200  ▶  Train Loss: 1.0833  |  Val Loss: 1.0572  |  Dice: 0.4425  |  IoU: 0.2943  |  LR: 1.00e-04
✅ New Best! Dice 0.4425 저장 완료
Epoch 2/200  ▶  Train Loss: 1.0428  |  Val Loss: 1.0124  |  Dice: 0.6619  |  IoU: 0.5223  |  LR: 1.00e-04
✅ New Best! Dice 0.6619 저장 완료
Epoch 3/200  ▶  Train Loss: 0.9948  |  Val Loss: 0.9483  |  Dice: 0.8392  |  IoU: 0.7391  |  LR: 1.00e-04
✅ New Best! Dice 0.8392 저장 완료
Epoch 4/200  ▶  Train Loss: 0.9270  |  Val Loss: 0.8712  |  Dice: 0.8935  |  IoU: 0.8166  |  LR: 1.00e-04
✅ New Best! Dice 0.8935 저장 완료
Epoch 5/200  ▶  Train Loss: 0.8556  |  Val Loss: 0.8046  |  Dice: 0.9091  |  IoU: 0.8395  |  LR: 1.00e-04
✅ New Best! Dice 0.9091 저장 완료
Epoch 6/200  ▶  Train Loss: 0.7926  |  Val Loss: 0.7363  |  Dice: 0.9326  |  IoU: 0.8772  |  LR: 1.00e-04
✅ New Best! Dice 0.9326 저장 완료
Epoch 7/200  ▶  Train Loss: 0.7384  |  Val Loss: 0.6944  |  Dice: 0.9397  |  IoU: 0.8893  |  LR: 1.00e-04
✅ New Best! Dice 0.9397 저장 완료
Epoch 8/200  ▶  Train Loss: 0.6970  |  Val Loss:

시각화 저장: 100%|██████████| 12/12 [02:08<00:00, 10.74s/it]


✅ 2. 시각화 완료: 186개 이미지 → C:/Users/USER/Desktop/Reco_Notebook\results\visualizations
✅ 3. 성능 결과 저장: C:/Users/USER/Desktop/Reco_Notebook\results\performance

🎉 깔끔한 저장 완료!
📁 저장 위치: C:/Users/USER/Desktop/Reco_Notebook\results
📊 구조:
  ├── best_model/        (학습된 모델)
  ├── visualizations/    (186개 예측 시각화)
  └── performance/       (성능 그래프 + 리포트)

🎉 완료! 모든 결과가 저장되었습니다:
📁 C:/Users/USER/Desktop/Reco_Notebook\results
📊 최종 성능: Dice 0.9899, IoU 0.9801
