# Audio Splicing - Drums

In [1]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn as nn
import soundfile as sf
import torch_optimizer as optim
import torch.nn.functional as F

from tqdm import tqdm
from functools import lru_cache
from typing import List, Tuple
from scipy.io.wavfile import write
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler

---

## Hyperparameters

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device:\t\t\t{device}")

torch.manual_seed(42)
np.random.seed(42)

if str(device) == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.manual_seed_all(42)
    print(f"Allocated CUDA memory:\t{torch.cuda.memory_allocated() / 1024 ** 3:8.4f} GiB")

num_epochs = 75         # Number of epochs to train
data_dir = "/mnt/data/Daftset"  # Directory containing the dataset (we will work on a copy)
batch_size = 1          # 3060 tackles single-entry batches at max
learning_rate = 1e-3    # Light learning rate
num_channels = 2        # Number of audio channels
freq_orig = 44100       # Original frequency of the audio files
optim_k = 5             # Average weight updates over optim_k steps to stabilize training
optim_alpha = 0.3       # Weight of influence of Lookahead's fast weights on the slow weights
chunk_duration = 2      # Duration of training examples in seconds -> samples_per_example = (freq_orig) * chunk_duration
weight_decay = 1e-4     # Weight decay for the optimizer
spectral_weight = 0.75  # Spectral loss impact for total loss calculation
mse_weight = 0.25       # Impact weight for MSE to total loss calculation
accum_steps = 256       # Effective_batch_size = batch_size * accumulation_steps

Using device:			cuda
Allocated CUDA memory:	  0.0000 GiB


---

## Dataset

Smaller datasets may afford us to load the entire dataset at once during `init`, but this set is custom and sizewise unpredictable enough to require on-the-fly loading.
Thing is, if we strictly load on request, how can we provide a `__len__` method for the dataset?<br>
Thankfully, audio files have metadata we can use to determine their length without loading them, based on which we can provide a `__len__` method.

- `_process_files` loads the in-out file pairs by name and as a list of name pairs, but not the actual audio data. We need that for ordered access to the dataset.
- `_get_file_info` takes this list of name pairs and loads just the metadata of the audio files, which we can use to determine the length of all files. Note that we expect input and output to be the same size here. This is unrealistic, but we can trim or pad the labels accordingly later when we actually load the data.
- The main data providing complexity is in `__get_item__`, where we 
    - calculate the idx of the requested input, label pair at a global scale (across all audio)
    - assemble the input and the label pair, even across different files (load until we reach the requested sample size)
    - trim or pad the label to match the input size
    


