In [1]:
from pynq import Overlay
ol = Overlay("base.bit")

In [23]:
import ipywidgets as widgets
import numpy as np
import traitlets
import plotly.express as px
import plotly.graph_objs as go
import cv2
import math
from colour import Color
from matplotlib import pyplot as plt
import asyncio

from hearlight.audio import audio_led_driver_setup

from pynq import allocate

"""Test class for software audio processing with CNN

"""
class AudioProcessSoftwareCNN(traitlets.HasTraits):
    weights_trait = traitlets.Dict()
    cnn_version_trait = traitlets.Int()
    
    # traits which will be linked to main control panel if dashboard is used (after initialisation of the class)
    led_max_current_trait = traitlets.Float()
    channel_max_current_trait = traitlets.Float()
    switch_max_current_trait = traitlets.Float()
    device_max_current_trait = traitlets.Float()
    dac_ref_trait = traitlets.Float()
    irr_to_current_trait = traitlets.Dict()
    current_to_irr_trait = traitlets.Dict()
    
    def __init__(self, panel):
        self.panel = panel
        
        self.panel.start_button.on_click(self._start_button_clicked)
        
        # file upload for weights selection
        self.weights_file_upload = widgets.FileUpload(accept='.dat', multiple=True)
        traitlets.link((self.weights_file_upload, 'value'), (self, 'weights_trait'))
        
        # dropdown for selecting nn version
        self.cnn_version_dropdown = widgets.Dropdown(options=[('v1', 1), ('v2', 2), ('v3', 3)], layout={'width' :  '150px'})
        traitlets.link((self.cnn_version_dropdown, 'value'), (self, 'cnn_version_trait'))

        self.panel.processor_settings = {'Select weights file: ' : self.weights_file_upload,
                                         'Select CNN version: ' : self.cnn_version_dropdown}
        
        self.panel.fig_ft = widgets.HTML(value='<i style="color:white;">Put a spectrogram here...</i>')
        
    def _start_button_clicked(self, button):
        self.run_task = asyncio.ensure_future(self.run())
        
    async def run(self):
        """Do the CNN demo...
        
        """
        pass
        
