### 가장 유사한 pair를 미리 저장합니다.

- 유사도를 매번 계산하는것은 너무 큰 비용임.

In [None]:
import os
import pandas as pd
import librosa
import numpy as np
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm
import re

# -----------------------------
# Step 1: Load Dataset with Correct Paths
# -----------------------------
def load_dataset(attributes_file, class_dir, class_name):
    if not os.path.isfile(attributes_file):
        raise FileNotFoundError(f"Attributes file not found: {attributes_file}")

    # CSV 파일에서 파일 이름과 라벨을 가져옴
    df = pd.read_csv(attributes_file)
    filenames = df['file_name'].tolist()
    labels = ['anomaly' if 'anomaly' in name.lower() else 'normal' for name in filenames]
    file_paths = [os.path.join(datasets_dir, f) for f in filenames]
    print(file_paths)

    return file_paths, labels

# -----------------------------
# Step 2: Compute Spectrograms for All Files
# -----------------------------
def compute_all_spectrograms(file_paths, n_fft=160, hop_length=80, target_shape=(32, 32)):
    spectrograms = {}
    for path in tqdm(file_paths, desc="Computing Spectrograms"):
        # 경로 확인 및 정규화
        norm_path = os.path.normpath(path)

        if not os.path.exists(norm_path):
            print(f"File not found: {norm_path}. Skipping.")
            continue

        try:
            y, sr = librosa.load(norm_path, sr=None)
            S = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length))
            S_resized = librosa.util.fix_length(S, size=target_shape[0], axis=0)
            S_resized = librosa.util.fix_length(S_resized, size=target_shape[1], axis=1)
            spectrograms[norm_path] = S_resized
        except Exception as e:
            print(f"Error loading {norm_path}: {e}. Skipping.")

    return spectrograms

# -----------------------------
# Step 3: Find Best Matching Pair for Anomalies using SSIM
# -----------------------------
def find_best_matches_for_anomalies(anomaly_paths, normal_paths, spectrograms):
    pairs = []
    
    for anomaly_path in tqdm(anomaly_paths, desc="Finding Best Matches"):
        # spectrograms 딕셔너리에 anomaly_path가 존재하는지 확인
        anomaly_spectrogram = spectrograms.get(anomaly_path)
        if anomaly_spectrogram is None:
            continue

        best_match = None
        best_score = -1
        
        for normal_path in normal_paths:
            normal_spectrogram = spectrograms.get(normal_path)
            if normal_spectrogram is None:
                continue

            # data_range 파라미터를 추가하여 스펙트로그램의 최대값을 범위로 지정
            score, _ = ssim(anomaly_spectrogram, normal_spectrogram, full=True, data_range=anomaly_spectrogram.max())
            
            if score > best_score:
                best_score = score
                best_match = normal_path

        pairs.append({
            "anomaly": anomaly_path,
            "normal": best_match,
            "similarity": best_score,
            "method": "SSIM"
        })
        
    return pairs

# -----------------------------
# Step 4: Save Pairs to CSV
# -----------------------------
def compute_and_save_all_class_pairs(datasets_dir, output_file="all_class_matching_pairs.csv"):
    all_pairs = []
    
    # 클래스별로 이상 및 정상 파일을 가져옴
    for class_name in os.listdir(datasets_dir):
        class_dir = os.path.join(datasets_dir, class_name)
        if not os.path.isdir(class_dir):
            continue

        attributes_file = os.path.join(class_dir, "attributes_00.csv")
        try:
            # load_dataset 함수에 class_dir을 전달
            file_paths, labels = load_dataset(attributes_file, class_dir, class_name)
            anomaly_paths = [path for path, label in zip(file_paths, labels) if label == 'anomaly']
            normal_paths = [path for path, label in zip(file_paths, labels) if label == 'normal']
        except FileNotFoundError:
            print(f"Attributes file not found for class: {class_name}")
            continue

        # 스펙트로그램을 미리 계산하여 저장
        all_paths = anomaly_paths + normal_paths
        spectrograms = compute_all_spectrograms(all_paths)

        # 각 이상 파일에 대해 가장 유사한 정상 파일을 찾아 페어링
        pairs = find_best_matches_for_anomalies(anomaly_paths, normal_paths, spectrograms)
        for pair in pairs:
            pair["class"] = class_name
            all_pairs.append(pair)

    # CSV로 저장
    all_pairs_df = pd.DataFrame(all_pairs)
    all_pairs_df.to_csv(output_file, index=False)
    print(f"All class matching pairs saved to {output_file}")

# -----------------------------
# Load Matching Pairs from Saved CSV File
# -----------------------------
def load_all_class_pairs(output_file="all_class_matching_pairs.csv"):
    if os.path.exists(output_file):
        all_pairs_df = pd.read_csv(output_file)
        all_pairs = {}
        
        for _, row in all_pairs_df.iterrows():
            class_name = row['class']
            anomaly_path = row['anomaly']
            normal_path = row['normal']
            
            if class_name not in all_pairs:
                all_pairs[class_name] = []
                
            all_pairs[class_name].append({
                "anomaly": anomaly_path,
                "normal": normal_path,
                "method": row["method"],
                "similarity": row["similarity"]
            })
        
        print("All class matching pairs loaded from CSV.")
        return all_pairs
    else:
        print(f"No matching pairs file found at {output_file}. Please run compute_and_save_all_class_pairs first.")
        return {}

# -----------------------------
# Example Usage
# -----------------------------
datasets_dir = "../../datasets/dev"
output_file = "all_class_matching_pairs.csv"

compute_and_save_all_class_pairs(datasets_dir, output_file)
all_class_pairs = load_all_class_pairs(output_file)

class_name = "gearbox"  # 조회할 클래스명
if class_name in all_class_pairs:
    for pair in all_class_pairs[class_name]:
        print(f"Class: {class_name}, Anomaly: {pair['anomaly']}, Normal: {pair['normal']}, Method: {pair['method']}, Similarity: {pair['similarity']}")
else:
    print(f"No pairs found for class: {class_name}")
