In [0]:
import tensorflow as tf
import numpy as np
import scipy as scp
from scipy import signal, linalg
from numpy import asarray, array, ravel, repeat, prod, mean, where, ones
#!pip install scikits.audiolab
#!pip install --upgrade-strategy=only-if-needed git+https://github.com/Uiuran/BregmanToolkit
#!pip install scikit-image
import scikits.audiolab as audio
import matplotlib.pyplot as plt
from tensorflow.python.client import timeline
correlate = signal.correlate

# Tensorboard display
from IPython.display import clear_output, Image, display, HTML

In [0]:
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.    
    
    Parameters:
    
    - session: The TensorFlow session to be frozen.
    
    Default Keyword Parameters
    
    - keep_var_names: A list of variable names that should not be frozen. Defaults None.     
    - output_names: Names of the relevant graph outputs/operation/tensor to be written. Defaults None.
    - clear_devices: Remove the device directives from the graph for better portability. Defaults True.
    
    return The frozen graph definition.
    """    
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in TF.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in TF.global_variables()]
        # Graph -> GraphDef ProtoBuf
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph

def wiener(im, mysize=None, noise=None):
    """
    Perform a Wiener filter on an N-dimensional array.
    Apply a Wiener filter to the N-dimensional array `im`.
    Parameters
    ----------
    im : ndarray
        An N-dimensional array.
    mysize : int or array_like, optional
        A scalar or an N-length list giving the size of the Wiener filter
        window in each dimension.  Elements of mysize should be odd.
        If mysize is a scalar, then this scalar is used as the size
        in each dimension.
    noise : float, optional
        The noise-power to use. If None, then noise is estimated as the
        average of the local variance of the input.
    Returns
    -------
    out : ndarray
        Wiener filtered result with the same shape as `im`.
    """      
    
      
    
    im = asarray(im)
    if mysize is None:
        mysize = [3] * im.ndim
    mysize = asarray(mysize)
    if mysize.shape == ():
        mysize = repeat(mysize.item(), im.ndim)

    # Estimate the local mean
    lMean = correlate(im, ones(mysize), 'same') / prod(mysize, axis=0)

    # Estimate the local variance
    lVar = (correlate(im ** 2, ones(mysize), 'same') /
            prod(mysize, axis=0) - lMean ** 2+1e-8)

    # Estimate the noise power if needed.
    if noise is None:
        noise = mean(ravel(lVar), axis=0)

    res = (im - lMean)
    res *= (1 - noise / lVar)
    res += lMean
    out = where(lVar < noise, lMean, res)    
    subtract = im-out
    return out,subtract

In [0]:
# funções auxiliares
def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = tf.compat.as_bytes("<stripped %d bytes>"%size)
    return strip_def
  
def rename_nodes(graph_def, rename_func):
    res_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = res_def.node.add() 
        n.MergeFrom(n0)
        n.name = rename_func(n.name)
        for i, s in enumerate(n.input):
            n.input[i] = rename_func(s) if s[0]!='^' else '^'+rename_func(s[1:])
    return res_def
  
# Função que usa HTML e javascript para exibir tensorboar no notebook e web
def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))
  
    iframe = """
        <iframe seamless style="width:800px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

In [20]:
data,fs,enc = audio.wavread('/content/species2.wav')
out,noise = wiener(data,mysize=fs/20)
l = np.size(data)
print(l)
tensordata = np.ndarray(shape=(1,l,1))
tensornoise = np.ndarray(shape=(1,l,1))
tensordata[0,:,0] = data 
tensornoise[0,:,0] = out 

# Comp. Graph 
graph = tf.Graph()
with graph.as_default():
  signal_in = tf.placeholder(tf.float32,(None,l,1), name='signal_in')  
  wfilter = tf.get_variable('wfilter', shape=[1600,1,1],initializer=tf.random_normal_initializer(), dtype=tf.float32)
  signal_ref = tf.placeholder(tf.float32,(None,l,1), name='signal_ref')
  # 1D Convolve  which internally uses 2D reshaped https://www.tensorflow.org/api_docs/python/tf/nn/conv1d
  signal_out = tf.nn.conv1d(signal_in,wfilter,1,'SAME', name='signal_out')  
  loss = tf.reduce_sum(tf.math.squared_difference(signal_out,signal_ref, name = 'squared'), name = 'loss')
  minimize_op = tf.train.AdamOptimizer(learning_rate=0.05).minimize(loss)
  # Print operations for the graph built

  for op in graph.get_operations():
    print(op.name)
  graph_def = graph.as_graph_def()
  show_graph(graph_def)


