# Audio Splicing 1 - Drums and Percussion

In [1]:
import os
import torch
import torchaudio
import numpy as np

from torch.utils.data import Dataset, DataLoader

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if str(device) == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    print(f"Allocated CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 3:6.3f} GiB") # Should be 0.000 GiB

num_epochs = 10
data_dir = ""
batch_size = 1
learning_rate = 1e-3
num_channels = 2
freq_orig = 44100
freq_scale = 2
chunk_duration = 3

In [5]:
class AudioDataset(Dataset):
    def __init__(self, data_dir, input_tail='.wav', label_tail='_labeled.wav'):
        self.data_dir = data_dir
        self.input_tail = input_tail
        self.label_tail = label_tail
        self.chunk_size = chunk_duration * (freq_orig // freq_scale)
        self.batch_count = 0
        self.input_files = [f for f in os.listdir(data_dir) if f.endswith(input_tail) and not f.endswith(label_tail)]
        self.chunk_counts = self._calculate_chunk_counts()

    def _calculate_chunk_counts(self):
        chunk_counts = []
        for input_filename in self.input_files:
            input_path = os.path.join(self.data_dir, input_filename)
            input_audio, _ = torchaudio.load(input_path)
            num_chunks = int(np.ceil(input_audio.shape[1] / self.chunk_size))
            chunk_counts.append(num_chunks)
        return torch.tensor(chunk_counts)
    
    def __len__(self):
        return self.chunk_counts.sum()
    
    def _get_file_and_chunk_idx(self, global_idx):
        cumulative = 0
        # This is why _calculate_chunk_counts() doesn't return sum directly:
        for file_idx, num_chunks in enumerate(self.chunk_counts):
            if cumulative + num_chunks > global_idx:
                chunk_idx = global_idx - cumulative
                return file_idx, chunk_idx
            cumulative += num_chunks
        raise IndexError("Index out of range")
    
    def __getitem__(self, global_idx):
        # Retrieve file containing chunk_idx
        file_idx, chunk_idx = self._get_file_and_chunk_idx(global_idx)
        input_filename = self.input_files[file_idx]
        label_filename = input_filename.replace(self.input_tail, self.label_tail)
        input_path = os.path.join(self.data_dir, input_filename)
        label_path = os.path.join(self.data_dir, label_filename)
        
        if not os.path.exists(label_path):
            raise FileNotFoundError(f"Label file {label_path} does not exist.")
        
        # Load and resample input and label audio
        input_audio = torchaudio.transforms.Resample(orig_freq=freq_orig, new_freq=freq_orig // freq_scale)(torchaudio.load(input_path)[0])
        label_audio = torchaudio.transforms.Resample(orig_freq=freq_orig, new_freq=freq_orig // freq_scale)(torchaudio.load(label_path)[0])

        # Trim length to ensure both are of the same size
        length = min(input_audio.shape[1], label_audio.shape[1])
        input_audio, label_audio = input_audio[:, :length], label_audio[:, :length]

        # Calculate chunk start and end positions
        start = chunk_idx * self.chunk_size
        end = min(start + self.chunk_size, length)

        # Check if the chunk is valid (non-empty)
        if start >= length:
            # Return empty tensor if the start index exceeds the audio length
            return torch.zeros((input_audio.shape[0], 0)), torch.zeros((label_audio.shape[0], 0))

        # Ensure that the chunk size is not empty
        if end <= start:
            return torch.zeros((input_audio.shape[0], 0)), torch.zeros((label_audio.shape[0], 0))

        # Return the sliced audio chunks
        return input_audio[:, start:end], label_audio[:, start:end]


    def get_batch(self, batch_size, randomized=False):
        if randomized:
            idx = np.random.choice(len(self), batch_size, replace=False)
        else:
            idx = np.arange(self.batch_count, self.batch_count + batch_size) % len(self)
            self.batch_count += batch_size

        input_audios, label_audios = zip(*[self[i] for i in idx])
        return input_audios, label_audios

    @staticmethod
    def collate_fn(batch):
        input_audio, label_audio = zip(*batch)
        input_audio = torch.nn.utils.rnn.pad_sequence(input_audio, batch_first=True)
        label_audio = torch.nn.utils.rnn.pad_sequence(label_audio, batch_first=True)
        return input_audio, label_audio

In [6]:
dataset = AudioDataset(data_dir)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=AudioDataset.collate_fn)
input_audio, label_audio = dataset[0]

In [None]:
# Direct Call Sanity Check -> Expect [2, 66150], [2, 66150]
print('Total Chunk Count:', len(dataset), '\nInput Tensor:', input_audio.shape, '\nLabel Tensor:', label_audio.shape)

# Loader Call Sanity Check -> Expect [1, 2, 66150], [1, 2, 66150]
# TODO: WAAAAAY TOO BUGGY, sometimes returns [1, 2, 66150], sometimes [1, 2, 0] out of nowhere
inputs, targets = next(iter(data_loader))
print('Loader Input Tensor:', inputs.shape, '\nLoader Label Tensor:', targets.shape)