In [None]:
#!/usr/bin/env python
# coding: utf8
"""
PyTorch/Torchaudio Audio Processing Library and Training Script

This file provides:
  - An abstract AudioAdapter and a concrete TorchaudioAdapter implementation.
  - Converter functions for channel manipulation and gain/dB conversions.
  - Spectrogram computation and simple augmentation functions (time stretch, pitch shift).
  - An AudioDataset class that reads audio file paths from a CSV file, loads the waveform,
    computes spectrograms, and yields batches for training.
  - A simple convolutional network (a lightweight U-Net style model) for source separation.
  - A training loop that demonstrates training on the MUSDB18 dataset.

Author: Your Name
License: MIT License
"""

import os
import csv
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchaudio
import torchaudio.transforms as T

# =============================================================================
# AudioAdapter and Concrete TorchaudioAdapter
# =============================================================================

class AudioAdapter(ABC):
    """
    An abstract class for manipulating audio signals.
    """

    _DEFAULT: Optional["AudioAdapter"] = None

    @abstractmethod
    def load(
        self,
        audio_path: Union[str, Path],
        offset: Optional[float] = None,
        duration: Optional[float] = None,
        sample_rate: Optional[int] = None,
    ) -> Tuple[torch.Tensor, int]:
        """
        Loads an audio file and returns a waveform tensor and sample rate.
        """
        pass

    @abstractmethod
    def save(
        self,
        path: Union[str, Path],
        data: torch.Tensor,
        sample_rate: int,
    ) -> None:
        """
        Saves a waveform tensor to an audio file.
        """
        pass

    @classmethod
    def default(cls) -> "AudioAdapter":
        if cls._DEFAULT is None:
            cls._DEFAULT = TorchaudioAdapter()
        return cls._DEFAULT

    @classmethod
    def get(cls, descriptor: str) -> "AudioAdapter":
        # For simplicity, return the default adapter.
        return cls.default()


class TorchaudioAdapter(AudioAdapter):
    """
    A concrete AudioAdapter implementation using torchaudio.
    """

    def load(
        self,
        audio_path: Union[str, Path],
        offset: Optional[float] = None,
        duration: Optional[float] = None,
        sample_rate: Optional[int] = None,
    ) -> Tuple[torch.Tensor, int]:
        # Convert Path to string if necessary.
        if isinstance(audio_path, Path):
            audio_path = str(audio_path)
        # torchaudio.load returns a waveform of shape (channels, samples)
        waveform, orig_sr = torchaudio.load(audio_path, normalize=True)
        # Apply offset and duration if specified.
        if offset is not None or duration is not None:
            start = int(offset * orig_sr) if offset is not None else 0
            if duration is not None:
                end = start + int(duration * orig_sr)
            else:
                end = waveform.size(1)
            waveform = waveform[:, start:end]
        # Resample if a different sample rate is requested.
        if sample_rate is not None and sample_rate != orig_sr:
            resampler = T.Resample(orig_sr, sample_rate)
            waveform = resampler(waveform)
            orig_sr = sample_rate
        return waveform, orig_sr

    def save(
        self,
        path: Union[str, Path],
        data: torch.Tensor,
        sample_rate: int,
    ) -> None:
        if isinstance(path, Path):
            path = str(path)
        # torchaudio.save expects tensor shape (channels, samples)
        torchaudio.save(path, data, sample_rate)

# =============================================================================
# Converter Functions
# =============================================================================

def to_n_channels(waveform: torch.Tensor, n_channels: int) -> torch.Tensor:
    """
    Ensure that the waveform has exactly n_channels.
    If there are fewer channels, repeat them; if more, slice.
    """
    current = waveform.size(0)
    if current >= n_channels:
        return waveform[:n_channels]
    else:
        return waveform.expand(n_channels, -1)[:n_channels]

def to_stereo(waveform: torch.Tensor) -> torch.Tensor:
    """
    Convert a waveform to stereo.
    """
    if waveform.size(0) == 1:
        return waveform.repeat(2, 1)
    elif waveform.size(0) > 2:
        return waveform[:2]
    return waveform