In [3]:
class AudioDataset(Dataset):
    def __init__(self, data_dir: str, batch_size: int = batch_size, chunk_duration: float = chunk_duration, 
                 freq_orig: int = freq_orig, input_tail: str='.wav', label_tail: str='_labeled.wav'):
        self.data_dir = data_dir
        self.input_tail = input_tail
        self.label_tail = label_tail
        self.batch_size = batch_size
        self.chunk_duration = chunk_duration
        self.freq_orig = freq_orig
        self.input_label_pairs = self._process_files(input_tail, label_tail)  # Get input-label file name pairs 
        self.file_info = self._get_file_info() # Sift only through metadata for each file
        self.input_length = sum(info['length'] for info in self.file_info) # Calculate total length and chunk information (from metadata)
        self.chunk_size = int(chunk_duration * freq_orig)  # Ensure chunk_size is an integer
        self.batch_count = 0
        self._file_handle_cache = {} # Create cache for file handles

    def _process_files(self, input_tail: str, label_tail: str) -> List[Tuple[str, str]]:
        all_files = os.listdir(self.data_dir)
        input_files = sorted([f for f in all_files if f.endswith(input_tail) and not f.endswith(label_tail)])
        label_files_set = {f for f in all_files if f.endswith(label_tail)}
        file_tuples = []
        
        for input_file in input_files:
            label_file = input_file.replace(input_tail, label_tail)
            if label_file in label_files_set:
                file_tuples.append((input_file, label_file))
            else:
                raise ValueError(f"Missing label file for {input_file}: Expected {label_file}.")
        
        if not file_tuples:
            raise ValueError("No matching input-label file pairs found.")
        return file_tuples

    def _get_file_info(self) -> List[dict]:
        file_info = []
        for in_fname, lb_fname in self.input_label_pairs:
            info = torchaudio.info(os.path.join(self.data_dir, in_fname))
            length = info.num_frames
            file_info.append({'length': length,
                              'input_path': os.path.join(self.data_dir, in_fname),
                              'label_path': os.path.join(self.data_dir, lb_fname)})
        return file_info

    @lru_cache(maxsize=8)
    def _get_file_handle(self, file_path: str) -> sf.SoundFile:
        if file_path not in self._file_handle_cache:
            self._file_handle_cache[file_path] = sf.SoundFile(file_path, 'r')
        return self._file_handle_cache[file_path]

    def _normalize_audio(self, data: torch.Tensor) -> torch.Tensor:
        data = data.float()
        if torch.abs(data).max() > 0:
            data = data / (torch.abs(data).max() + 1e-8)
        return torch.clamp(data, min=-0.99, max=0.99)  # Prevent extreme values

    def _pad_audio(self, audio: torch.Tensor, target_length: int) -> torch.Tensor:
        current_length = audio.shape[1]
        if current_length < target_length:
            padding = target_length - current_length
            return F.pad(audio, (0, padding))
        elif current_length > target_length:
            return audio[:, :target_length]
        return audio

    def _read_audio_chunk(self, file_path: str, start: int, length: int) -> torch.Tensor:
        handle = self._get_file_handle(file_path)
        handle.seek(start)
        data = handle.read(length)
        data = torch.from_numpy(data).T
        data = self._normalize_audio(data)
        return self._pad_audio(data, self.chunk_size)

    def __len__(self) -> int:
        total_chunks = self.input_length // self.chunk_size
        return total_chunks // self.batch_size

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if idx >= len(self):
            raise IndexError("Index out of bounds")
            
        global_start = (idx * self.chunk_size) % self.input_length
        input_chunks, label_chunks = [], []
        samples_remaining = self.chunk_size
        cumulative_length = 0
        current_file_index = 0

        # First, find the correct starting file
        while cumulative_length + self.file_info[current_file_index]['length'] <= global_start:
            cumulative_length += self.file_info[current_file_index]['length']
            current_file_index += 1
            if current_file_index >= len(self.file_info):
                current_file_index = 0
                cumulative_length = 0

        # Now read the chunks
        while samples_remaining > 0:
            if current_file_index >= len(self.file_info):
                current_file_index = 0

            file_info = self.file_info[current_file_index]
            local_start = global_start - cumulative_length if current_file_index == 0 else 0
            
            # Calculate how many samples we can take from this file
            samples_from_file = min(
                file_info['length'] - local_start,  # samples available in file
                samples_remaining  # samples we still need
            )

            if samples_from_file > 0:
                input_chunk = self._read_audio_chunk(
                    file_info['input_path'], 
                    local_start, 
                    samples_from_file
                )
                label_chunk = self._read_audio_chunk(
                    file_info['label_path'], 
                    local_start, 
                    samples_from_file
                )
                
                input_chunks.append(input_chunk)
                label_chunks.append(label_chunk)
                samples_remaining -= samples_from_file

            cumulative_length += file_info['length']
            current_file_index += 1

        # Concatenate all chunks and ensure final size
        input_audio = torch.cat(input_chunks, dim=1)
        label_audio = torch.cat(label_chunks, dim=1)

        # Final size check and adjustment
        input_audio = self._pad_audio(input_audio, self.chunk_size)
        label_audio = self._pad_audio(label_audio, self.chunk_size)

        return input_audio, label_audio

    def get_batch(self, batch_size: int, randomized: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        if randomized:
            idx = np.random.randint(0, len(self))
        else:
            idx = batch_size * self.batch_count
            if idx + batch_size > len(self):
                self.batch_count = 0
                idx = 0
                
        batch = [self.__getitem__((idx + i) % len(self)) for i in range(batch_size)]
        batch_input, batch_label = [torch.stack(items) for items in zip(*batch)]
        
        if not randomized:
            self.batch_count += 1
            
        return batch_input, batch_label

    @staticmethod
    def collate_fn(batch):
        batch_input, batch_label = [torch.stack(items) for items in zip(*batch)]
        batch_input = batch_input.view(batch_input.shape[0], 2, -1)
        batch_label = batch_label.view(batch_label.shape[0], 2, -1)
        return batch_input, batch_label

    def __del__(self):
        for handle in self._file_handle_cache.values():
            handle.close()

---

### Dataset Sanity Check

In [4]:
dataset = AudioDataset(data_dir)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=AudioDataset.collate_fn)

