# AudioLM

AudioLM is an high-quality audio generation framework, with long-term consistency. The main idea is to map an input audio sequence of discrete tokens in an intermediary discrete reppresentation space.
One of the main issues that has been addressed in this work is the concilitation between generating audio with long-term consistency and and the recostruction of high quality audio.

This are the necessary imports.

In [1]:
import os
from pathlib import Path

import torch
from torch import nn

from audiolm.data_preparation import AudioDataLoader
from audiolm.w2v_hubert import W2VHuBert
from audiolm.absolute_transformer import (
    SemanticTransformer,
    CoarseAcousticTransformer,
    FineAcousticTransformer,
)
from audiolm.encodec import Encodec

In [None]:
%load_ext tensorboard

This are constants that will be important later.

In [44]:
DATA_PATH = Path(os.getcwd()) / Path("..") / Path("data") / Path("datasets")
MODEL_PATH = Path(os.getcwd()) / Path("..") / Path("data")
INTERVALS = 10
EARLY_STOP_COUNTER = 0
EARLY_STOPPING_RANGE = 5
EPOCHS = 10

Both the semantic encoder are assumed pretrained and freezed ahead of time.

In [4]:
semantic_encoder = W2VHuBert()
acoustic_encoder_decoder = Encodec()

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


The three dataloaders that will be needed, train val ,test

In [15]:
train_dataloader = AudioDataLoader(DATA_PATH / "train", batch_size=4,max_elems=10)
val_dataloader = AudioDataLoader(DATA_PATH / "val", batch_size=4,max_elems=5)
test_dataloader = AudioDataLoader(DATA_PATH / "test", batch_size=4,max_elems=2)

We define what an abstract trainer should look like. This will give the skeleton for the specific trainers.
This is needed because this framework uses a hierachical model approach, however at each level of the hierarchy only one transfomere is trained and much of the inherent logic is repeated.
We can use this insight to create spoecialized classes for the training of a given transformer.
The loss generator is the main component that will be changed the most during the specialization as it reflects how the input/output are created and is the only thing that changes between hierachies

In [45]:
import os
from math import ceil
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional

from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

from audiolm.encodec import Encodec
from audiolm.constants import DEVICE
from audiolm.data_preparation import AudioDataLoader
from audiolm.w2v_hubert import W2VHuBert
from audiolm.utils import save_checkpoint, save_model


