<a href="https://colab.research.google.com/github/OverfitSurvivor/code/blob/main/drone_DAE_plz_0416.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
!pip install pytorch_msssim

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->pytorch_msssim)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->pytorch_msssim)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->pytorch_msssim)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->pytorch_msssim)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->pytorch_msssim)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->pytorch_mss

In [3]:
import os
import zipfile

zip_path = "/content/drive/MyDrive/ICSV31AIChallengeDataset.zip"  # 업로드한 ZIP 파일 경로
extract_path = "/content/ICSV31AIChallengeDataset"  # 압축을 풀 폴더 경로

# 폴더가 없으면 생성
os.makedirs(extract_path, exist_ok=True)

# 압축 해제
with zipfile.ZipFile(zip_path, "r") as zip_ref:
    zip_ref.extractall(extract_path)

print("압축 해제 완료:", extract_path)

압축 해제 완료: /content/ICSV31AIChallengeDataset


## 모듈 불러오기

In [4]:
import csv
import argparse
import os
from typing import Any, List, Tuple

import torch
import torchaudio
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import sklearn.metrics as metrics
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

## 라벨링 및 데이터셋 로더
+ global scaling STFT 적용

In [5]:
#######################
# 1. Utils
#######################
def read_csv(file_path: str) -> List:
    with open(file_path, "r") as f:
        reader = csv.reader(f)
        return list(reader)

def save_csv(save_data: List[Any], save_file_path: str) -> None:
    with open(save_file_path, "w", newline="") as f:
        writer = csv.writer(f, lineterminator="\n")
        writer.writerows(save_data)

def get_anomaly_label(file_path: str) -> int:
    file_name = os.path.basename(file_path)
    train_mode = file_name.split("_")[0]
    if train_mode == "test":
        return -1
    elif "normal" in file_name:
        return 0
    else:
        return 1

def get_drone_label(file_path: str) -> int:
    file_name = os.path.basename(file_path)
    drone_mode = file_name.split("_")[1]
    if drone_mode == "A":
        return 0
    elif drone_mode == "B":
        return 1
    elif drone_mode == "C":
        return 2
    else:
        return -1

def get_direction_label(file_path: str) -> int:
    file_name = os.path.basename(file_path)
    direction_mode = file_name.split("_")[2]
    if direction_mode == "Back":
        return 0
    elif direction_mode == "Front":
        return 1
    elif direction_mode == "Left":
        return 2
    elif direction_mode == "Right":
        return 3
    elif direction_mode == "Clockwise":
        return 4
    elif direction_mode == "CounterClockwise":
        return 5
    else:
        return -1

### 2048/768/256
Global Mean: -4.636630590752749
Global Std: 12.173359870910645

### 2048/768/192 #
Global Mean: -4.616434057646901
Global Std: 12.162795066833496

### 2048/768/128
Global Mean: -4.620500203780632
Global Std: 12.1651611328125

### 4096/4096/256
Global Mean: 2.3749662920402215
Global Std: 11.927532196044922

In [6]:
#######################
# 2. Feature Extraction & Augmentation
#######################
# 계산된 전역 평균과 표준편차
# STFT 바꿀때마다 계산해줘야함

GLOBAL_MEAN =  -4.616434057646901
GLOBAL_STD = 12.162795066833496
def wav_to_log_stft(
    wav_path: str,
    sr: int,
    n_fft: int,
    win_length: int,
    hop_length: int,
    power: float,
) -> torch.Tensor:
    """
    WAV 파일을 STFT 기반 로그 스펙트럼으로 변환.
    - torchaudio.transforms.Spectrogram로 STFT 계산
    - AmplitudeToDB로 로그 변환 후 global standard scaling 적용 **
    """
    stft_transform = torchaudio.transforms.Spectrogram(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        power=power
    )
    wav_data, _ = torchaudio.load(wav_path)
    spec = stft_transform(wav_data)
    amp_to_db = torchaudio.transforms.AmplitudeToDB()
    log_spec = amp_to_db(spec)

    # 전체 데이터셋의 평균과 표준편차로 정규화 적용
    # clamp 추가
    log_spec = (log_spec - GLOBAL_MEAN) / (GLOBAL_STD + 1e-9)
    return log_spec


