In [2]:
%pip install demucs
%pip install soundfile

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [9]:
import torch
import torchaudio
import matplotlib.pyplot as plt
from IPython.display import Audio
from mir_eval import separation
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.utils import download_asset

In [10]:
bundle = HDEMUCS_HIGH_MUSDB_PLUS

model = bundle.get_model()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

sample_rate = bundle.sample_rate

print(f"Sample rate: {sample_rate}")

Sample rate: 44100


  state_dict = torch.load(path)


In [11]:
from torchaudio.transforms import Fade


def separate_sources(
    model,
    mix,
    segment=10.0,
    overlap=0.1,
    device=None,
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final


def plot_spectrogram(stft, title="Spectrogram"):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    _, axis = plt.subplots(1, 1)
    axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
    axis.set_title(title)
    plt.tight_layout()

In [None]:
# We download the audio file from our storage. Feel free to download another file and use audio from a specific path
SAMPLE_SONG = download_asset("tutorial-assets/hdemucs_mix.wav")
waveform, sample_rate = torchaudio.load(SAMPLE_SONG)  # replace SAMPLE_SONG with desired path for different song
waveform = waveform.to(device)
mixture = waveform

# parameters
segment: int = 10
overlap = 0.1

print("Separating track")

ref = waveform.mean(0)
waveform = (waveform - ref.mean()) / ref.std()  # normalization

sources = separate_sources(
    model,
    waveform[None],
    device=device,
    segment=segment,
    overlap=overlap,
)[0]
sources = sources * ref.std() + ref.mean()

sources_list = model.sources
sources = list(sources)

audios = dict(zip(sources_list, sources))

In [None]:
import os
import numpy as np
import speech_recognition as sr
from datetime import datetime, timedelta
from queue import Queue
from time import sleep
import torch
from torchaudio.transforms import Fade, Resample
import torchaudio
from IPython.display import Audio, display

# Audio recording setup
energy_threshold = 1000
record_timeout = 2.0
phrase_timeout = 3.0
phrase_time = None
data_queue = Queue()
recorder = sr.Recognizer()
recorder.energy_threshold = energy_threshold
recorder.dynamic_energy_threshold = False

source = sr.Microphone(sample_rate=16000)

def record_callback(_, audio: sr.AudioData) -> None:
    data = audio.get_raw_data()
    data_queue.put(data)

# Initialize Demucs model
bundle = HDEMUCS_HIGH_MUSDB_PLUS
model = bundle.get_model()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model_sample_rate = bundle.sample_rate

def separate_sources(model, mix, segment=10.0, overlap=0.1, device=None):
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape
    chunk_len = int(model_sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * model_sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final

# Main processing loop
with source:
    recorder.adjust_for_ambient_noise(source)

stop_call = recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout)
print("Model loaded and microphone initialized.\n")

try:
    while True:
        now = datetime.utcnow()
        if not data_queue.empty():
            if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
                phrase_complete = True
            phrase_time = now

            # Get audio data
            audio_data = b''.join(data_queue.queue)
            data_queue.queue.clear()

            # Convert to waveform using torchaudio
            waveform = torch.frombuffer(audio_data, dtype=torch.int16).float() / 32768.0
            waveform = waveform.unsqueeze(0).to(device)

            # Resample to model's sample rate if necessary
            if source.SAMPLE_RATE != model_sample_rate:
                resampler = Resample(orig_freq=source.SAMPLE_RATE, new_freq=model_sample_rate)
                waveform = resampler(waveform)

            # Ensure correct shape for Demucs model
            if waveform.shape[0] != model.audio_channels:
                waveform = waveform.expand(model.audio_channels, -1)

            # Normalize the waveform
            ref = waveform.mean(0)
            waveform = (waveform - ref.mean()) / ref.std()

            # Separate sources using Demucs
            separated_sources = separate_sources(
                model,
                waveform.unsqueeze(0),
                device=device,
                segment=10,
                overlap=0.1
            )[0]

            # Denormalize the sources
            separated_sources = separated_sources * ref.std() + ref.mean()

            # Extract vocals
            sources_list = model.sources
            sources = list(separated_sources)
            audios = dict(zip(sources_list, sources))
            vocals = audios["vocals"]

            # Play back the separated vocals
            vocals_np = vocals.squeeze().cpu().numpy()
            display(Audio(vocals_np, rate=model_sample_rate))

        else:
            sleep(0.25)

except KeyboardInterrupt:
    print("\nSeparation stopped by user.")
    stop_call()