In [24]:
import sounddevice as sd
import numpy as np
import queue
import torch
import torchaudio
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
from torchaudio.transforms import Fade
from collections import deque

# Input settings (CABLE Output)
INPUT_SAMPLE_RATE = 44100
INPUT_DEVICE = 34
INPUT_CHANNELS = 2

# Output settings (Headphones/Speakers)
OUTPUT_SAMPLE_RATE = 48000
OUTPUT_DEVICE = 31
OUTPUT_CHANNELS = 2

# Get BLOCK_SIZE samples out of 44100, and run inference on the last BLOCK_SIZE * WINDOW_BLOCKS
DTYPE = np.float32
BLOCK_SIZE = 4096
WINDOW_BLOCKS = 4
OVERLAP = 0.45

class AudioProcessor:
    def __init__(self):
        self.input_queue = queue.Queue()
        self.output_queue = queue.Queue()
        self.block_buffer = deque(maxlen=WINDOW_BLOCKS)
        
        print("Loading Demucs model...")
        self.model = HDEMUCS_HIGH_MUSDB.get_model()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        print(f"Using device: {self.device}")
        
        self.resampler = torchaudio.transforms.Resample(
            orig_freq=INPUT_SAMPLE_RATE,
            new_freq=OUTPUT_SAMPLE_RATE,
            dtype=torch.float32
        ).to(self.device)
        
        overlap_frames = int(BLOCK_SIZE * WINDOW_BLOCKS * OVERLAP)
        self.fade = Fade(
            fade_in_len=overlap_frames,
            fade_out_len=overlap_frames,
            fade_shape="linear"
        ).to(self.device)
        
        self.prev_chunk = None
        
        # Buffer should be silence before it's written
        silence = np.zeros((BLOCK_SIZE, INPUT_CHANNELS), dtype=DTYPE)
        for _ in range(WINDOW_BLOCKS - 1):
            self.block_buffer.append(silence.copy())
    
    def input_callback(self, indata, frames, time, status):
        if status:
            print(f"Input status: {status}")
        self.input_queue.put(indata.copy())
    
    def output_callback(self, outdata, frames, time, status):
        if status:
            print(f"Output status: {status}")
        try:
            data = self.output_queue.get_nowait()
            outdata[:] = data
        except queue.Empty:
            outdata.fill(0)
    
    def process_audio(self, audio_chunk):
        try:
            # Add new chunk to buffer and concatenate into a bigger chunk
            self.block_buffer.append(audio_chunk)
            combined_chunk = np.concatenate(list(self.block_buffer), axis=0)

            # Convert to torch tensor, normalize and move to device
            waveform = torch.from_numpy(combined_chunk).to(self.device).T.unsqueeze(0)
            ref = waveform.mean(0)
            waveform = (waveform - ref.mean()) / (ref.std() + 1e-7)

            with torch.inference_mode():
                # GPU consumption can probably be approximated by:
                # %time below and divide wall time by BLOCK_SIZE/INPUT_SAMPLE_RATE
                sources = self.model.forward(waveform)
                # GPU consumption can be limited by adjusting block_size by consumption

            # Get strike sounds, denormalize, and fade
            drums = self.fade(sources[:, 0] * ref.std() + ref.mean())

            if self.prev_chunk is not None:
                overlap_size = int(BLOCK_SIZE * WINDOW_BLOCKS * OVERLAP)
                drums[:, :, :overlap_size] += self.prev_chunk[:, :, -overlap_size:]
            self.prev_chunk = drums.clone()

            # Different frequencies so gotta resample
            drums = self.resampler(drums)

            # This is the key change to fix the choppy output
            drums_numpy = drums.cpu().squeeze(0).permute(1, 0).numpy()

            # Crossfade a bit of the last block, and leave some for the next block
            if hasattr(self, 'previous_output_remainder'):
                crossfade_size = int(BLOCK_SIZE * 0.5)  # 50% overlap for crossfade
                fade_in = np.linspace(0, 1, crossfade_size)
                fade_out = np.linspace(1, 0, crossfade_size)

                # Apply crossfade
                current_block = np.zeros((BLOCK_SIZE, OUTPUT_CHANNELS))
                current_block[:crossfade_size] = self.previous_output_remainder[-crossfade_size:] * fade_out[:, np.newaxis]
                current_block[:crossfade_size] += drums_numpy[:crossfade_size] * fade_in[:, np.newaxis]
                current_block[crossfade_size:] = drums_numpy[crossfade_size:BLOCK_SIZE]

                # Save the remainder for next time
                self.previous_output_remainder = drums_numpy[:BLOCK_SIZE*2]
            else:
                # First run, no previous output to crossfade with
                current_block = drums_numpy[:BLOCK_SIZE]
                self.previous_output_remainder = drums_numpy[:BLOCK_SIZE*2]

            return current_block
        
        except Exception as e:
            print(f"Error in process_audio: {e}")
            return np.zeros((BLOCK_SIZE, OUTPUT_CHANNELS))
    
    def run(self):
        print("\nIt's TIME!!!")
        print("Press Ctrl + C or I + I to stop")
        
        input_stream = None
        output_stream = None
        
        try:
            input_stream = sd.InputStream(
                device=INPUT_DEVICE,
                channels=INPUT_CHANNELS,
                samplerate=INPUT_SAMPLE_RATE,
                dtype=DTYPE,
                blocksize=BLOCK_SIZE,
                callback=self.input_callback
            )
            
            output_stream = sd.OutputStream(
                device=OUTPUT_DEVICE,
                channels=OUTPUT_CHANNELS,
                samplerate=OUTPUT_SAMPLE_RATE,
                dtype=DTYPE,
                blocksize=BLOCK_SIZE,
                callback=self.output_callback
            )
            
            with input_stream, output_stream:
                input_stream.start()
                output_stream.start()
                
                while True:
                    # Input -> Process -> Validate -> Play
                    audio_chunk = self.input_queue.get()
                    drums = self.process_audio(audio_chunk)
                    # Output needs to be a specific shape or u get error
                    if drums.shape != (BLOCK_SIZE, OUTPUT_CHANNELS):
                        drums = np.resize(drums, (BLOCK_SIZE, OUTPUT_CHANNELS))
                    self.output_queue.put(drums)
                    
        except KeyboardInterrupt:
            print("\nStopping audio processing...")
        except Exception as e:
            print(f"\nError: {e}")
        finally:
            if input_stream is not None and input_stream.active:
                input_stream.stop()
            if output_stream is not None and output_stream.active:
                output_stream.stop()

processor = AudioProcessor()
processor.run()

Loading Demucs model...
Using device: cuda:0

It's TIME!!!
Press Ctrl + C or I + I to stop

Stopping audio processing...
