# Audio Splicing 1 - Drums and Percussion

In [1]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils as utils

from torch.utils.data import Dataset, DataLoader

---

## Hyperparameters

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if str(device) == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    print(f"Allocated CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 3:6.3f} GiB")

num_epochs = 10
data_dir = "/mnt/data/Daftset/Dataset"
tmp_dir = "/mnt/data/Daftset/Dataset_tmp"
batch_size = 1
learning_rate = 1e-3
num_channels = 2
freq_orig = 44100
freq_scale = 2
chunk_duration = 3
weight_decay = 1e-4
spectral_weight = 0.5

Using device: cuda
Allocated CUDA memory:  0.000 GiB


---

## Dataset

In [3]:
class AudioDataset(Dataset):
    def __init__(self, data_dir: str, tmp_dir: str, input_tail: str='.wav', label_tail: str='_labeled.wav'):
        self.data_dir = data_dir      # Directory containing input *and* label files
        self.tmp_dir = tmp_dir        # Directory for temporary files
        self.input_tail = input_tail  # File extensions for input files
        self.label_tail = label_tail  # File extensions for label files
        self.chunk_size = chunk_duration * (freq_orig // freq_scale)  # Number of samples per chunk
        self.input_label_pairs = self._process_files(input_tail, label_tail)  # List of tuples: (input_filename, label_filename)
        self.input_label_lengths = self._load_audio_lengths(self.input_label_pairs)  # List of tuples: (input_length, label_length)
        self.input_length = sum([in_len for in_len, _ in self.input_label_lengths])  # Total number of samples
        self.batch_count = 0  # 'Global' counter for batch generation

    def _process_files(self, input_tail: str, label_tail: str) -> list:
        all_files = os.listdir(self.data_dir)
        input_files = sorted([f for f in all_files if f.endswith(input_tail) and not f.endswith(label_tail)])
        label_files_set = {f for f in all_files if f.endswith(label_tail)}
        file_tuples = []
        for input_file in input_files:
            label_file = input_file.replace(input_tail, label_tail)
            if label_file in label_files_set:
                file_tuples.append((input_file, label_file))
            else:
                raise ValueError(f"Missing label file for {input_file}: Expected {label_file}.")
        if not file_tuples:
            raise ValueError("No matching input-label file pairs found.")
        return file_tuples

    def _load_audio_lengths(self, file_tuples: list) -> list:
        lengths = []
        for in_fname, lb_fname in file_tuples:
            audio_in = torchaudio.load(os.path.join(self.data_dir, in_fname))[0]
            audio_lb = torchaudio.load(os.path.join(self.data_dir, lb_fname))[0]
            # Trim both audios to equal length
            max_length = max(audio_in.shape[1], audio_lb.shape[1])
            if audio_in.shape[1] < max_length:
                audio_in = torch.nn.functional.pad(audio_in, (0, max_length - audio_in.shape[1]))
            elif audio_lb.shape[1] < max_length:
                audio_lb = torch.nn.functional.pad(audio_lb, (0, max_length - audio_lb.shape[1]))
            tmp_path_in = os.path.join(self.tmp_dir, in_fname)
            tmp_path_lb = os.path.join(self.tmp_dir, lb_fname)
            torchaudio.save(tmp_path_in, audio_in, freq_orig // freq_scale)
            torchaudio.save(tmp_path_lb, audio_lb, freq_orig // freq_scale)
            lengths.append((max_length, max_length))
        return lengths

    def __len__(self) -> int:
        return self.input_length // self.chunk_size

    def __getitem__(self, idx: int) -> tuple:
        global_start = (idx * self.chunk_size) % self.input_length
        input_chunk, label_chunk = [], []
        remaining_samples = self.chunk_size
        while remaining_samples > 0:
            for (in_file, lb_file), (audio_length, _) in zip(self.input_label_pairs, self.input_label_lengths):
                if global_start >= audio_length:
                    global_start -= audio_length
                    continue
                samples_from_file = min(audio_length - global_start, remaining_samples)
                path_in = os.path.join(self.tmp_dir, in_file)
                path_lb = os.path.join(self.tmp_dir, lb_file)
                audio_in = torchaudio.load(path_in)[0][:, global_start:global_start + samples_from_file]
                audio_lb = torchaudio.load(path_lb)[0][:, global_start:global_start + samples_from_file]
                input_chunk.append(audio_in)
                label_chunk.append(audio_lb)
                remaining_samples -= samples_from_file
                global_start = 0  # Reset for next file
                if remaining_samples == 0:
                    break
            if remaining_samples > 0:
                # Start over if we've gone through all files and still need more samples
                global_start = 0
        return torch.cat(input_chunk, dim=1), torch.cat(label_chunk, dim=1)

    def get_batch(self, batch_size, randomized=False):
        if randomized:
            max_start_idx = len(self) - 1
            idx = np.random.randint(0, max_start_idx + 1)
        else:
            idx = batch_size * self.batch_count
            if idx + batch_size > len(self):
                self.batch_count = 0
                idx = 0
        batch = []
        for i in range(batch_size):
            current_idx = (idx + i) % len(self) # Wrap around
            batch.append(self.__getitem__(current_idx))
        batch_input, batch_label = [torch.stack(items) for items in zip(*batch)]
        if not randomized:
            self.batch_count += 1
        return batch_input, batch_label

    @staticmethod
    def collate_fn(batch):
        batch_input, batch_label = [torch.stack(items) for items in zip(*batch)]
        batch_input = batch_input.view(batch_input.shape[0], 2, -1)
        batch_label = batch_label.view(batch_label.shape[0], 2, -1)
        return batch_input, batch_label

---

### Dataset Sanity Check

In [4]:
dataset = AudioDataset(data_dir, tmp_dir)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=AudioDataset.collate_fn)

In [5]:
# Direct Call Sanity Check
input_audio, label_audio = dataset[0]
print('Total Chunk Count:', len(dataset), '\nInput Tensor:', input_audio.shape, '\nLabel Tensor:', label_audio.shape)

# Expect shape: [2, 66150]
assert input_audio.shape == (2, 66150), "Error: Input tensor shape does not match expected size."
assert label_audio.shape == (2, 66150), "Error: Label tensor shape does not match expected size."

# Batch Call Sanity Check
input_batch, label_batch = dataset.get_batch(1)
print('Batch Input Tensor:', input_batch.shape, '\nBatch Label Tensor:', label_batch.shape)

# Expect shape: [1, 2, 66150]
assert input_batch.shape == (1, 2, 66150), "Error: Batch input tensor shape does not match expected size."
assert label_batch.shape == (1, 2, 66150), "Error: Batch label tensor shape does not match expected size."

# Loader Call Sanity Check
input_collate, label_collate = next(iter(data_loader))
print('Loader Input Tensor:', input_collate.shape, '\nLoader Label Tensor:', label_collate.shape)

# Expect shape: [1, 2, 66150]
assert input_collate.shape == (1, 2, 66150), "Error: Loader input tensor shape does not match expected size."
assert label_collate.shape == (1, 2, 66150), "Error: Loader label tensor shape does not match expected size."

# Check if input_audio and input_collate are equal
assert torch.equal(input_audio, input_collate.squeeze(0)), "Error: Collate Loader vs. Direct Call are not equal."

# Test multiple samples via direct call
for i in range(1, 10):
    input_audio, label_audio = dataset[i]
    assert input_audio.shape == (2, 66150), f"Error at index {i}: Input tensor shape mismatch."
    assert label_audio.shape == (2, 66150), f"Error at index {i}: Label tensor shape mismatch."

# Test multiple batches via get_batch
for _ in range(5):
    input_batch, label_batch = dataset.get_batch(1)
    assert input_batch.shape == (1, 2, 66150), f"Error at index {i}: Batch input tensor shape mismatch."
    assert label_batch.shape == (1, 2, 66150), f"Error at index {i}: Batch label tensor shape mismatch."

# Check edge cases (last sample)
input_audio_last, label_audio_last = dataset[len(dataset) - 1]
assert input_audio_last.shape == (2, 66150), "Error: Last sample input tensor shape mismatch."
assert label_audio_last.shape == (2, 66150), "Error: Last sample label tensor shape mismatch."

# Check random access in get_batch
for i in range(5):
    input_batch, label_batch = dataset.get_batch(1, randomized=True)
    assert input_batch.shape == (1, 2, 66150), f"Error at loop {i}: Randomized batch input tensor shape mismatch."
    assert label_batch.shape == (1, 2, 66150), f"Error at loop {i}: Randomized batch label tensor shape mismatch."

print("\nSanity checks passed!")

Total Chunk Count: 2994 
Input Tensor: torch.Size([2, 66150]) 
Label Tensor: torch.Size([2, 66150])
Batch Input Tensor: torch.Size([1, 2, 66150]) 
Batch Label Tensor: torch.Size([1, 2, 66150])
Loader Input Tensor: torch.Size([1, 2, 66150]) 
Loader Label Tensor: torch.Size([1, 2, 66150])

Sanity checks passed!


---

## Model

In [14]:
class ResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate=16, num_layers=4):
        super(ResidualDenseBlock, self).__init__()
        layers = []
        for i in range(num_layers):
            layers.append(
                nn.Conv1d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1)
            )
            layers.append(nn.LeakyReLU())
        self.layers = nn.Sequential(*layers)
        self.final_conv = nn.Conv1d(in_channels + num_layers * growth_rate, in_channels, kernel_size=3, padding=1)

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = torch.cat([out, layer(out)], dim=1)
        return self.final_conv(out) + x

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.down = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU()
        )
        
    def forward(self, x):
        return self.down(x)

