In [1]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal
from ipywidgets import interact, fixed, IntSlider, HBox, Layout, Output, VBox
import ipywidgets as widgets

%matplotlib widget

# Arrows

In [2]:
plt.close('all')
fig, axs = plt.subplots(1, figsize=(5, 3))

def draw_arrow(ax=None, pos=(0,0), length=1, width=1, linewidth=8, color='black', orientation='down'):
    if ax is None:
        ax = plt.gca()
    pos = np.array(pos)
    if orientation == 'down':
        ax.plot((pos[0], pos[0]), pos[1] + np.array([10*length, 0.2*length]), color=color, linewidth=linewidth)
        ax.plot(pos[0] + np.array([0, -2.7*width]), pos[1] + np.array([0, 2*length]), color=color, linewidth=linewidth)
        ax.plot(pos[0] + np.array([0, 2.7*width]), pos[1] + np.array([0, 2*length]), color=color, linewidth=linewidth)
    elif orientation == 'up':
        ax.plot((pos[0], pos[0]), pos[1] + np.array([-10*length, -0.2*length]), color=color, linewidth=linewidth)
        ax.plot(pos[0] + np.array([0, -2.7*width]), pos[1] + np.array([0, -2*length]), color=color, linewidth=linewidth)
        ax.plot(pos[0] + np.array([0, 2.7*width]), pos[1] + np.array([0, -2*length]), color=color, linewidth=linewidth)
    elif orientation == 'right':
        ax.plot(pos[0] + np.array([-10*length, -0.2*length]), (pos[1], pos[1]), color=color, linewidth=linewidth)
        ax.plot(pos[0] + np.array([0, -2*length]), pos[1] + np.array([0, -2.7*width]), color=color, linewidth=linewidth)
        ax.plot(pos[0] + np.array([0, -2*length]), pos[1] + np.array([0, 2.7*width]), color=color, linewidth=linewidth)
    elif orientation == 'left':
        ax.plot(pos[0] + np.array([10*length, 0.2*length]), (pos[1], pos[1]), color=color, linewidth=linewidth)
        ax.plot(pos[0] + np.array([0, 2*length]), pos[1] + np.array([0, -2.7*width]), color=color, linewidth=linewidth)
        ax.plot(pos[0] + np.array([0, 2*length]), pos[1] + np.array([0, 2.7*width]), color=color, linewidth=linewidth)
    else:
        raise AssertionError(f"'{orientation}' is not a valid orientation.")

def draw_adder(ax=None, pos=(0,0), height=1, width=1, scale=1, linewidth=5, color='black', n=100):
    if ax is None:
        ax = plt.gca() 
    height *= scale
    width *= scale
    pos = np.array(pos)
    # Circle
    ax.plot(pos[0] + width*np.cos(np.linspace(0, 2*np.pi, n)), pos[1] + height*np.sin(np.linspace(0, 2*np.pi, n)), linewidth=linewidth, color=color)
    # Cross
    ax.plot(pos[0] + [-width + 0.1*width, width - 0.1*width], (pos[1],pos[1]), linewidth=linewidth, color=color)
    ax.plot((pos[0], pos[0]), pos[1] + [-height + 0.1*height, height - 0.1*height], linewidth=linewidth, color=color)

# Test adder with arrows
draw_arrow(axs, pos=(-2,5), width=0.5, color='blue', orientation='down', linewidth=5)
draw_arrow(axs, pos=(-2,-5), width=0.5, color='red', orientation='up', linewidth=5)
draw_arrow(axs, pos=(13,0), width=1.2, color='black', orientation='right', linewidth=5, length=1.15)
draw_adder(axs, pos=(-2,0), scale=3, width=3/5, linewidth=3)
axs.set_xlim([-15, 15])
axs.set_ylim([-15, 15])
axs.axis('off')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Class definition