"""Class for software audio processing with FFT

"""
class AudioProcessSoftwareFFT(traitlets.HasTraits):
    fs_trait = traitlets.Int()
    time_to_run_trait = traitlets.Int()
    n_leds_trait = traitlets.Int()
    n_fft_bins_trait = traitlets.Int()
    min_frequency_trait = traitlets.Float()
    max_frequency_trait = traitlets.Float()
    n_samples_window_trait = traitlets.Int()
    
    # traits which will be linked to main control panel if dashboard is used (after initialisation of the class)
    led_max_current_trait = traitlets.Float()
    channel_max_current_trait = traitlets.Float()
    switch_max_current_trait = traitlets.Float()
    device_max_current_trait = traitlets.Float()
    dac_ref_trait = traitlets.Float()
    irr_to_current_trait = traitlets.Dict()
    current_to_irr_trait = traitlets.Dict()
    
    def __init__(self, panel):
        self.panel = panel
        
        # download the base overlay
        self.ol = ol
        
        self.panel.observe(self._setup_fft_plot, ['fs_trait'])
        self.panel.observe(self.setup_demo, ['fs_trait'])
        
        self.panel.start_button.on_click(self._start_button_clicked)
        
        # toggle button for stop button override
        self.panel.stop_button = widgets.ToggleButton(description='STOP', icon='stop')
        self.panel.stop_button.add_class('stop_button')

        # text entry for time to run
        self.time_to_run_entry = widgets.IntText(value=10, disabled=False, layout = {'width' : '100%'})
        traitlets.link((self.time_to_run_entry, 'value'), (self, 'time_to_run_trait'))
        
        # text entry for number of LEDs
        self.n_leds_entry = widgets.IntText(value=1, disabled=False, layout = {'width' : '100%'})
        traitlets.link((self.n_leds_entry, 'value'), (self, 'n_leds_trait'))
        
        # text entry for number of FFT bins
        self.n_fft_bins_entry = widgets.IntText(value=200, disabled=False, layout = {'width' : '100%'})
        traitlets.link((self.n_fft_bins_entry, 'value'), (self, 'n_fft_bins_trait'))
        self.observe(self._setup_fft_plot, ['n_fft_bins_trait'])
        self.observe(self.setup_demo, ['n_fft_bins_trait'])
                
        # text entry for min frequency
        self.min_frequency_entry = widgets.FloatText(value=4000, disabled=False, layout = {'width' : '100%'})
        traitlets.link((self.min_frequency_entry, 'value'), (self, 'min_frequency_trait'))
        self.observe(self.tonotopic_map_to_frequencies, ['min_frequency_trait'])
        self.observe(self.setup_demo, ['min_frequency_trait'])

        # text entry for max frequency
        self.max_frequency_entry = widgets.FloatText(value=32000, disabled=False, layout = {'width' : '100%'})
        traitlets.link((self.max_frequency_entry, 'value'), (self, 'max_frequency_trait'))
        self.observe(self.tonotopic_map_to_frequencies, ['max_frequency_trait'])
        
        # number of time samples to perform FFT
        self.n_samples_window_entry = widgets.IntText(value=2048, disabled=False, layout = {'width' : '100%'})
        traitlets.link((self.n_samples_window_entry, 'value'), (self, 'n_samples_window_trait'))
        self.observe(self.setup_demo, ['n_samples_window_trait'])
        
        self.panel.processor_settings = {'Time to run (s): ' : self.time_to_run_entry,
                                'Number of LEDs: ' : self.n_leds_entry,
                                'Number of FFT bins: ' : self.n_fft_bins_entry,
                                'Minimum frequency (Hz): ' : self.min_frequency_entry,
                                'Maximum frequency (Hz): ' : self.max_frequency_entry,
                                'Number of samples in window: ' : self.n_samples_window_entry}
        
        # PLOTS SETUP
        # frequency domain plot
        # divide by 2 as we will only look at positive frequencies
        buffer_ft = np.zeros(int(self.n_fft_bins_trait/2))
        ft_mag_line = px.line(x = np.linspace(0, int(self.panel.fs_trait/2), int(self.n_fft_bins_trait/2)),
                              y = buffer_ft)
        self.panel.fig_ft = go.FigureWidget(ft_mag_line)
        
        self._setup_fft_plot(0)
        
        # generate frequency map from image of tonotopic map
        self.tonotopic_map_to_frequencies(0)
        
        # setup audio processing with FFT demo
        self.setup_demo(0)
        
        # function to update main control panel log - linked automatically if dashboard is used
        self.update_main_control_panel_log = None
        
        # load LED driver and set up linked array of currents
        self.dac_ref_trait = 6.25
        self.switch_max_current_trait = 130
        self.device_max_current_trait = 2500
        self.observe(self.setup_led_driver, names=['dac_ref_trait', 'switch_max_current_trait', 'device_max_current_trait'])
        self.setup_led_driver([])
        
        # Configure audio input
        self.pAudio = ol.audio_codec_ctrl_0
        self.pAudio.configure()
        self.pAudio.select_line_in()
                
    def _setup_fft_plot(self, trait_change):
        # frequency domain plot
        # divide by 2 as we will only look at positive frequencies
        buffer_ft = np.zeros(int(self.n_fft_bins_trait/2))
        self.panel.fig_ft.data[0].update({'x' : np.linspace(0, int(self.panel.fs_trait/2), int(self.n_fft_bins_trait/2))})
        self.panel.fig_ft.data[0].update({'y' : buffer_ft})

        self.panel.fig_ft.update_layout(xaxis = {'title' : 'frequency (Hz)', 'gridcolor' : '#444444'})
        self.panel.fig_ft.update_layout(yaxis = {'title' : 'magnitude', 'gridcolor' : '#444444'})
        self.panel.fig_ft.update_layout(title = {'text' : 'Fourier Transform Magnitude'})
        self.panel.fig_ft.update_layout(title_font_color='#FFFFFF')
        self.panel.fig_ft.update_xaxes(color='#FFFFFF')
        self.panel.fig_ft.update_yaxes(color='#FFFFFF')
        self.panel.fig_ft.update_layout(paper_bgcolor='#212121')
        self.panel.fig_ft.update_layout(plot_bgcolor='#212121')
        self.panel.fig_ft.data[0].line.color = "#635faa"

    def show_original_tonotopic_map(self):
        plt.imshow(cv2.cvtColor(self.im, cv2.COLOR_BGR2RGB))
        plt.show()

    def show_sampled_tonotopic_map(self):
        plt.imshow(self.im_sampled)
        plt.show()
        
    def tonotopic_map_to_frequencies(self, trait_change):
        """Converts an image of the tonotopic map to a frequency map. This should be changed to a better solution.
        
        """
        self.im = cv2.imread(f'images/tonotopic_map_image.png')

        self.im_sampled = self.im[10::math.ceil(np.shape(self.im)[0]/10), 0::math.ceil(np.shape(self.im)[1]/10), :]
        self.im_sampled = cv2.cvtColor(self.im_sampled, cv2.COLOR_BGR2RGB)
        
        # go from red to blue as image is BGR
        red = Color("blue")
        n_colours = 20
        colours = list(red.range_to(Color("red"), n_colours))

        frequencies = np.linspace(start=self.min_frequency_trait, stop=self.max_frequency_trait, num=n_colours)

        colours_list_rgb = [colour.rgb for colour in colours]
        colours_image_rgb = self.im_sampled/255

        # get frequency bin for each location on sampled tonotopic map
        self.frequency_map = np.zeros((10,10))

        for r in range(10):
            for c in range(10):
                # use euclidean distance to get frequency bins for tonotopic map image
                idx = np.argmin(np.array([np.linalg.norm(colours_list_rgb[colour] - colours_image_rgb[r, c, :]) for colour in range(n_colours)]))

                self.frequency_map[r, c] = frequencies[idx]
                
    def setup_led_driver(self, trait_change):
        # updating this array will set the LEDs to the corresponding current counts value
        self.led_counts = allocate(shape=(10,10), dtype=np.uint16)
        self.led_counts.fill(0)
        
        dac_refs = {0 : 3.125,
            1 : 6.25,
            2 : 12.5,
            3 : 25,
            4 : 50,
            5 : 100,
            6 : 200,
            7 : 300}
        dac_ref = [v for v in dac_refs.values()].index(self.dac_ref_trait)
        switch_max_current = int(self.switch_max_current_trait)
        device_max_current = int(self.device_max_current_trait)
        
        led_driver_program = audio_led_driver_setup(self.ol)
        led_driver_program.leds_configure(dac_ref, switch_max_current, device_max_current) # 6.25 mA limit
        led_driver_program.leds_start(self.led_counts)
        
        if self.update_main_control_panel_log != None:
            self.update_main_control_panel_log('Microblaze LED driver programmed\n')
    
    def setup_demo(self, trait_change):
        """Use frequency map generated from image of tonotopic map to create mapping of FFT bins to grid.
        
        """
        # GENERATE MAPPING OF FREQUENCY TO GRID
        self.ft_bins = np.linspace(0, self.panel.fs_trait/2, int(self.n_fft_bins_trait/2))

        # location of minimum frequency in ft
        self.idx_min = np.argmin(np.abs(self.ft_bins - self.min_frequency_trait))
        self.idx_max = np.argmax(np.abs(self.ft_bins - self.min_frequency_trait))
        ft_bins_in_range = np.linspace(self.ft_bins[self.idx_min], self.ft_bins[self.idx_max], 100)

        self.ft_bin_indices_to_grid = [[np.argmin(np.abs(ft_bins_in_range-self.frequency_map[r,c])) for c in range(10)] for r in range(10)]
        
        self.sample_time = self.n_samples_window_trait / self.panel.fs_trait
        
        self.threshold = 10000
        
    def _start_button_clicked(self, button):
        self.run_task = asyncio.ensure_future(self.run())
        
    async def run(self):
        """Run demo for specified time.
        
        """
        for i in range(int(self.time_to_run_trait / self.sample_time)):
            # time domain plot
            self.pAudio.record(seconds=self.sample_time)
            buffer_signed = ((self.pAudio.buffer << 8).view(np.int32) >> 8)[0::2] # change to signed int and look at one channel (mono)
            
            # UPDATE TIME PLOT
            if self.panel.time_plot_pause_toggle.value == False:
                self.panel.fig.data[0].update({'y' : np.concatenate([self.panel.fig.data[0]['y'][self.n_samples_window_trait:], buffer_signed])})

            # compute fft
            ft = abs(np.fft.fft(buffer_signed))[0:int(self.n_fft_bins_trait/2)]
            
            # UPDATE FREQUENCY PLOT
            if self.panel.freq_plot_pause_toggle.value == False:
                self.panel.fig_ft.data[0].update({'y' : ft})

            ft_bin_max_values = np.array([max(arr) for arr in np.array_split(ft, 100)])
            ft_bin_max_values_norm = ft_bin_max_values/1000000

            # translate values to grid
            ft_bin_values_on_grid = (ft_bin_max_values_norm[np.array(self.ft_bin_indices_to_grid)]*65535).astype(np.uint16)
            ft_bin_values_on_grid[ft_bin_values_on_grid < self.threshold] = 0

            # LED CONTROL ##############################
            self.led_counts[:,:] = ft_bin_values_on_grid[:,:]
            
            if self.panel.stop_button.value:
                self.panel.stop_button.value = False
                break
            
            if i%10 == 0:
                await asyncio.sleep(0.001)
            
        # switch off all LEDs
        self.led_counts.fill(0)
        