class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='linear', align_corners=True),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU()
        )
        
    def forward(self, x):
        return self.up(x)

class CV_TasNet_Block(nn.Module):
    # Simplified representation; adjust based on specific requirements or TasNet variant
    def __init__(self, in_channels, out_channels, num_blocks=8):
        super(CV_TasNet_Block, self).__init__()
        self.blocks = nn.ModuleList([self._build_block(in_channels, out_channels) for _ in range(num_blocks)])
        
    def _build_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, in_channels, kernel_size=3, padding=1)
        )
        
    def forward(self, x):
        residual = x
        for block in self.blocks:
            x = block(x) + residual
            residual = x
        return x

class TemporalSelfAttention(nn.Module):
    def __init__(self, d_model, nhead=4):
        super(TemporalSelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead)

    def forward(self, x):
        x = x.permute(2, 0, 1)  # Convert from (batch, channels, time) to (time, batch, channels)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 2, 0)  # Convert back to (batch, channels, time)

class AudioUNet(nn.Module):
    def __init__(self, num_channels=2):
        super(AudioUNet, self).__init__()
        self.down1 = ResidualDenseBlock(num_channels)
        self.down2 = DownSample(16, 32)
        self.down3 = DownSample(32, 64)
        self.cv_tasnet = CV_TasNet_Block(64, 64)
        self.temporal_attention = TemporalSelfAttention(64)

        # Multi-scale feature fusion
        # Fuse upsampled features and skip connection from downsampled features
        self.up3 = UpSample(64 + 64, 32)
        self.up2 = UpSample(32 + 32, 16)
        self.up1 = nn.Sequential(
            nn.Conv1d(16 + 16, 16, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Conv1d(16, num_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
                utils.weight_norm(m)

    def forward(self, x): # x: (batch, channels, time)
        skip1 = self.down1(x)
        skip2 = self.down2(skip1)
        skip3 = self.down3(skip2)
        x = self.cv_tasnet(skip3)
        x = self.temporal_attention(x)
        x = self.up3(torch.cat([x, skip3], dim=1))
        x = self.up2(torch.cat([x, skip2], dim=1))
        x = self.up1(torch.cat([x, skip1], dim=1))
        return x

---

## Training

In [8]:
dataset = AudioDataset(data_dir, tmp_dir)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=AudioDataset.collate_fn)

In [15]:
def spectral_loss(output, target, n_fft=1024, hop_length=None):
    output_stft = torch.stft(output, n_fft=n_fft, hop_length=hop_length, return_complex=True)
    target_stft = torch.stft(target, n_fft=n_fft, hop_length=hop_length, return_complex=True)
    spectral_diff = torch.abs(torch.abs(output_stft) - torch.abs(target_stft))
    return torch.mean(spectral_diff)

In [16]:
model = AudioUNet(num_channels=2).to(device) # Adjust audio_length and num_speakers
criterion_mse = nn.MSELoss() # Mean Squared Error for audio regression tasks
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

  WeightNorm.apply(module, name, dim)


In [17]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(data_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        
        mse_loss = criterion_mse(outputs, labels)
        spec_loss = spectral_loss(outputs, labels, n_fft=1024, hop_length=256)
        loss = mse_loss + spectral_weight * spec_loss

        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 20 == 19:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 20))
            running_loss = 0.0

        if str(device) == "cuda":
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

    #if epoch % 2 == 0: # Validate every 2 epochs
    #    model.eval()
    #    with torch.no_grad():
    #        val_loss = 0.0
    #        for i, data in enumerate(val_loader, 0):
    #            inputs, labels = data
    #            inputs, labels = inputs.to(device), labels.to(device)
    #            outputs = model(inputs)
    #            loss = criterion(outputs, labels)
    #            val_loss += loss.item()
    #        print(f'Validation loss: {val_loss / len(val_loader):.3f}')
        
    if epoch % 5 == 0: # Save every 5 epochs
        torch.save(model.state_dict(), f'audio_unet_epoch_{epoch}.pth')

A: torch.Size([1, 2, 66150])


RuntimeError: Given groups=1, weight of size [16, 18, 3], expected input[1, 36, 66150] to have 18 channels, but got 36 channels instead