# Task

- Make the decoding streamable. This means that the input to decoder is provided iteratively (e.g. in a setting where we want to decode the output of autogressive generative model iteratively). To enable streaming the decoder needs to be able to process chunks of inputs iteratively, instead of being applied to the whole input sequence at once. Please see the provided pseudo code snippet that shows what the streaming solution should enable.

```python
# How it is done normally (without streaming): 
all_tokens = []
for token in codec_tokens:
		all_tokens.append(token)

# Dac decoder waits until the full sequence is generated, then it decodes the sequence.
audio_wav_1 = dac.decode(all_tokens)

# Pseudocode of the solution, what we are aiming for: 
for token in codec_tokens:
    # Dac decodes as tokens become available.
		audio_wav_2.append(dac.streaming_decode(token))
```

- Prepare a quick demo (e.g. notebook) to present your solution.
- Think about testing, edge cases, and design trade-offs (you don’t need an exhaustive test suit or retrain anything, but be ready to discuss what you would consider).
- It is completely fine to mock data for this assignment, but if you prefer, you’re also welcome to use real speech/music/sfx data.

## Note
Before start the task, I made a minor modification on dac codebase since there was a problem in their code and it would affect the  when processing audio in 16kHz:

Code file: dac/model/dac.py: line 94 - 109
```python3
# Original code
class DecoderBlock(nn.Module):
    def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
        super().__init__()
        self.block = nn.Sequential(
            Snake1d(input_dim),
            WNConvTranspose1d(
                input_dim,
                output_dim,
                kernel_size=2 * stride,
                stride=stride,
                padding=math.ceil(stride / 2),
            ),
            ResidualUnit(output_dim, dilation=1),
            ResidualUnit(output_dim, dilation=3),
            ResidualUnit(output_dim, dilation=9),
        )
# Modified code
class DecoderBlock(nn.Module):
    def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
        super().__init__()
        self.block = nn.Sequential(
            Snake1d(input_dim),
            WNConvTranspose1d(
                input_dim,
                output_dim,
                kernel_size=2 * stride,
                stride=stride,
                padding=math.floor(stride / 2), # ceil() -> floor(), for 16kHz
            ),
            ResidualUnit(output_dim, dilation=1),
            ResidualUnit(output_dim, dilation=3),
            ResidualUnit(output_dim, dilation=9),
        )
```
If we consider 1s audio in the 16kHz setting in the DAC, we have 50 tokens in total, so we can compute the feature length after each up-sampling:

| Upsampling layer | Feature dim | Upsample rate | Feature lens (math.floor) | Feature lens (math.ceil)
| --- | --- | --- | --- | --- |
| 0   | 1536   | \   | 50 | 50 |
| 1   | 768   | 8   | 400 | 400 |
| 2   | 384   | 5   | 2,001 | 1,999 |
| 3   | 192   | 4   | 8,004 | 7,996 |
| 4   | 96   | 2   | 16,008 | 15,992 |

In practice, I found that using `math.floor` corresponds to the model they have pretrained

In [1]:
import torch
import torchaudio
import numpy as np
from tqdm import tqdm

import dac
from audiotools import AudioSignal

In [3]:
# Load DAC model
model_path = 'pretrained/weights_16khz_8kbps_0.0.5.pth'
model = dac.DAC.load(model_path)
model.eval()
print('Successfully load DAC model')

  WeightNorm.apply(module, name, dim)


Successfully load DAC model