class Trainer(ABC):
    """
    Trainer class for training a Transformer model.
    """

    @abstractmethod
    # pylint: disable =too-many-arguments
    def __init__(
        self,
        semantic_encoder: Optional[W2VHuBert] = None,
        semantic_transformer: Optional[SemanticTransformer] = None,
        acoustic_encoder_decoder: Optional[Encodec] = None,
        coarse_acoustic_transformer: Optional[CoarseAcousticTransformer] = None,
        fine_acoustic_transformer: Optional[FineAcousticTransformer] = None,
        train_dataloader: Optional[AudioDataLoader] = None,
        val_dataloader: Optional[AudioDataLoader] = None,
        test_dataloader: Optional[AudioDataLoader] = None,
        loss: Optional[nn.Module] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        intervals: Optional[int] = None,
        save_path: Optional[os.PathLike] = None,
        early_stop_counter: Optional[int] = None,
        early_stopping_range: Optional[int] = None,
        epochs: Optional[int] = None,
    ):
        self.semantic_encoder = semantic_encoder
        self.semantic_transformer = semantic_transformer
        self.acoustic_encoder_decoder = acoustic_encoder_decoder
        self.coarse_acoustic_transformer = coarse_acoustic_transformer
        self.fine_acoustic_transformer = fine_acoustic_transformer
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.optimizer = optimizer
        self.intervals = intervals
        self.epochs = epochs
        self.save_path = save_path
        self.best_val_loss = float("inf")
        self.early_stopping_range = early_stopping_range
        self.early_stop_counter = early_stop_counter
        self.loss = loss
        if save_path is not None and not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    # region: Abstract methods, this methods must be redefined accordingly.
    @abstractmethod
    def loss_generator(self, batch):
        """Generate loss"""

    @abstractmethod
    def train(self):
        """
        Train the Transformer model.
        """

    @abstractmethod
    def test(self):
        """Test the model on the test dataset."""

    # endregion

    # region: Private methods.
    def _train_step(self, model: nn.Module) -> float:
        model.train()
        train_loss = 0
        
        for batch in tqdm(self.train_dataloader, total = ceil(len(self.train_dataloader) / self.train_dataloader.batch_size) ):
            batch = batch.to(DEVICE)
            loss = self.loss_generator(batch)
            train_loss += loss.item()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        train_loss /= len(self.train_dataloader)
        return train_loss

    def _validation_step(self, model: nn.Module) -> float:
        model.eval()
        validation_loss = 0
        with torch.inference_mode():
            for batch in self.val_dataloader:
                batch = batch.to(DEVICE)
                loss = self.loss_generator(batch)
                validation_loss += loss.item()

        validation_loss /= len(self.val_dataloader)

        return validation_loss

    def _train(self, model: nn.Module):
        writer = SummaryWriter(Path(self.save_path) / "runs" / str(type(model).__name__))
        for epoch in tqdm(range(self.epochs), total = self.epochs, desc="Training"):
            train_loss = self._train_step(model)
            validation_loss = self._validation_step(model)
            print("SAVING CHECKPOINT...")
            save_checkpoint(
                model, epoch, self.optimizer, self.early_stop_counter, self.save_path
            )
            print("SAVING RUN FOR TENSORBOARD...")
            writer.add_scalars(
                main_tag=f"Loss_{str(type(model).__name__)}",
                tag_scalar_dict={
                    "train_loss": train_loss,
                    "validation_loss": validation_loss,
                },
                global_step=epoch,
            )

            if validation_loss < self.best_val_loss:
                self.best_val_loss = validation_loss
                self.early_stop_counter = 0
            else:
                self.early_stop_counter += 1

            if self.early_stop_counter >= self.early_stopping_range:
                print(f"Early stopping training at epoch: {epoch+1}")
                break
        writer.flush()
        writer.close()
        save_model(model, self.save_path)

    def _test(self, model):
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for batch in tqdm(self.test_dataloader, desc="Testing"):
                batch = batch.to(DEVICE)
                loss = self.loss_generator(batch)
                test_loss += loss.item()

        test_loss /= len(self.test_dataloader)
        print(f"Test Loss: {test_loss: .4f}")

        return test_loss

    # endregion

The first hierarchy of the model is used to train the semantic transformer used for autoregressive prediction of semantic tokens.
The training is pretty straight forward in fact we only need to take the output of berts tokenization and feed into the transformer.
`train` and `test` are overridden in order to specify which trasformer we're interessed to train and test rispectively.

In [32]:
class SemanticTrainer(Trainer):
    """Trainer class derived from `Trainer`."""

    def __init__(
        self,
        semantic_encoder: W2VHuBert,
        semantic_transformer: SemanticTransformer,
        train_dataloader: AudioDataLoader,
        val_dataloader: AudioDataLoader,
        test_dataloader: AudioDataLoader,
        loss: nn.Module,
        optimizer: torch.optim.Optimizer,
        intervals: int,
        save_path: Path,
        early_stop_counter: int,
        early_stopping_range: int,
        epochs: int,
    ):
        """
        Takes as input `semantic_encoder` and `semantic_transformer`.
        They determine the `semantic_modelling`.

        `semantic_encoder` must be trained ahead of time, this trainer only
        trains `semantic_transformer`.

        Args
        ----
            `semantic_encoder` (W2VHuBert)

            `semantic_transformer` (TransformerDecoderOnly)

            `train_dataloader` (AudioDataLoader)

            `val_dataloader` (AudioDataLoader)

            `test_dataloader` (AudioDataLoader)

            `loss` (nn.Module)

            `optimizer` (torch.optim.Optimizer)

            `intervals` (int)

            `save_path` (Path)

            `early_stop_counter` (int)

            `early_stopping_range` (int)

            `epochs` (int)
        """
        super().__init__(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            loss=loss,
            optimizer=optimizer,
            intervals=intervals,
            save_path=save_path,
            early_stop_counter=early_stop_counter,
            early_stopping_range=early_stopping_range,
            epochs=epochs,
        )

    def loss_generator(self, batch):
        semantic_encode = self.semantic_encoder(batch)

        output, target = self.semantic_transformer.fit(semantic_encode)

        loss = self.loss(output, target)
        return loss

    def train(self):
        return self._train(self.semantic_transformer)

    def test(self):
        return self._test(self.semantic_transformer)


