In [9]:
import time 
import functools
import numpy as np

from pynq import (Overlay,
                  allocate)

# some handy functions to use along widgets
from IPython.display import display, Markdown, clear_output
# widget packages
import ipywidgets as widgets

%matplotlib inline
import matplotlib.pyplot as plt

In [21]:
class GUI():
    def __init__(self):
        self.N_SNR = 7
        self.N_CH = 200

        self.overlay = []
        self.outputs = [[], [], [], [], [], [], []]
        
        # EDIT THIS: Write the number of DNN IPs you want to run and  plot
        # MAX VALUE = 7
        self.N_DNN = 1
        
        # EDIT THIS: Write the path of the DNN IPs you want to run and  plot
        self.ip_paths_dnn = ["./bitstreams/sample_ip.bit", 
                             "", 
                             "", 
                             "",
                             "",
                             "",
                             ""]
        
        # EDIT THIS: Write the name you want to display next to each IP
        # This name does not need to be the same as the bitstream name
        self.ip_names = ["dnn_sample", 
                         "", 
                         "", 
                         "",
                         "",
                         "",
                         ""]
        
    def _delete_all(self):
        self.overlay = []
        self.outputs = [[],[],[]]

    @property
    def ip_outputs(self):
        return self.outputs
    
    @ip_outputs.setter
    def ip_outputs(self, arg_outputs):
        self.outputs = arg_outputs
        
    def _load_xin_file(self, filename):
        data = np.zeros(shape=(104,200))
        with open(filename, 'r') as reader:
            file_data = reader.readlines()
            for idx, line in enumerate(file_data):
                line = line[:-1]
                line = line.split(" ")
                arr = np.asarray(line, dtype=np.float32)
                data[idx] = arr
        data = data.transpose()
        data = data[0]
        return data

    def _load_yin_file(self, filename):
        data = np.zeros(shape=(200,104))
        with open(filename, 'r') as reader:
            file_data = reader.readlines()
            for idx, line in enumerate(file_data):
                line = line.replace("{","")
                line = line.replace("}","")
                line = line[:-2]
                line = line.split(",")
                if len(line) == 1:
                    continue
                arr = np.asarray(line, dtype=np.float32)
                data[idx] = arr
        return data

    def _load_actual_file(self, filename):
        data = np.zeros(shape=(104,200))
        with open(filename, 'r') as reader:
            file_data = reader.readlines()
            for idx, line in enumerate(file_data):
                line = line[:-1]
                line = line.split(" ")
                arr = np.asarray(line, dtype=np.float32)
                data[idx] = arr

        data = data.transpose()
        return data
    
    def _load_err_file(self, filename):
        data = np.zeros(shape=(7,))
        with open(filename, 'r') as reader:
            file_data = reader.readlines()
            for idx, line in enumerate(file_data):
                line = line[:-1]
                line = line.split(" ")
                arr = np.asarray(line, dtype=np.float32)
                data[idx] = arr
        return data
    
    def _load_dnn_file(self, filename):
        data = np.zeros(shape=(104,))
        with open(filename, 'r') as reader:
            file_data = reader.readlines()
            for idx, line in enumerate(file_data):
                line = line[:-1]
                line = line.split(",")
                arr = np.asarray(line, dtype=np.float32)
                data = arr
        return data
        
    def _load_dnn_ip(self, index):
        _overlay = Overlay(self.ip_paths_dnn[index])
        dma = _overlay.axi_dma_0
        neuralnet = _overlay.neuralNetworkHwR_0
        self.overlay = [index, self.ip_names[index], _overlay, [dma, neuralnet]]
        
    def _calculate_dnn_ip_output(self):
        """
            self.overlay = [index, self.ip_names[index], _overlay, [dma, neuralnet]]
        """
        in_buffer = allocate(shape=(2*104,), dtype=np.float32)
        out_buffer = allocate(shape=(104,), dtype=np.float32)

        IP_output = np.zeros([self.N_SNR, self.N_CH, 104])
        Err_DNN_hw = np.zeros([self.N_SNR,])
        Phf_hw = np.zeros([self.N_SNR,])
        xin = self._load_xin_file("./../data/XinC_1.txt")

        for n_snr in range(self.N_SNR):
            if n_snr == 0:
                print(f"\nSignal: {n_snr+1}")
            else:
                print(f"\n\nSignal: {n_snr+1}")

            xin = self._load_xin_file("./../data/XinC_1.txt")
            yin = self._load_yin_file(f"./../data/Yinc_m{n_snr+1}m.dat")

            for n_ch in range(self.N_CH):
                print(f"  Channel: {n_ch+1}\r", end="")
                self._run_dnn_ip(yin[n_ch], xin, in_buffer, out_buffer)        
                IP_output[n_snr,n_ch,:] = out_buffer

        denorm_IP_output = self._denomarlize(IP_output)
        Err_DNN_normalized_hw = self._caculate_error(denorm_IP_output)

        self.outputs[self.overlay[0]] = [self.overlay[1], denorm_IP_output, Err_DNN_normalized_hw]
        
    def _denomarlize(self, IP_output):
        dnn_mean = self._load_dnn_file("./../data/dnn_mean_o7.dat")
        dnn_std = self._load_dnn_file("./../data/dnn_std_o7.dat")

        denorm_IP_output = np.zeros([self.N_SNR, self.N_CH, 104])

        for n_snr in range(self.N_SNR):
            for n_ch in range(self.N_CH):
                    denorm_IP_output[n_snr, n_ch, :] = np.multiply(IP_output[n_snr, n_ch, :], dnn_std) + dnn_mean

        return denorm_IP_output

    def _caculate_error(self, IP_output):
        Err_DNN_hw = np.zeros([self.N_SNR,])
        Phf_hw = np.zeros([self.N_SNR,])

        for n_snr in range(self.N_SNR):
            act = self._load_actual_file(f"./../data/actual{n_snr+1}.txt")

            # 2-norm (largest sing. value)
            Err_DNN_hw[n_snr] += np.linalg.norm(act - IP_output[n_snr], 2)**2
            Phf_hw[n_snr] +=  np.linalg.norm(act, 2)**2

        Err_DNN_hw /=  200
        Phf_hw /= 200
        Err_DNN_normalized_hw = Err_DNN_hw / Phf_hw
    
        return Err_DNN_normalized_hw
    
    def _run_dnn_ip(self, yin, xin, in_buffer, out_buffer):
        """
            self.overlay = [index, self.ip_names[index], _overlay, [dma, neuralnet]]
        """
        for idx in range(2*104):
            if idx<104:
                in_buffer[idx] = yin[idx]
            else:
                in_buffer[idx] = xin[idx-104]

        self.overlay[3][1].write(0x00, 0x1)
        self.overlay[3][0].sendchannel.transfer(in_buffer)
        self.overlay[3][0].recvchannel.transfer(out_buffer)

    def _plot_err_curves(self):
        """
            self.outputs= [out_ls_ip, out_mmse_ip, out_dnn_ip]
                out_ip = ["ip_name", IP_output, err_normalized]
        """
        flag_plt_empty = True
        for idx, out in enumerate(self.outputs):
            if len(out) != 0:
                flag_plt_empty = False
                err_ls_ip_sim = self._load_err_file(f"./../data/err_ls_ip_sim.txt")
                err_ls_ip_th = self._load_err_file(f"./../data/err_ls_ip_th.txt")
                err_mmse_ip_sim = self._load_err_file(f"./../data/err_mmse_ip_sim.txt")
                err_mmse_ip_th = self._load_err_file(f"./../data/err_mmse_ip_th.txt")
                
                plt.semilogy(np.arange(0, self.N_SNR*5, 5), err_ls_ip_sim, label="err_ls_ip_sim", marker="1", markersize=20)
                plt.semilogy(np.arange(0, self.N_SNR*5, 5), err_ls_ip_th, label="err_ls_ip_th")
                
                plt.semilogy(np.arange(0, self.N_SNR*5, 5), err_mmse_ip_sim, label="err_mmse_ip_sim", marker="1", markersize=20)
                plt.semilogy(np.arange(0, self.N_SNR*5, 5), err_mmse_ip_th, label="err_mmse_ip_th")
                
                plt.semilogy(np.arange(0, self.N_SNR*5, 5), out[2], label=out[0])
        
        if not flag_plt_empty:
            plt.legend(loc="upper right")

        plt.xlabel("Preamble SNR")
        plt.ylabel("Average Error per subcarrier")
        plt.show()
    
    def _create_widgets(self):
        self.box_layout = widgets.Layout(display='flex',
                flex_flow='column',
                align_items='center',
                width='100%')

        self.ip_selector = widgets.Dropdown(
                        options=self.ip_names[:self.N_DNN],
                        value=self.ip_names[0],
                        description='Load IP:')

        self.btn_load_IP = widgets.Button(description='Load the IP')
        self.btn_run_selected_ip = widgets.Button(description='Run IP')
        self.btn_plot = widgets.Button(description='Plot')
        self.btn_clear_outputs = widgets.Button(description='Reset')
        self.out = widgets.Output()
        
        self.btn_load_IP.on_click(self._btn_load_IP_on_click)
        self.btn_run_selected_ip.on_click(self._btn_run_ip_on_click)
        self.btn_plot.on_click(self._btn_plot_on_click)
        self.btn_clear_outputs.on_click(self._btn_clear_outputs_on_click)
        
    def _btn_load_IP_on_click(self, change):
         # "linking function with output"
        with self.out:
            # what happens when we press the button
            clear_output()
            index = self.ip_names.index(self.ip_selector.value)
            self._load_dnn_ip(index)
            print(f"{self.ip_selector.value} IP loaded!")  

    def _btn_run_ip_on_click(self, change):
        # "linking function with output"
        with self.out:
        # what happens when we press the button
            clear_output()
            print(f"Running {self.overlay[1]} IP...") 
            self._calculate_dnn_ip_output()
            print("\n\nIP run finished!")

    def _btn_plot_on_click(self, change):
        # "linking function with output"
        with self.out:
        # what happens when we press the button
            clear_output()
            self._plot_err_curves()
            
    def _btn_clear_outputs_on_click(self, change):
        # "linking function with output"
        with self.out:
        # what happens when we press the button
            clear_output()
            self._delete_all()
            print("All outputs cleared!")
    
    def display_widgets(self):
        self._create_widgets()
        display(
            widgets.HBox(
                [
                    self.ip_selector, 
                    self.btn_load_IP, 
                    self.btn_run_selected_ip, 
                    self.btn_plot,
                    self.btn_clear_outputs,
                    self.out
                ], 
                layout=self.box_layout)
        )

In [22]:
x = GUI()
x.display_widgets()

HBox(children=(Dropdown(description='Load IP:', options=('dnn_sample',), value='dnn_sample'), Button(descripti…

In [None]:
x.outputs