In [5]:
# Direct Call Sanity Check
input_audio, label_audio = dataset[0]
print('Total Chunk Count:', len(dataset), '\nInput Tensor:', input_audio.shape, '\nLabel Tensor:', label_audio.shape)
assert_f = (freq_orig) * chunk_duration

# Expect shape: [2, 44100]
assert input_audio.shape == (2, assert_f), "Error: Input tensor shape does not match expected size."
assert label_audio.shape == (2, assert_f), "Error: Label tensor shape does not match expected size."

# Batch Call Sanity Check
input_batch, label_batch = dataset.get_batch(1)
print('Batch Input Tensor:', input_batch.shape, '\nBatch Label Tensor:', label_batch.shape)

# Expect shape: [1, 2, 44100]
assert input_batch.shape == (1, 2, assert_f), "Error: Batch input tensor shape does not match expected size."
assert label_batch.shape == (1, 2, assert_f), "Error: Batch label tensor shape does not match expected size."

# Loader Call Sanity Check
input_collate, label_collate = next(iter(data_loader))
print('Loader Input Tensor:', input_collate.shape, '\nLoader Label Tensor:', label_collate.shape)

# Expect shape: [1, 2, 44100]
assert input_collate.shape == (1, 2, assert_f), "Error: Loader input tensor shape does not match expected size."
assert label_collate.shape == (1, 2, assert_f), "Error: Loader label tensor shape does not match expected size."

# Check if input_audio and input_collate are equal
assert torch.equal(input_audio, input_collate.squeeze(0)), "Error: Collate Loader vs. Direct Call are not equal."

# Test multiple samples via direct call
for i in range(1, 10):
    input_audio, label_audio = dataset[i]
    assert input_audio.shape == (2, assert_f), f"Error at index {i}: Input tensor shape mismatch."
    assert label_audio.shape == (2, assert_f), f"Error at index {i}: Label tensor shape mismatch."

# Test multiple batches via get_batch
for _ in range(5):
    input_batch, label_batch = dataset.get_batch(1)
    assert input_batch.shape == (1, 2, assert_f), f"Error at index {i}: Batch input tensor shape mismatch."
    assert label_batch.shape == (1, 2, assert_f), f"Error at index {i}: Batch label tensor shape mismatch."

# Check edge cases (last sample)
input_audio_last, label_audio_last = dataset[len(dataset) - 1]
assert input_audio_last.shape == (2, assert_f), "Error: Last sample input tensor shape mismatch."
assert label_audio_last.shape == (2, assert_f), "Error: Last sample label tensor shape mismatch."

# Check random access in get_batch
for i in range(5):
    input_batch, label_batch = dataset.get_batch(1, randomized=True)
    assert input_batch.shape == (1, 2, assert_f), f"Error at loop {i}: Randomized batch input tensor shape mismatch."
    assert label_batch.shape == (1, 2, assert_f), f"Error at loop {i}: Randomized batch label tensor shape mismatch."

