In [3]:
# 필요한 라이브러리 설치
!pip install asteroid torch torchaudio soundfile tqdm matplotlib thop

Collecting asteroid
  Using cached asteroid-0.7.0-py3-none-any.whl.metadata (11 kB)
Collecting thop
  Using cached thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Collecting asteroid-filterbanks>=0.4.0 (from asteroid)
  Using cached asteroid_filterbanks-0.4.0-py3-none-any.whl.metadata (3.3 kB)
Collecting pytorch-lightning>=2.0.0 (from asteroid)
  Using cached pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics<=0.11.4 (from asteroid)
  Using cached torchmetrics-0.11.4-py3-none-any.whl.metadata (15 kB)
Collecting pb-bss-eval>=0.0.2 (from asteroid)
  Using cached pb_bss_eval-0.0.2-py3-none-any.whl.metadata (3.1 kB)
Collecting torch-stoi>=0.1.2 (from asteroid)
  Using cached torch_stoi-0.2.3-py3-none-any.whl.metadata (3.6 kB)
Collecting torch-optimizer<0.2.0,>=0.0.1a12 (from asteroid)
  Using cached torch_optimizer-0.1.0-py3-none-any.whl.metadata (53 kB)
Collecting julius (from asteroid)
  Using cached julius-0.2.7-py3-none-any.whl
Collecting cac

