##### Copyright 2021 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
import os
from tensorflow.io import gfile
from scipy import io
from scipy import sparse
import numpy as np
from scipy.io import wavfile
from scipy.signal import resample
import tempfile

In [None]:
# @title Mount Google Drive
from google.colab import drive
ROOT_DIR = '/content/gdrive'
drive.mount(ROOT_DIR, force_remount=True)

In [None]:
# @title Code for cochlear implant audio-to-electrodogram processor.

from typing import List, Optional, Tuple
import dataclasses
import numpy as np
import scipy.signal as sig

# Electrodogram channel order, -1 to convert to base-0 indices.
ELGRAM_CHANNEL_ORDER = tuple(
    np.array([1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16]) - 1)
NUM_CHANNELS = len(ELGRAM_CHANNEL_ORDER)


@dataclasses.dataclass
class Params:
  """Processing parameters."""

  ## Preprocessing. ------------------------------------------------------------
  # Compression ratio applied in preprocessing.
  preprocess_compression_ratio: float = 2.0

  ## Filterbank. ---------------------------------------------------------------
  # Downsample factor between input audio rate and sample rate for envelopes.
  hop: int = 16
  # Number of frequency channels in the filterbank.
  num_channels: int = NUM_CHANNELS
  # Center frequency for channel 0.
  lowest_cf_hz: float = 350.0
  # Number of channels per octave.
  channels_per_octave: float = 4.0
  # Number of poles in the full gammatone filter.
  gtf_order: int = 6
  # Q parameter of the gammatone filter.
  gtf_q: float = 5.0
  # Time constant in seconds for the "slow" envelope.
  tau_slow: float = 1.0
  # Time constant in seconds for the "medium" envelope.
  tau_medium: float = 0.1
  # Time constant in seconds for the "fast" envelope.
  tau_fast: float = 0.001

  ## PCEN compression. ---------------------------------------------------------
  # First stage denominator exponent, used on the slow envelope.
  pcen_alpha1: float = 0.8
  # Second stage denominator exponent, used on the medium envelope.
  pcen_alpha2: float = 0.6
  # First stage denominator offset, should be tuned proportional to the noise.
  pcen_epsilon1: float = 1e-6
  # Second stage denominator offset, should be tuned proportional to the noise.
  pcen_epsilon2: float = 1e-4
  # PCEN zero offset. Energy below this, amplitude above, suppresses noise.
  pcen_offset: float = 0.05
  # PCEN beta exponent. 1/2 and 1/3 are about equally good.
  pcen_beta: float = 1 / 2.0

  ## Coring sparsification step. -----------------------------------------------
  # Factor for how much of the mean to subtract in coring sparisfying step.
  mean_subtraction_gain: float = 0.25

  ## Conversion to electrodogram. ----------------------------------------------
  # Log scale factor to compensate for exp(5*EF) in vocoder.
  elgram_log_scale: float = 5.0 / 0.4 * 0.5
  # Factor for adjusting log offset.
  elgram_floor_factor: float = 0.9
  # Max electrodogram magnitude.
  max_microamps: float = 550.0


@dataclasses.dataclass
class Envelopes:
  slow: np.ndarray
  medium: np.ndarray
  fast: np.ndarray


def desired_audio_sample_rate(params: Optional[Params] = None) -> float:
  """Computes the desired sample rate for the input audio based on the hop."""
  if params is None:
    params = Params()
  return NUM_CHANNELS / (params.hop * 2 * 18e-6)


def audio_to_elgram(audio: np.ndarray,
                    sample_rate_hz: float,
                    params: Optional[Params] = None) -> np.ndarray:
  """Processes audio to electrodogram."""
  amplitudes = audio_to_amplitudes(audio, sample_rate_hz, params)
  elgram = amplitudes_to_elgram(amplitudes, params)
  return elgram