# Check dataset item size consistency
for idx in tqdm(range(len(dataset)), desc="Checking dataset item size consistency"):
    input_audio, label_audio = dataset[idx]
    assert input_audio.shape == (2, assert_f), f"Error at index {idx}: Input tensor shape mismatch. Expected (2, {assert_f}), got {input_audio.shape}"
    assert label_audio.shape == (2, assert_f), f"Error at index {idx}: Label tensor shape mismatch. Expected (2, {assert_f}), got {label_audio.shape}"

# Taking a listen ensures consistent pair assembly
input_audio, label_audio = dataset[np.random.randint(0, len(dataset)-1)]
input_audio = input_audio.T.numpy()
label_audio = label_audio.T.numpy()
write('input_sample.wav', freq_orig, input_audio)
write('label_sample.wav', freq_orig, label_audio)

print("\nSanity checks passed!")

Total Chunk Count: 2329 
Input Tensor: torch.Size([2, 88200]) 
Label Tensor: torch.Size([2, 88200])
Batch Input Tensor: torch.Size([1, 2, 88200]) 
Batch Label Tensor: torch.Size([1, 2, 88200])
Loader Input Tensor: torch.Size([1, 2, 88200]) 
Loader Label Tensor: torch.Size([1, 2, 88200])


Checking dataset item size consistency: 100%|██████████| 2329/2329 [00:20<00:00, 112.48it/s]


Sanity checks passed!





---

## Model

In [6]:
class ResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate=32, num_layers=6):
        super(ResidualDenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_channels + i * growth_rate
            self.norms.append(nn.InstanceNorm1d(in_ch))
            self.layers.append(nn.Sequential(
                nn.Conv1d(in_ch, growth_rate, kernel_size=3, padding=1),
                nn.InstanceNorm1d(growth_rate),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.1)))
        final_ch = in_channels + num_layers * growth_rate
        self.final_norm = nn.InstanceNorm1d(final_ch)
        self.final_conv = nn.Conv1d(final_ch, in_channels, kernel_size=1)
        self.output_norm = nn.InstanceNorm1d(in_channels)
        
    def forward(self, x):
        features = [x]
        for norm, layer in zip(self.norms, self.layers):
            inputs = torch.cat(features, dim=1)
            inputs = norm(inputs)
            out = layer(inputs)
            features.append(out)
        out = self.final_norm(torch.cat(features, dim=1))
        out = self.final_conv(out)
        out = self.output_norm(out)
        return (out * 0.5) + x

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.down = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU())
        
    def forward(self, x):
        return self.down(x)

class Resizer(nn.Module):
    def __init__(self, in_channel, out_channel, target_size):
        super(Resizer, self).__init__()
        self.conv = nn.ConvTranspose1d(in_channel, out_channel, kernel_size=4, stride=2, padding=1)
        self.target_size = target_size
        self.activation = nn.LeakyReLU()

    def forward(self, x):
        x = self.conv(x)
        x = F.interpolate(x, size=self.target_size, mode='linear', align_corners=True)
        return self.activation(x)

class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='linear', align_corners=True),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU())
        
    def forward(self, x, skip):
        x = self.up(x)
        diff = x.size(2) - skip.size(2)
        if diff > 0:
            x = x[:, :, :skip.size(2)]
        elif diff < 0:
            x = nn.functional.pad(x, (0, -diff))
        return x

class CV_TasNet_Block(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks=14):
        super(CV_TasNet_Block, self).__init__()
        self.blocks = nn.ModuleList([self._build_block(in_channels, out_channels) for _ in range(num_blocks)])
        
    def _build_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, in_channels, kernel_size=3, padding=1))
        
    def forward(self, x):
        residual = x
        for block in self.blocks:
            x = block(x) + residual
            residual = x
        return x

