In [None]:
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import numpy as np
import torch.nn as nn
import math
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class AudioPreprocessor:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
        self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(self.device)
        self.model.eval()
        self.positional_encoder = PositionalEncoding(d_model=768).to(self.device)

    def get_embedding(self, audio_path):
        """오디오 파일 처리 및 임베딩 추출"""
        try:
            # 오디오 로드 및 전처리
            waveform, sample_rate = torchaudio.load(audio_path)
            
            # 스테레오를 모노로 변환
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            # 샘플링 레이트 변환
            if sample_rate != 16000:
                resampler = torchaudio.transforms.Resample(sample_rate, 16000)
                waveform = resampler(waveform)
            
            # Wav2Vec2 입력 처리
            inputs = self.processor(
                waveform.squeeze().numpy(),
                sampling_rate=16000,
                return_tensors="pt",
                padding=True
            )
            
            input_values = inputs.input_values.to(self.device)
            
            # 특징 추출
            with torch.no_grad():
                outputs = self.model(input_values)
                features = outputs.last_hidden_state
                features = self.positional_encoder(features)
                features = features.mean(dim=1).squeeze().cpu().numpy()
            
            return features
            
        except Exception as e:
            logger.error(f"오디오 처리 중 오류 발생: {str(e)}")
            raise

def extract_audio_features(audio_path):
    preprocessor = AudioPreprocessor()
    return preprocessor.get_embedding(audio_path)