import torchaudio.transforms as T
# -------------------------------
# Feature Extraction & Augmentation
# -------------------------------
def augment_spec(spec: torch.Tensor) -> torch.Tensor:
    max_shift = int(spec.shape[-1] * 0.1)
    shift = torch.randint(-max_shift, max_shift + 1, (1,)).item()
    spec = torch.roll(spec, shifts=shift, dims=-1)
    time_mask_param = max(1, int(spec.shape[-1] * 0.05))
    time_mask = T.TimeMasking(time_mask_param=time_mask_param)
    spec = time_mask(spec)
    freq_mask_param = max(1, int(spec.shape[-2] * 0.05))
    freq_mask = T.FrequencyMasking(freq_mask_param=freq_mask_param)
    spec = freq_mask(spec)
    return spec

#######################
# 3. Dataset
#######################
class BaselineDataLoader(Dataset):
    def __init__(
        self,
        file_list: List[str],
        sr: int,
        n_fft: int,
        win_length: int,
        hop_length: int,
        power: float,
        augment: bool = False
    ) -> None:
        self.file_list = file_list
        self.sr = sr
        self.n_fft = n_fft
        self.win_length = win_length
        self.hop_length = hop_length
        self.power = power
        self.augment = augment

    def __len__(self) -> int:
        return len(self.file_list)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int, int]:
        wav_path = self.file_list[idx]
        spec = wav_to_log_stft(wav_path, self.sr, self.n_fft, self.win_length, self.hop_length, self.power)
        if self.augment:
            spec = augment_spec(spec)
        if torch.isnan(spec).any() or torch.isinf(spec).any():
            print(f"Warning: Spec from {wav_path} contains NaN or Inf")
        anomaly_label = get_anomaly_label(wav_path)
        drone_label = get_drone_label(wav_path)
        direction_label = get_direction_label(wav_path)
        # 데이터셋에서 스펙트로그램이 로드된 후 검사

        return spec, anomaly_label, drone_label, direction_label

In [9]:
#!/usr/bin/env python3
"""
Drone Anomaly Detection Training Script

이 스크립트는 향상된 Denoising Autoencoder(DAE)를 이용하여 드론 이상 탐지 모델을 학습하고 평가합니다.
훈련 시에는 warm-up epoch 동안 L1 손실을 사용한 뒤, MSSSIM + L1 손실을 사용하고, 최종 평가 시에는 재구성 품질을 GMSD로 측정합니다.
CosineAnnealingLR을 적용하여 학습률을 조정하며, 각 에포크마다 업데이트된 학습률을 출력합니다.
"""

import os
import math
import argparse
import csv
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn import metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from pytorch_msssim import ms_ssim
import torchaudio.transforms as T

