In [None]:
from typing import Union
import torch
from pathlib import Path
import torchaudio

def process_wav_to_stack(audio_waveform: torch.Tensor, length: int = 4, step: int=1, sample_rate: int=16_000) -> torch.Tensor:
    """
    This function takes a waveform tensor and returns a stack of audio tensors of length `length` seconds, with a step of `step` seconds.

    Args:
    - audio_waveform: torch.Tensor of shape (1,num_samples)
    """

    # check if input tensor is 2D (1, num_samples)
    if len(audio_waveform.shape) != 2:
        raise ValueError(f"Expected a 2D tensor, got a {len(audio_waveform.shape)}D tensor")

    num_samples = audio_waveform.shape[1] 
    segment_length = length * sample_rate
    step_length = step * sample_rate

    if num_samples < segment_length:
        raise ValueError("The audio waveform is too short to create even one segment")

    audio_segments = [
        audio_waveform[:, i * step_length : i * step_length + segment_length]
        for i in range((num_samples - segment_length) // step_length + 1)
    ]
    # return stack of tensors should be of shape (num_segments, 1, segment_length)
    audio_tensor = torch.stack(audio_segments)
    return audio_tensor

def get_audio_tensor(audio: Union[str, Path, torch.Tensor], length: int = 4, step: int=1, sample_rate: int=16_000) -> torch.Tensor:
    """
    This function takes a path to a wav file or a torch tensor and returns a stack of audio tensors of length `length` seconds, with a step of `step` seconds.

    Args:
    - audio: Union[str, Path, torch.Tensor]
    - length: int
    - step: int
    - sample_rate: int
    """
    if isinstance(audio, (str, Path)):
        audio_waveform, sample_rate = torchaudio.load(audio)
    elif isinstance(audio, torch.Tensor):
        audio_waveform = audio
    else:
        raise ValueError("Expected a path to a wav file or a torch tensor")

    audio_tensor = process_wav_to_stack(audio_waveform, length, step, sample_rate)
    return audio_tensor


In [None]:
# load a audio using torchaudio by path and get a stack of audio tensors
import torchaudio
from IPython.display import Audio

path = Path("../data/voxceleb2_wav_eval/id00817/21.wav")
audio_waveform, sample_rate = torchaudio.load(path)
display(Audio(audio_waveform, rate=sample_rate))
print(audio_waveform.shape)
audio_tensor = get_audio_tensor(audio_waveform, length=4, step=1)
print(audio_tensor.shape)

#iterate through the first dimension of the audio tensor
for audio in audio_tensor:
    display(Audio(audio, rate=sample_rate))