In [24]:
from hearlight import Dashboard

In [25]:
%%capture
dashboard = Dashboard(led_control_panel = False, array_currents_panel = False, audio_control_panel = AudioProcessSoftwareFFT)

In [26]:
dashboard.banner_panel

Box(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x0e\xe2\x00\x00\x01<\x08\x06\x00\x00\x…

In [27]:
# create panel for images of tonotopic map
tonotopic_map_images_panel = widgets.VBox()
tonotopic_map_images_panel.add_class('t_map_panel_css')

tonotopic_map_image_original_heading = widgets.widgets.HTML(value='<b>&nbsp;Original tonotopic map</b>')
tonotopic_map_image_original_heading.add_class('section_heading')
tonotopic_map_images_panel.children += (tonotopic_map_image_original_heading,)

tonotopic_map_image_original_box = widgets.Box()
tonotopic_map_image_original_image = widgets.Image(value=open('images/tonotopic_map_image_orig.png', 'rb').read())
tonotopic_map_image_original_image.add_class('t_map_image_css')
tonotopic_map_image_original_box.children += (tonotopic_map_image_original_image,)
tonotopic_map_image_original_box.add_class('t_map_image_box_css')
tonotopic_map_images_panel.children += (tonotopic_map_image_original_box,)

tonotopic_map_image_sampled_heading = widgets.widgets.HTML(value='<b>&nbsp;Sampled tonotopic map</b>')
tonotopic_map_image_sampled_heading.add_class('section_heading')
tonotopic_map_images_panel.children += (tonotopic_map_image_sampled_heading,)

tonotopic_map_image_sampled_box = widgets.Box()
tonotopic_map_image_sampled_image = widgets.Image(value=open('images/tonotopic_map_image_sampled.png', 'rb').read())
tonotopic_map_image_sampled_image.add_class('t_map_image_css')
tonotopic_map_image_sampled_box.children += (tonotopic_map_image_sampled_image,)
tonotopic_map_image_sampled_box.add_class('t_map_image_box_css')
tonotopic_map_images_panel.children += (tonotopic_map_image_sampled_box,)

#tonotopic_map_images_panel

In [28]:
from ipywidgets import GridspecLayout

dashboard_box = GridspecLayout(12, 12, height='850px', width='1465px', grid_gap="10px")

dashboard_box[0:5,0:2] = dashboard.info_panel

dashboard_box[0:5,2:12] = dashboard.main_control_panel

dashboard_box[5:12,0:10] = dashboard.audio_control_panel

dashboard_box[5:12,10:12] = tonotopic_map_images_panel

dashboard_box

GridspecLayout(children=(VBox(children=(HTML(value='<h1 style="font-size:1.7em"><b>HearLight PYNQ control syst…

# DEMO: LED control through audio input using FFT
--- 
*May take a few minutes to load...*

#### Instructions
- Connect audio input to *LINE-IN* on PYNQ-Z2 board
- Start the program by pressing ***START*** on the audio control panel
- The audio signal is windowed and an FFT processes samples in real-time to control LEDs based on the tonotopic map of the mouse auditory cortex
