# Audio Super-Resolution using Deep Learning

This notebook implements a deep learning model to perform audio super-resolution using a U-Net architecture.

## 1. Setup and Dependencies

First, let's install all the required libraries and import them.

In [None]:
!pip install torch torchaudio matplotlib numpy scipy librosa soundfile

import os
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from google.colab import drive
import librosa
import soundfile as sf
from tqdm.notebook import tqdm
import random
import gc
import json
from datetime import datetime
from tqdm import tqdm
import pickle
from pathlib import Path

## 2. Mount Google Drive and Set Paths

In [2]:
# Mount Google Drive to access your dataset
drive.mount('/content/drive')

# Set the path to your dataset
DATASET_PATH = '/content/drive/MyDrive/fma_small'  # Change this to your dataset location
RESULTS_DIR  = '/content/drive/MyDrive/audio_sr_results'  # For saving models and results

# Create output directory if it doesn't exist
os.makedirs(RESULTS_DIR, exist_ok=True)

Mounted at /content/drive


In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

## 3. Data Preparation

We'll create a custom dataset class for loading and preprocessing the audio files. This includes:
- Loading the audio files
- Converting to mono if needed
- Creating low-resolution (16 kHz) versions of the audio
- Splitting the dataset into training, validation, and test sets

In [3]:
class AudioSuperResolutionDataset(Dataset):
    def __init__(self, root_dir, segment_length=16384, sr_orig=44100, sr_low=16000, max_files=None):
        self.root_dir = root_dir
        self.segment_length = segment_length
        self.sr_orig = sr_orig
        self.sr_low = sr_low

        # Recursively find all mp3 files in the directory structure
        self.file_list = []
        print(f"Scanning for MP3 files in {root_dir} (recursive)...")

        # Walk through all subdirectories
        for dirpath, dirnames, filenames in os.walk(root_dir):
            for filename in filenames:
                if filename.endswith('.mp3'):
                    full_path = os.path.join(dirpath, filename)
                    self.file_list.append(full_path)

        print(f"Found {len(self.file_list)} MP3 files")

        # Shuffle the file list to ensure randomness across folders
        random.shuffle(self.file_list)

        # Limit the number of files if specified
        if max_files is not None:
            self.file_list = self.file_list[:max_files]
            print(f"Limited dataset to {len(self.file_list)} files")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        audio_path = self.file_list[idx]

        # Load audio file
        try:
            waveform, sample_rate = torchaudio.load(audio_path)
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            # If there's an error, return a random valid index instead
            return self.__getitem__(random.randint(0, len(self.file_list) - 1))

        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Resample if necessary
        if sample_rate != self.sr_orig:
            waveform = torchaudio.functional.resample(waveform, sample_rate, self.sr_orig)

        # Randomly select segment
        if waveform.shape[1] > self.segment_length:
            start_idx = torch.randint(0, waveform.shape[1] - self.segment_length, (1,))
            waveform = waveform[:, start_idx:start_idx + self.segment_length]
        else:
            # Pad if audio is shorter than segment_length
            padding = self.segment_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))

        # Create low resolution version with proper low-pass filtering
        # 1. Downsample to low SR
        waveform_low = torchaudio.functional.resample(waveform, self.sr_orig, self.sr_low)
        # 2. Upsample back to original SR
        waveform_low = torchaudio.functional.resample(waveform_low, self.sr_low, self.sr_orig)

        # 3. Apply low-pass filter to simulate bandwidth limitation
        cutoff_freq = self.sr_low / 2 * 0.6  # 60% of Nyquist frequency for the low sample rate
        waveform_low = torchaudio.functional.lowpass_biquad(
            waveform_low,
            self.sr_orig,
            cutoff_freq
        )

        # Ensure both waveforms have the same length after resampling and filtering
        if waveform_low.shape[1] != self.segment_length:
             if waveform_low.shape[1] > self.segment_length:
                waveform_low = waveform_low[:, :self.segment_length]
             else:
                padding = self.segment_length - waveform_low.shape[1]
                waveform_low = torch.nn.functional.pad(waveform_low, (0, padding))

        # Normalize
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
        waveform_low = waveform_low / (torch.max(torch.abs(waveform_low)) + 1e-8)

        return waveform_low.squeeze(0), waveform.squeeze(0)