class TemporalSelfAttention(nn.Module):
    def __init__(self, d_model, nhead=8):
        super(TemporalSelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = x.permute(2, 0, 1)
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x.permute(1, 2, 0)

class AudioUNet(nn.Module):
    def __init__(self, num_channels=2):
        super(AudioUNet, self).__init__()
        self.down1 = ResidualDenseBlock(num_channels)
        self.down2 = DownSample(2, 4)
        self.down3 = DownSample(4, 8)
        self.down4 = DownSample(8, 16)
        self.cv_tasnet = CV_TasNet_Block(16, 16)
        self.temporal_attention = TemporalSelfAttention(16)
        self.up4 = UpSample(16, 8)
        self.up3 = UpSample(8, 4)
        self.up2 = UpSample(4, 2)
        self.resizer = Resizer(2, 2, assert_f)
        self.up1 = nn.Sequential(nn.Conv1d(4, 4, kernel_size=3, padding=1),
                                 nn.LeakyReLU(),
                                 nn.Conv1d(4, num_channels, kernel_size=3, padding=1),
                                 nn.Tanh()) # Normalize to [-1, 1] for audio
 
    def forward(self, x):
        x_orig = x
        skip1 = self.down1(x)            # (batch, 2, (freq_orig) * chunk_duration)
        skip2 = self.down2(skip1)        # (batch, 4, ((freq_orig) * chunk_duration) / 2)
        skip3 = self.down3(skip2)        # (batch, 8, ((freq_orig) * chunk_duration) / 4)
        skip4 = self.down4(skip3)        # (batch, 16, ((freq_orig) * chunk_duration) / 4)
        x = self.cv_tasnet(skip4)        # (batch, 16, ((freq_orig) * chunk_duration) / 8)
        x = self.temporal_attention(x)   # (batch, 16, ((freq_orig) * chunk_duration) / 8)
        x = self.up4(x, skip4)           # (batch, 8, ((freq_orig) * chunk_duration) / 4)
        x = self.up3(x, skip3)           # (batch, 4, ((freq_orig) * chunk_duration) / 4)
        x = self.up2(x, skip2)           # (batch, 2, ((freq_orig) * chunk_duration) / 2)
        x = self.resizer(x)              # (batch, 2, (freq_orig) * chunk_duration)
        x = torch.cat([x, skip1], dim=1) # (batch, 4, (freq_orig) * chunk_duration)
        x = self.up1(x)                  # (batch, 2, (freq_orig) * chunk_duration)
        return x + x_orig                # x applies 'morphing' to original input audio -> model focuses solely on splicing, not reconstruction

---

## Training

In [7]:
def create_train_val_splits(dataset, val_ratio=0.2, shuffle=True):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(val_ratio * dataset_size))
    if shuffle:
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    return SubsetRandomSampler(train_indices), SubsetRandomSampler(val_indices)

In [8]:
dataset = AudioDataset(data_dir)
train_sampler, val_sampler = create_train_val_splits(dataset, val_ratio=0.2, shuffle=True)

train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=AudioDataset.collate_fn, num_workers=0)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, collate_fn=AudioDataset.collate_fn, num_workers=0)

In [9]:
def spectral_loss(output, target, n_fft=1024, hop_length=None, epsilon=1e-8):
    if hop_length is None:
        hop_length = n_fft // 4
    window = torch.hann_window(n_fft).to(output.device)
    def safe_log(x):
        return torch.log(torch.clamp(x, min=epsilon))
    loss = 0
    for i in range(output.shape[1]):
        output_stft = torch.stft(output[:, i, :], n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True)
        target_stft = torch.stft(target[:, i, :], n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True)
        output_mag, target_mag = torch.abs(output_stft), torch.abs(target_stft)
        output_log_mag, target_log_mag = safe_log(output_mag + epsilon), safe_log(target_mag + epsilon)
        mag_loss = torch.mean(torch.abs(output_log_mag - target_log_mag))
        output_phase, target_phase = torch.angle(output_stft), torch.angle(target_stft)
        phase_loss = 1 - torch.mean(torch.cos(output_phase - target_phase))
        loss += mag_loss + 0.1 * phase_loss
    return loss / output.shape[1]

In [10]:
model = AudioUNet(num_channels=2).to(device).float() # Adjust audio_length and num_speakers
model = torch.compile(model)
criterion_mse = nn.MSELoss().to(device) # Mean Squared Error
base_optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, eps=1e-8)

