# Audio Splicing - Drums

In [1]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn as nn
import soundfile as sf
import torch.optim as optim
import torch.nn.functional as F

from functools import lru_cache
from typing import List, Tuple
from torch.nn.utils.parametrizations import weight_norm
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler

---

## Hyperparameters

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device:\t\t\t{device}")

if str(device) == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    print(f"Allocated CUDA memory:\t{torch.cuda.memory_allocated() / 1024 ** 3:8.4f} GiB")

num_epochs = 10                             # Number of epochs to train
data_dir = "/mnt/data/Daftset/Dataset"      # Directory containing the dataset (we will work on a copy)
tmp_dir = "/mnt/data/Daftset/Dataset_tmp"   # Directory to store temporary files (aligned, resampled dataset copy)
batch_size = 1          # 3060 tackles single-entry batches at max
learning_rate = 5e-4    # Light learning rate
num_channels = 2        # Number of audio channels
freq_orig = 44100       # Original frequency of the audio files
freq_scale = 2          # Downsampling factor
chunk_duration = 2      # Duration of training examples in seconds -> samples_per_example = (freq_orig // freq_scale) * chunk_duration
weight_decay = 1e-4     # Weight decay for the optimizer
spectral_weight = 0.5   # Spectral loss impact for total loss calculation

Using device:			cuda
Allocated CUDA memory:	  0.0000 GiB


---

## Dataset

In [3]:
class AudioDataset(Dataset):
    def __init__(self, data_dir: str, tmp_dir: str, input_tail: str='.wav', label_tail: str='_labeled.wav'):
        """Initialize the dataset with minimal memory footprint."""
        self.data_dir = data_dir
        self.tmp_dir = tmp_dir
        self.input_tail = input_tail
        self.label_tail = label_tail
        self.batch_size = batch_size
        
        # Process files and store only metadata
        self.input_label_pairs = self._process_files(input_tail, label_tail)
        self.file_info = self._get_file_info()
        
        # Calculate total length and chunk information
        self.input_length = sum(info['length'] for info in self.file_info)
        self.chunk_size = chunk_duration * (freq_orig // freq_scale)
        self.batch_count = 0
        
        # Create cache for file handles
        self._file_handle_cache = {}
        
    def _process_files(self, input_tail: str, label_tail: str) -> List[Tuple[str, str]]:
        """Process files and return input-label pairs."""
        all_files = os.listdir(self.data_dir)
        input_files = sorted([f for f in all_files if f.endswith(input_tail) and not f.endswith(label_tail)])
        label_files_set = {f for f in all_files if f.endswith(label_tail)}
        
        file_tuples = []
        for input_file in input_files:
            label_file = input_file.replace(input_tail, label_tail)
            if label_file in label_files_set:
                file_tuples.append((input_file, label_file))
            else:
                raise ValueError(f"Missing label file for {input_file}: Expected {label_file}.")
        
        if not file_tuples:
            raise ValueError("No matching input-label file pairs found.")
        return file_tuples

    def _get_file_info(self) -> List[dict]:
        """Get file information without loading audio data."""
        file_info = []
        for in_fname, lb_fname in self.input_label_pairs:
            info = torchaudio.info(os.path.join(self.data_dir, in_fname))
            length = info.num_frames
            file_info.append({
                'length': length,
                'input_path': os.path.join(self.data_dir, in_fname),
                'label_path': os.path.join(self.data_dir, lb_fname)
            })
        return file_info

    @lru_cache(maxsize=8)
    def _get_file_handle(self, file_path: str) -> sf.SoundFile:
        """Get or create cached file handle."""
        if file_path not in self._file_handle_cache:
            self._file_handle_cache[file_path] = sf.SoundFile(file_path, 'r')
        return self._file_handle_cache[file_path]

    def _read_audio_chunk(self, file_path: str, start: int, length: int) -> torch.Tensor:
        """Read a chunk of audio data efficiently."""
        handle = self._get_file_handle(file_path)
        handle.seek(start)
        data = handle.read(length)
        return torch.from_numpy(data).T

    def __len__(self) -> int:
        """Calculate total number of chunks."""
        total_chunks = self.input_length // self.chunk_size
        return total_chunks // self.batch_size

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a single chunk of audio data."""
        global_start = (idx * self.chunk_size) % self.input_length
        input_chunks, label_chunks = [], []
        remaining_samples = self.chunk_size
        cumulative_length = 0
        
        for file_info in self.file_info:
            if global_start >= cumulative_length + file_info['length']:
                cumulative_length += file_info['length']
                continue
                
            local_start = global_start - cumulative_length
            if local_start < 0:
                local_start = 0
                
            samples_from_file = min(file_info['length'] - local_start, remaining_samples)
            
            if samples_from_file <= 0:
                continue
                
            input_chunk = self._read_audio_chunk(file_info['input_path'], local_start, samples_from_file)
            label_chunk = self._read_audio_chunk(file_info['label_path'], local_start, samples_from_file)
            
            input_chunks.append(input_chunk)
            label_chunks.append(label_chunk)
            
            remaining_samples -= samples_from_file
            global_start = 0
            cumulative_length += file_info['length']
            
            if remaining_samples <= 0:
                break
                
        if remaining_samples > 0:
            # Handle wrap-around case
            return self.__getitem__(0)
            
        return (torch.cat(input_chunks, dim=1), 
                torch.cat(label_chunks, dim=1))

    def get_batch(self, batch_size: int, randomized: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a batch of audio chunks."""
        if randomized:
            idx = np.random.randint(0, len(self))
        else:
            idx = batch_size * self.batch_count
            if idx + batch_size > len(self):
                self.batch_count = 0
                idx = 0
                
        batch = [self.__getitem__((idx + i) % len(self)) for i in range(batch_size)]
        batch_input, batch_label = [torch.stack(items) for items in zip(*batch)]
        
        if not randomized:
            self.batch_count += 1
            
        return batch_input, batch_label

    @staticmethod
    def collate_fn(batch):
        """Collate function for DataLoader."""
        batch_input, batch_label = [torch.stack(items) for items in zip(*batch)]
        batch_input = batch_input.view(batch_input.shape[0], 2, -1)
        batch_label = batch_label.view(batch_label.shape[0], 2, -1)
        return batch_input, batch_label

    def __del__(self):
        """Clean up file handles."""
        for handle in self._file_handle_cache.values():
            handle.close()

---

### Dataset Sanity Check

In [4]:
dataset = AudioDataset(data_dir, tmp_dir)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=AudioDataset.collate_fn)