# -------------------------------
# Utility Functions & Set Seed
# -------------------------------
def set_seed(seed: int) -> None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def match_size(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    src_h, src_w = source.size(2), source.size(3)
    tgt_h, tgt_w = target.size(2), target.size(3)
    if src_h > tgt_h or src_w > tgt_w:
        start_h = (src_h - tgt_h) // 2
        start_w = (src_w - tgt_w) // 2
        source = source[:, :, start_h:start_h+tgt_h, start_w:start_w+tgt_w]
    elif src_h < tgt_h or src_w < tgt_w:
        diff_h = tgt_h - src_h
        diff_w = tgt_w - src_w
        source = F.pad(source, (diff_w // 2, diff_w - diff_w // 2,
                                diff_h // 2, diff_h - diff_h // 2))
    return source

# -------------------------------
# Loss Functions
# -------------------------------
class MSSSIMLoss(nn.Module):
    def __init__(self, window_size=7, size_average=True, data_range=1.0, K=(0.01, 0.03)):
        super(MSSSIMLoss, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.data_range = data_range
        self.K = K

    def forward(self, img1, img2):
        return 1 - ms_ssim(img1, img2,
                           data_range=self.data_range,
                           win_size=self.window_size,
                           size_average=self.size_average,
                           K=self.K)

class CombinedLoss(nn.Module):
    def __init__(self, window_size=7, size_average=True, msssim_weight=0.8, l1_weight=0.2):
        super(CombinedLoss, self).__init__()
        self.msssim_loss = MSSSIMLoss(window_size, size_average)
        self.msssim_weight = msssim_weight
        self.l1_weight = l1_weight

    def forward(self, img1, img2):
        loss_msssim = self.msssim_loss(img1, img2)
        loss_l1 = F.l1_loss(img1, img2)
        return self.msssim_weight * loss_msssim + self.l1_weight * loss_l1

class GMSDLoss(nn.Module):
    def __init__(self, T: float = 0.0026):
        super(GMSDLoss, self).__init__()
        self.T = T
        weight_x = torch.tensor([[-1, 0, 1],
                                 [-2, 0, 2],
                                 [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        weight_y = torch.tensor([[-1, -2, -1],
                                 [0, 0, 0],
                                 [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('weight_x', weight_x)
        self.register_buffer('weight_y', weight_y)

    def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
        weight_x = self.weight_x.to(dtype=img1.dtype, device=img1.device)
        weight_y = self.weight_y.to(dtype=img1.dtype, device=img1.device)
        grad1_x = F.conv2d(img1, weight_x, padding=1)
        grad1_y = F.conv2d(img1, weight_y, padding=1)
        grad2_x = F.conv2d(img2, weight_x, padding=1)
        grad2_y = F.conv2d(img2, weight_y, padding=1)
        grad1 = torch.sqrt(grad1_x ** 2 + grad1_y ** 2)
        grad2 = torch.sqrt(grad2_x ** 2 + grad2_y ** 2)
        eps = 1e-5
        denom = grad1 ** 2 + grad2 ** 2 + self.T
        denom = torch.clamp(denom, min=eps)
        gms = (2 * grad1 * grad2 + self.T) / denom
        gmsd = torch.std(gms, unbiased=False)
        return gmsd

class MultiScaleGMSDLoss(nn.Module):
    def __init__(self, scales=4, T: float = 0.0026):
        super(MultiScaleGMSDLoss, self).__init__()
        self.scales = scales
        self.gmsd_loss = GMSDLoss(T=T)

    def forward(self, img1, img2):
        gmsd_vals = []
        for scale in range(self.scales):
            gmsd_val = self.gmsd_loss(img1, img2)
            gmsd_vals.append(gmsd_val)
            if scale < self.scales - 1:
                img1 = F.interpolate(img1, scale_factor=0.5, mode='bilinear', align_corners=False)
                img2 = F.interpolate(img2, scale_factor=0.5, mode='bilinear', align_corners=False)
        return sum(gmsd_vals) / len(gmsd_vals)

# -------------------------------
# Dataset Class
# -------------------------------
class BaselineDataLoader(Dataset):
    def __init__(self, file_list: list, sr: int, n_fft: int, win_length: int, hop_length: int, power: float, augment: bool = False) -> None:
        self.file_list = file_list
        self.sr = sr
        self.n_fft = n_fft
        self.win_length = win_length
        self.hop_length = hop_length
        self.power = power
        self.augment = augment

    def __len__(self) -> int:
        return len(self.file_list)

    def __getitem__(self, idx: int) -> tuple:
        wav_path = self.file_list[idx]
        spec = wav_to_log_stft(wav_path, self.sr, self.n_fft, self.win_length, self.hop_length, self.power)
        if self.augment:
            spec = augment_spec(spec)
        anomaly_label = get_anomaly_label(wav_path)
        drone_label = get_drone_label(wav_path)
        direction_label = get_direction_label(wav_path)
        return spec, anomaly_label, drone_label, direction_label

# -------------------------------
# DataLoader 생성 함수
# -------------------------------
def get_train_val_loader(args: argparse.Namespace, pin_memory: bool = False, num_workers: int = 0):
    file_list = sorted(os.listdir(args.train_dir))
    file_list = [os.path.join(args.train_dir, f) for f in file_list]
    # train_dir에는 정상 데이터만 존재한다고 가정.
    train_files, val_files = train_test_split(file_list, test_size=0.2, random_state=2025)
    train_files = [f for f in train_files if get_anomaly_label(f) == 0]
    train_dataset = BaselineDataLoader(train_files, sr=args.sr, n_fft=args.n_fft,
                                        win_length=args.win_length, hop_length=args.hop_length,
                                        power=args.power, augment=True)
    val_dataset = BaselineDataLoader(val_files, sr=args.sr, n_fft=args.n_fft,
                                      win_length=args.win_length, hop_length=args.hop_length,
                                      power=args.power, augment=False)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
                            num_workers=num_workers, pin_memory=pin_memory)
    return train_loader, val_loader

def get_eval_loader(args: argparse.Namespace, pin_memory: bool = False, num_workers: int = 0):
    file_list = sorted(os.listdir(args.eval_dir))
    file_list = [os.path.join(args.eval_dir, f) for f in file_list]
    dataset = BaselineDataLoader(file_list, sr=args.sr, n_fft=args.n_fft, win_length=args.win_length,
                                 hop_length=args.hop_length, power=args.power, augment=False)
    loader = DataLoader(dataset, batch_size=1, shuffle=False,
                        num_workers=num_workers, pin_memory=pin_memory)
    return loader, file_list

# -------------------------------
# Model Architecture (Enhanced DAE)
# -------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------------------------
# Residual Block
# -------------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout=0.03):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.downsample = nn.Identity()

    def forward(self, x):
        identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

# -------------------------------
# Encoder / Decoder Blocks
# -------------------------------
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.03):
        super(EncoderBlock, self).__init__()
        self.resblock = ResidualBlock(in_channels, out_channels, stride=2, dropout=dropout)

    def forward(self, x):
        return self.resblock(x)

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels, dropout=0.03):
        super(DecoderBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3,
                                         stride=2, padding=1, output_padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        total_in = out_channels + skip_channels
        self.resblock = ResidualBlock(total_in, out_channels, stride=1, dropout=dropout)

    def forward(self, x, skips):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dropout(x)

        if isinstance(skips, (list, tuple)):
            skips = [match_size(s, x) for s in skips]
            skip = torch.cat(skips, dim=1)
        else:
            skip = match_size(skips, x)

        x = torch.cat([x, skip], dim=1)
        x = self.resblock(x)
        return x

# -------------------------------
# Enhanced Denoising Autoencoder
# -------------------------------
class EnhancedDenoisingAutoencoder(nn.Module):
    def __init__(self, input_channels=1, dropout=0.03):
        super(EnhancedDenoisingAutoencoder, self).__init__()
        self.enc1 = EncoderBlock(input_channels, 32, dropout)
        self.enc2 = EncoderBlock(32, 64, dropout)
        self.enc3 = EncoderBlock(64, 128, dropout)
        self.enc4 = EncoderBlock(128, 256, dropout)
        self.enc5 = EncoderBlock(256, 512, dropout)

        # 정확한 skip channel 수로 수정
        self.dec5 = DecoderBlock(in_channels=512, skip_channels=256, out_channels=256, dropout=dropout)
        self.dec4 = DecoderBlock(in_channels=256, skip_channels=384, out_channels=128, dropout=dropout)
        self.dec3 = DecoderBlock(in_channels=128, skip_channels=192, out_channels=64, dropout=dropout)
        self.dec2 = DecoderBlock(in_channels=64,  skip_channels=96,  out_channels=32, dropout=dropout)

        self.dec1 = nn.ConvTranspose2d(32, input_channels, kernel_size=3, stride=2,
                                       padding=1, output_padding=1)

    def forward(self, x):
        e1 = self.enc1(x)   # (B, 32)
        e2 = self.enc2(e1)  # (B, 64)
        e3 = self.enc3(e2)  # (B, 128)
        e4 = self.enc4(e3)  # (B, 256)
        e5 = self.enc5(e4)  # (B, 512)

        d5 = self.dec5(e5, e4)              # decoder5 ← e4
        d4 = self.dec4(d5, [e4, e3])        # decoder4 ← e4, e3
        d3 = self.dec3(d4, [e3, e2])        # decoder3 ← e3, e2
        d2 = self.dec2(d3, [e2, e1])        # decoder2 ← e2, e1
        d1 = self.dec1(d2)                  # 마지막 복원

        d1 = F.interpolate(d1, size=x.shape[2:], mode='bilinear', align_corners=False)
        return d1

def DAEModel(dropout) -> nn.Module:
    return EnhancedDenoisingAutoencoder(input_channels=1, dropout=dropout)

# -------------------------------
# Training & Evaluation Functions
# -------------------------------
def train_and_evaluate(args: argparse.Namespace) -> None:
    print("Training started...")
    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    set_seed(2025)
    if device.type == "cuda":
        torch.backends.cudnn.benchmark = True

    model = DAEModel(args.dropout).to(device)

    train_loader, val_loader = get_train_val_loader(args, pin_memory=True, num_workers=args.n_workers)
    eval_loader, eval_file_list = get_eval_loader(args, pin_memory=True, num_workers=args.n_workers)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    # Combined loss (MSSSIM + L1) for normal training after warm-up
    criterion = CombinedLoss(window_size=7, size_average=True, msssim_weight=0.9, l1_weight=0.1)
    # L1 loss only for warm-up phase
    criterion_l1 = nn.L1Loss()

    train_losses = []
    val_losses = []
    scaler = torch.amp.GradScaler()

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(1, args.epochs + 1):
        model.train()
        total_loss = 0.0
        count = 0
        p_bar = tqdm(train_loader, total=len(train_loader),
                     desc=f"Epoch {epoch}/{args.epochs}", ncols=100)
        for data in p_bar:
            spec = data[0].to(device)
            # 노이즈 추가 (training augmentation)
            noisy_spec = spec + args.noise_factor * torch.randn_like(spec)
            optimizer.zero_grad()
            with torch.amp.autocast(device_type='cuda'):
                decoded = model(noisy_spec)
                # Warm-up: 초기 args.warmup_epochs 에서는 L1 loss만 사용
                if epoch <= args.warmup_epochs:
                    loss = criterion_l1(decoded, spec)
                else:
                    loss = criterion(decoded, spec)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
            count += 1
            p_bar.set_description(f"Epoch {epoch}, Loss: {loss.item():.4f}")

        avg_train_loss = total_loss / count
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch} average train loss: {avg_train_loss:.4f}")

        model.eval()
        total_val_loss = 0.0
        val_count = 0
        y_true_val = []
        y_pred_val = []
        with torch.no_grad():
            for val_data in val_loader:
                spec, anomaly_label, _, _ = val_data
                spec = spec.to(device)
                decoded = model(spec)
                loss = criterion(decoded, spec)
                total_val_loss += loss.item()
                val_count += 1
                y_true_val.append(anomaly_label.item())
                y_pred_val.append(loss.item())
        avg_val_loss = total_val_loss / val_count
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch} average val loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            checkpoint_path = os.path.join(args.model_dir, args.model_path)
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Epoch {epoch}: Model improved. Saved checkpoint to {checkpoint_path}")
        else:
            patience_counter += 1
            if patience_counter >= args.early_stopping_patience:
                print("Early stopping triggered.")
                break

        scheduler.step()

    # 학습 곡선 저장
    epochs_range = range(1, len(train_losses) + 1)
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Combined Loss (MSSSIM + L1)')
    plt.title('Learning Curve')
    plt.legend()
    plt.grid(True)
    curve_path = os.path.join(args.result_dir, "learning_curve.png")
    plt.savefig(curve_path)
    plt.show()
    print(f"Learning curve saved to {curve_path}")

    # GMSD 기반 평가 (clamp 제거)
    gmsd_criterion = MultiScaleGMSDLoss().to(device)
    y_true_eval = []
    y_pred_eval = []
    score_list = [["File Name", "True Label", "Anomaly Score"]]

    model.eval()
    with torch.no_grad():
        for idx, (spec, anomaly_label, drone_label, _) in enumerate(tqdm(eval_loader, desc="Evaluating")):
            spec = spec.to(device)
            decoded = model(spec)
            loss_gmsd = gmsd_criterion(decoded, spec)

            y_true_eval.append(anomaly_label.item())
            y_pred_eval.append(loss_gmsd.item())

            file_name = os.path.splitext(os.path.basename(eval_file_list[idx]))[0]
            score_list.append([file_name, anomaly_label.item(), loss_gmsd.item()])

    roc_auc_eval = metrics.roc_auc_score(np.array(y_true_eval), np.array(y_pred_eval))
    print(f"Final ROC AUC on eval set: {roc_auc_eval:.4f}")

    csv_path = os.path.join(args.result_dir, "eval_score_enhanced_dae_11.csv")
    with open(csv_path, mode="w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(score_list)
    print(f"Saved evaluation scores to {csv_path}")

# -------------------------------
# Argument Parsing 함수
# -------------------------------
def get_args() -> argparse.Namespace:
    param = {
        "train_dir": "/content/ICSV31AIChallengeDataset/train",
        "eval_dir": "/content/ICSV31AIChallengeDataset/eval",
        "result_dir": "/content/drive/MyDrive",
        "model_dir": "/content/drive/MyDrive",
        "model_path": "enhanced_model_dae_11.pth",
        "epochs": 250,
        "batch_size": 32,
        "lr": 0.001,
        "gpu": 0,
        "n_workers": 1,
        "early_stopping_patience": 20,
        "warmup_epochs": 15,           # Warm-up epoch: 초기 epoch은 L1 loss만 사용
        "noise_factor": 0.05,
        "dropout": 0.05,
        "sr": 16000,
        "n_fft": 2048,
        "win_length": 768,
        "hop_length": 192,
        "power": 2.0
    }
    parser = argparse.ArgumentParser()
    for key, value in param.items():
        parser.add_argument(f"--{key.replace('_','-')}", default=value, type=type(value))
    args, unknown = parser.parse_known_args()
    return args

## 모델 학습 및 평가

In [None]:
 if __name__ == "__main__":
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    set_seed(2025)
    train_and_evaluate(args)

Training started...
Using device: cuda


Epoch 1, Loss: 0.3627: 100%|██████████████████████████████████████| 135/135 [01:06<00:00,  2.02it/s]

Epoch 1 average train loss: 0.4138





Epoch 1 average val loss: 0.6578
Epoch 1: Model improved. Saved checkpoint to /content/drive/MyDrive/enhanced_model_dae_11.pth


Epoch 2, Loss: 0.3535:  32%|████████████▍                          | 43/135 [00:19<00:40,  2.26it/s]