# Imports

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    !git clone https://github.com/jhauret/eben.git
    %cd eben
else:
    print('Not running on CoLab')

In [None]:
# imports from external libraries
import os
import torch
import torch.nn as nn
from torch import pi
import IPython.display as ipd
import plotly.graph_objects as go
import plotly
import torchaudio

from src.pqmf import PseudoQMFBanks
from src.temporal_transforms import TemporalTransforms
from src.generator import GeneratorEBEN

# Activity 1: PQMF analysis and synthesis of a chirp signal

## You can try your own parameters !

In [None]:
PQMF_BANDS = 6
PQMF_KS = PQMF_BANDS*8  # Try PQMF_KS = PQMF_BANDS*128 to see how it affects the bands.

MAX_FREQ = 8000 #Hz
DURATION = 3 # seconds
SR=16000 # Hz

In [None]:
# pqmf instance
pqmf = PseudoQMFBanks(decimation=PQMF_BANDS, kernel_size=PQMF_KS)

# chirp instance
chirp = torch.sin(pi*torch.linspace(start=0,end=MAX_FREQ,steps=SR*DURATION)*torch.linspace(0,DURATION,SR*DURATION)) # sin(πft)
chirp = chirp.unsqueeze(0).unsqueeze(0) # torch formalism: (batch_size,channels,time_len)
chirp = pqmf.cut_tensor(chirp) # avoid non-matching shapes between original and recomposed signals
ipd.Audio(chirp.squeeze(), rate=SR) #play audio

## Computations

In [None]:
chirp_decomposed = pqmf(chirp, "analysis")
chirp_recomposed = torch.sum(pqmf(chirp_decomposed, "synthesis"), 1, keepdim=True)

## Shapes and Signal Noise Ratio

In [None]:
print(f'Original chirp length: {chirp.shape[2]} with {chirp.shape[1]} channel') 
print(f'Decomposed chirp length: {chirp_decomposed.shape[2]} with {chirp_decomposed.shape[1]} channels')
print(f'Recomposed chirp length: {chirp_recomposed.shape[2]} with {chirp.shape[1]} channel')
print(f'SNR of chirp_recomposed: {10*torch.log10((chirp_recomposed**2).mean()/((chirp-chirp_recomposed)**2).mean()).item():.2f}dB')


## Visualization

In [None]:
fig = go.Figure()
time = torch.linspace(0, DURATION, SR*DURATION)

# Trace signals
fig.add_trace(go.Scatter(x=time, y=chirp.squeeze(),name='Original chirp'))
fig.add_trace(go.Scatter(x=time, y=chirp_recomposed.squeeze(),name='Recomposed chirp'))

# Trace bands
for band in range(PQMF_BANDS):
    fig.add_trace(go.Scatter(x=time[::PQMF_BANDS], y=chirp_decomposed[0,band,:]+3*band+3,name=f'band_{1+band}'))


fig.update_layout(
    title={'text': (f"""Temporal representation <br><sup>To align all signals in time domain, bands are dilated by a factor {PQMF_BANDS} </sup>"""),'y':0.9,'x':0.45,'xanchor': 'center','yanchor': 'top'},
    font=dict(family='Latin Modern Roman', size=18),
    paper_bgcolor='rgba(0,0,0,0)',
    xaxis=dict(title='Time (seconds)',titlefont=dict(family='Latin Modern Roman', size=18)),
    yaxis=dict(title='Amplitude',titlefont=dict(family='Latin Modern Roman', size=18)))

fig.update_yaxes(tickmode='array',ticktext=['signal']+[f'band_{i+1}' for i in range(PQMF_BANDS)], tickvals=[3*i for i in range(PQMF_BANDS+1)])


fig.show()

# Activity 2: inference with EBEN

## Loadings

In [None]:
# load audio
audio, sr = torchaudio.load('example_audio.flac', normalize=False) # try with your own audio.flac !

# load generator weights
weights = torch.load('./generator.ckpt')

# Instantiate EBEN's generator
generator = GeneratorEBEN(m=4, n=32)

# Load weights
generator.load_state_dict(weights)

## In-ear-like degradation pipeline

In [None]:
tt_audio_corrupted = TemporalTransforms(audio.float(), sr)

# degradation
tt_audio_corrupted.remove_hf()
tt_audio_corrupted.add_noise()

# smoothing boarders
tt_audio_corrupted.smoothing()

# normalize
tt_audio_corrupted.normalize()

## Enhance audio with eben model

In [None]:
cut_corrupted_audio = generator.cut_tensor(tt_audio_corrupted.audio.unsqueeze(0))
enhanced_speech, enhanced_speech_decomposed = generator(cut_corrupted_audio)

## Listen to results

### In-ear-like

In [None]:
ipd.Audio(tt_audio_corrupted.audio, rate=sr) 

### EBEN enhanced

In [None]:
ipd.Audio(enhanced_speech.squeeze().detach(), rate=sr) 

### Reference

In [None]:
ipd.Audio(audio, rate=sr) 