def audio_to_amplitudes(audio: np.ndarray,
                        sample_rate_hz: float,
                        params: Optional[Params] = None):
  """Processes audio to compressed amplitudes.

  Args:
    audio: 1D numpy array, input audio waveform. Should be scaled so that
      samples have the nominal range [-1, 1].
    sample_rate_hz: Sample rate in Hz. Should be get_sample_rate_for_hop(hop).
    params: Params.

  Returns:
    2D numpy array with `filterbank_params.num_channels` rows.
  """
  if params is None:
    params = Params()
  audio = preprocess(audio, sample_rate_hz, params)
  envelopes = filterbank(audio, sample_rate_hz, params)
  amplitudes = apply_pcen(envelopes, params)

  # Find the mean for each frame.
  means = np.mean(amplitudes, axis=0, keepdims=True)
  # Enhance and sparsify spectral shape.
  amplitudes = np.maximum(0.0,
                          amplitudes - means * params.mean_subtraction_gain)
  # Normalization for listening tests to use full range without clip.
  amplitudes /= amplitudes.max()
  return amplitudes


def amplitudes_to_elgram(amplitudes,
                         params: Optional[Params] = None) -> np.ndarray:
  """Converts compressed amplitudes to electrodogram.

  Args:
    amplitudes: 2D numpy array with 16 rows.
    params: Params.

  Returns:
    2D numpy array with 16 rows.
  """
  if params is None:
    params = Params()

  nl_offset = params.elgram_floor_factor * np.exp(-params.elgram_log_scale)
  log_amplitudes = np.log(amplitudes + nl_offset) / params.elgram_log_scale
  log_amplitudes = np.maximum(0.0, 1.0 + log_amplitudes - log_amplitudes.max())

  # Make electrodogram from amplitudes.
  ehop = 2 * NUM_CHANNELS
  elgram = np.zeros((NUM_CHANNELS, ehop * log_amplitudes.shape[1]))
  start = 0

  for c in ELGRAM_CHANNEL_ORDER:
    microamps = params.max_microamps * log_amplitudes[c]
    elgram[c, start::ehop] = microamps
    elgram[c, (start + 1)::ehop] = -microamps
    start += 2

  return elgram


def preprocess(audio: np.ndarray,
               unused_sample_rate_hz: float,
               params: Optional[Params] = None) -> np.ndarray:
  """Preprocesses input audio with preemphasis filtering and compression."""
  if params is None:
    params = Params()

  audio = sig.lfilter([1.0, -0.8], [1.0], audio)
  # NOTE: We take the mean of the whole signal to simulate the effect of a
  # slow-acting compressor. Some input recordings are too short (or start with
  # pure silence) to be able to initialize causally.
  # The "np.mean(audio**4)**(1/4)" part is RMS average, but with 4th power 
  # instead of 2 so that the result is influenced more by the peaks.
  agc_gain = (1e-12 +
              np.mean(audio**4))**(-1 /
                                   (4 * params.preprocess_compression_ratio))
  audio *= agc_gain
  return audio


def design_gammatone_filter_stage(
    center_frequency_hz: float,
    sample_rate_hz: float,
    order: int = 6,
    q: float = 5.0) -> Tuple[np.ndarray, np.ndarray]:
  """Designs complex gammatone filter (GTF).

  Args:
    center_frequency_hz: Passband center frequency in Hz.
    sample_rate_hz: Sample rate in Hz.
    order: Integer, the filter order.
    q: Filter Q value.

  Returns:
    (gtf_numer, gtf_denom) filter coefficients for one stage. This stage filter
    should be applied `order` times to implement the full GTF filter.
  """
  zeta = np.sqrt(order) / (2.0 * q)
  omega = 2.0 * np.pi * center_frequency_hz
  T = 1.0 / sample_rate_hz  # pylint: disable=invalid-name
  s_pole = -zeta * omega + 1j * omega
  z_pole = np.exp(s_pole * T)
  # Evaluate abs(1/(z - z_pole)) at z = exp(i omega / sample_rate_hz).
  stage_gain = 1.0 / np.abs(np.exp(1j * omega / sample_rate_hz) - z_pole)

  gtf_numer = np.array([1.0, 1.0]) / (2 * stage_gain)
  gtf_denom = np.array([1.0, -z_pole])
  return gtf_numer, gtf_denom