In [4]:
!pip install torchinfo ptflops

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting ptflops
  Downloading ptflops-0.7.4-py3-none-any.whl.metadata (9.4 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Downloading ptflops-0.7.4-py3-none-any.whl (19 kB)
Installing collected packages: torchinfo, ptflops
Successfully installed ptflops-0.7.4 torchinfo-1.8.0


In [1]:
import os
import torch
import soundfile as sf
from asteroid.models import DPTNet
# from asteroid.losses import SI_SNR
from ptflops import get_model_complexity_info

In [2]:
# SI-SNR 계산 함수 (직접 구현)
def calculate_si_snr(estimation, reference):
    """
    Args:
        estimation (torch.Tensor): 추정된 오디오 (B, T)
        reference (torch.Tensor): 참조 오디오 (B, T)

    Returns:
        si_snr (torch.Tensor): SI-SNR 값 (B,)
    """
    reference = reference - reference.mean(dim=-1, keepdim=True)
    estimation = estimation - estimation.mean(dim=-1, keepdim=True)
    # Inner product
    s_target = torch.sum(estimation * reference, dim=-1, keepdim=True) * reference / torch.sum(reference**2, dim=-1, keepdim=True)
    e_noise = estimation - s_target
    si_snr = 10 * torch.log10(torch.sum(s_target**2, dim=-1) / torch.sum(e_noise**2, dim=-1))
    return si_snr


In [3]:
# 데이터셋 로드 함수
def load_dataset(dataset_path):
    """
    Args:
        dataset_path (str): 데이터셋의 최상위 폴더 경로 (/test)

    Returns:
        file_list (list): 데이터 파일 리스트 [(mixed_path, [s1_path, s2_path]), ...]
    """
    mixed_path = os.path.join(dataset_path, "mixed_data")
    source_paths = [os.path.join(dataset_path, "s1"), os.path.join(dataset_path, "s2")]

    file_list = []
    for filename in os.listdir(mixed_path):
        if filename.endswith(".wav"):
            mixed_file = os.path.join(mixed_path, filename)
            sources = [os.path.join(source_path, filename) for source_path in source_paths]
            file_list.append((mixed_file, sources))
    return file_list

In [4]:
# 모델 초기화
def initialize_model():
    model = DPTNet(n_src=2)  # 2개의 소스 분리
    model.eval()
    return model

In [5]:
from asteroid.metrics import get_metrics
def calculate_metrics(estimates, references, mixture):
    """
    Args:
        estimates (Tensor): 추정된 소스 신호 (N, T)
        references (list of torch.Tensor): 정답 소스 신호 리스트
        mixture (torch.Tensor): 혼합 신호

    Returns:
        si_snri (float): SI-SNRi 값
        sdri (float): SDRi 값
    """
    # estimates 텐서 변환: (B, N, T) -> (N, T)
    estimates = estimates.squeeze(0)  # (N, T)
    print(f"Adjusted estimates shape: {estimates.shape}")

    # 모델 출력 길이에 참조 신호 길이 맞추기
    ref_len = estimates.size(1)
    references = [
        torch.tensor(ref[:ref_len], dtype=torch.float32) if len(ref) > ref_len
        else torch.nn.functional.pad(torch.tensor(ref, dtype=torch.float32), (0, ref_len - len(ref)))
        for ref in references
    ]
    references = torch.stack(references)  # (N, T)

    # 혼합 신호 길이 맞추기
    mixture = mixture[:, :ref_len]  # (1, T)

    # SI-SNR 계산
    si_snr_mixture = calculate_si_snr(mixture.repeat(len(references), 1), references).mean().item()
    si_snr_estimates = calculate_si_snr(estimates, references).mean().item()

    # SI-SNRi 및 SDRi 계산
    si_snri = si_snr_estimates - si_snr_mixture
    # SDR 계산 (Asteroid의 get_metrics 사용)
    metrics_mixture = get_metrics(
        references.numpy(),         # 정답 신호
        mixture.numpy(),            # 혼합 신호
        references.numpy(),         # 기준은 항상 정답 신호
        sample_rate=8000,
        metrics_list=["sdr"]
    )
    metrics_estimates = get_metrics(
        references.numpy(),         # 정답 신호
        estimates.numpy(),          # 추정 신호
        references.numpy(),         # 기준은 항상 정답 신호
        sample_rate=8000,
        metrics_list=["sdr"]
    )

    # SDRi 계산
    sdr_mixture = metrics_mixture["sdr"]
    sdr_estimates = metrics_estimates["sdr"]
    sdri = sdr_estimates - sdr_mixture
    return si_snri, sdri

In [7]:
# 메인 함수
def main(dataset_path):
    # 데이터 로드
    file_list = load_dataset(dataset_path)

    # 모델 초기화
    model = initialize_model()

    total_si_snri = 0
    total_sdri = 0
    total_files = len(file_list)

    print(f"총 {total_files}개의 파일을 처리합니다...")

    # 각 파일에 대해 처리
    for idx, (mixed_file, source_files) in enumerate(file_list):
        print(f"[{idx + 1}/{total_files}] Processing: {mixed_file}")

        # 혼합 오디오 및 소스 로드
        mixed_wave, _ = sf.read(mixed_file)
        target_waves = [sf.read(src)[0] for src in source_files]

        # 입력 데이터 준비
        mixture = torch.tensor(mixed_wave, dtype=torch.float32).unsqueeze(0)  # (1, T)
        references = [torch.tensor(tgt, dtype=torch.float32) for tgt in target_waves]


        # 음성 분리 수행
        with torch.no_grad():
            estimates = model(mixture)  # 모델 추정 결과 (B, T, N)

        # 디버깅: estimates 내용 확인
        output_len = estimates.size(-1)
        print(f"Model output shape: {estimates.shape}")

        mixture = mixture[:, :output_len]

        # SI-SNRi 및 SDRi 계산
        si_snri, sdri = calculate_metrics(estimates, references, mixture)
        total_si_snri += si_snri
        total_sdri += sdri

        print(f"  SI-SNRi: {si_snri:.2f} dB, SDRi: {sdri:.2f} dB")

    # 최종 결과 출력
    avg_si_snri = total_si_snri / total_files
    avg_sdri = total_sdri / total_files
    print(f"\n평균 SI-SNRi: {avg_si_snri:.2f} dB")
    print(f"평균 SDRi: {avg_sdri:.2f} dB")


In [8]:
DATASET_PATH = "/content/drive/MyDrive/Dataset/mk_1800_dataset/test"
main(DATASET_PATH)

총 180개의 파일을 처리합니다...
[1/180] Processing: /content/drive/MyDrive/Dataset/mk_1800_dataset/test/mixed_data/Animal_Animal_401.wav
Model output shape: torch.Size([1, 2, 55125])
Adjusted estimates shape: torch.Size([2, 55125])


  else torch.nn.functional.pad(torch.tensor(ref, dtype=torch.float32), (0, ref_len - len(ref)))


  SI-SNRi: -23.68 dB, SDRi: -10.02 dB
[2/180] Processing: /content/drive/MyDrive/Dataset/mk_1800_dataset/test/mixed_data/Animal_Animal_402.wav
Model output shape: torch.Size([1, 2, 220500])
Adjusted estimates shape: torch.Size([2, 220500])
  SI-SNRi: -19.45 dB, SDRi: -15.94 dB
[3/180] Processing: /content/drive/MyDrive/Dataset/mk_1800_dataset/test/mixed_data/Animal_Animal_403.wav
Model output shape: torch.Size([1, 2, 220500])
Adjusted estimates shape: torch.Size([2, 220500])
  SI-SNRi: -31.74 dB, SDRi: 2.36 dB
[4/180] Processing: /content/drive/MyDrive/Dataset/mk_1800_dataset/test/mixed_data/Animal_Animal_404.wav
Model output shape: torch.Size([1, 2, 220500])
Adjusted estimates shape: torch.Size([2, 220500])
  SI-SNRi: -24.49 dB, SDRi: 8.96 dB
[5/180] Processing: /content/drive/MyDrive/Dataset/mk_1800_dataset/test/mixed_data/Animal_Animal_405.wav
Model output shape: torch.Size([1, 2, 220500])
Adjusted estimates shape: torch.Size([2, 220500])
  SI-SNRi: -34.44 dB, SDRi: -29.47 dB
[6/180

In [9]:
from torchinfo import summary
from ptflops import get_model_complexity_info

def calculate_model_metrics(model):
    """
    Args:
        model: PyTorch 모델

    Returns:
        params (str): 모델 파라미터 수
        macs (str): 모델 MACs 수
    """
    # 입력 데이터 크기 (예: 1초 길이의 신호, 8000 Hz 샘플링 레이트 가정)
    input_size = (1, 32000)  # (Batch, Time)

    # 모델 파라미터 수 계산
    print("모델 요약:")
    summary(model, input_size=input_size)

    # MACs 수 계산
    macs, params = get_model_complexity_info(
        model,
        input_size,
        as_strings=True,
        print_per_layer_stat=False,
        verbose=False
    )
    return params, macs

# DualPathRNN 또는 DPTNet 모델 초기화
from asteroid.models import DPRNNTasNet

# 모델 생성
model = DPTNet(n_src=2)

# 파라미터와 MACs 계산
params, macs = calculate_model_metrics(model)
print(f"모델 파라미터 수: {params}")
print(f"모델 MACs 수: {macs}")

모델 요약:
모델 파라미터 수: 8.53 M
모델 MACs 수: 72.87 GMac