def gain_to_db(tensor: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor:
    """
    Convert linear gain to decibels.
    """
    return 20.0 * torch.log10(torch.clamp(tensor, min=epsilon))

def db_to_gain(tensor: torch.Tensor) -> torch.Tensor:
    """
    Convert decibels to linear gain.
    """
    return torch.pow(10.0, tensor / 20.0)

# =============================================================================
# Spectrogram and Data Augmentation Functions
# =============================================================================

def compute_spectrogram(
    waveform: torch.Tensor,
    n_fft: int = 2048,
    hop_length: int = 512,
    power: float = 1.0,
) -> torch.Tensor:
    """
    Compute the spectrogram (magnitude) of a waveform.
    Returns a tensor of shape (channels, freq_bins, time_frames).
    """
    spec_complex = torch.stft(
        waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        return_complex=True,
        center=True,
    )
    magnitude = spec_complex.abs() ** power
    return magnitude

def time_stretch(spectrogram: torch.Tensor, factor: float = 1.0) -> torch.Tensor:
    """
    Time-stretch a spectrogram by a given factor using interpolation.
    Assumes spectrogram shape is (channels, freq, time).
    """
    channels, freq, time = spectrogram.shape
    new_time = int(time * factor)
    # Interpolate along the time dimension (unsqueeze to add batch dimension).
    stretched = F.interpolate(spectrogram.unsqueeze(0), size=(freq, new_time),
                              mode='bilinear', align_corners=False).squeeze(0)
    # Crop or pad to original time length.
    if stretched.shape[-1] > time:
        stretched = stretched[..., :time]
    elif stretched.shape[-1] < time:
        pad_amount = time - stretched.shape[-1]
        stretched = F.pad(stretched, (0, pad_amount))
    return stretched

def pitch_shift(spectrogram: torch.Tensor, semitone_shift: float = 0.0) -> torch.Tensor:
    """
    Shift the pitch of a spectrogram by a given number of semitones.
    This function rescales the frequency axis.
    """
    factor = 2 ** (semitone_shift / 12.0)
    channels, freq, time = spectrogram.shape
    new_freq = int(freq * factor)
    shifted = F.interpolate(spectrogram.unsqueeze(0), size=(new_freq, time),
                            mode='bilinear', align_corners=False).squeeze(0)
    # Crop or pad frequency dimension back to original.
    if new_freq > freq:
        shifted = shifted[:, :freq, :]
    elif new_freq < freq:
        pad_amount = freq - new_freq
        shifted = F.pad(shifted, (0, 0, 0, pad_amount))
    return shifted

# =============================================================================
# Dataset Module
# =============================================================================

class AudioDataset(Dataset):
    """
    A PyTorch dataset that reads audio file paths and metadata from a CSV file,
    loads the waveform using an AudioAdapter, computes the spectrogram, and applies
    optional augmentation transforms.
    
    CSV file format:
      - Must contain at least a column named 'file' with relative paths to audio files.
    """

    def __init__(
        self,
        csv_file: str,
        audio_path: str,
        audio_adapter: AudioAdapter,
        sample_rate: int = 44100,
        duration: float = 20.0,
        transform: Optional[Any] = None,
    ):
        self.samples: List[Dict[str, str]] = []
        self.audio_path = audio_path
        self.audio_adapter = audio_adapter
        self.sample_rate = sample_rate
        self.duration = duration
        self.transform = transform

        with open(csv_file, 'r', newline='') as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.samples.append(row)

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> torch.Tensor:
        sample = self.samples[idx]
        # Build the full path to the audio file.
        file_path = Path(self.audio_path) / sample['file']
        # Load the waveform using the provided audio adapter.
        waveform, sr = self.audio_adapter.load(file_path, duration=self.duration, sample_rate=self.sample_rate)
        # Optionally, convert to stereo.
        waveform = to_stereo(waveform)
        # Compute the spectrogram.
        spec = compute_spectrogram(waveform, n_fft=2048, hop_length=512, power=1.0)
        # Apply any additional transforms (e.g., time stretch, pitch shift).
        if self.transform:
            spec = self.transform(spec)
        return spec

# =============================================================================
# Define a Lightweight Model (Simple U-Net style architecture)
# =============================================================================

class SimpleSeparationModel(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=16):
        super(SimpleSeparationModel, self).__init__()
        # Encoder: reduce time-frequency resolution.
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features * 2, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
        )
        # Decoder: reconstruct the spectrogram.
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(features * 2, features, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(features, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# =============================================================================
# Training Setup
# =============================================================================

def train(model, dataloader, criterion, optimizer, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in dataloader:
            # Expecting batch shape: (batch_size, channels, freq_bins, time_frames)
            inputs = batch.to(device)
            targets = batch.to(device)  # In practice, replace with true isolated sources.
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader):.6f}")

# =============================================================================
# Main Function: Data Loading and Training Loop
# =============================================================================

def main():
    # Set device (CPU or CUDA)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set up paths and parameters
    csv_file = 'musdb18_train.csv'      # CSV with a column 'file'
    audio_folder = 'musdb18_audio'        # Folder containing the MUSDB18 audio files
    sample_rate = 44100
    duration = 20.0                       # seconds per sample

    # Create the audio adapter.
    adapter = AudioAdapter.default()

    # Optional augmentation transform.
    def augment(spec: torch.Tensor) -> torch.Tensor:
        # For demonstration, use fixed factors (can randomize for real augmentation).
        spec = time_stretch(spec, factor=1.0)
        spec = pitch_shift(spec, semitone_shift=0.0)
        return spec

    # Create the dataset.
    dataset = AudioDataset(
        csv_file=csv_file,
        audio_path=audio_folder,
        audio_adapter=adapter,
        sample_rate=sample_rate,
        duration=duration,
        transform=augment,  # Set to None if no augmentation is desired.
    )

    # Create a DataLoader for batching.
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

    # Instantiate the model (assumes single-channel spectrogram input).
    model = SimpleSeparationModel(in_channels=1, out_channels=1, features=16).to(device)

    # Define loss function and optimizer.
    criterion = nn.L1Loss()  # L1 loss is common for source separation.
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Train the model.
    train(model, dataloader, criterion, optimizer, device, num_epochs=10)

    # Save the model after training.
    torch.save(model.state_dict(), "lightweight_separation_model.pth")
    print("Training complete and model saved.")

if __name__ == '__main__':
    main()
