# Practical Thursday: Text-to-Audio Generation

This notebook aims to give a brief view of modern model for text to audio generation. Basically, it contains two main models:
- [EnCodec](https://arxiv.org/abs/2210.13438), a ResidualVQ based audio codec model to compress the raw wavform into discrete tokens 
- [AudioGen](https://arxiv.org/abs/2209.15352), a Transformer based audio-language model

In this lab session, we will train a EnCodec modelon a toy dataset, from data preparation to model configuration. Then, we will compare it with the EnCodec model from MetaAI which is fully pretrained on a large dataset. Finally we will see how the pretrained audio codec model can be used for text-to-audio generation, by using the AudioGen model pretrained by MetaAI also.

This notebook is inspired from the [AudioCraft](https://github.com/facebookresearch/audiocraft) project, you could check more details in their repo. 

For any other questions, pleas contact xiaoyu[dot]bie[at]telecom-paris[dot]fr



In [None]:
%%capture
python -m pip install -r 'requirements.txt'
python -m pip install -e .

In [1]:
import torch
import torchaudio
import IPython
import julius
import numpy as np
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
## Load config
cfg_filepath = 'encodec16k.yaml'
cfg = OmegaConf.load(cfg_filepath)

## EnCodec

### Pipeline figure, from the [EnCodec](https://arxiv.org/abs/2210.13438)
![](assets/arc_encodec.png)


### Prepare dataloader

To prepare a dataloader, we first create a csv file which contains the meta information (filepath, sample rate, etc) of all the training data

The data directory is defined in `cfg.toy_data`

In [None]:
from audiodata import DatasetAudioTrain

audio_dir = Path(cfg.toy_data)
audio_manifest = 'mani_audio.csv'
ext = 'wav'

## Prepare data metadata
audio_len = 0
with open(audio_manifest, 'w') as f:
    f.write('id,filepath,sr,length\n') # libri-light too large, no silence trim
    for audio_filepath in tqdm(sorted(list(audio_dir.glob(f'**/*.{ext}'))), desc=f'prepare..'):
        audio_id = audio_filepath.stem
        x, sr = torchaudio.load(audio_filepath)
        length = x.shape[-1]
        utt_len = length / sr
        audio_len += utt_len
        line = '{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length)
        f.write(line)
    print('Total audio len: {:.2f} mins'.format(audio_len/60))


## get dataloader
dataset = DatasetAudioTrain(csv_file=audio_manifest,
                            sample_rate=cfg.sample_rate,
                            n_examples=cfg.dataset.n_examples,
                            chunk_size=cfg.dataset.segment_duration,
                            trim_silence=cfg.dataset.trim_silence,
                            normalize=cfg.dataset.normalize,
                            lufs_norm_db=cfg.dataset.lufs_norm_db,
                            lufs_var=cfg.dataset.lufs_var)

dataloader = DataLoader(dataset=dataset, 
                        batch_size=cfg.dataset.batch_size, num_workers=cfg.dataset.num_workers,
                        shuffle=cfg.dataset.shuffle, drop_last=cfg.dataset.drop_last)
print('Batch size: {}, {} iterations per epoch'.format(cfg.dataset.batch_size, len(dataloader)))

### Prepare model and loss

The training of EnCodec is similar to [VQGAN](https://arxiv.org/abs/2012.09841), which contains two main parts:
- VQ-VAE, for vector quantization and data reconstruction
- Discriminator, for adversarial training

Futhermore, EnCodec introduces a loss balancer to stabilize training. Defining the gradients $g_i = \frac{\partial l_i}{\partial \hat{x}}$, and $\langle || g_i ||_2 \rangle_{\beta}$ the exponential moving average of $g_i$. Given a set of weights $\lambda_i$, and a reference norm $R$ it has:

$$
\hat{g}_i = R \frac{\lambda_i}{\sum \lambda_i} \cdot \frac{g_i}{\langle || g_i ||_2 \rangle_{\beta}}
$$

In practice, $R = 1$ and $\beta = 0.999$. 


In [None]:
from audiocraft.solvers.builders import (
    get_optimizer,
    get_audio_datasets,
    get_adversarial_losses,
    get_loss,
    get_balancer
)
from audiocraft.models.builders import get_compression_model


## get model and optimizer
model = get_compression_model(cfg.model)
optimizer = get_optimizer(model.parameters(), cfg.optim)
print('Use {} optimizer, learning rate: {}'.format(cfg.optim.optimizer, cfg.optim.lr))

# get loss function
adv_losses = get_adversarial_losses(cfg)
aux_losses = torch.nn.ModuleDict()
loss_weights = dict()
for loss_name, weight in cfg.losses.items():
    if loss_name in ['adv', 'feat']:
        for adv_name, _ in adv_losses.items():
            loss_weights[f'{loss_name}_{adv_name}'] = weight
    elif weight > 0:
        aux_losses[loss_name] = get_loss(loss_name, cfg)
        loss_weights[loss_name] = weight
balancer = get_balancer(loss_weights, cfg.balancer)
print("Total # of params: {:.2f} M".format(sum(p.numel() for p in model.parameters())/1e6))

### Training

Due to time constraints, we only performed preliminary training on a toy dataset containing only 10 audio data.

In [None]:
# train
ckpt_path = 'last_ckpt.pth'
total_epoch = cfg.optim.epochs
model = model.to(cfg.device)
model.train()
print('Training epoch: {}'.format(total_epoch))
for epo in range(total_epoch):
    for audio_data in tqdm(dataloader, total=len(dataloader)):
        # prepare data
        x = audio_data.to(cfg.device)
        y = x.clone()
        metrics = {}

        # forward
        qres = model(x)
        y_pred = qres.x

        # discrimilator loss
        d_losses: dict = {}
        for adv_name, adversary in adv_losses.items():
            disc_loss = adversary.train_adv(y_pred, y)
            d_losses[f'd_{adv_name}'] = disc_loss
        metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values())))
        
        balanced_losses: dict = {}
        other_losses: dict = {}

        # penalty from quantization
        if qres.penalty is not None and qres.penalty.requires_grad:
            other_losses['penalty'] = qres.penalty  # penalty term from the quantizer

        # adversarial losses
        for adv_name, adversary in adv_losses.items():
            adv_loss, feat_loss = adversary(y_pred, y)
            balanced_losses[f'adv_{adv_name}'] = adv_loss
            balanced_losses[f'feat_{adv_name}'] = feat_loss

        # auxiliary losses
        for loss_name, criterion in aux_losses.items():
            loss = criterion(y_pred, y)
            balanced_losses[loss_name] = loss

        # backprop losses that are not handled by balancer
        other_loss = torch.tensor(0., device=cfg.device)
        if 'penalty' in other_losses:
            other_loss += other_losses['penalty']
        if other_loss.requires_grad:
            other_loss.backward(retain_graph=True)
        
        # balancer losses backward
        metrics['g_loss'] = balancer.backward(balanced_losses, y_pred)

        # optimize
        optimizer.step()
        optimizer.zero_grad()

    # save model every epoch
    print('====> Epoch: {}, d_loss: {:.3f}, g_loss: {:.3f}'.format(epo, metrics['d_loss'], metrics['g_loss']))
    torch.save({
            'epoch': epo,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics,
            }, ckpt_path)


# Audio Compression

Once the training is finished, we can use the model to reconstruct the input audio via:
- **Encoder**: $z = Enc(x)$
- **Quantizer**: $z_q = Quant(z)$
- **Decoder**: $y = Dec(z_q)$

We will see the performance of the model trained on toy dataset, compared with the EnCodec model fully trained from [facebook/encodec_32khz](https://huggingface.co/facebook/encodec_32khz)

### Load an example audio

In [None]:
## Load an audio
audio_filepath = 'example.wav'
x, fs = torchaudio.load(audio_filepath)
print('Audio length: {:.1f}s'.format(x.shape[-1]/fs))

## Display the audio
print('Original Audio:')
IPython.display.Audio(audio_filepath)

### Reconstruction through the model trained on toy dataset

In [None]:
recon_filepath = 'example_recon.wav'

## Load model
model = get_compression_model(cfg.model)
checkpoint = torch.load('last_ckpt.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.cpu().eval()

## Preprocess
x_in = julius.resample_frac(x, old_sr=fs, new_sr=16000)

## Reconstruction
codes, scale = model.encode(x_in[None,])
y = model.decode(codes, scale)[0].detach()

## Write the audio
torchaudio.save(recon_filepath, y, sample_rate=16000)

## Display the audio
print('Reconstructed Audio from the model trained on toy dataset:')
IPython.display.Audio(recon_filepath)

### Reconstruction through the model from facebook

In [None]:
from audiocraft.models import CompressionModel
recon_filepath = 'example_recon_fb.wav'

## Load model
model_fb = CompressionModel.get_pretrained('facebook/encodec_32khz')

## Preprocess, they don't provide 16k model, so we use 32k instead
x_in = julius.resample_frac(x, old_sr=fs, new_sr=32000)
codes, scale = model_fb.encode(x_in[None,])
y_fb = model_fb.decode(codes, scale)[0].detach()
y_fb = julius.resample_frac(y_fb, old_sr=32000, new_sr=16000)
torchaudio.save(recon_filepath, y_fb, sample_rate=16000)

## Display the audio
print('Reconstructed Audio from the model trained on toy dataset:')
IPython.display.Audio(recon_filepath)

# Text-to-Audio Generation

Once we have a well-trained audio codec model, we can incoorperate it with:
- **Text Encoder**, turn the text information into feature reprensentations
- **Audio Language Model**, a Transformer-based LLM that predicts the audio latent codes based on the text condidtion


### Pipeline figure, from the [AudioGen](https://arxiv.org/abs/2209.15352)
![](assets/arc_audiogen.png)

Due to the time and resource limitation, we directly use the pretrained AudioGen model from [facebook/audiogen-medium](https://huggingface.co/facebook/audiogen-medium)

In [None]:
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write


## you can use any customized text prompt
description = 'dog barking'
# description = 'sirene of an emergency vehicle'
# description = 'footsteps in a corridor'

model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=5)
wav = model.generate([description], progress=True)[0]
audio_write('sample', wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) # -14 db LUFS

print('Description {}'.format(description))
IPython.display.Audio('sample.wav')