In [4]:
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.signal as signal
import collections
import os, sys
sys.path.append(os.path.abspath('../'))
import signaltrain as st
import signaltrain.audio as audio
import signaltrain.nn_modules.nn_proc as nn_proc

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    device = torch.device("cpu")
    torch.set_default_tensor_type('torch.FloatTensor')

In [5]:
# Audio signal options and data processing

ra_gen = audio.readaudio_generator(4*44100, norm=True, path="./")
next(ra_gen)
sr = 44100   # sample rate
t = 0
def get_input_sample(chooser):
    global t, sr
    t = np.linspace(0,1,chunk_size)
    if 'sine' == chooser:
        return audio.randsine(t,freq_range=[5,20])
    elif 'box' == chooser:
        return audio.box(t)
    elif 'noisy sine' == chooser:
        return audio.randsine(t,freq_range=[5,20]) + 0.1*(2*np.random.rand(t.shape[0])-1)
    elif 'noisybox' == chooser:
        return audio.box(t) * (2*np.random.rand(t.shape[0])-1)
    elif 'pluck' == chooser:
        return audio.pluck(t)
    elif 'real audio' == chooser:
        x =  next(ra_gen)
        t = np.linspace(0,x.shape[0]/sr,x.shape[0])
        return x
  
chunk_size = 4096
def torch_chunkify(x, chunk_size=chunk_size):
    # pads x with zeros and returns a 2D array 
    rows = int(np.ceil(x.shape[0]/chunk_size))  # this will be the batch size
    nearest_mult = rows*(chunk_size)
    xnew = np.zeros(nearest_mult)
    xnew[0:x.shape[0]] = x[0:x.shape[0]]
    xnew  = xnew.reshape(rows, chunk_size)
    x_torch = torch.autograd.Variable(torch.from_numpy(xnew).to(device), requires_grad=False).float()
    return x_torch 

old_signal_type = 'sine'
x = get_input_sample(old_signal_type)

torch.set_default_tensor_type('torch.FloatTensor')
x_torch = torch_chunkify(x)
y_true, y_pred = 0, 0   # saving these for later globals



Set up the model and audio effect...


In [6]:
# Data settings
shrink_factor = 2  # reduce dimensionality of run by this factor
time_series_length = 8192 // shrink_factor
sampling_freq = 44100. // shrink_factor

# Analysis parameters
ft_size = 1024 // shrink_factor
hop_size = 384 // shrink_factor
expected_time_frames = int(np.ceil(time_series_length/float(hop_size)) + np.ceil(ft_size/float(hop_size)))


# Define effect and point to appropriate model weights file
#effect, checkpoint_file = audio.Compressor(), 'modelcheckpoint.tar'
effect, checkpoint_file = audio.Compressor_4c(), 'modelcheckpoint_4c.tar'
#effect, checkpoint_file = audio.Denoise(), 'modelcheckpoint_denoise.tar'
knobranges = effect.knob_ranges

def load_model(checkpoint_file, silent=False):
    if silent:
        save_stdout = sys.stdout
        sys.stdout = open('trash', 'w')
    state_dict, rv = st.misc.load_checkpoint(checkpoint_file, device=device)
    scale_factor, shrink_factor = rv['scale_factor'], rv['shrink_factor']
    knob_names, knob_ranges = rv['knob_names'], rv['knob_ranges']
    num_knobs = len(knob_names)
    sr = rv['sr']
    chunk_size, out_chunk_size = rv['in_chunk_size'], rv['out_chunk_size']

    model = nn_proc.st_model(scale_factor=scale_factor, shrink_factor=shrink_factor, num_knobs=num_knobs, sr=sr)
    model.load_state_dict(state_dict, strict=False)   # overwrite the weights using the checkpoint
    if silent:
        sys.stdout = save_stdout
    return model 

model = load_model(checkpoint_file)