# Create dataset splits
def create_dataset_splits(dataset_path, batch_size=8, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1,
                          segment_length=16384, sr_orig=44100, sr_low=16000, num_workers=2, seed=42,
                          max_files=1000):  # Add max_files parameter with default of 1000
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    random.seed(seed)

    # Create the dataset with recursive file search and file limit
    full_dataset = AudioSuperResolutionDataset(
        dataset_path,
        segment_length=segment_length,
        sr_orig=sr_orig,
        sr_low=sr_low,
        max_files=max_files  # Pass the max_files parameter
    )

    # Split the dataset
    dataset_size = len(full_dataset)
    train_size = int(train_ratio * dataset_size)
    val_size = int(val_ratio * dataset_size)
    test_size = dataset_size - train_size - val_size

    print(f"Dataset split: {train_size} training, {val_size} validation, {test_size} test samples")

    # Create a generator for reproducible splits
    generator = torch.Generator().manual_seed(seed)

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size, test_size], generator=generator
    )

    # Create data loaders with worker initialization to ensure proper randomization
    def worker_init_fn(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        worker_init_fn=worker_init_fn,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

## 4. Visualize Dataset Samples

Let's visualize some samples from our dataset to ensure everything is working correctly.

In [None]:
def visualize_samples(dataset_path, num_samples, segment_length):
    dataset = AudioSuperResolutionDataset(dataset_path, segment_length=segment_length)

    for i in range(num_samples):
        #idx = random.randint(0, len(dataset) - 1)
        idx = 2
        low_res, high_res = dataset[idx]

        plt.figure(figsize=(15, 6))

        plt.subplot(2, 1, 1)
        plt.plot(high_res.numpy())
        plt.title('High Resolution (44.1 kHz)')
        plt.xlabel('Sample')
        plt.ylabel('Amplitude')

        plt.subplot(2, 1, 2)
        plt.plot(low_res.numpy())
        plt.title('Low Resolution (44.1 kHz  ->  16 kHz  ->  44.1 kHz)')
        plt.xlabel('Sample')
        plt.ylabel('Amplitude')

        plt.tight_layout()
        plt.show()

        # Also plot spectrograms
        plt.figure(figsize=(15, 6))

        plt.subplot(1, 2, 1)
        spec = librosa.stft(high_res.numpy())
        spec_db = librosa.amplitude_to_db(np.abs(spec), ref=np.max)
        plt.imshow(spec_db, aspect='auto', origin='lower')
        plt.title('High Resolution Spectrogram')
        plt.colorbar(format='%+2.0f dB')

        # The spectrogram will apear much more "noisy" compared to the high-res version
        plt.subplot(1, 2, 2)
        spec = librosa.stft(low_res.numpy())
        spec_db = librosa.amplitude_to_db(np.abs(spec), ref=np.max)
        plt.imshow(spec_db, aspect='auto', origin='lower')
        plt.title('Low Resolution Spectrogram')
        plt.colorbar(format='%+2.0f dB')

        plt.tight_layout()
        plt.show()

# Uncomment to visualize samples
visualize_samples(DATASET_PATH, num_samples=1, segment_length=8192)

## 5. Model Architecture

We'll implement a U-Net architecture for audio super-resolution, which has proven effective for this task.

In [5]:
class AudioUNet(nn.Module):
    def __init__(self):
        super(AudioUNet, self).__init__()

        # Encoder
        self.enc1 = nn.Conv1d(1, 64, kernel_size=15, stride=1, padding=7)
        self.enc2 = nn.Conv1d(64, 128, kernel_size=15, stride=2, padding=7)
        self.enc3 = nn.Conv1d(128, 256, kernel_size=15, stride=2, padding=7)
        self.enc4 = nn.Conv1d(256, 512, kernel_size=15, stride=2, padding=7)

        # Decoder
        self.dec4 = nn.ConvTranspose1d(512, 256, kernel_size=16, stride=2, padding=7)
        self.dec3 = nn.ConvTranspose1d(512, 128, kernel_size=16, stride=2, padding=7)
        self.dec2 = nn.ConvTranspose1d(256, 64, kernel_size=16, stride=2, padding=7)
        self.dec1 = nn.Conv1d(128, 1, kernel_size=15, stride=1, padding=7)

        # Batch normalization
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(256)
        self.bn4 = nn.BatchNorm1d(512)

        self.bn_dec4 = nn.BatchNorm1d(256)
        self.bn_dec3 = nn.BatchNorm1d(128)
        self.bn_dec2 = nn.BatchNorm1d(64)

    def forward(self, x):
        # Add channel dimension if needed
        if len(x.shape) == 2:
            x = x.unsqueeze(1)

        # Encoder
        e1 = F.leaky_relu(self.bn1(self.enc1(x)), 0.2)
        e2 = F.leaky_relu(self.bn2(self.enc2(e1)), 0.2)
        e3 = F.leaky_relu(self.bn3(self.enc3(e2)), 0.2)
        e4 = F.leaky_relu(self.bn4(self.enc4(e3)), 0.2)

        # Decoder with skip connections
        d4 = F.relu(self.bn_dec4(self.dec4(e4)))
        d4 = torch.cat([d4, e3], dim=1)  # Skip connection

        d3 = F.relu(self.bn_dec3(self.dec3(d4)))
        d3 = torch.cat([d3, e2], dim=1)  # Skip connection

        d2 = F.relu(self.bn_dec2(self.dec2(d3)))
        d2 = torch.cat([d2, e1], dim=1)  # Skip connection

        d1 = torch.tanh(self.dec1(d2))

        return d1

## 6. Define STFT loss function

In [6]:
class STFTLoss(nn.Module):
    def __init__(self, fft_sizes = [512, 1024, 2048, 4096], hop_sizes = [128, 256, 512, 1024], win_lengths = [512, 1024, 2048, 4096], window='hann'):
        super(STFTLoss, self).__init__()

        # Use multiple FFT sizes for multi-resolution analysis
        self.fft_sizes = fft_sizes
        self.hop_sizes = hop_sizes
        self.win_lengths = win_lengths if win_lengths else fft_sizes

        # Register windows for each FFT size
        for i in range(len(self.fft_sizes)):
            self.register_buffer(f'window_{i}', torch.hann_window(self.win_lengths[i]))

    def stft(self, x, fft_size, hop_size, win_length, window):
        # Handle different PyTorch versions
        if hasattr(torch, 'stft'):
            try:
                # For PyTorch 1.7+
                return torch.stft(
                    x, fft_size, hop_size, win_length, window,
                    return_complex=True)
            except TypeError:
                # For older PyTorch versions
                stft_result = torch.stft(
                    x, fft_size, hop_size, win_length, window)
                real, imag = stft_result.unbind(-1)
                return torch.complex(real, imag)
        else:
            raise RuntimeError("Current PyTorch version doesn't support torch.stft")

    def compute_spectral_convergence(self, x_mag, y_mag):
        return torch.norm(x_mag - y_mag, p='fro') / torch.norm(y_mag, p='fro')

    def compute_magnitude_loss(self, x_mag, y_mag):
        return torch.mean(torch.abs(x_mag - y_mag))

    def forward(self, x, y):
        # Ensure inputs are the same shape
        if x.size() != y.size():
            raise ValueError(f"Inputs must have same size, got {x.size()} and {y.size()}")

        # Ensure inputs are 2D (batch, samples)
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
            y = y.unsqueeze(0)

        # Initialize losses
        sc_loss = 0.0
        mag_loss = 0.0

        # Compute loss for each FFT size
        for i in range(len(self.fft_sizes)):
            x_stft = self.stft(x, self.fft_sizes[i], self.hop_sizes[i],
                               self.win_lengths[i], getattr(self, f'window_{i}'))
            y_stft = self.stft(y, self.fft_sizes[i], self.hop_sizes[i],
                               self.win_lengths[i], getattr(self, f'window_{i}'))

            # Compute magnitudes
            x_mag = torch.abs(x_stft)
            y_mag = torch.abs(y_stft)

            # Accumulate losses
            sc_loss += self.compute_spectral_convergence(x_mag, y_mag)
            mag_loss += self.compute_magnitude_loss(x_mag, y_mag)

        # Average over number of FFT sizes
        sc_loss = sc_loss / len(self.fft_sizes)
        mag_loss = mag_loss / len(self.fft_sizes)

        # Total loss (weighted sum)
        loss = sc_loss + mag_loss

        return loss

## 7. Training Functions

Now let's define functions for training and monitoring our model's progress.

In [None]:
def save_to_file(file_path, train_losses, val_losses, learning_rates):
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(file_path), exist_ok=True)

    # Prepare data structure
    metrics_data = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "epochs_completed": len(train_losses),
        "best_val_loss": min(val_losses) if val_losses else None,
        "best_epoch": val_losses.index(min(val_losses)) + 1 if val_losses else None,
        "final_learning_rate": learning_rates[-1] if learning_rates else None,
        "train_losses": train_losses,
        "val_losses": val_losses,
        "learning_rates": learning_rates
    }

    # Save to file (with error handling)
    try:
        with open(file_path, 'w') as f:
            json.dump(metrics_data, f, indent=4)
        print(f"Training metrics saved to {file_path}")
    except Exception as e:
        print(f"Error saving metrics to {file_path}: {e}")


