# Imports

In [None]:
# imports from external libraries
import torch
import torch.nn as nn
from torch import pi
import IPython.display as ipd
import plotly.graph_objects as go
import plotly

# imports from project
from pqmf import PseudoQMFBanks

# Example 1: Chirp PQMF analysis and synthesis

## You can try your own parameters !

In [None]:
PQMF_BANDS = 6
PQMF_KS = 48 # must be divisible by 4 * PQMF_BANDS

MAX_FREQ = 8000 #Hz
DURATION = 3 # seconds
SR=16000 # Hz

In [None]:
# pqmf instance
pqmf = PseudoQMFBanks(decimation=PQMF_BANDS, kernel_size=PQMF_KS)

# chirp instance
chirp = torch.sin(pi*torch.linspace(start=0,end=MAX_FREQ,steps=SR*DURATION)*torch.linspace(0,DURATION,SR*DURATION)) # sin(πft)
chirp = chirp.unsqueeze(0).unsqueeze(0) # torch formalism: (batch_size,channels,time_len)
chirp = pqmf.cut_tensor(chirp) # avoid non-matching shapes between original and recomposed signals
ipd.Audio(chirp.squeeze(), rate=SR) #play audio

## Computations

In [None]:
chirp_decomposed = pqmf(chirp, "analysis")
chirp_recomposed = torch.sum(pqmf(chirp_decomposed, "synthesis"), 1, keepdim=True)

## Shapes and Signal Noise Ratio

In [None]:
print(f'Original chirp length: {chirp.shape[2]} with {chirp.shape[1]} channel') 
print(f'Decomposed chirp length: {chirp_decomposed.shape[2]} with {chirp_decomposed.shape[1]} channels')
print(f'Recomposed chirp length: {chirp_recomposed.shape[2]} with {chirp.shape[1]} channel')
print(f'SNR of chirp_recomposed: {10*torch.log10((chirp_recomposed**2).mean()/((chirp-chirp_recomposed)**2).mean()).item():.2f}dB')


## Visualization

In [None]:
fig = go.Figure()
time = torch.linspace(0, DURATION, SR*DURATION)

# Trace signals
fig.add_trace(go.Scatter(x=time, y=chirp.squeeze(),name='Original chirp'))
fig.add_trace(go.Scatter(x=time, y=chirp_recomposed.squeeze(),name='Recomposed chirp'))

# Trace bands
for band in range(PQMF_BANDS):
    fig.add_trace(go.Scatter(x=time[::PQMF_BANDS], y=chirp_decomposed[0,band,:]+3*band+3,name=f'band_{1+band}'))


fig.update_layout(
    title={'text': (f"""Temporal representation <br><sup>To align all signals in time domain, bands are dilated by a factor {PQMF_BANDS} </sup>"""),'y':0.9,'x':0.45,'xanchor': 'center','yanchor': 'top'},
    font=dict(family='Latin Modern Roman', size=18),
    paper_bgcolor='rgba(0,0,0,0)',
    xaxis=dict(title='Time (seconds)',titlefont=dict(family='Latin Modern Roman', size=18)),
    yaxis=dict(title='Amplitude',titlefont=dict(family='Latin Modern Roman', size=18)))

fig.update_yaxes(tickmode='array',ticktext=['signal']+[f'band_{i+1}' for i in range(PQMF_BANDS)], tickvals=[3*i for i in range(PQMF_BANDS+1)])


fig.show()