***** Checkpoint file found. Loading weights.
Input chunk size = 4096
Intended Output chunk size = 4096
Sample rate = 44100
    Setting out_chunk_size = y_size = 3968
AsymMPAEC: expected_time_frames, ft_size, hop_size, decomposition_rank, n_knobs, output_tf =  14 1024 384 64 4 14
AsymAutoEncoder __init__: T, R, K, OT =  14 64 4 14
AsymAutoEncoder __init__: T, R, K, OT =  14 64 4 14


In [7]:
# optional: grab the latest checkpoint file (in case github isn't up to date)
!curl -OL http://hedges.belmont.edu/~shawley/modelcheckpoint_4c.tar

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 56.2M  100 56.2M    0     0  7894k      0  0:00:07  0:00:07 --:--:-- 9173k


## Main Demo
Define interactive widgets and their handler routine. Get an interactive graph...

In [8]:
def forward_and_plot(knobs_wc, refresh=True):
    # Input: x should be a continuous 1D array of mono audio
    #        x_cuda should be "chunkified" into a 2D array of windows to send to the network 
    #        attack and release are in miliseconds only due to IPython display limitations
    
    global x, x_torch, y_true, y_pred   # for playing audio later 
    # update the model (useful for checking progress during model training!)
    
    if refresh:
        #checkpoint = torch.load(checkpoint_file, map_location=device)
        #model.load_state_dict(checkpoint['state_dict'])
        model = load_model(checkpoint_file, silent=True)

    
    # convert from 'world coordinates' of knobs to normalized [-0.5,0.5] values for network
    knobs_nn = (knobs_wc - knobranges[:,0])/(knobranges[:,1]-knobranges[:,0]) - 0.5
    
    # run the effect
    if not isinstance(effect,audio.Denoise): #...actually we'll do something special elsewhere for Denoiser
        y_true, x = effect.go(x, knobs_nn)

    # use the same knob settings for all chunks
    knobs = knobs_nn# np.array([thresh_nn, ratio_nn, attack_nn])
    rows = x_torch.size()[0]
    knobs = np.tile(knobs,(rows,1))
    knobs_torch = torch.autograd.Variable(torch.from_numpy(knobs).to(device), requires_grad=False).float()

    x_torch = torch_chunkify(x)   # break up input audio into chunks
    
    # call the network in inference
    y_pred_torch, mag, mag_hat = model.forward(x_torch, knobs_torch)
    
    # Plot
    y_pred = y_pred_torch.data.cpu().numpy().flatten()[0:t.shape[0]]  #flattened numpy version
    
    plt.figure(figsize=(8,5))
    plt.plot(t,x,c=(0,0,0.6,0.75),lw=1.5, label='Input')
    plt.plot(t[-len(y_true):],y_true,c='r',lw=1.5, label='Target')
    plt.plot(t[-len(y_pred):],y_pred,c=(0,0.5,0,0.75),lw=1.5, label='Predicted')
    
    plt.legend(loc='lower right')
    plt.title(effect.name)
    plt.ylim(-1,1)
    plt.show()
    return 


knobsdict = collections.OrderedDict({})
knobslist = []
input_dict = {'signal_type': ['box','sine','pluck','noisybox','noisy sine','real audio'] }
for i in range(len(effect.knob_names)):
    knobslist.append( FloatSlider(min=knobranges[i,0], \
        max=knobranges[i,1],step=(knobranges[i,1]-knobranges[i,0])/25,continuous_update=False) )
def ThePlotterWidget(**kwargs):
    global old_signal_type, x, x_torch, y_true, called
    values = list(kwargs.values())
    signal_type, knobs_wc = values[0], np.array(values[1:])
    if (signal_type != old_signal_type): # don't regen x unless input changed
        x = get_input_sample(signal_type)
        y_true = x.copy()
        if isinstance(effect, audio.Denoise):
            x += (0.2+0.3*np.random.rand())*(2*np.random.rand(x.shape[0])-1)
    old_signal_type = signal_type
    x_torch = torch_chunkify(x)
    forward_and_plot(knobs_wc)

