# Imports

In [7]:
# Torch / TorchAudio
import torch
import torchaudio
from torchaudio.transforms import Fade

# MatPlotLib (Graphing)
import matplotlib.pyplot as plt

# Mir_Eval for SDR (Signal-to-Distortion Ratio) calculations
from IPython.display import Audio
from mir_eval import separation
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.utils import download_asset

# PyTorch + Model Configuration

In [9]:

bundle = HDEMUCS_HIGH_MUSDB_PLUS # Pre-trained model

model = bundle.get_model()

# Configure device
device = torch.device("cpu") # ("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

# Configure audio sample rate
sample_rate = bundle.sample_rate

print(f"Sample rate: {sample_rate}")

Sample rate: 44100


# Split-Sources Function
Note: This model is very memory-intensive, so a large portion of the program below is recommend by PyTorch to split the song and reduce memory load.

In [2]:
def split_sources(
        model,
        mix,
        segment=10.0,
        overlap=1,
        device=None
):
    # Check that device is configured
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)
    
    # Get audio parameters
    batch, channels, length = mix.shape

    # Get chunk information
    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate

    # Configure fade (helps combine audio after splitting into chunks)
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")

    # Create a tensor filled with zeros which will be used for output
    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        # Get current chunk
        chunk = mix[:, :, start:end]

        # Apply model
        with torch.no_grad():
            out = model.forward(chunk)
        
        # Apply fade
        out = fade(out)

        # Append current chunk to final output
        final[:, :, :, start:end] += out

        # Configure fade in / out for next frame
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
            end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
        
        return final

# Plot spectogram function
This will help us visualize the audio

In [3]:
def plot_spectrogram(stft, title="Spectrogram"):
    # Calculate magnitude to find spectogram scale
    magnitude = stft.abs()
    # Update values using magnitutde
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    # Plot data
    _, axis = plt.subplots(1, 1)
    axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
    axis.set_title(title)
    plt.tight_layout()

# Run Model

In [15]:
# Load song
waveform, sample_rate = torchaudio.load("./song.wav")
waveform = waveform.to(device)
mixture = waveform

# Configure parameters
segment: int = 1000
overlap = 0.1

# Normalize Audio
ref = waveform.mean(0)
waveform = (waveform - ref.mean()) / ref.std()

# Split sources
sources = split_sources(
    model,
    waveform[None],
    device=device,
    segment=segment,
    overlap=overlap,
)[0]

# Undo normalization
sources = sources * ref.std() + ref.mean()

# LOOK INTO BELOW (I don't know what it does)
sources_list = model.sources
sources = list(sources)

audios = dict(zip(sources_list, sources))

N_FFT = 4096
N_HOP = 4
stft = torchaudio.transforms.Spectrogram(
    n_fft=N_FFT,
    hop_length=N_HOP,
    power=None,
)

# Full Audio
full = Audio(mixture, rate=sample_rate)
with open("./full.wav", "wb") as file:
    file.write(full.data)

# Drums Audio
drums = Audio(audios["drums"], rate=sample_rate)
with open("./drums.wav", "wb") as file:
    file.write(drums.data)

# Bass Audio
bass = Audio(audios["bass"], rate=sample_rate)
with open("./bass.wav", "wb") as file:
    file.write(bass.data)

# Vocals Audio
vocals = Audio(audios["vocals"], rate=sample_rate)
with open("./vocals.wav", "wb") as file:
    file.write(vocals.data)

# Other Audio
other = Audio(audios["other"], rate=sample_rate)
with open("./other.wav", "wb") as file:
    file.write(other.data)
