In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd '/content/drive/MyDrive/IAP_Final/PC/source'

/content/drive/.shortcut-targets-by-id/10b_mTeA0LBn0XnIPn1kDV03ig94XXR0w/IAP_Final/PC/source


In [None]:
import torch
import torch.nn as nn

class VGGish(nn.Module):
    def __init__(self):
        super(VGGish, self).__init__()
        self.features = nn.Sequential(
            # Conv Block 1
            nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Conv Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Conv Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Conv Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 6 * 4, 4096), nn.ReLU(),
            nn.Linear(4096, 128)  # Output embedding size = 128
        )

    def forward(self, x):  # x: [B, 1, 96, 64]
        x = self.features(x)
        x = self.classifier(x)
        return x


In [None]:
import torch
import torchaudio
import time
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

# 🔁 VGGish 모델 클래스
# from your_code import VGGish
model = VGGish()
model.eval()

# 🔁 오디오 전처리 함수 (96 x 64로 자르기 포함)
def preprocess_wav_to_vggish_input(wav_path):
    waveform, sr = torchaudio.load(wav_path)

    # 리샘플링: 16kHz
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)

    # 모노 변환
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # MelSpectrogram 생성 (VGGish 스펙과 유사하게 설정)
    mel_transform = MelSpectrogram(
        sample_rate=16000,
        n_fft=400,
        hop_length=160,
        win_length=400,
        n_mels=64,
        # Removed fmin and fmax as they are not supported by the constructor
        # fmin=125,
        # fmax=7500
    )
    db_transform = AmplitudeToDB()

    mel_spec = mel_transform(waveform)
    log_mel_spec = db_transform(mel_spec)  # [1, 64, T]

    # 96-frame (time axis) 고정
    if log_mel_spec.shape[2] < 96:
        # zero-padding
        pad = 96 - log_mel_spec.shape[2]
        log_mel_spec = torch.nn.functional.pad(log_mel_spec, (0, pad))
    else:
        log_mel_spec = log_mel_spec[:, :, :96]

    # [1, 1, 96, 64] shape으로 변경
    log_mel_spec = log_mel_spec.permute(0, 2, 1)  # [1, 96, 64]
    log_mel_spec = log_mel_spec.unsqueeze(0)     # [1, 1, 96, 64]

    return log_mel_spec

# 🔁 Inference time 측정 함수
def measure_inference_time(wav_path, repeat=5):
    input_tensor = preprocess_wav_to_vggish_input(wav_path)

    # warm-up
    with torch.no_grad():
        _ = model(input_tensor)

    # 시간 측정
    times = []
    for _ in range(repeat):
        start = time.time()
        with torch.no_grad():
            _ = model(input_tensor)
        times.append(time.time() - start)

    avg_time = sum(times) / repeat
    print(f"Inference time for {wav_path}: {avg_time:.4f} seconds")
    return avg_time

# 🔁 실행 예시
wav_path = "./voices/user_voice_clips/loaded_musics/cliped_user (1).wav"
measure_inference_time(wav_path)

Inference time for ./voices/user_voice_clips/loaded_musics/cliped_user (1).wav: 0.0763 seconds


0.07627973556518555

In [None]:
pip install onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m50.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected pack

In [None]:
def measure_onnx_inference_time(wav_path, onnx_path, repeat=5):
    input_tensor = preprocess_wav_to_vggish_input(wav_path)  # [1, 1, 96, 64]

    # 모델 기대 입력: [1, 64, 96]
    input_numpy = input_tensor.squeeze(1).permute(0, 2, 1).numpy().astype(np.float32)

    ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
    input_name = ort_session.get_inputs()[0].name

    # warm-up
    _ = ort_session.run(None, {input_name: input_numpy})

    # 시간 측정
    times = []
    for _ in range(repeat):
        start = time.time()
        _ = ort_session.run(None, {input_name: input_numpy})
        times.append(time.time() - start)

    avg_time = sum(times) / repeat
    print(f"[ONNX] Inference time for {wav_path}: {avg_time:.4f} seconds")
    return avg_time


In [None]:
onnx_model_path = "./model/VGGish/audioset-vggish-3.onnx"
wav_path = "./voices/user_voice_clips/loaded_musics/cliped_user (1).wav"

measure_onnx_inference_time(wav_path, onnx_model_path)

[ONNX] Inference time for ./voices/user_voice_clips/loaded_musics/cliped_user (1).wav: 0.0459 seconds


0.04587597846984863