In [1]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

2.0.1+cu117
2.0.2+cpu


In [2]:
import io
import os
import tarfile
import tempfile

import matplotlib.pyplot as plt
import requests


from IPython.display import Audio
from torchaudio.utils import download_asset

# SAMPLE_GSM = download_asset("tutorial-assets/steam-train-whistle-daniel_simon.gsm")
SAMPLE_WAV = "D:\Speech\clips\common_voice_en_8981.wav"
# SAMPLE_WAV_8000 = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042-8000hz.wav")

In [3]:
metadata = torchaudio.info(SAMPLE_WAV)
print(metadata)

AudioMetaData(sample_rate=32000, num_frames=161280, num_channels=1, bits_per_sample=16, encoding=PCM_S)


In [4]:
waveform, sample_rate = torchaudio.load(SAMPLE_WAV)

In [5]:
def plot_waveform(waveform, sample_rate, output_path=None, width=8, height=4, dpi=100):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    # Decrease the size of the figure
    figure, axes = plt.subplots(num_channels, 1, figsize=(width, height), dpi=dpi)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("Waveform")

    if output_path:
        plt.savefig(output_path, bbox_inches="tight")
    else:
        plt.show(block=False)

In [None]:
plot_waveform(waveform, sample_rate)

In [6]:
import torchaudio.transforms as transforms
import numpy as np
import torch.nn as nn

class LogMelSpec(nn.Module):
    def __init__(self, sample_rate=8000, n_mels=128, win_length=160, hop_length=80):
        super(LogMelSpec, self).__init__()
        self.transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate, n_mels=n_mels,
            win_length=win_length, hop_length=hop_length)

    def forward(self, x):
        x = self.transform(x)  # mel spectrogram
        x = np.log(x + 1e-14)  # logarithmic, add small value to avoid inf
        return x

class SpecAugment(nn.Module):
    def __init__(self, rate, policy, freq_mask, time_mask):
        super(SpecAugment, self).__init__()
        self.time_masking = torchaudio.transforms.TimeMasking(time_mask)
        self.freq_masking = torchaudio.transforms.FrequencyMasking(freq_mask)

    def forward(self, spec):
        spec = self.time_masking(spec)
        spec = self.freq_masking(spec)
        return spec

def plot_log_mel_specgram_torchaudio_featurizer(waveform, sample_rate, n_feats=81, title="Log Mel Spectrogram"):
    featurizer = LogMelSpec(sample_rate=sample_rate, n_mels=n_feats, win_length=160, hop_length=80)
    log_mel_specgram = featurizer(waveform)
    
    num_channels, num_features, num_frames = log_mel_specgram.shape  # Update order of dimensions

    figure, axes = plt.subplots(num_channels, 2)  # Create two subplots for original and augmented
    if num_channels == 1:
        axes = [axes]
    
    # Plot original spectrogram
    for c in range(num_channels):
        axes[c][0].imshow(log_mel_specgram[c].detach().numpy(), origin='lower', aspect='auto', cmap='viridis', interpolation='nearest')
        axes[c][0].set_xlabel("Frame")
        axes[c][0].set_ylabel("Mel Frequency Bin")
        axes[c][0].set_title("Original Spectrogram")
        if num_channels > 1:
            axes[c][0].set_ylabel(f"Channel {c+1}")
    
    # Apply spec augmentation and plot augmented spectrogram
    spec_augment = SpecAugment(rate=0.2, policy=2, freq_mask=10, time_mask=15)  # Adjust augmentation parameters
    augmented_specgram = spec_augment(log_mel_specgram)
    for c in range(num_channels):
        axes[c][1].imshow(augmented_specgram[c].detach().numpy(), origin='lower', aspect='auto', cmap='viridis', interpolation='nearest')
        axes[c][1].set_xlabel("Frame")
        axes[c][1].set_ylabel("Mel Frequency Bin")
        axes[c][1].set_title("Augmented Spectrogram")
        if num_channels > 1:
            axes[c][1].set_ylabel(f"Channel {c+1}")
    
    figure.suptitle(title)
    plt.tight_layout()
    plt.show(block=False)


In [None]:
# Load the audio file
SAMPLE_WAV = r"D:\Speech\clips\common_voice_en_8981.wav"
waveform, sample_rate = torchaudio.load(SAMPLE_WAV)

# Call the plotting function
plot_log_mel_specgram_torchaudio_featurizer(waveform, sample_rate)

In [None]:
Audio(waveform.numpy()[0], rate=sample_rate)