training becomes a matter of instatiating the trainer and call train.

In [48]:
semantic_transformer = SemanticTransformer(num_heads=16,layers=12)
semantic_loss = nn.CrossEntropyLoss()
semantic_optimizer = torch.optim.Adam(
        semantic_transformer.parameters(), lr=0.001
)
semantic_trainer = SemanticTrainer(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            loss=semantic_loss,
            optimizer=semantic_optimizer,
            intervals=INTERVALS,
            save_path=MODEL_PATH,
            early_stop_counter=EARLY_STOP_COUNTER,
            early_stopping_range=EARLY_STOPPING_RANGE,
            epochs=EPOCHS,
        )
semantic_trainer.train()

Training:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

  0%|          | 0/68 [00:00<?, ?it/s]

In [42]:
%tensorboard --logdir=../data/runs/Semantic_Transformer/

Reusing TensorBoard on port 6007 (pid 22024), started 0:01:32 ago. (Use '!kill 22024' to kill it.)

The second stage of this hierachy is the coarse acoustic modelling ... informazioni sul coarse.
come funziona la generazione della loss i.e come avviene il condizionamento ecc.

In [63]:
class CoarseAcousticTrainer(Trainer):
    """Trainer class derived from `Trainer`."""

    def __init__(
        self,
        semantic_encoder: W2VHuBert,
        semantic_transformer: SemanticTransformer,
        acoustic_encoder_decoder: Encodec,
        coarse_acoustic_transformer: CoarseAcousticTransformer,
        train_dataloader: AudioDataLoader,
        val_dataloader: AudioDataLoader,
        test_dataloader: AudioDataLoader,
        loss: nn.Module,
        optimizer: torch.optim.Optimizer,
        intervals: int,
        save_path: Path,
        early_stop_counter: int,
        early_stopping_range: int,
        epochs: int,
    ):
        """
        Takes as input `semantic_encoder` and `semantic_transformer`.
        They determine the `semantic_modelling`.

        `semantic_encoder` must be trained ahead of time, this trainer only
        trains `semantic_transformer`.

        Args
        ----
            `semantic_encoder` (W2VHuBert)

            `semantic_transformer` (TransformerDecoderOnly)

            `train_dataloader` (AudioDataLoader)

            `val_dataloader` (AudioDataLoader)

            `test_dataloader` (AudioDataLoader)

            `loss` (nn.Module)

            `optimizer` (torch.optim.Optimizer)

            `intervals` (int)

            `save_path` (Path)

            `early_stop_counter` (int)

            `early_stopping_range` (int)

            `epochs` (int)
        """
        super().__init__(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            acoustic_encoder_decoder=acoustic_encoder_decoder,
            coarse_acoustic_transformer=coarse_acoustic_transformer,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            loss=loss,
            optimizer=optimizer,
            intervals=intervals,
            save_path=save_path,
            early_stop_counter=early_stop_counter,
            early_stopping_range=early_stopping_range,
            epochs=epochs,
        )

    def loss_generator(self, batch):

        semantic_encode = self.semantic_encoder(batch)
        print(semantic_encode.shape())
        semantic_token = self.semantic_transformer.generate(semantic_encode, 3)
       
        coarse_acoustic_tokens, _, _ = self.acoustic_encoder_decoder.encode(batch)

        conditioning = torch.cat((semantic_token, coarse_acoustic_tokens), dim=1)

        output, target = self.coarse_acoustic_transformer.fit(conditioning)

        loss = self.loss(output, target)
        return loss

    def train(self):
        return self._train(self.coarse_acoustic_transformer)

    def test(self):
        return self._test(self.coarse_acoustic_transformer)