slider_dict = {f'{effect.knob_names[i]}':slider for i, slider in enumerate(knobslist)}
kwargs = {**input_dict, **slider_dict}
interact(ThePlotterWidget, **kwargs)

interactive(children=(Dropdown(description='signal_type', options=('box', 'sine', 'pluck', 'noisybox', 'noisy …

<function __main__.ThePlotterWidget(**kwargs)>

## So what does it sound like?

Note that there's "bug/feature" in Juptyer Notebook's audio "display" whereby it rescales the audio...which makes it almost useless for checking how a compressor performs.  
So first we're going to define our own Audio display. 

In [6]:
# Redefine IPython audio display widget to disable normalization
from IPython.core.display import DisplayObject
class Audio(DisplayObject):
    def __init__(self, data=None, filename=None, url=None, embed=None, rate=None, autoplay=False, norm=True):
        if filename is None and url is None and data is None:
            raise ValueError("No image data found. Expecting filename, url, or data.")
        if embed is False and url is None:
            raise ValueError("No url found. Expecting url when embed=False")
            
        if url is not None and embed is not True:
            self.embed = False
        else:
            self.embed = True
        self.autoplay = autoplay
        super(Audio, self).__init__(data=data, url=url, filename=filename)
            
        if self.data is not None and not isinstance(self.data, bytes):
            self.data = self._make_wav(data,rate,norm)
            
    def reload(self):
        """Reload the raw data from file or URL."""
        import mimetypes
        if self.embed:
            super(Audio, self).reload()

        if self.filename is not None:
            self.mimetype = mimetypes.guess_type(self.filename)[0]
        elif self.url is not None:
            self.mimetype = mimetypes.guess_type(self.url)[0]
        else:
            self.mimetype = "audio/wav"
                  
    def _make_wav(self,data,rate,norm):
        """ Transform a numpy array to a PCM bytestring """
        import struct
        from io import BytesIO
        import wave
        if norm:
            maxabsvalue = max(map(abs,data))
            scaled = map(lambda x: int(x/maxabsvalue*32767), data)  
        else:
            scaled = map(lambda x: int(np.clip(x,-1,1)*32767), data)

        fp = BytesIO()
        waveobj = wave.open(fp,mode='wb')
        waveobj.setnchannels(1)
        waveobj.setframerate(rate)
        waveobj.setsampwidth(2)
        waveobj.setcomptype('NONE','NONE')
        waveobj.writeframes(b''.join([struct.pack('<h',x) for x in scaled]))
        val = fp.getvalue()
        waveobj.close()
        return val
    
    def _data_and_metadata(self):
        """shortcut for returning metadata with url information, if defined"""
        md = {}
        if self.url:
            md['url'] = self.url
        if md:
            return self.data, md
        else:
            return self.data
        
    def _repr_html_(self):
        src = """
                <audio controls="controls" {autoplay}>
                    <source src="{src}" type="{type}" />
                    Your browser does not support the audio element.
                </audio>
              """
        return src.format(src=self.src_attr(),type=self.mimetype, autoplay=self.autoplay_attr())

    def src_attr(self):
        import base64
        if self.embed and (self.data is not None):
                return """data:{type};base64,{base64}""".format(type=self.mimetype, 
                                                                base64=base64.b64encode(self.data).decode('ascii'))
        elif self.url is not None:
            return self.url
        else:
            return ""

    def autoplay_attr(self):
        if(self.autoplay):
            return 'autoplay="autoplay"'
        else:
            return ''

Using the audio from the above graph(s)...

Input audio:

In [7]:
Audio(x, rate=44100, norm=False)

'Real' effect output, i.e. Target audio:

In [8]:
Audio(y_true, rate=44100,norm=False)

Network output, i.e Predicted audio 

In [9]:
Audio(y_pred, rate=44100, norm=False)