def plot_training_metrics(metrics_file=None, train_losses=None, val_losses=None, learning_rates=None, output_path=None):
    # Load metrics from file if provided
    if metrics_file and (train_losses is None or val_losses is None or learning_rates is None):
        try:
            with open(metrics_file, 'r') as f:
                metrics_data = json.load(f)

            train_losses = metrics_data.get('train_losses', [])
            val_losses = metrics_data.get('val_losses', [])
            learning_rates = metrics_data.get('learning_rates', [])

            print(f"Loaded metrics from {metrics_file}")
            print(f"Epochs: {len(train_losses)}")
            print(f"Best validation loss: {min(val_losses):.4f} at epoch {val_losses.index(min(val_losses)) + 1}")

        except Exception as e:
            print(f"Error loading metrics from {metrics_file}: {e}")
            return None

    # Create figure and subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Generate epochs range
    epochs_range = list(range(1, len(train_losses) + 1))

    # Plot training and validation loss
    ax1.plot(epochs_range, train_losses, 'b-', label='Training Loss')
    ax1.plot(epochs_range, val_losses, 'r-', label='Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('STFT Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True, linestyle='--', alpha=0.6)

    # Add annotation for best validation loss
    if val_losses:
        min_val_epoch = val_losses.index(min(val_losses)) + 1
        min_val_loss = min(val_losses)
        ax1.annotate(f'Best: {min_val_loss:.4f}',
                    xy=(min_val_epoch, min_val_loss),
                    xytext=(min_val_epoch, min_val_loss * 1.1),
                    arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
                    fontsize=10)

    # Plot learning rate
    ax2.plot(epochs_range, learning_rates, 'g-')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title('Learning Rate Schedule')
    ax2.grid(True, linestyle='--', alpha=0.6)
    ax2.set_yscale('log')

    plt.tight_layout()

    # Save figure if output path is provided
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path)
        print(f"Plot saved to {output_path}")