# Switching between providing 'fast weights' and 'slow weights' for AdamW optimizer update calculations
optimizer = optim.Lookahead(base_optimizer, k=optim_k, alpha=optim_alpha)

# Gradually warm and then cool down LR over time
scheduler = OneCycleLR(optimizer, max_lr=learning_rate*2, epochs=num_epochs, steps_per_epoch=len(train_loader) // accum_steps,
                       pct_start=0.1, anneal_strategy='cos', div_factor=10.0, final_div_factor=1000.0)

print(f"Model Parameter Count: {sum(p.numel() for p in model.parameters()):,}")

Model Parameter Count: 78,000


In [11]:
best_val_loss = float('inf') # Initialize as highest possible

print(f'Effective Batch Size: {batch_size * accum_steps} examples')

for epoch in range(num_epochs):
    # Training loop
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device).float(), labels.to(device).float()
        with torch.amp.autocast(device_type=str(device)):
            outputs = model(inputs)
            mse_loss = criterion_mse(outputs, labels)
            spec_loss = spectral_loss(outputs, labels, n_fft=1024, hop_length=256)
            # Normalize mini losses to emulate larger batch size (not just accumulation)
            loss = (mse_weight * mse_loss + spectral_weight * spec_loss) / accum_steps
        loss.backward() # Accumulate gradients
        if (i + 1) % accum_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
        if torch.isnan(loss):
            print(f"NaN detected in loss at epoch {epoch+1}, batch {i+1}")
            print(f"MSE Loss: {mse_loss.item()}, Spectral Loss: {spec_loss.item()}")
        # Showing the true mini-batch loss during logging
        running_loss += loss.item() * accum_steps
        if str(device) == "cuda":
            del inputs, labels, outputs, mse_loss, spec_loss, loss
            torch.cuda.empty_cache()
        if i % (2 * accum_steps) == (2 * accum_steps) - 1:
            print(f'Epoch [{epoch+1:3}/{num_epochs}] | '
                  f'Mini-Batch [{i+1:4}/{len(data_loader)}] | '
                  f'Train: {(running_loss / (2 * accum_steps)):8.6f} | '
                  f'LR: {optimizer.param_groups[-1]["lr"]:.6f}')
            running_loss = 0.0
    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(val_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device).float(), labels.to(device).float()
            with torch.amp.autocast(device_type=str(device)):
                outputs = model(inputs)
                mse_loss = criterion_mse(outputs, labels)
                spec_loss = spectral_loss(outputs, labels, n_fft=1024, hop_length=256)
                loss = mse_weight * mse_loss + spectral_weight * spec_loss
                val_loss += loss.item()
            if device.type == "cuda":
                del inputs, labels, outputs, mse_loss, spec_loss, loss
                torch.cuda.empty_cache()
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch [{epoch+1:3}/{num_epochs}] | Validation: {avg_val_loss:8.6f}')
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'val_loss': best_val_loss,
            }, 'best_audio_unet.pth')
    if str(device) == "cuda":
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()

Effective Batch Size: 256 examples
Epoch [  1/75] | Mini-Batch [ 512/2329] | Train: 3.176363 | LR: 0.000207
Epoch [  1/75] | Mini-Batch [1024/2329] | Train: 3.083100 | LR: 0.000227
Epoch [  1/75] | Mini-Batch [1536/2329] | Train: 3.046726 | LR: 0.000260
Epoch [  1/75] | Validation: 2.084655
Epoch [  2/75] | Mini-Batch [ 512/2329] | Train: 2.995884 | LR: 0.000332
Epoch [  2/75] | Mini-Batch [1024/2329] | Train: 2.966147 | LR: 0.000395
Epoch [  2/75] | Mini-Batch [1536/2329] | Train: 3.094161 | LR: 0.000468
Epoch [  2/75] | Validation: 1.870864
Epoch [  3/75] | Mini-Batch [ 512/2329] | Train: 2.784604 | LR: 0.000596
Epoch [  3/75] | Mini-Batch [1024/2329] | Train: 2.921733 | LR: 0.000690
Epoch [  3/75] | Mini-Batch [1536/2329] | Train: 2.751730 | LR: 0.000791
Epoch [  3/75] | Validation: 1.850610
Epoch [  4/75] | Mini-Batch [ 512/2329] | Train: 2.737650 | LR: 0.000950
Epoch [  4/75] | Mini-Batch [1024/2329] | Train: 2.613682 | LR: 0.001059
Epoch [  4/75] | Mini-Batch [1536/2329] | Train:

## Inference

In [15]:
assert_f = (freq_orig) * chunk_duration

def load_model(model_path, device='cpu'):
    if isinstance(device, str):
        device = torch.device(device)
    state_dict = torch.load(model_path, map_location=device, weights_only=False)
    if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
        state_dict = state_dict['model_state_dict']
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    new_state_dict = {}
    for key, value in state_dict.items():
        if 'weight_mask' not in key:
            new_key = key.replace('weight_orig', 'weight')
            new_state_dict[new_key] = value
    model = AudioUNet(num_channels=2).to(device).float()
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()
    return model

def preprocess_audio(audio_path, target_sample_rate=44100):
    if not os.path.exists(audio_path):
        raise FileNotFoundError(f"File not found: {audio_path}")
    waveform, sample_rate = torchaudio.load(audio_path)
    if sample_rate != target_sample_rate:
        transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = transform(waveform)
    # Normalize the waveform
    waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
    # Ensure the waveform is stereo
    if waveform.shape[0] == 1:
        waveform = waveform.repeat(2, 1)
    elif waveform.shape[0] > 2:
        waveform = waveform[:2, :]
    
    return waveform

def infer_audio(model, audio_path, output_path, device, chunk_duration=2, sample_rate=44100):
    waveform = preprocess_audio(audio_path, target_sample_rate=sample_rate)
    chunk_size = int(chunk_duration * sample_rate)
    total_length = waveform.shape[1]
    processed_waveform = torch.zeros_like(waveform)
    overlap_count = torch.zeros(waveform.shape[1], dtype=torch.float32)
    overlap = chunk_size // 2
    window = torch.hann_window(chunk_size, dtype=torch.float32)
    
    model.eval()
    with torch.no_grad():
        for start in tqdm(range(0, total_length, overlap), desc="Generating audio"):
            end = min(start + chunk_size, total_length)
            if end - start < chunk_size:
                chunk = torch.zeros((2, chunk_size), dtype=torch.float32)
                chunk[:, :(end - start)] = waveform[:, start:end]
            else:
                chunk = waveform[:, start:end]
            chunk = chunk.unsqueeze(0).to(device)
            output = model(chunk).squeeze(0).cpu()
            if end - start < chunk_size:
                valid_length = end - start
                output = output[:, :valid_length]
                window_section = window[:valid_length]
            else:
                window_section = window
            output = output * window_section.view(1, -1)
            if end - start < chunk_size:
                processed_waveform[:, start:end] += output
                overlap_count[start:end] += window_section
            else:
                processed_waveform[:, start:start + chunk_size] += output
                overlap_count[start:start + chunk_size] += window_section
    mask = overlap_count > 0
    processed_waveform[:, mask] /= overlap_count[mask].view(1, -1)
    processed_waveform = processed_waveform / (torch.max(torch.abs(processed_waveform)) + 1e-8)
    processed_waveform = torch.clamp(processed_waveform, -0.99, 0.99)
    sf.write(output_path, processed_waveform.T.numpy(), sample_rate)
    print(f"Generated audio saved at: {output_path}")
    return processed_waveform

model_path  = "best_audio_unet.pth"
audio_path  = "input_audio.wav"
output_path = "generated_audio.wav"

# You'll need more than 16GB of GPU memory to run via CUDA
device = torch.device("cpu")
model = load_model(model_path, device)
infer_audio(model, audio_path, output_path, device)

Generated audio saved at: generated_audio.wav


tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0660, 0.0667, 0.0660],
        [0.0000, 0.0000, 0.0000,  ..., 0.1472, 0.1460, 0.1472]])