In [5]:
# Direct Call Sanity Check
input_audio, label_audio = dataset[0]
print('Total Chunk Count:', len(dataset), '\nInput Tensor:', input_audio.shape, '\nLabel Tensor:', label_audio.shape)

assert_f = (freq_orig // freq_scale) * chunk_duration

# Expect shape: [2, 44100]
assert input_audio.shape == (2, assert_f), "Error: Input tensor shape does not match expected size."
assert label_audio.shape == (2, assert_f), "Error: Label tensor shape does not match expected size."

# Batch Call Sanity Check
input_batch, label_batch = dataset.get_batch(1)
print('Batch Input Tensor:', input_batch.shape, '\nBatch Label Tensor:', label_batch.shape)

# Expect shape: [1, 2, 44100]
assert input_batch.shape == (1, 2, assert_f), "Error: Batch input tensor shape does not match expected size."
assert label_batch.shape == (1, 2, assert_f), "Error: Batch label tensor shape does not match expected size."

# Loader Call Sanity Check
input_collate, label_collate = next(iter(data_loader))
print('Loader Input Tensor:', input_collate.shape, '\nLoader Label Tensor:', label_collate.shape)

# Expect shape: [1, 2, 44100]
assert input_collate.shape == (1, 2, assert_f), "Error: Loader input tensor shape does not match expected size."
assert label_collate.shape == (1, 2, assert_f), "Error: Loader label tensor shape does not match expected size."

# Check if input_audio and input_collate are equal
assert torch.equal(input_audio, input_collate.squeeze(0)), "Error: Collate Loader vs. Direct Call are not equal."

# Test multiple samples via direct call
for i in range(1, 10):
    input_audio, label_audio = dataset[i]
    assert input_audio.shape == (2, assert_f), f"Error at index {i}: Input tensor shape mismatch."
    assert label_audio.shape == (2, assert_f), f"Error at index {i}: Label tensor shape mismatch."

# Test multiple batches via get_batch
for _ in range(5):
    input_batch, label_batch = dataset.get_batch(1)
    assert input_batch.shape == (1, 2, assert_f), f"Error at index {i}: Batch input tensor shape mismatch."
    assert label_batch.shape == (1, 2, assert_f), f"Error at index {i}: Batch label tensor shape mismatch."

# Check edge cases (last sample)
input_audio_last, label_audio_last = dataset[len(dataset) - 1]
assert input_audio_last.shape == (2, assert_f), "Error: Last sample input tensor shape mismatch."
assert label_audio_last.shape == (2, assert_f), "Error: Last sample label tensor shape mismatch."

# Check random access in get_batch
for i in range(5):
    input_batch, label_batch = dataset.get_batch(1, randomized=True)
    assert input_batch.shape == (1, 2, assert_f), f"Error at loop {i}: Randomized batch input tensor shape mismatch."
    assert label_batch.shape == (1, 2, assert_f), f"Error at loop {i}: Randomized batch label tensor shape mismatch."

print("\nSanity checks passed!")

Total Chunk Count: 4438 
Input Tensor: torch.Size([2, 44100]) 
Label Tensor: torch.Size([2, 44100])
Batch Input Tensor: torch.Size([1, 2, 44100]) 
Batch Label Tensor: torch.Size([1, 2, 44100])
Loader Input Tensor: torch.Size([1, 2, 44100]) 
Loader Label Tensor: torch.Size([1, 2, 44100])

Sanity checks passed!


---

## Model

In [6]:
class ResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate=16, num_layers=4):
        super(ResidualDenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(nn.Sequential(
                nn.Conv1d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1),
                nn.LeakyReLU()
            ))
        self.final_conv = nn.Conv1d(in_channels + num_layers * growth_rate, in_channels, kernel_size=1)

    def forward(self, x):
        features = [x]
        for layer in self.layers:
            out = layer(torch.cat(features, dim=1))
            features.append(out)
        return self.final_conv(torch.cat(features, dim=1)) + x

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.down = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU()
        )
        
    def forward(self, x):
        return self.down(x)