# Example usage:
# 1. Plot from saved file:
metrics_file = os.path.join(RESULTS_DIR, "metrics.json")
output_path  = os.path.join(RESULTS_DIR, "reconstructed_plot.png")
plot_training_metrics(metrics_file=metrics_file, output_path=output_path)

# 2. Plot from provided lists:
# plot_training_metrics(train_losses=train_losses, val_losses=val_losses, learning_rates=learning_rates)

In [8]:
def train_model(model, train_loader, val_loader, num_epochs=100, learning_rate=0.001,
                gradient_accumulation_steps=4, checkpoint=None):
    # Set matplotlib to inline mode for notebook display
    %matplotlib inline
    import torch.nn as nn

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = model.to(device)

    # STFT Loss instead of L1 Loss
    criterion = STFTLoss(
        fft_sizes   = [512, 1024, 2048, 4096],
        hop_sizes   = [128, 256, 512, 1024],
        win_lengths = [512, 1024, 2048, 4096],
        window='hann'
    ).to(device)

    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Initialize training history and start epoch
    train_losses = []
    val_losses = []
    learning_rates = []
    start_epoch = 0
    best_val_loss = float('inf')

    # Load from checkpoint if provided
    if checkpoint is not None:
        # Load optimizer state
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Get the epoch we're resuming from
        start_epoch = checkpoint['epoch'] + 1

        # Load previous losses if available
        if 'train_losses' in checkpoint:
            train_losses = checkpoint.get('train_losses', [])
        if 'val_losses' in checkpoint:
            val_losses = checkpoint.get('val_losses', [])

        # Load best validation loss if available
        if 'best_val_loss' in checkpoint:
            best_val_loss = checkpoint['best_val_loss']
        elif 'val_loss' in checkpoint:
            best_val_loss = checkpoint['val_loss']

        # Load learning rates if available
        if 'learning_rates' in checkpoint:
            learning_rates = checkpoint['learning_rates']

        print(f"Resumed training from epoch {start_epoch}")
        print(f"Previous best validation loss: {best_val_loss:.4f}")

    # Initialize scheduler (with proper state if resuming)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

    # If resuming, update the scheduler with the last validation loss
    if checkpoint is not None and len(val_losses) > 0:
        scheduler.step(val_losses[-1])

    for epoch in range(start_epoch, num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        optimizer.zero_grad()

        for i, (low_res, high_res) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")):
            low_res, high_res = low_res.to(device), high_res.to(device)

            # Forward pass
            outputs = model(low_res)

            # Apply STFT loss - ensure outputs and targets are properly shaped
            # For waveform processing, we expect [B, T] shape
            outputs = outputs.squeeze(1) if outputs.dim() > 2 else outputs
            high_res = high_res.squeeze(1) if high_res.dim() > 2 else high_res

            loss = criterion(outputs, high_res) / gradient_accumulation_steps

            # Backward pass
            loss.backward()

            # Update weights after accumulating gradients
            if (i + 1) % gradient_accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()
                optimizer.zero_grad()

            train_loss += loss.item() * gradient_accumulation_steps

            # Clear memory periodically
            if i % 50 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()

        train_loss = train_loss / len(train_loader)
        train_losses.append(train_loss)

        # Validation
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for low_res, high_res in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                low_res, high_res = low_res.to(device), high_res.to(device)

                outputs = model(low_res)

                # Apply STFT loss - ensure outputs and targets are properly shaped
                outputs = outputs.squeeze(1) if outputs.dim() > 2 else outputs
                high_res = high_res.squeeze(1) if high_res.dim() > 2 else high_res

                loss = criterion(outputs, high_res)

                val_loss += loss.item()

        val_loss = val_loss / len(val_loader)
        val_losses.append(val_loss)

        # Track current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        learning_rates.append(current_lr)

        # Adjust learning rate
        scheduler.step(val_loss)

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'learning_rates': learning_rates,
                'best_val_loss': best_val_loss,
                'train_loss': train_loss,
                'val_loss': val_loss,
            }, os.path.join(RESULTS_DIR, 'best_model.pt'))
            print(f"Saved new best model with validation loss: {val_loss:.4f}")

        # Plot and save loss curves every 5 epochs
        if (epoch + 1) % 3 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'learning_rates': learning_rates,
                'best_val_loss': best_val_loss,
                'train_loss': train_loss,
                'val_loss': val_loss,
            }, os.path.join(RESULTS_DIR, f'model_checkpoint_epoch_{epoch+1}.pt'))

            # Create and display loss plots
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

            # Loss plot
            epochs_range = list(range(1, len(train_losses) + 1))
            ax1.plot(epochs_range, train_losses, 'b-', label='Training Loss')
            ax1.plot(epochs_range, val_losses, 'r-', label='Validation Loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('STFT Loss')
            ax1.set_title(f'Training and Validation Loss (Epoch {epoch+1})')
            ax1.legend()
            ax1.grid(True, linestyle='--', alpha=0.6)

            # Learning rate plot
            ax2.plot(epochs_range, learning_rates, 'g-')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Learning Rate')
            ax2.set_title('Learning Rate Schedule')
            ax2.grid(True, linestyle='--', alpha=0.6)
            ax2.set_yscale('log')

            plt.tight_layout()
            plt.show()

            # Save the figure
            fig.savefig(os.path.join(RESULTS_DIR, 'loss_curves.png'))
            plt.close()

            # Collect garbage to free memory
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        save_to_file(os.path.join(RESULTS_DIR, "metrics.json"), train_losses, val_losses, learning_rates)


    # Save final model
    torch.save(model.state_dict(), os.path.join(RESULTS_DIR, 'audio_sr_final_model.pt'))

    # Final plot with all training history
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Final loss plot
    epochs_range = list(range(1, len(train_losses) + 1))
    ax1.plot(epochs_range, train_losses, 'b-', label='Training Loss')
    ax1.plot(epochs_range, val_losses, 'r-', label='Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('STFT Loss')
    ax1.set_title('Training and Validation Loss (Final)')
    ax1.legend()
    ax1.grid(True, linestyle='--', alpha=0.6)

    # Add min/max annotations
    min_val_epoch = val_losses.index(min(val_losses)) + 1
    ax1.annotate(f'Best: {min(val_losses):.4f}',
                xy=(min_val_epoch, min(val_losses)),
                xytext=(min_val_epoch, min(val_losses) * 1.1),
                arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
                fontsize=10)

    # Learning rate plot
    ax2.plot(epochs_range, learning_rates, 'g-')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title('Learning Rate Schedule')
    ax2.grid(True, linestyle='--', alpha=0.6)
    ax2.set_yscale('log')

    plt.tight_layout()
    plt.show()

    # Save the final figure
    fig.savefig(os.path.join(RESULTS_DIR, 'final_training_plots.png'))

    print(f"Training completed! Best validation loss: {best_val_loss:.4f}")

    return train_losses, val_losses

## 8. Evaluation Functions

In [9]:
def evaluate_model(model, test_loader):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Initialize STFT loss for evaluation
    stft_criterion = STFTLoss(
        fft_sizes=[1024, 2048, 512],
        hop_sizes=[120, 240, 50],
        win_lengths=None,
        window='hann'
    ).to(device)

    # Metrics
    total_l1_loss = 0.0
    total_mse_loss = 0.0
    total_stft_loss = 0.0

    with torch.no_grad():
        for i, (low_res, high_res) in enumerate(tqdm(test_loader, desc="Evaluating")):
            low_res, high_res = low_res.to(device), high_res.to(device)

            # Forward pass
            outputs = model(low_res)

            # Ensure correct dimensions for waveform processing
            outputs_waveform = outputs.squeeze(1) if outputs.dim() > 2 else outputs
            high_res_waveform = high_res.squeeze(1) if high_res.dim() > 2 else high_res

            # Calculate metrics
            l1_loss = F.l1_loss(outputs_waveform, high_res_waveform)
            mse_loss = F.mse_loss(outputs_waveform, high_res_waveform)
            stft_loss = stft_criterion(outputs_waveform, high_res_waveform)

            total_l1_loss += l1_loss.item()
            total_mse_loss += mse_loss.item()
            total_stft_loss += stft_loss.item()

    # Calculate average metrics
    avg_l1_loss = total_l1_loss / len(test_loader)
    avg_mse_loss = total_mse_loss / len(test_loader)
    avg_stft_loss = total_stft_loss / len(test_loader)

    print(f"Test Results:")
    print(f"- Average L1 Loss: {avg_l1_loss:.4f}")
    print(f"- Average MSE Loss: {avg_mse_loss:.4f}")
    print(f"- Average STFT Loss: {avg_stft_loss:.4f}")

    return avg_l1_loss, avg_mse_loss, avg_stft_loss

## 9. Inference Function

Let's create a function to enhance new audio files using our trained model.

In [10]:
def enhance_audio(model_path, input_audio_path, output_audio_path):

    # Load the trained model
    model = AudioUNet()

    # Load the checkpoint and extract just the model state dictionary
    checkpoint = torch.load(model_path)

    # Use the model_state_dict key instead of the whole checkpoint
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Load and preprocess the input audio
    waveform, sample_rate = torchaudio.load(input_audio_path)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Resample to the original high sampling rate if necessary
    if sample_rate != 44100:
        waveform = torchaudio.functional.resample(waveform, sample_rate, 44100)

    # Normalize
    waveform = waveform / torch.max(torch.abs(waveform))

    # Process in chunks to avoid memory issues
    chunk_size = 65536  # Adjust based on available memory
    enhanced_chunks = []
    total_chunks = (waveform.shape[1] + chunk_size - 1) // chunk_size

    for i in range(0, waveform.shape[1], chunk_size):
        chunk_num = i // chunk_size + 1
        chunk = waveform[:, i:i+chunk_size]

        # Pad if necessary
        if chunk.shape[1] < chunk_size:
            padding = chunk_size - chunk.shape[1]
            chunk = torch.nn.functional.pad(chunk, (0, padding))

        # Move to device
        chunk = chunk.to(device)

        # Enhance
        with torch.no_grad():
            enhanced_chunk = model(chunk)

        # Remove padding if added
        if i + chunk_size > waveform.shape[1]:
            original_length = waveform.shape[1] - i
            enhanced_chunk = enhanced_chunk[:, :original_length]

        # Add to list
        enhanced_chunks.append(enhanced_chunk.cpu())

    # Concatenate chunks
    enhanced_waveform = torch.cat(enhanced_chunks, dim=1)

    # Ensure the tensor is 2D [channels, time] before saving
    if enhanced_waveform.dim() != 2:
        if enhanced_waveform.dim() > 2:
            # If it has more than 2 dimensions, flatten all but the first dimension
            enhanced_waveform = enhanced_waveform.reshape(enhanced_waveform.shape[0], -1)
        elif enhanced_waveform.dim() == 1:
            # If it's 1D, add a channel dimension
            enhanced_waveform = enhanced_waveform.unsqueeze(0)

    # Save the enhanced audio
    torchaudio.save(output_audio_path, enhanced_waveform, 44100)

    # Return waveforms for visualization
    return waveform.cpu(), enhanced_waveform

## 10. Training Pipeline

### 10.1. Split the dataset into train, val and test sets.

In [None]:
# Check GPU availability
if torch.cuda.is_available():
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"Memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
else:
    print("No GPU available, using CPU")


MAX_FILES      = 2000   # size of (train_set + val_set + test_set)
BATCH_SIZE     = 32     # Reduce this if you encounter OOM errors
SEGMENT_LENGTH = 65536  # Reduce this if you encounter OOM errors

# Create dataset splits
train_loader, val_loader, test_loader = create_dataset_splits(
    DATASET_PATH,
    batch_size=BATCH_SIZE,
    max_files=MAX_FILES,
    segment_length=SEGMENT_LENGTH
)

print(f"Dataset ready: {len(train_loader)} training batches, {len(val_loader)} validation batches")

### 10.2. Function definitions for saving and loading dataset split for future use when resuming training from checkpoint.

In [12]:
DATASET_SAVE_PATH = os.path.join(RESULTS_DIR, "dataset_splits.pkl") # useful when training from chekpoint and using the same splits

# Load dataset splits from a file if it exists.
def load_dataset_splits(save_path=DATASET_SAVE_PATH):
    if os.path.exists(save_path):
        print(f"Loading dataset splits from {save_path}")
        with open(save_path, 'rb') as f:
            dataset_info = pickle.load(f)

        # Recreate the DataLoaders with the saved datasets
        train_loader = torch.utils.data.DataLoader(
            dataset_info['train_dataset'],
            batch_size=dataset_info['batch_size'],
            shuffle=True,
            pin_memory=True
        )

        val_loader = torch.utils.data.DataLoader(
            dataset_info['val_dataset'],
            batch_size=dataset_info['batch_size'],
            shuffle=False,
            pin_memory=True
        )

        test_loader = torch.utils.data.DataLoader(
            dataset_info['test_dataset'],
            batch_size=dataset_info['batch_size'],
            shuffle=False,
            pin_memory=True
        )

        return train_loader, val_loader, test_loader, True

    return None, None, None, False

# Save dataset splits to a file for later reuse.
def save_dataset_splits(train_loader, val_loader, test_loader, save_path=DATASET_SAVE_PATH):
    # Ensure the directory exists
    os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)

    # Save the dataset parameters
    dataset_info = {
        'train_dataset': train_loader.dataset,
        'val_dataset': val_loader.dataset,
        'test_dataset': test_loader.dataset,
        'batch_size': train_loader.batch_size,
        'segment_length': getattr(train_loader.dataset, 'segment_length', None)
    }

    with open(save_path, 'wb') as f:
        pickle.dump(dataset_info, f)

    print(f"Dataset splits saved to {save_path}")
    return save_path


# Save the dataset splits for future use
# save_dataset_splits(train_loader, val_loader, test_loader)

### 10.3. Start training

In [None]:
# Initialize model
model = AudioUNet()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train model
train_losses, val_losses = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=100,  # You can adjust this
    learning_rate=0.0010,
    gradient_accumulation_steps=4  # Helps with memory issues
)