In [67]:
print(MODEL_PATH / "models" / f"{str(type(semantic_transformer).__name__)}.pth")

c:\Users\josed\Documents\AudioLM\notebooks\..\data\models\SemanticTransformer.pth


In [66]:
semantic_transformer = SemanticTransformer()
state_dict = torch.load(
            MODEL_PATH / "models" / f"{str(type(semantic_transformer).__name__)}.pth"
        )
semantic_transformer.load_state_dict(state_dict)
coarse_acoustic_transformer = CoarseAcousticTransformer(num_heads=16,layers=12)
coarse_loss = nn.CrossEntropyLoss()
coarse_optimizer = torch.optim.Adam(
                coarse_acoustic_transformer.parameters(), lr=0.001
            )
coarse_acoustic_trainer = CoarseAcousticTrainer(
                semantic_encoder=semantic_encoder,
                semantic_transformer=semantic_transformer,
                acoustic_encoder_decoder=acoustic_encoder_decoder,
                coarse_acoustic_transformer=coarse_acoustic_transformer,
                train_dataloader=train_dataloader,
                val_dataloader=val_dataloader,
                test_dataloader=test_dataloader,
                loss=coarse_loss,
                optimizer=coarse_optimizer,
                intervals=INTERVALS,
                save_path=MODEL_PATH,
                early_stop_counter=EARLY_STOP_COUNTER,
                early_stopping_range=EARLY_STOPPING_RANGE,
                epochs=EPOCHS,
            )
coarse_acoustic_trainer.train()

Training:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/69 [00:00<?, ?it/s]

TypeError: 'torch.Size' object is not callable

In [57]:
print(len(train_dataloader))

275


For the sake of completness we implement the Fine acoustic trainer, even though it won't be used.

In [None]:
class FineAcousticTrainer(Trainer):
    """Trainer class derived from `Trainer`."""

    def __init__(
        self,
        semantic_encoder: W2VHuBert,
        semantic_transformer: SemanticTransformer,
        acoustic_encoder_decoder: Encodec,
        coarse_acoustic_transformer: CoarseAcousticTransformer,
        fine_acoustic_transformer: FineAcousticTransformer,
        train_dataloader: AudioDataLoader,
        val_dataloader: AudioDataLoader,
        test_dataloader: AudioDataLoader,
        loss: nn.Module,
        optimizer: torch.optim.Optimizer,
        intervals: int,
        save_path: Path,
        early_stop_counter: int,
        early_stopping_range: int,
        epochs: int,
    ):
        """
        Takes as input `semantic_encoder` and `semantic_transformer`.
        They determine the `semantic_modelling`.

        `semantic_encoder` must be trained ahead of time, this trainer only
        trains `semantic_transformer`.

        Args
        ----
            `semantic_encoder` (W2VHuBert)

            `semantic_transformer` (TransformerDecoderOnly)

            `train_dataloader` (AudioDataLoader)

            `val_dataloader` (AudioDataLoader)

            `test_dataloader` (AudioDataLoader)

            `loss` (nn.Module)

            `optimizer` (torch.optim.Optimizer)

            `intervals` (int)

            `save_path` (Path)

            `early_stop_counter` (int)

            `early_stopping_range` (int)

            `epochs` (int)
        """
        super().__init__(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            acoustic_encoder_decoder=acoustic_encoder_decoder,
            coarse_acoustic_transformer=coarse_acoustic_transformer,
            fine_acoustic_transformer=fine_acoustic_transformer,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            loss=loss,
            optimizer=optimizer,
            intervals=intervals,
            save_path=save_path,
            early_stop_counter=early_stop_counter,
            early_stopping_range=early_stopping_range,
            epochs=epochs,
        )

    def loss_generator(self, batch):
        semantic_encode = self.semantic_encoder(batch)
        semantic_token = self.semantic_transformer.generate(semantic_encode, 3)

        coarse_acoustic_tokens, fine_acoustic_tokens, _ = (
            self.acoustic_encoder_decoder.encode(batch)
        )
        coarse_conditioning = torch.cat((semantic_token, coarse_acoustic_tokens), dim=1)
        coarse_tokens = self.coarse_acoustic_transformer.generate(
            coarse_conditioning, 3
        )

        output, target = self.fine_acoustic_transformer(
            torch.cat((coarse_tokens, fine_acoustic_tokens), dim=1)
        )
        loss = self.loss(output, target)
        return loss

    def train(self):
        return self._train(self.fine_acoustic_transformer)

    def test(self):
        return self._test(self.fine_acoustic_transformer)