634880
signal_in
wfilter/Initializer/random_normal/shape
wfilter/Initializer/random_normal/mean
wfilter/Initializer/random_normal/stddev
wfilter/Initializer/random_normal/RandomStandardNormal
wfilter/Initializer/random_normal/mul
wfilter/Initializer/random_normal
wfilter
wfilter/Assign
wfilter/read
signal_ref
signal_out/ExpandDims/dim
signal_out/ExpandDims
signal_out/ExpandDims_1/dim
signal_out/ExpandDims_1
signal_out/Conv2D
signal_out/Squeeze
squared
Const
loss
gradients/Shape
gradients/grad_ys_0
gradients/Fill
gradients/loss_grad/Reshape/shape
gradients/loss_grad/Reshape
gradients/loss_grad/Shape
gradients/loss_grad/Tile
gradients/squared_grad/Shape
gradients/squared_grad/Shape_1
gradients/squared_grad/BroadcastGradientArgs
gradients/squared_grad/scalar
gradients/squared_grad/Mul
gradients/squared_grad/sub
gradients/squared_grad/mul_1
gradients/squared_grad/Sum
gradients/squared_grad/Reshape
gradients/squared_grad/Sum_1
gradients/squared_grad/Reshape_1
gradients/squared_grad/Neg
grad

In [0]:
# Create session

with graph.as_default():
  session = tf.Session()

  feed_dict = {
      signal_in: tensordata,
      signal_ref:tensornoise   
  }
  options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
  run_metadata = tf.RunMetadata()
  session.run(tf.global_variables_initializer())

  # Perform  gradient descent steps    
  for step in range(120):
    
    loss_value = session.run(loss, feed_dict)    
    
    if step % 1 == 0:
      print("Step:", step, " Loss:", loss_value)      
    
      
    session.run(minimize_op,
                feed_dict = feed_dict,
                options=options,
                run_metadata=run_metadata)    
    
    # Profiling
    #fetched_timeline = timeline.Timeline(run_metadata.step_stats)
    #chrome_trace = fetched_timeline.generate_chrome_trace_format()
    #with open('timeline_0_0_step_%d.json' % step, 'w') as f:
      #f.write(chrome_trace)

# Plot

  signal_out_value = session.run(graph.get_tensor_by_name('signal_out/Squeeze:0'), feed_dict)  
  wfilter_value = wfilter.eval(session=session)
  print('Wiener filter ')
  plt.figure()
  plt.plot(wfilter_value[0,:,0])    
  print('output_filter_SGD')
  plt.figure()
  plt.plot(signal_out_value[0,:,0])  
  print('input_signal')  
  plt.figure()
  plt.plot(tensordata[0,:,0])

('Step:', 0, ' Loss:', 3020094.0)
('Step:', 1, ' Loss:', 2008613.5)
('Step:', 2, ' Loss:', 1298492.2)
('Step:', 3, ' Loss:', 844935.75)
('Step:', 4, ' Loss:', 589868.1)
('Step:', 5, ' Loss:', 467732.25)
('Step:', 6, ' Loss:', 420709.0)
('Step:', 7, ' Loss:', 408023.88)
('Step:', 8, ' Loss:', 405211.84)
('Step:', 9, ' Loss:', 399636.94)
('Step:', 10, ' Loss:', 386526.03)
('Step:', 11, ' Loss:', 365835.0)
('Step:', 12, ' Loss:', 339833.06)
('Step:', 13, ' Loss:', 311490.03)
('Step:', 14, ' Loss:', 283417.12)
('Step:', 15, ' Loss:', 257226.88)
('Step:', 16, ' Loss:', 233377.81)
('Step:', 17, ' Loss:', 211512.77)
('Step:', 18, ' Loss:', 191065.0)
('Step:', 19, ' Loss:', 171762.6)
('Step:', 20, ' Loss:', 153820.42)
('Step:', 21, ' Loss:', 137812.67)
('Step:', 22, ' Loss:', 124344.59)
('Step:', 23, ' Loss:', 113716.95)
('Step:', 24, ' Loss:', 105750.81)
('Step:', 25, ' Loss:', 99836.38)
('Step:', 26, ' Loss:', 95150.836)
('Step:', 27, ' Loss:', 90915.19)
('Step:', 28, ' Loss:', 86571.91)
('S

In [0]:
frozen_graph = freeze_session(session, output_names= None)
tf.train.write_graph(frozen_graph, "model", "./wiener.pb", as_text=False)   

In [0]:
audio.wavwrite(signal_out_value[0,:,0], '/content/species2filter.wav', fs = fs, enc = enc)

In [0]:
data,fs,enc = audio.wavread('/content/species1.wav')
data1,fs,enc = audio.wavread('/content/species1filter.wav')
data2,fs,enc = audio.wavread('/content/species1filter0.wav')