In [None]:
import os
import torchaudio
import torchaudio.transforms as T
import torch

target_sr = 16000
min_duration = 5
min_length = target_sr * min_duration

name = "music"
target_dir = f"./{name}"
wav_tensor_list = []

# 리샘플러 미리 준비 (동일한 orig_freq일 때만 적용)
resamplers = {}  # orig_sr -> resampler

for root, _, files in os.walk(target_dir):
    for fname in files:
        if fname.lower().endswith(".wav"):
            path = os.path.join(root, fname)
            try:
                waveform, sr = torchaudio.load(path)
                waveform = waveform.to(torch.float16)

                # 모노로 변환
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)
                waveform = waveform.squeeze(0)
                # 리샘플링 필요할 때만
                if sr != target_sr:
                    if sr not in resamplers:
                        resamplers[sr] = T.Resample(orig_freq=sr, new_freq=target_sr)
                    waveform = resamplers[sr](waveform)

                num_samples = waveform.shape[0]
                if num_samples < min_length:
                    continue  # 5초 미만 제거
                elif num_samples > min_length:
                    length = int(num_samples/min_length)
                    for i in range(length):
                        clip  = waveform[i*min_length:(i+1)*min_length]
                        wav_tensor_list.append(clip)

            except Exception as e:
                print(f"오류: {path}, {e}")

print(f"총 {len(wav_tensor_list)}개의 5초짜리 wav 텐서를 저장했습니다.")


총 30361개의 5초짜리 wav 텐서를 저장했습니다.


In [None]:
import os
import torchaudio
import torchaudio.transforms as T
import torch
import random

target_sr = 16000
min_sec = 5
max_sec = 8
min_length = target_sr * min_sec
max_length = target_sr * max_sec
name = "music"
target_dir = f"./{name}"
wav_tensor_list = []

resamplers = {}  # orig_sr -> resampler

for root, _, files in os.walk(target_dir):
    for fname in files:
        if fname.lower().endswith(".wav"):
            path = os.path.join(root, fname)
            try:
                waveform, sr = torchaudio.load(path)
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)
                waveform = waveform.squeeze(0)

                if sr != target_sr:
                    if sr not in resamplers:
                        resamplers[sr] = T.Resample(orig_freq=sr, new_freq=target_sr)
                    waveform = resamplers[sr](waveform)

                num_samples = waveform.shape[0]
                if num_samples < min_length:
                    continue

                start_idx = 0
                while start_idx + min_length <= num_samples:
                    rand_sec = random.uniform(min_sec, max_sec)
                    clip_len = int(rand_sec * target_sr)
                    end_idx = start_idx + clip_len

                    if end_idx > num_samples:
                        break

                    clip = waveform[start_idx:end_idx]
                    wav_tensor_list.append(clip)
                    start_idx = end_idx

                # 💡 마지막 남은 구간도 최소 길이 이상이면 추가
                if num_samples - start_idx >= min_length:
                    clip = waveform[start_idx:]
                    wav_tensor_list.append(clip)

            except Exception as e:
                print(f"오류: {path}, {e}")

print(f"총 {len(wav_tensor_list)}개의 랜덤 길이(5~8초) wav 텐서를 저장했습니다.")


In [2]:
import random
# 셔플
random.seed(42)  # 재현 가능성 위한 고정 시드
random.shuffle(wav_tensor_list)

# 나누기
n_total = len(wav_tensor_list)
n_train = int(n_total * 0.9)

train_list = wav_tensor_list[:n_train]
valid_list = wav_tensor_list[n_train:]

print(f"학습 데이터 개수: {len(train_list)}")
print(f"검증 데이터 개수: {len(valid_list)}")

학습 데이터 개수: 27324
검증 데이터 개수: 3037


In [3]:
torch.save(train_list,f"train_{name}.pt")
torch.save(valid_list,f"eval_{name}.pt")