def energy_envelope(signal: np.ndarray,
                    sample_rate_hz: float,
                    time_constant_s: float,
                    hop: int = 16,
                    hot_initialize: bool = True) -> np.ndarray:
  """Extracts energy envelope.

  Args:
    signal: Possibly complex-valued input signal.
    sample_rate_hz: Sample rate in Hz.
    time_constant_s: Positive float, smoothing time constant is seconds.
    hop: Positive integer, downsampling factor.
    hot_initialize: If true, "hot" initialize smoothing filter based on the
      mean(energy**4). If false, use zero initialization.

  Returns:
    Numpy array, the envelope at `sample_rate_hz / hop`.
  """
  # 2-pole smoother for anti-aliasing before downsampling by hop.
  tau_energy = 0.7 * hop  # Plausible corner for downsampling; in samples.
  lpf_denom = [1, -2.0 * np.exp(-1.0 / tau_energy), np.exp(-2.0 / tau_energy)]
  # For unit gain at DC, null at Nyquist.
  lpf_numer = np.ones(2) * 0.5 * np.sum(lpf_denom)
  energy = sig.lfilter(lpf_numer, lpf_denom, np.abs(signal)**2)[::hop]

  T_hop = hop / sample_rate_hz  # pylint: disable=invalid-name
  s_pole = -1.0 / time_constant_s
  # Map s = -1/tau to z = exp(s*T).
  z_pole = np.exp(s_pole * T_hop)
  state = sig.lfilter_zi([1 - z_pole], [1, -z_pole])

  if hot_initialize:
    state *= (np.mean(energy**4)**0.25)  # Average focused on peaks.
  else:
    state *= 0.0

  envelope, _ = sig.lfilter([1 - z_pole], [1, -z_pole], energy, zi=state)
  return envelope


def filterbank(audio: np.ndarray,
               sample_rate_hz: float,
               params: Optional[Params] = None) -> List[Envelopes]:
  """Runs GTF filterbank on audio, returning a collection of energy envelopes.

  Args:
    audio: Numpy array, audio to filter.
    sample_rate_hz: Sample rate in Hz.
    params: Params.

  Returns:
    A list of `params.num_channels` Envelopes, in which each envelope has sample
    rate `sample_rate_hz / params.hop`.
  """
  if params is None:
    params = Params()

  envelopes = []
  for c in range(params.num_channels):
    # Design GTF filter for this channel.
    gtf_numer, gtf_denom = design_gammatone_filter_stage(
        center_frequency_hz=params.lowest_cf_hz *
        2**(c / params.channels_per_octave),
        sample_rate_hz=sample_rate_hz,
        order=params.gtf_order,
        q=params.gtf_q)

    # GTF bank; identical filter stages, some with energy outputs hopping.
    # Stage 0:
    filtered = sig.lfilter(gtf_numer, gtf_denom, audio)
    slow_env = energy_envelope(filtered, sample_rate_hz, params.tau_slow,
                               params.hop)
    # Stage 1:
    filtered = sig.lfilter(gtf_numer, gtf_denom, filtered)
    medium_env = energy_envelope(filtered, sample_rate_hz, params.tau_medium,
                                 params.hop)
    # Remaining stages.
    for _ in range(2, params.gtf_order):
      filtered = sig.lfilter(gtf_numer, gtf_denom, filtered)
    fast_env = energy_envelope(
        filtered,
        sample_rate_hz,
        params.tau_fast,
        params.hop,
        hot_initialize=False)

    envelopes.append(Envelopes(slow=slow_env, medium=medium_env, fast=fast_env))

  return envelopes


