In [None]:
pip install librosa

In [8]:
import torch
import torch.onnx
import numpy as np
import librosa

class AudioToMagnitude(torch.nn.Module):
    """
    A PyTorch module for computing the magnitude of the Short-Time Fourier Transform (STFT) of an audio signal.

    Args:
        hop_length (int): Hop length used in the STFT computation.
        n_fft (int): Number of Fourier coefficients used in each frame of the STFT.

    Attributes:
        hop_length (int): Hop length used in the STFT computation.
        n_fft (int): Number of Fourier coefficients used in each frame of the STFT.

    Methods:
        forward(x): Compute the magnitude of the STFT given an input audio signal.

    Example:
        # Create an instance of the AudioToMagnitude module
        model = AudioToMagnitude()

        # Load an audio file using librosa
        audio_path = "path/to/your/audio/file.wav"
        y, sr = librosa.load(audio_path)

        # Convert the audio signal to a PyTorch tensor
        input_tensor = torch.tensor(y)

        # Run the forward pass to compute the magnitude
        magnitude = model(input_tensor)
    """

    def __init__(self, hop_length=512, n_fft=2048):
        """
        Initialize the AudioToMagnitude module with specified hop_length and n_fft.

        Args:
            hop_length (int): Hop length used in the STFT computation.
            n_fft (int): Number of Fourier coefficients used in each frame of the STFT.
        """
        super(AudioToMagnitude, self).__init__()
        self.hop_length = hop_length
        self.n_fft = n_fft

    def forward(self, x):
        """
        Compute the magnitude of the Short-Time Fourier Transform (STFT) given an input audio signal.

        Args:
            x (torch.Tensor): Input audio signal as a 1D tensor.

        Returns:
            torch.Tensor: Magnitude of the STFT as a 3D tensor with shape (batch_size, frequency_bins, time_frames).

            This implementation manually computes the STFT using PyTorch operations.

        Example:
            # Create an instance of the AudioToMagnitude module
            model = AudioToMagnitude()

            # Load an audio file using librosa
            audio_path = "path/to/your/audio/file.wav"
            y, sr = librosa.load(audio_path)

            # Convert the audio signal to a PyTorch tensor
            input_tensor = torch.tensor(y)

            # Run the forward pass to compute the magnitude
            magnitude = model(input_tensor)
        """
        # Manually implement STFT using PyTorch operations
        window = torch.hann_window(self.n_fft, periodic=True).to(x.device)
        x = x.unsqueeze(1)
        x = x.view(x.size(0), -1, x.size(-1))

        # Real part of STFT
        stft_real = torch.nn.functional.conv1d(x, window.view(1, 1, -1), stride=self.hop_length, padding=self.n_fft // 2)
        # Imaginary part of STFT
        stft_imag = torch.nn.functional.conv1d(x, window.view(1, 1, -1), stride=self.hop_length, padding=self.n_fft // 2)

        # Represent complex numbers using real and imaginary parts
        complex_stft = torch.stack([stft_real, stft_imag], dim=-1)

        magnitude = torch.abs(complex_stft)
        return magnitude

# Audio file path
audio_path = "D:/ZOHO INTERN/pytorch/arun.wav"
y, sr = librosa.load(audio_path)

# Create a PyTorch model
model = AudioToMagnitude()

# Convert the audio signal to a PyTorch tensor
input_tensor = torch.tensor(y)

# Export the PyTorch model to ONNX format
onnx_file_path = "D:/ZOHO INTERN/pytorch/audio_to_magnitude.onnx"
torch.onnx.export(model, input_tensor, onnx_file_path, verbose=True, input_names=["input_audio"], output_names=["output_magnitude"])

# Run the ONNX model
onnx_output = run_onnx_model(onnx_file_path, input_tensor.numpy())

# Print the output magnitude
print("Output Magnitude Shape:", onnx_output[0].shape)
print("Output Magnitude:", onnx_output[0])


Output Magnitude Shape: (52369, 1, 1, 2)
Output Magnitude: [[[[4.5421851e-07 4.5421851e-07]]]


 [[[4.6568475e-06 4.6568475e-06]]]


 [[[1.6021406e-05 1.6021406e-05]]]


 ...


 [[[2.0786545e-01 2.0786545e-01]]]


 [[[2.2808412e-01 2.2808412e-01]]]


 [[[1.1395590e-01 1.1395590e-01]]]]