class Resizer(nn.Module):
    def __init__(self, in_channel, out_channel, target_size):
        super(Resizer, self).__init__()
        self.conv = nn.ConvTranspose1d(in_channel, out_channel, kernel_size=4, stride=2, padding=1)
        self.target_size = target_size
        self.activation = nn.LeakyReLU()

    def forward(self, x):
        x = self.conv(x)
        # Use interpolate for precise resizing
        x = F.interpolate(x, size=self.target_size, mode='linear', align_corners=True)
        return self.activation(x)

class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='linear', align_corners=True),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU()
        )
        
    def forward(self, x, skip):
        x = self.up(x)
        # Adjust x or skip size if they don't match
        diff = x.size(2) - skip.size(2)
        if diff > 0:
            x = x[:, :, :skip.size(2)]
        elif diff < 0:
            x = nn.functional.pad(x, (0, -diff))
        return x

class CV_TasNet_Block(nn.Module):
    # Simplified representation; adjust based on specific requirements or TasNet variant
    def __init__(self, in_channels, out_channels, num_blocks=8):
        super(CV_TasNet_Block, self).__init__()
        self.blocks = nn.ModuleList([self._build_block(in_channels, out_channels) for _ in range(num_blocks)])
        
    def _build_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, in_channels, kernel_size=3, padding=1)
        )
        
    def forward(self, x):
        residual = x
        for block in self.blocks:
            x = block(x) + residual
            residual = x
        return x

class TemporalSelfAttention(nn.Module):
    def __init__(self, d_model, nhead=4):
        super(TemporalSelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead)

    def forward(self, x):
        x = x.permute(2, 0, 1)  # Convert from (batch, channels, time) to (time, batch, channels)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 2, 0)  # Convert back to (batch, channels, time)

class AudioUNet(nn.Module):
    def __init__(self, num_channels=2):
        super(AudioUNet, self).__init__()
        # (batch, channels, time), e.g. (1, 2, 66150)
        self.down1 = ResidualDenseBlock(num_channels)
        self.down2 = DownSample(2, 4)
        self.down3 = DownSample(4, 8)
        self.cv_tasnet = CV_TasNet_Block(8, 8)
        self.temporal_attention = TemporalSelfAttention(8)

        # Multi-scale feature fusion
        # Fuse upsampled features and skip connection from downsampled features
        self.up3 = UpSample(8, 4)
        self.up2 = UpSample(4, 2)
        self.resizer = Resizer(2, 2, assert_f)
        self.up1 = nn.Sequential(
            nn.Conv1d(4, 4, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Conv1d(4, num_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
                weight_norm(m, name='weight', dim=0)

    def forward(self, x):
        skip1 = self.down1(x)            # (batch, 2, (freq_orig // freq_scale) * chunk_duration)
        skip2 = self.down2(skip1)        # (batch, 4, ((freq_orig // freq_scale) * chunk_duration) / 2)
        skip3 = self.down3(skip2)        # (batch, 8, ((freq_orig // freq_scale) * chunk_duration) / 4)
        x = self.cv_tasnet(skip3)        # (batch, 8, ((freq_orig // freq_scale) * chunk_duration) / 4)
        x = self.temporal_attention(x)   # (batch, 8, ((freq_orig // freq_scale) * chunk_duration) / 4)
        x = self.up3(x, skip3)           # (batch, 4, ((freq_orig // freq_scale) * chunk_duration) / 4)
        x = self.up2(x, skip2)           # (batch, 2, ((freq_orig // freq_scale) * chunk_duration) / 2)
        x = self.resizer(x)              # (batch, 2, (freq_orig // freq_scale) * chunk_duration)
        x = torch.cat([x, skip1], dim=1) # (batch, 4, (freq_orig // freq_scale) * chunk_duration)
        return self.up1(x)               # (batch, 2, (freq_orig // freq_scale) * chunk_duration)

---

## Training

In [7]:
def create_train_val_splits(dataset, val_ratio=0.2, shuffle=True):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(val_ratio * dataset_size))
    if shuffle:
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    return SubsetRandomSampler(train_indices), SubsetRandomSampler(val_indices)

In [8]:
dataset = AudioDataset(data_dir, tmp_dir)
train_sampler, val_sampler = create_train_val_splits(dataset, val_ratio=0.2)

train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=AudioDataset.collate_fn, num_workers=0)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, collate_fn=AudioDataset.collate_fn, num_workers=0)