In [87]:
class Deconvolution_demo():
    def __init__(self, var=1):
        self.out = Output(layout={'width': '1000px', 'height': '700px'})
        self.axs = []
        
        self.var = var
        self.t = np.linspace(-5, 5, 1001)
        self.x = self.x_func()
        self.h = self.h_func()
        self.xh = self.xh_func()
        self.n = self.n_func()
        self.y = self.y_func()
        self.y1 = self.y1_func()
        self.xt = self.xt_func()
        
        # Initialize figure
        self.init_figure()
        
        # Display
        display(self.out)
        plt.tight_layout(pad=0.1, w_pad=1.0, h_pad=0.1)
        
    def x_func(self, t=None):
        t = self.t if t is None else t
        
        return np.where(t >= 0, np.exp(-t), 0)
    
    def h_func(self, t=None):
        t = self.t if t is None else t
        
        return 1/(2*np.pi)*np.exp(-t**2/2)
    
    def xh_func(self, x=None, h=None):
        x = self.x if x is None else x
        h = self.h if h is None else h
        return np.real(np.fft.ifftshift(np.fft.ifft(np.fft.fft(x) * np.fft.fft(h))))
    
    def n_func(self, var=None, t=None):
        t = self.t if t is None else t
        var = self.var if var is None else var
        
        return np.random.normal(loc=0.0, scale=np.sqrt(var), size=len(t))
    
    def y_func(self, xh=None, n=None):
        xh = self.xh if xh is None else xh
        n = self.n if n is None else n
        
        return self.xh + n
    
    def y1_func(self, y=None, h=None):
        y = self.y if y is None else y
        h = self.h if h is None else h
        
        H = np.fft.fft(h)
        return np.real(np.fft.fftshift(np.fft.ifft(np.fft.fft(y) / H)))
    
    def xt_func(self, y1=None, x=None, h=None, var=None):
        x = self.x if x is None else x
        h = self.h if h is None else h
        y1 = self.y1 if y1 is None else y1
        var = self.var if var is None else var
        
        Sn1 = np.abs(np.fft.fft(h))**2 * var
        Sx = np.abs(np.fft.fft(x))**2
        Hw = Sx / (Sx + Sn1)
        return np.real(np.fft.fftshift(np.fft.ifft(np.fft.fft(y1) * Hw)))
    
    def init_figure(self):
        with self.out:
            self.fig = plt.figure(figsize=(8.5, 6))
            self.gs = self.fig.add_gridspec(12, 3)

        ##### DRAW CONNECTING ARROWS #####
        # Prefilter stage
        self.axs.append(self.fig.add_subplot(self.gs[3, 0]))
        self.draw_arrow(self.axs[0], width=0.5, linewidth=3)
        self.axs[0].set_xlim([-15, 15])
        self.axs[0].set_ylim([-1, 10])
        self.axs[0].axis('off')
        self.prefilt_txt = self.axs[0].text(2, 3, r'$\ast h(t)$')
        
        # Adder stage
        self.axs.append(self.fig.add_subplot(self.gs[7:9, 0]))
        # Down
        self.draw_arrow(self.axs[1], pos=(0,6), width=0.5, color='black', orientation='down', linewidth=3)
        # Up
        self.draw_arrow(self.axs[1], pos=(0,-6), width=0.5, color='black', orientation='up', linewidth=3)
        # Right
        self.draw_arrow(self.axs[1], pos=(13,0), width=1.8, length=1, color='black', orientation='right', linewidth=3)
        # Adder
        self.draw_adder(self.axs[1], pos=(0,0), scale=3.5, width=0.4, linewidth=2)
        self.axs[1].set_xlim([-15, 15])
        self.axs[1].set_ylim([-15, 15])
        self.axs[1].axis('off')
        
        # Wiener filter stage
        self.axs.append(self.fig.add_subplot(self.gs[3, 2]))
        self.draw_arrow(self.axs[2], width=0.5, orientation='up', linewidth=3)
        self.axs[2].set_xlim([-15, 15])
        self.axs[2].set_ylim([-10, 1])
        self.axs[2].axis('off')
        self.wien_filt_txt = self.axs[2].text(1, -8, r'$\ast h_{WH}(t)$')
        
        # Inverse filter stage
        self.axs.append(self.fig.add_subplot(self.gs[7:9, 2]))
        self.axs[3].plot([-15, 0], [0, 0], color='black', linewidth=3)
        self.draw_arrow(self.axs[3], pos=(0, 10), width=0.5, orientation='up', linewidth=3)
        self.axs[3].set_xlim([-15, 15])
        self.axs[3].set_ylim([-11, 11])
        self.axs[3].axis('off')
        self.inv_filt_txt = self.axs[3].text(-8, -5, r'$\ast h_{inv}(t)$')
        ##### END OF CONNECTING ARROWS #####
        
        # Define axes indices for plots
        self.x_ind = 4
        self.axs.append(self.fig.add_subplot(self.gs[0:3, 0]))
        self.axs[self.x_ind].plot(self.t, self.x)
        self.axs[self.x_ind].set_title(r'$x(t)$')
        self.h_ind = 5
        self.axs.append(self.fig.add_subplot(self.gs[4:7, 0]))
        self.axs[self.h_ind].plot(self.t, self.xh)
        self.axs[self.h_ind].set_title(r'$x_1(t)$')
        self.n_ind = 6
        self.axs.append(self.fig.add_subplot(self.gs[9:12, 0]))
        self.axs[self.n_ind].plot(self.t, self.n)
        self.axs[self.n_ind].set_title(r'$n(t)$')
        self.xt_ind = 7
        self.axs.append(self.fig.add_subplot(self.gs[0:3, 2]))
        self.axs[self.xt_ind].plot(self.t, self.xt)
        self.axs[self.xt_ind].set_title(r'$\tilde{x}(t)$')
        self.y1_ind = 8
        self.axs.append(self.fig.add_subplot(self.gs[4:7, 2]))
        self.axs[self.y1_ind].plot(self.t, self.y1)
        self.axs[self.y1_ind].set_title(r'$y_1(t)$')
        self.y_ind = 9
        self.axs.append(self.fig.add_subplot(self.gs[6:10, 1]))
        self.axs[self.y_ind].plot(self.t, self.y)
        self.axs[self.y_ind].set_title(r'$y(t)$')

                
    
    def draw_arrow(self, ax=None, pos=(0,0), length=1, width=1, linewidth=8, color='black', orientation='down'):
        if ax is None:
            ax = plt.gca()
        pos = np.array(pos)
        if orientation == 'down':
            ax.plot((pos[0], pos[0]), pos[1] + np.array([10*length, 0.2*length]), color=color, linewidth=linewidth)
            ax.plot(pos[0] + np.array([0, -2.7*width]), pos[1] + np.array([0, 2*length]), color=color, linewidth=linewidth)
            ax.plot(pos[0] + np.array([0, 2.7*width]), pos[1] + np.array([0, 2*length]), color=color, linewidth=linewidth)
        elif orientation == 'up':
            ax.plot((pos[0], pos[0]), pos[1] + np.array([-10*length, -0.2*length]), color=color, linewidth=linewidth)
            ax.plot(pos[0] + np.array([0, -2.7*width]), pos[1] + np.array([0, -2*length]), color=color, linewidth=linewidth)
            ax.plot(pos[0] + np.array([0, 2.7*width]), pos[1] + np.array([0, -2*length]), color=color, linewidth=linewidth)
        elif orientation == 'right':
            ax.plot(pos[0] + np.array([-10*length, -0.2*length]), (pos[1], pos[1]), color=color, linewidth=linewidth)
            ax.plot(pos[0] + np.array([0, -2*length]), pos[1] + np.array([0, -2.7*width]), color=color, linewidth=linewidth)
            ax.plot(pos[0] + np.array([0, -2*length]), pos[1] + np.array([0, 2.7*width]), color=color, linewidth=linewidth)
        elif orientation == 'left':
            ax.plot(pos[0] + np.array([10*length, 0.2*length]), (pos[1], pos[1]), color=color, linewidth=linewidth)
            ax.plot(pos[0] + np.array([0, 2*length]), pos[1] + np.array([0, -2.7*width]), color=color, linewidth=linewidth)
            ax.plot(pos[0] + np.array([0, 2*length]), pos[1] + np.array([0, 2.7*width]), color=color, linewidth=linewidth)
        else:
            raise AssertionError(f"'{orientation}' is not a valid orientation.")

    def draw_adder(self, ax=None, pos=(0,0), height=1, width=1, scale=1, linewidth=5, color='black', n=100):
        if ax is None:
            ax = plt.gca() 
        height *= scale
        width *= scale
        pos = np.array(pos)
        # Circle
        ax.plot(pos[0] + width*np.cos(np.linspace(0, 2*np.pi, n)), pos[1] + height*np.sin(np.linspace(0, 2*np.pi, n)), linewidth=linewidth, color=color)
        # Cross
        ax.plot(pos[0] + [-width + 0.1*width, width - 0.1*width], (pos[1],pos[1]), linewidth=linewidth, color=color)
        ax.plot((pos[0], pos[0]), pos[1] + [-height + 0.1*height, height - 0.1*height], linewidth=linewidth, color=color)


In [88]:
plt.close('all')
deconv_demo = Deconvolution_demo(var=1)

Output(layout=Layout(height='700px', width='1000px'))

In [57]:
x = np.zeros(11)
x[5] = 1
h = [1, 0, 1]
y = signal.convolve(x, h)
print(x, h, y)
xt = signal.deconvolve(y, h)
print(xt)

[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [1, 0, 1] [0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0.]
(array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
