In [4]:
import sounddevice as sd
import numpy as np
import queue
import torch
import torchaudio
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
from torchaudio.transforms import Fade

In [5]:
# Input settings (CABLE Output)
INPUT_SAMPLE_RATE = 44100
INPUT_DEVICE = 34
INPUT_CHANNELS = 2

# Output settings (HyperX Cloud II)
OUTPUT_SAMPLE_RATE = 48000
OUTPUT_DEVICE = 31
OUTPUT_CHANNELS = 2

# Common settings
DTYPE = np.float32
BLOCK_SIZE = 2048
OVERLAP = 0.2

class AudioProcessor:
    def __init__(self):
        self.input_queue = queue.Queue()
        self.output_queue = queue.Queue()
        
        print("Loading Demucs model...")
        self.model = HDEMUCS_HIGH_MUSDB.get_model()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        print(f"Using device: {self.device}")
        
        self.resampler = torchaudio.transforms.Resample(
            orig_freq=INPUT_SAMPLE_RATE,
            new_freq=OUTPUT_SAMPLE_RATE,
            dtype=torch.float32
        ).to(self.device)
        
        overlap_frames = int(BLOCK_SIZE * OVERLAP)
        self.fade = Fade(
            fade_in_len=overlap_frames,
            fade_out_len=overlap_frames,
            fade_shape="linear"
        ).to(self.device)
        
        self.prev_chunk = None
    
    # Callbacks for the stream
    def input_callback(self, indata, frames, time, status):
        if status:
            print(f"Input status: {status}")
        self.input_queue.put(indata.copy())
    
    def output_callback(self, outdata, frames, time, status):
        if status:
            print(f"Output status: {status}")
        try:
            data = self.output_queue.get_nowait()
            outdata[:] = data
        except queue.Empty:
            outdata.fill(0)
    
    def process_audio(self, audio_chunk):
        try:
            # Convert to torch tensor, normalize and move to device
            waveform = torch.from_numpy(audio_chunk).to(self.device).T.unsqueeze(0)
            ref = waveform.mean(0)
            waveform = (waveform - ref.mean()) / (ref.std() + 1e-7)
            
            with torch.inference_mode():
                sources = self.model.forward(waveform)
            
           # Get strike sounds, denormalize, and fade
            drums = self.fade(sources[:, 0] * ref.std() + ref.mean())
            if self.prev_chunk is not None:
                overlap_size = int(BLOCK_SIZE * OVERLAP)
                drums[:, :, :overlap_size] += self.prev_chunk[:, :, -overlap_size:]
            self.prev_chunk = drums.clone()
            
            # Different frequencies so gotta resample
            drums = self.resampler(drums)
            
            # Return to CPU for output
            return drums.cpu().squeeze(0).permute(1, 0).numpy()
            
        except Exception as e:
            print(f"Error in process_audio: {e}")
            return np.zeros((BLOCK_SIZE, OUTPUT_CHANNELS))
    
    def run(self):
        print("\nIt's TIME!!!")
        print("Press Ctrl + C or I + I to stop")
        
        input_stream = None
        output_stream = None
        
        try:
            input_stream = sd.InputStream(
                device=INPUT_DEVICE,
                channels=INPUT_CHANNELS,
                samplerate=INPUT_SAMPLE_RATE,
                dtype=DTYPE,
                blocksize=BLOCK_SIZE,
                callback=self.input_callback
            )
            
            output_stream = sd.OutputStream(
                device=OUTPUT_DEVICE,
                channels=OUTPUT_CHANNELS,
                samplerate=OUTPUT_SAMPLE_RATE,
                dtype=DTYPE,
                blocksize=BLOCK_SIZE,
                callback=self.output_callback
            )
            
            with input_stream, output_stream:
                input_stream.start()
                output_stream.start()
                
                while True:
                    # Input -> Process -> Play
                    audio_chunk = self.input_queue.get()
                    drums = %time self.process_audio(audio_chunk)
                    # Usually there's a size mismatch, so use this
                    if drums.shape != (BLOCK_SIZE, OUTPUT_CHANNELS):
                        drums = np.resize(drums, (BLOCK_SIZE, OUTPUT_CHANNELS))
                    self.output_queue.put(drums)
                    
        except KeyboardInterrupt:
            print("\nStopping audio processing...")
        except Exception as e:
            print(f"\nError: {e}")
        finally:
            if input_stream is not None and input_stream.active:
                input_stream.stop()
            if output_stream is not None and output_stream.active:
                output_stream.stop()

In [6]:
processor = AudioProcessor()
processor.run()

Loading Demucs model...
Using device: cuda:0

It's TIME!!!
Press Ctrl + C or I + I to stop
CPU times: total: 0 ns
Wall time: 81 ms
CPU times: total: 0 ns
Wall time: 90.3 ms
CPU times: total: 0 ns
Wall time: 30.9 ms
CPU times: total: 0 ns
Wall time: 32.2 ms
CPU times: total: 0 ns
Wall time: 34 ms
CPU times: total: 0 ns
Wall time: 23 ms
CPU times: total: 0 ns
Wall time: 20.4 ms
CPU times: total: 15.6 ms
Wall time: 27.7 ms
CPU times: total: 0 ns
Wall time: 24.7 ms
CPU times: total: 15.6 ms
Wall time: 33.7 ms
CPU times: total: 0 ns
Wall time: 23.9 ms
CPU times: total: 0 ns
Wall time: 24 ms
CPU times: total: 0 ns
Wall time: 27.8 ms
CPU times: total: 0 ns
Wall time: 22.6 ms
CPU times: total: 0 ns
Wall time: 20.3 ms
CPU times: total: 0 ns
Wall time: 26.1 ms
CPU times: total: 0 ns
Wall time: 23.8 ms
CPU times: total: 0 ns
Wall time: 29 ms
CPU times: total: 0 ns
Wall time: 22.9 ms
CPU times: total: 0 ns
Wall time: 25.4 ms
CPU times: total: 0 ns
Wall time: 27 ms
CPU times: total: 0 ns
Wall time:

CPU times: total: 0 ns
Wall time: 20.9 ms
CPU times: total: 0 ns
Wall time: 27.2 ms
CPU times: total: 0 ns
Wall time: 23.6 ms
CPU times: total: 0 ns
Wall time: 25 ms
CPU times: total: 0 ns
Wall time: 23.4 ms
CPU times: total: 0 ns
Wall time: 20.5 ms
CPU times: total: 0 ns
Wall time: 20.8 ms
CPU times: total: 0 ns
Wall time: 24.5 ms
CPU times: total: 0 ns
Wall time: 27.9 ms
CPU times: total: 0 ns
Wall time: 21.4 ms
CPU times: total: 0 ns
Wall time: 22.3 ms
CPU times: total: 0 ns
Wall time: 22.1 ms
CPU times: total: 0 ns
Wall time: 23.1 ms
CPU times: total: 0 ns
Wall time: 28.4 ms
CPU times: total: 15.6 ms
Wall time: 21.9 ms
CPU times: total: 0 ns
Wall time: 20 ms
CPU times: total: 0 ns
Wall time: 22.9 ms


KeyboardInterrupt: 


Error: 'NoneType' object has no attribute 'shape'