In [9]:
def spectral_loss(output, target, n_fft=1024, hop_length=None, epsilon=1e-10):
    window = torch.hann_window(n_fft).to(output.device)
    assert output.dim() == 3 and target.dim() == 3, "Input tensors must be 3D (batch, channels, time)"
    _, num_channels, _ = output.shape
    loss = 0
    for i in range(num_channels):
        output_stft = torch.stft(output[:, i, :], n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True)
        target_stft = torch.stft(target[:, i, :], n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True)
        spectral_diff = torch.abs(torch.abs(output_stft) - torch.abs(target_stft) + epsilon)
        loss += torch.mean(spectral_diff)
    return loss / num_channels

In [10]:
model = AudioUNet(num_channels=2).to(device).float() # Adjust audio_length and num_speakers
criterion_mse = nn.MSELoss().to(device)              # Mean Squared Error
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scaler = torch.amp.GradScaler(enabled=(str(device) != 'cpu'), init_scale=2.0) 

In [11]:
print(f"Total number of parameters: {sum(p.numel() for p in model.parameters()):,}")

Total number of parameters: 9,890


In [12]:
accumulation_steps = 32      # Effective_batch_size = batch_size * accumulation_steps
best_val_loss = float('inf') # Initialize as highest possible

for epoch in range(num_epochs):
    # Training loop
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()
    for i, data in enumerate(data_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device).float(), labels.to(device).float()
        with torch.amp.autocast(device_type=str(device)):
            outputs = model(inputs)
            mse_loss = criterion_mse(outputs, labels)
            spec_loss = spectral_loss(outputs, labels, n_fft=1024, hop_length=256)
            # Normalize mini losses to emulate larger batch size (not just accumulation)
            loss = (mse_loss + spectral_weight * spec_loss) / accumulation_steps
        scaler.scale(loss).backward()
        if (i + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        # Showing the true mini-batch loss during logging
        running_loss += loss.item() * accumulation_steps
        if str(device) == "cuda":
            del inputs, labels, outputs, mse_loss, spec_loss, loss
            torch.cuda.empty_cache()
        if i % 100 == 99:
            print(f'Epoch [{epoch+1:3}/{num_epochs}] | '
                  f'Mini-Batch [{i+1:4}/{len(data_loader)}] | '
                  f'Train: {(running_loss / 100):8.6f} | '
                  f'LR: {optimizer.param_groups[-1]["lr"]:.6f} | '
                  f'Effective Batch: {batch_size * accumulation_steps}')
            running_loss = 0.0
    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(val_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device).float(), labels.to(device).float()
            with torch.amp.autocast(device_type=str(device)):
                outputs = model(inputs)
                mse_loss = criterion_mse(outputs, labels)
                spec_loss = spectral_loss(outputs, labels, n_fft=1024, hop_length=256)
                loss = mse_loss + spectral_weight * spec_loss
                val_loss += loss.item()
            if device.type == "cuda":
                del inputs, labels, outputs, mse_loss, spec_loss, loss
                torch.cuda.empty_cache()
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch [{epoch+1:3}/{num_epochs}] | Validation: {avg_val_loss:8.6f}')
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
            }, 'best_audio_unet.pth')
    if str(device) == "cuda":
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()

Epoch [  1/10] | Mini-Batch [100/4438] | Train: 0.571767 | LR: 0.000500 | Effective Batch: 8
Epoch [  1/10] | Mini-Batch [200/4438] | Train: 0.549238 | LR: 0.000500 | Effective Batch: 8
Epoch [  1/10] | Mini-Batch [300/4438] | Train: 0.277730 | LR: 0.000500 | Effective Batch: 8
Epoch [  1/10] | Mini-Batch [400/4438] | Train: 0.326018 | LR: 0.000500 | Effective Batch: 8
Epoch [  1/10] | Mini-Batch [500/4438] | Train: 0.296549 | LR: 0.000500 | Effective Batch: 8


KeyboardInterrupt: 