### 10.4. (optional) Resume training from a specific checkpoint.
Files required:
 - `model_checkpoint.pt`
 - `previous_dataset_split.pkl`

In [None]:
# Initialize model
model = AudioUNet()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Checkpoint path - specify the checkpoint to resume from
checkpoint_path = os.path.join(RESULTS_DIR, "model_checkpoint_epoch_39.pt")

# Check if checkpoint exists
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Resuming from epoch {checkpoint['epoch'] + 1}")
else:
    print("No checkpoint found. Starting training from scratch.")
    checkpoint = None


train_loader, val_loader, test_loader, loaded = load_dataset_splits(DATASET_SAVE_PATH)
if not loaded:
    print("Failed to load dataset, please check the file path")

# Train model
train_losses, val_losses = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=110, # Remember to pass a higher epoch number
    gradient_accumulation_steps=4,
    checkpoint=checkpoint  # Pass the loaded checkpoint
)

## 11. Evaluate model on test set

In [None]:
# Load the best model
best_model = AudioUNet()
checkpoint = torch.load(os.path.join(RESULTS_DIR, 'best_model.pt'))#, map_location='cpu')
best_model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate model
l1_loss, mse_loss, stft_loss = evaluate_model(best_model, test_loader)

## 12. Inference on new audio files