We can now define the full audiolm model

In [None]:
class AudioLM:
    def __init__(
        self,
        semantic_encoder: W2VHuBert,
        semantic_transformer: SemanticTransformer,
        acoustic_encoder_decoder: Encodec,
        coarse_acoustic_transformer: CoarseAcousticTransformer,
        fine_acoustic_transformer: FineAcousticTransformer,
        # https://stackoverflow.com/a/53797072
        *,
        audio_len=1,
        # We set Q' = 4 such that we predict the flattened tokens corresponding
        # to the coarse 4 layers in the second stage.
        n_coarse_quantizers=4,
        # Not specified, but num quantizers must be a power of 2
        # so this is the most reasonable combination.
        n_fine_quantizers=4,
    ) -> None:
        super().__init__()
        self.semantic_encoder = semantic_encoder
        for param in self.semantic_encoder.model.parameters():
            param.requires_grad = False
        self.acoustic_encoder_decoder = acoustic_encoder_decoder
        for param in self.acoustic_encoder_decoder.model.parameters():
            param.requires_grad = False
        self.semantic_transformer = semantic_transformer
        self.coarse_acoustic_transformer = coarse_acoustic_transformer
        self.fine_acoustic_transformer = fine_acoustic_transformer
        self.audio_len = audio_len
        self.n_coarse_quantizers = n_coarse_quantizers
        self.n_fine_quantizers = n_fine_quantizers

    def generate(self, x: torch.Tensor, audio_len: int = 3):
        
        semantic_encode = self.semantic_encoder(x)
        semantic_token = self.semantic_transformer.generate(
            semantic_encode, audio_len * 50
        )

        coarse_acoustic_tokens, fine_acoustic_tokens, audio_scales = (
            self.acoustic_encoder_decoder.encode(x)
        )

        coarse_conditioning = torch.cat((semantic_token, coarse_acoustic_tokens), dim=1)
        coarse_tokens = self.coarse_acoustic_transformer.generate(
            coarse_conditioning, audio_len * 75
        )
        
        output = self.acoustic_encoder_decoder.decode(
            coarse_tokens.unsqueeze(0).unsqueeze(0), [None]
        )
        return output["audio_values"]
    @staticmethod
    def from_pretrained(
        models_path: os.PathLike,
        semantic_encoder : W2VHuBert,
        acoustic_encoder_decoder : Encodec
        ):
        
        semantic_transformer = SemanticTransformer()
        state_dict = torch.load(
            models_path / "models" / f"{str(type(semantic_transformer).__name__)}.pth"
        )
        semantic_transformer.load_state_dict(state_dict)
        
        coarse_acoustic_transformer = CoarseAcousticTransformer()
        state_dict = torch.load(
            models_path
            / "models"
            / f"{str(type(coarse_acoustic_transformer).__name__)}.pth"
        )
        coarse_acoustic_transformer.load_state_dict(state_dict)
        return AudioLM(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            acoustic_encoder_decoder=acoustic_encoder_decoder,
            coarse_acoustic_transformer=coarse_acoustic_transformer,
            fine_acoustic_transformer=None,
        )

We now create a function for the instantation of the model assuming that the transformers are trained ahead of time

In [None]:
audiolm = AudioLM.from_pretrained(MODEL_PATH, semantic_encoder,acoustic_encoder_decoder)
