In [46]:
from typing import List, Tuple, Union
import torch
import numpy as np
from nnAudio import Spectrogram
import torchaudio.compliance.kaldi as ta_kaldi
from spafe.fbanks import gammatone_fbanks
import librosa
class Cqt(torch.nn.Module):
    """Convert Raw audio to CQT"""
    def __init__(
        self,
        fs: int = 16000,
        n_bins: int = 80,
        bins_per_octave: int = 12,
        hop_length: int = 160,
    ):
        super().__init__()
        self.extract_cqt = Spectrogram.CQT(sr=fs, n_bins=n_bins, bins_per_octave=bins_per_octave, hop_length=hop_length)

    def forward(
        self,
        feat: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        output = []
        output_lens = []
        for i, instance in enumerate(feat):
            cqt = self.extract_cqt(feat[i].unsqueeze(0)).squeeze().T
            output.append(cqt)
            output_lens.append(cqt.shape[0])
        cqt_feat = torch.stack(output, 0)
        output_lens = torch.tensor(output_lens).to(cqt_feat.device)
        return cqt_feat, output_lens


class Mfcc(torch.nn.Module):
    """Convert Raw audio to MFCC"""
    def __init__(self, fs: int = 16000, n_mfcc: int = 80, n_fft: int = 400, hop_length: int = 160):
        super().__init__()
        self.fs = fs
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length

    def forward(
        self,
        feat: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        def extract_mfcc(wav):
            mfcc = librosa.feature.mfcc(y=wav, sr=self.fs, n_mfcc=self.n_mfcc, n_fft=self.n_fft, hop_length=self.hop_length)
            return mfcc.T

        output = []
        output_lens = []
        for i, instance in enumerate(feat):
            mfcc = extract_mfcc(feat[i].cpu().numpy())
            output.append(torch.Tensor(mfcc))
            output_lens.append(mfcc.shape[0])
        mfcc_feat = torch.stack(output, 0)
        output_lens = torch.tensor(output_lens).to(mfcc_feat.device)
        return mfcc_feat, output_lens


class Gamma(torch.nn.Module):
    """Convert Raw audio to GAMMA"""
    def __init__(self, fs: int = 16000, n_filts: int = 80, n_fft: int = 400, hop_length: int = 160):
        super().__init__()
        self.fs = fs
        self.n_filts = n_filts
        self.n_fft = n_fft
        self.hop_length = hop_length

    def forward(
        self,
        feat: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        def extract_gamma(wav):
            gammatone_filter_bank = gammatone_fbanks.gammatone_filter_banks(nfilts=self.n_filts, nfft=self.n_fft, fs=self.fs)
            y = librosa.util.normalize(wav)
            magnitude = np.abs(librosa.stft(y=y, win_length=self.n_fft, n_fft=self.n_fft, hop_length=self.hop_length)) ** 2
            Gam = np.dot(gammatone_filter_bank[0], magnitude)
            LogGamSpec = librosa.power_to_db(Gam, ref=np.max)
            return LogGamSpec.T

        output = []
        output_lens = []
        for i, instance in enumerate(feat):
            gamma = extract_gamma(feat[i].cpu().numpy())
            output.append(torch.Tensor(gamma))
            output_lens.append(gamma.shape[0])
        gamma_feat = torch.stack(output, 0)
        output_lens = torch.tensor(output_lens).to(gamma_feat.device)
        return gamma_feat, output_lens


class LogMel(torch.nn.Module):
    """Convert Raw audio to LOGMEL using librosa"""
    def __init__(self, fs: int = 16000, n_mels: int = 80, n_fft: int = 400, hop_length: int = 160, fmin: float = 0.0, fmax: float = None):
        super().__init__()
        self.fs = fs
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.fmin = fmin
        self.fmax = fmax if fmax is not None else fs // 2

    def forward(
        self,
        feat: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        def extract_mel(wav):
            wav_tensor = torch.tensor(wav).unsqueeze(0)
            return mel_spectrogram_transform(wav_tensor).T

        output=[]
        output_lens=[]
        for i, instance in enumerate(feat):
            mel = extract_mel(feat[i].cpu().numpy())
            output.append(torch.Tensor(mel))
            output_lens.append(mel.shape[0])
        #Hardcode again
        logmel_feat = torch.stack(output,0).cuda()
        output_lens = logmel_feat.new_full(
                [logmel_feat.size(0)], fill_value=logmel_feat.size(1), dtype=torch.long
        )
        return logmel_feat, output_lens.cuda()

import torchaudio.compliance.kaldi as ta_kaldi
import torch
from typing import Tuple

class LogMelTorchaudio(torch.nn.Module):
    """Convert Raw audio to LOGMEL using torchaudio"""
    def __init__(self, fs: int = 16000, n_mels: int = 80, n_fft: int = 400, hop_length: int = 160, fmin: float = 0.0, fmax: float = None):
        super().__init__()
        self.fs = fs
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.fmin = fmin
        self.fmax = fmax if fmax is not None else fs // 2

    def forward(
        self,
        feat: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        def extract_mel(wav):
            wav_tensor = torch.tensor(wav).unsqueeze(0)
            features = ta_kaldi.fbank(
                wav_tensor,
                num_mel_bins=self.n_mels,
                sample_frequency=self.fs,
                frame_length=self.n_fft / self.fs * 1000,  # Convert to milliseconds
                frame_shift=self.hop_length / self.fs * 1000,  # Convert to milliseconds
                window_type='hamming',
                use_energy=True,
                dither=0.0,
                energy_floor=0.0,
                use_log_fbank=True,
                use_power=True,
                snip_edges=False,
            )
            return features.numpy()

        output = []
        output_lens = []
        for i, instance in enumerate(feat):
            mel = extract_mel(feat[i].cpu().numpy())
            output.append(torch.Tensor(mel))
            output_lens.append(mel.shape[0])
        logmel_feat = torch.stack(output, 0)
        output_lens = torch.tensor(output_lens).to(logmel_feat.device)
        return logmel_feat, output_lens

    
    
    
class FusedFeatureExtractor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.logmel_extractor = LogMel()
        self.mfcc_extractor = Mfcc()
        self.cqt_extractor = Cqt()
        self.gamma_extractor = Gamma()
        self.logmeltorchaudio=LogMelTorchaudio()

    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        logmel_feat, logmel_lens = self.logmel_extractor(input, input_lengths)
        mfcc_feat, mfcc_lens = self.mfcc_extractor(input, input_lengths)
        cqt_feat, cqt_lens = self.cqt_extractor(input, input_lengths)
        gamma_feat, gamma_lens = self.gamma_extractor(input, input_lengths)
        logmelTA_feat, logmel_lensTA = self.logmeltorchaudio(input, input_lengths)
        print(logmel_lens)
        print(mfcc_lens)
        print(cqt_lens)
        print(gamma_lens)
        print(logmel_lensTA)
        min_length = min(logmel_lens.min(), mfcc_lens.min(), cqt_lens.min(), gamma_lens.min())

        logmel_feat = logmel_feat[:, :min_length, :]
        mfcc_feat = mfcc_feat[:, :min_length, :]
        cqt_feat = cqt_feat[:, :min_length, :]
        gamma_feat = gamma_feat[:, :min_length, :]

        return [logmel_feat, mfcc_feat, cqt_feat, gamma_feat], logmel_lens.new_full((logmel_feat.size(0),), min_length)


# Example usage
fused_feature_extractor = FusedFeatureExtractor()

input_tensor = torch.randn(8, 16000)  # Example input tensor
input_lengths = torch.tensor([16000] * 8)  # Example input lengths

# Extract synchronized features
synchronized_features, synchronized_lengths = fused_feature_extractor(input_tensor, input_lengths)


Creating CQT kernels ...CQT kernels created, time used = 0.0076 seconds


  return mel_spectrogram_transform(wav_tensor).T


tensor([101, 101, 101, 101, 101, 101, 101, 101], device='cuda:0')
tensor([101, 101, 101, 101, 101, 101, 101, 101])
tensor([101, 101, 101, 101, 101, 101, 101, 101])
tensor([101, 101, 101, 101, 101, 101, 101, 101])
tensor([100, 100, 100, 100, 100, 100, 100, 100])


In [16]:
for i in [32452,46346,25667,422]:
    input_tensor = torch.randn(8, i)  # Example input tensor
    input_lengths = torch.tensor([i] * 8)  # Example input lengths

    # Extract synchronized features
    synchronized_features, synchronized_lengths = fused_feature_extractor(input_tensor, input_lengths)

tensor([201, 201, 201, 201, 201, 201, 201, 201])
tensor([203, 203, 203, 203, 203, 203, 203, 203])
tensor([203, 203, 203, 203, 203, 203, 203, 203])
tensor([203, 203, 203, 203, 203, 203, 203, 203])
tensor([288, 288, 288, 288, 288, 288, 288, 288])
tensor([290, 290, 290, 290, 290, 290, 290, 290])
tensor([290, 290, 290, 290, 290, 290, 290, 290])
tensor([290, 290, 290, 290, 290, 290, 290, 290])
tensor([158, 158, 158, 158, 158, 158, 158, 158])
tensor([161, 161, 161, 161, 161, 161, 161, 161])
tensor([161, 161, 161, 161, 161, 161, 161, 161])
tensor([161, 161, 161, 161, 161, 161, 161, 161])


RuntimeError: Argument #4: Padding size should be less than the corresponding input dimension, but got: padding (8192, 8192) at dimension 2 of input [1, 1, 422]

In [44]:
import torch
import torchaudio
import torchaudio.transforms as transforms

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load an example audio file
waveform, sample_rate = torchaudio.load('/home/marin/Desktop/TEST2.wav')
waveform = torch.randn(1,18000)
waveform = waveform.to(device)

# Define the Mel Spectrogram transform
mel_spectrogram_transform = transforms.MelSpectrogram(
    sample_rate=sample_rate,
    n_mels=80,
    n_fft=400,
    hop_length=160,
).to('cpu')

# Define the MFCC transform
mfcc_transform = transforms.MFCC(
    sample_rate=sample_rate,
    n_mfcc=80,
    melkwargs={'n_mels': 80, 'n_fft': 400, 'hop_length': 160}
).to(device)

# Apply the transforms
mel_spectrogram = mel_spectrogram_transform(waveform)
mfcc = mfcc_transform(waveform)

print("Mel Spectrogram shape:", mel_spectrogram.shape)
print("MFCC shape:", mfcc.shape)


RuntimeError: stft input and window must be on the same device but got self on cuda:0 and window on cpu