In [None]:
# Example of using the enhance_audio function
input_path_hq = '/content/hq-track.mp3'  # Change this to your input file
input_path_lq = "/content/lq-track.mp3"
output_path = '/content/enhanced.wav'

# Downsample the hq track

# Load the audio
waveform, sample_rate = torchaudio.load(input_path_hq)
# Convert to mono if stereo
if waveform.shape[0] > 1:
    waveform = torch.mean(waveform, dim=0, keepdim=True)

# Create low resolution version with proper low-pass filtering

# 1. Downsample to low SR
waveform_low = torchaudio.functional.resample(waveform, sample_rate, 16000)

# 2. Upsample back to original SR
waveform_low = torchaudio.functional.resample(waveform_low, 16000, 44100)

# 3. Apply low-pass filter to simulate bandwidth limitation
cutoff_freq = 16000 / 2 * 0.6  # 60% of Nyquist frequency for the low sample rate
waveform_low = torchaudio.functional.lowpass_biquad(
    waveform_low,
    44100,
    cutoff_freq
)

torchaudio.save(input_path_lq, waveform_low, 44100)

# Uncomment to enhance an audio file
original, enhanced = enhance_audio(
    model_path=os.path.join(RESULTS_DIR, 'best_model.pt'),
    input_audio_path=input_path_lq,
    output_audio_path=output_path
)

# Visualize original and enhanced waveforms
plt.figure(figsize=(15, 6))
plt.subplot(2, 1, 1)
plt.plot(original.numpy()[0])
plt.title('Original Audio')

plt.subplot(2, 1, 2)
plt.plot(enhanced.numpy()[0])
plt.title('Enhanced Audio')

plt.tight_layout()
plt.show()