In [None]:
# Augmentation.py code
import torch
import torchaudio
import random
import matplotlib.pyplot as plt
import librosa
import librosa.display

def add_gaussian_noise(audio: torch.Tensor, noise_level: float = 0.005) -> torch.Tensor:
    noise = torch.randn_like(audio) * noise_level
    return audio + noise

def time_stretch(audio: torch.Tensor, rate: float = 1.0) -> torch.Tensor:
    return torchaudio.transforms.TimeStretch(n_freq=201)(audio, rate)

def pitch_shift(audio: torch.Tensor, n_steps: int = 0) -> torch.Tensor:
    return torchaudio.transforms.PitchShift(sample_rate=16000, n_steps=n_steps)(audio)

def random_augment(audio: torch.Tensor) -> torch.Tensor:
    augmentations = [
        (add_gaussian_noise, {"noise_level": random.uniform(0.001, 0.01)}),
        (time_stretch, {"rate": random.uniform(0.9, 1.1)}),
        (pitch_shift, {"n_steps": random.randint(-2, 2)})
    ]
    
    # Randomly choose 1-3 augmentations
    num_augmentations = random.randint(1, 3)
    chosen_augmentations = random.sample(augmentations, num_augmentations)
    
    for aug_func, params in chosen_augmentations:
        audio = aug_func(audio, **params)
    
    return audio

# Function to visualize the augmentations
def plot_waveform(waveform, title="Waveform"):
    plt.figure(figsize=(12, 4))
    plt.plot(waveform.t().numpy())
    plt.title(title)
    plt.xlabel("Sample")
    plt.ylabel("Amplitude")
    plt.show()

def plot_spectrogram(waveform, title="Spectrogram"):
    spectrogram = librosa.stft(waveform.numpy().squeeze())
    spectrogram_db = librosa.amplitude_to_db(abs(spectrogram))
    plt.figure(figsize=(12, 4))
    librosa.display.specshow(spectrogram_db, sr=16000, x_axis='time', y_axis='hz')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.show()

# Example usage
# Load a sample audio file (you'll need to provide your own audio file)
# waveform, sample_rate = torchaudio.load('path_to_your_audio_file.wav')

# # Original audio
# plot_waveform(waveform, "Original Waveform")
# plot_spectrogram(waveform, "Original Spectrogram")

# # Augmented audio
# augmented_waveform = random_augment(waveform)
# plot_waveform(augmented_waveform, "Augmented Waveform")
# plot_spectrogram(augmented_waveform, "Augmented Spectrogram")

In [None]:
# updated audio dataset methods
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
    "Returns one sample of data, and its class index"

    audio_tensor = self.load_audio(index)
    class_name = self.paths[index].parent.name
    class_idx = self.class_to_idx[class_name]
    
    # Standardize audio length and frequency
    audio_tensor = self.standardize_audio(audio_tensor)
    
    # Using the AST model's correct size for transformation
    inputs = self.feature_extractor(audio_tensor.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)
    inputs = inputs['input_values'][0]  # Extract input tensor

    return inputs, class_idx

def standardize_audio(self, audio_tensor: torch.Tensor, target_length: int = 80000, target_sr: int = 16000) -> torch.Tensor:
    """Standardize audio to 5 seconds at 16kHz"""
    current_sr = 16000  # Assuming original sample rate is 16kHz
    
    # Resample if necessary
    if current_sr != target_sr:
        audio_tensor = torchaudio.transforms.Resample(current_sr, target_sr)(audio_tensor)
    
    # Pad or trim to target length (5 seconds * 16000 Hz = 80000 samples)
    if audio_tensor.size(1) < target_length:
        audio_tensor = torch.nn.functional.pad(audio_tensor, (0, target_length - audio_tensor.size(1)))
    else:
        audio_tensor = audio_tensor[:, :target_length]
    
    return audio_tensor