def apply_pcen(envelopes: List[Envelopes],
               params: Optional[Params] = None) -> np.ndarray:
  """Applies PCEN-like AGC to fast envelopes to compute compressed amplitudes.

  See reference:
    Y. Wang, P. Getreuer, T. Hughes, R. F. Lyon, and R. A. Saurous,
    “Trainable frontend for robust and far-field keyword spotting,” in Proc.
    IEEE ICASSP, 2017.

  Args:
    envelopes: List of num_channels Envelopes.
    params: Params.

  Returns:
    2D numpy array with num_channels rows.
  """
  if params is None:
    params = Params()

  num_channels = len(envelopes)
  amplitudes = np.empty((num_channels, len(envelopes[0].slow)))

  for c, env in enumerate(envelopes):
    # Use slow and medium envelopes to apply PCEN-style AGC to env.fast.
    gain1 = (env.slow + params.pcen_epsilon1)**(-params.pcen_alpha1)
    gain2 = (env.medium * gain1 + params.pcen_epsilon2)**(-params.pcen_alpha2)
    amplitudes[c] = (
        (env.fast * gain1 * gain2 + params.pcen_offset)**params.pcen_beta -
        params.pcen_offset**params.pcen_beta)

  return amplitudes
  

In [None]:
# @title File read/write functions.
def write_wav(filename, waveform, sample_rate=16000):
  """Write a audio waveform (float numpy array) as .wav file."""
  wavfile.write(
      filename, sample_rate,
      np.round(np.clip(waveform * 2**15, -32768, 32767)).astype(np.int16))

def read_wav(wav_path, sample_rate=16000, channel=None):
  """Read a wav file as numpy array.

  Args:
    wav_path: String, path to .wav file.
    sample_rate: Int, sample rate for audio to be converted to.
    channel: Int, option to select a particular channel for stereo audio.

  Returns:
    Audio as float numpy array.
  """
  sr_read, x = wavfile.read(wav_path)
  x = x.astype(np.float32) / (2**15)

  if sr_read != sample_rate:
    x = resample(x, int(round((float(sample_rate) / sr_read) * len(x))))
  if x.ndim > 1 and channel is not None:
    return x[:, channel]
  return x


def store_elgram(results, elgram_output_dir):
  basename = os.path.basename(results['sourceName'])
  inputFileName = os.path.splitext(basename)[0]

  elgram_output_path = os.path.join(elgram_output_dir, inputFileName + '.mat')
  gfile.makedirs(os.path.dirname(elgram_output_path))

  sparse_mat = sparse.csr_matrix(results['elGram'])

  local_temp_path = os.path.join(tempfile.mkdtemp(),
                                 os.path.basename(elgram_output_path))
  io.savemat(local_temp_path, {'elData': sparse_mat})
  gfile.copy(local_temp_path, elgram_output_path, overwrite=True)


def ci_process_file_elgram(wav_filename: str):
  """Computes amplitudes array and electrodogram for a wav file.

  Args:
    wav_filename: string name of the wav file

  Returns:
    elgram: a 2D float array electrodogram
  """
  nChan = NUM_CHANNELS
  params = Params(
      hop=16,  # can be 12 or 16 or 24 or various other possibilities; 16 best?
  )
  fs = desired_audio_sample_rate(params)
  sig_smp_wavIn = read_wav(wav_filename, sample_rate=fs)

  amplitudes = audio_to_amplitudes(sig_smp_wavIn, fs, params)
  elgram = amplitudes_to_elgram(amplitudes, params)

  return elgram

In [None]:
#@title run audio-to-electrodogram
audio_clip_subpath = '*.wav'  #@param

PATH_AUDIO = 'gdrive/My Drive/cihack_audio_enhanced_mixed_variable'  # E.g. gdrive/My Drive/cihack_audio_enhanced_mixed_variable

PATH_ELGRAM2 = PATH_AUDIO + '_elgram2'

wavs = gfile.glob(os.path.join(PATH_AUDIO, audio_clip_subpath))
for wav in wavs:
  elgram = ci_process_file_elgram(wav)
  results = {'sourceName': wav, 'elGram': elgram}
  store_elgram(results, PATH_ELGRAM2)