In [4]:
# Define DAC streaming decoder
class DacStreamingDecoder():
    def __init__(self,
                 model,
                 chunk_frames,
                 overlap_frames=None):
        self.model = model
        self.chunk_frames = chunk_frames
        self.overlap_frames = overlap_frames if overlap_frames else self.compute_rf()
        self.code_buffer = []
        self.hop_length = np.prod(self.model.decoder_rates)

        assert self.chunk_frames > self.overlap_frames; \
            f'chunk_frames: {self.chunk_frames} should be larger than overlap_frames: {self.overlap_frames}'
        self.out_frames = self.chunk_frames - self.overlap_frames
        self.out_samples = self.out_frames * self.hop_length
        self.total_frames = 0

    
    def compute_rf(self, T=101):
        # create a random input latent feature
        latent_dim = self.model.latent_dim
        z = torch.randn(1, latent_dim, T).to(self.model.device)
        z.requires_grad_(True)
        z.retain_grad()
        out = self.model.decode(z)
        # backward
        grad = torch.zeros_like(out)
        grad[:, :, grad.shape[-1] // 2] = 1
        out.backward(grad)
        # select features that contain gradients
        gradmap = z.grad.detach().abs().sum(dim=1).squeeze(0)
        idx = (gradmap != 0).nonzero(as_tuple=False).squeeze(-1)
        center = T // 2
        left_rf_frames = center - int(idx.min().item())
        return left_rf_frames


    def streaming_decode(self,
                         token_frame: torch.Tensor):

        # Add token frames into buffer
        self.code_buffer.append(token_frame.detach().cpu())
        if len(self.code_buffer) < self.chunk_frames:
            return None

        # Build token chunk
        chunk = torch.stack(self.code_buffer[: self.chunk_frames], dim=0).transpose(0, 1).unsqueeze(0)  # (1, n_q, T)
        chunk = chunk.to(next(self.model.parameters()).device)

        # Decode
        with torch.no_grad():
            z, _, _ = self.model.quantizer.from_codes(chunk)
            y = self.model.decode(z)
        y = y.squeeze().cpu()
        out = y[:self.out_samples]
        self.code_buffer = self.code_buffer[self.out_frames: ]

        return out

    
    def decode_last(self):

        if len(self.code_buffer) == 0:
            return None
        
        chunk = torch.stack(self.code_buffer, dim=0).transpose(0, 1).unsqueeze(0)  # (1, n_q, T)
        chunk = chunk.to(next(self.model.parameters()).device)

        # Decode
        with torch.no_grad():
            z, _, _ = self.model.quantizer.from_codes(chunk)
            y = self.model.decode(z)
        y = y.squeeze().cpu()
        out = y
        return out

In [9]:
# Load audio signal file
# audio_path = 'p227_001_16k.wav'
signal = AudioSignal(audio_path)
sample_rate = signal.sample_rate
print('Load audio signal with {} length and {} sampling rate'.format(signal.audio_data.shape[-1], sample_rate))

# Encode audio signal as one long file
signal.to(model.device)
x = model.preprocess(signal.audio_data, signal.sample_rate)
z, codes, latents, _, _ = model.encode(x)
print('Audio length after processing: {}'.format(x.shape[-1]))

# Get audio tokens
codec_tokens = list(codes[0].T)
print('Get {} audio tokens with {} resigual layers'.format(len(codec_tokens), len(codec_tokens[0])))

Load audio signal with 2646000 length and 44100 sampling rate


AssertionError: 

In [6]:
# Streaming decoding
StreamingDACDecoder = DacStreamingDecoder(model, chunk_frames=20)
audio_wav_chunks = []
for token in tqdm(codec_tokens, desc='streaming decoding'):
    # Dac decodes as tokens become available.
    audio_wav_chunks.append(StreamingDACDecoder.streaming_decode(token))
audio_wav_chunks.append(StreamingDACDecoder.decode_last())
audio_wav_chunks = [c for c in audio_wav_chunks if c is not None]
audio_wav = torch.concat(audio_wav_chunks, dim=0)

streaming decoding: 100%|████████████████████████████████████████████████████████████████████████████████████| 206/206 [00:00<00:00, 229.82it/s]


In [7]:
# Show audio
from IPython.display import Audio, display

# ref audio
print('Referene audio')
ref_audio = x[0].detach().cpu().numpy()
display(Audio(ref_audio, rate=sample_rate))

# decoded audio from streaming decoding
print('Reconstructed audio from streaming decoding')
recon_audio = audio_wav.unsqueeze(0).detach().cpu().numpy()
display(Audio(recon_audio, rate=sample_rate))

# decoded audio from global decoding
print('Reconstructed audio from global decoding')
z = model.quantizer.from_codes(codes)[0]
recon_audio_global = model.decode(z)
recon_audio_global = recon_audio_global[0].detach().cpu().numpy()
display(Audio(recon_audio_global, rate=sample_rate))

Referene audio


Reconstructed audio from streaming decoding


Reconstructed audio from global decoding
