In [3]:
# Check if running locally or on colab
try:
    import google.colab # type: ignore  # noqa: F401
    # Check if running locally or in colab
    %pip install https://github.com/Haislich/AudioLM/raw/main/dist/audiolm-0.1.0.tar.gz
    # Thanks stack overflow https://stackoverflow.com/questions/77451004/attributeerror-module-numpy-linalg-umath-linalg-has-no-attribute-ilp64#:~:text=pip%20install%20numpy%3D%3D1.23.5
    %pip install numpy==1.23.5
except:  # noqa: E722
    pass


## AudioLM: A Language Modeling Approach to Audio Generation

**AudioLM** represents a self-supervised learning methodology aimed at generating high-quality audio with sustained long-term consistency. This technique maps input audio to a sequence of discrete tokens, treating the audio generation process akin to a language modeling task within this representational framework. The proposal introduces a hybrid tokenization scheme, employing the discretized activations of a pre-trained masked language model on audio to capture long-term structures, and discrete codes from a neural audio codec to ensure high-fidelity synthesis.

#### Tokenization Approach:
The core innovation of this method lies in the **Hybrid Tokenization Scheme**, crucial for subsequent conditioning of transformers.

**Hybrid Tokenization Scheme**: AudioLM combines semantic and acoustic tokens hierarchically to strike a balance between long-term consistency and high-quality audio synthesis. The semantic tokens are extracted from a pre-trained w2v-BERT model, while the acoustic tokens are derived from a finely-tuned SoundStream on a speech dataset. This approach emphasizes their complementary strengths in phonetic discriminability and reconstruction quality.

<div style="text-align: center;">
    <img src="images/semantic_acoustic_tokens.png" alt="Hybrid Tokenization Scheme" width="400">
</div>

#### Hybrid Tokenization:
As depicted in the image, the hybrid tokenization scheme is divided into two parallel components:
1. **Semantic Tokens: W2V-Bert-based tokenizer**
   This component is responsible for extracting features from audio waveforms as a 1024-dimensional embedding. A K-means quantizer with 1024 clusters discretizes these embeddings. Each feature vector in the space is associated with a reference cluster based on its proximity, thus identifying the i-th **semantic token**.

2. **Acoustic Tokens: Soundstream tokenizer**
    .....

#### Autoregressive Prediction:
The prediction phase is subsequently divided into three distinct stages:
1. **Semantic Stage:** This initial stage utilizes semantic tokenization to train a decoder-only transformer. By receiving a sequence of semantic tokens, it performs an autoregressive prediction task during inference. Specifically, the first stage models $$p(z_t|z_{t-1})$$, focusing on the autoregressive prediction of semantic tokens to capture long-term temporal structure.

<div style="text-align: center;">
    <img src="images/semantic_stage.png" alt="Semantic Stage" width="400">
</div>

2. **Coarse Acoustic Stage:** Similar to the first, this stage performs an autoregressive prediction task; however, unlike the first, the second transformer receives a sequence of acoustic tokens input from the first four quantizers of Soundstream, conditioned by the previously generated semantic tokens. Specifically, the acoustic tokens possess a hierarchical structure where coarse quantizers capture fundamental acoustic properties such as speaker identity and recording conditions. The second stage models $$p(y_t | z, y_{<t}, y_t)$$ for $q \leq Q'$, representing the token sequence as $(z_1, z_2, \ldots, z_{T_s}, y_1^{1}, y_1^{2}, \ldots, y_{Q'}^{1}, y_2^{1}, \ldots, y_{T_A}^{Q'})$, with $y_1$ being the first token predicted during training.

<div style="text-align: center;">
    <img src="images/acoustic_stage.png" alt="Coarse Acoustic Stage" width="400">
</div>

3. **Fine Acoustic Stage and Decoding:** Finally, this stage significantly improves audio quality by eliminating the lossy compression artifacts produced during the second stage. Here, the prompt and the sampled acoustic tokens are fed into the **SoundStream** decoder to reconstruct a waveform $\hat x$.

<div style="text-align: center;">
    <img src="images/decoding.png" alt="Fine Acoustic Stage and Decoding" width="400">
</div>



## Imports and set up

In [None]:
%load_ext tensorboard
%load_ext autoreload


In [None]:
import os
from pathlib import Path

import torch
from torch import nn

from audiolm.data_preparation import AudioDataLoader
from audiolm.w2v_hubert import W2VHuBert
from audiolm.absolute_transformer import (
    SemanticTransformer,
    CoarseAcousticTransformer,
    FineAcousticTransformer,
)
from audiolm.encodec import Encodec
from audiolm.utils import (
    get_latest_checkpoint_path,
    get_model_path,
    load_checkpoint,
    load_model,
)

In [None]:
DATA_PATH = Path(os.getcwd()) / Path("..") / Path("data") / Path("datasets")
SAVE_LOAD_PATH = Path(os.getcwd()) / Path("..") / Path("data")
# DATA_PATH = Path("/Volumes/SSD-EXT/NN_project/dataset")
# SAVE_LOAD_PATH = Path("/Volumes/SSD-EXT/NN_project/models")
INTERVALS = 10
EARLY_STOP_COUNTER = 0
EARLY_STOPPING_RANGE = 5
EPOCHS = 20

## Semantic and Acoustic encoders

In [None]:
semantic_encoder = W2VHuBert()
acoustic_encoder_decoder = Encodec()

## Dataloader

In [None]:
train_dataloader = AudioDataLoader(DATA_PATH / "train", batch_size=6,max_elems=35)
val_dataloader = AudioDataLoader(DATA_PATH / "val", batch_size=6,max_elems=15)
test_dataloader = AudioDataLoader(DATA_PATH / "test", batch_size=6,max_elems=5)

We define what an abstract trainer should look like. This will give the skeleton for the specific trainers.
This is needed because this framework uses a hierachical model approach, however at each level of the hierarchy only one transfomere is trained and much of the inherent logic is repeated.
We can use this insight to create spoecialized classes for the training of a given transformer.
The loss generator is the main component that will be changed the most during the specialization as it reflects how the input/output are created and is the only thing that changes between hierachies

## Trainer class

In [None]:
from math import ceil
from abc import ABC, abstractmethod
from typing import Optional

from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm


from audiolm.encodec import Encodec
from audiolm.constants import DEVICE
from audiolm.data_preparation import AudioDataLoader
from audiolm.w2v_hubert import W2VHuBert
from audiolm.utils import save_checkpoint, save_model

class Trainer(ABC):
    """
    Trainer class for training a Transformer model.
    """

    @abstractmethod
    # pylint: disable =too-many-arguments
    def __init__(
        self,
        semantic_encoder: Optional[W2VHuBert] = None,
        semantic_transformer: Optional[SemanticTransformer] = None,
        acoustic_encoder_decoder: Optional[Encodec] = None,
        coarse_acoustic_transformer: Optional[CoarseAcousticTransformer] = None,
        fine_acoustic_transformer: Optional[FineAcousticTransformer] = None,
        train_dataloader: Optional[AudioDataLoader] = None,
        val_dataloader: Optional[AudioDataLoader] = None,
        test_dataloader: Optional[AudioDataLoader] = None,
        loss: Optional[nn.Module] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        intervals: Optional[int] = None,
        save_path: Optional[os.PathLike] = None,
        early_stop_counter: Optional[int] = None,
        early_stopping_range: Optional[int] = None,
        epochs: Optional[int] = None,
    ):
        self.semantic_encoder = semantic_encoder
        self.semantic_transformer = semantic_transformer
        self.acoustic_encoder_decoder = acoustic_encoder_decoder
        self.coarse_acoustic_transformer = coarse_acoustic_transformer
        self.fine_acoustic_transformer = fine_acoustic_transformer
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.optimizer = optimizer
        self.intervals = intervals
        self.epochs = epochs
        self.save_path = save_path
        self.best_val_loss = float("inf")
        self.early_stopping_range = early_stopping_range
        self.early_stop_counter = early_stop_counter
        self.loss = loss
        if save_path is not None and not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    # region: Abstract methods, this methods must be redefined accordingly.
    @abstractmethod
    def loss_generator(self, batch):
        """Generate loss"""

    @abstractmethod
    def train(self):
        """
        Train the Transformer model.
        """

    @abstractmethod
    def test(self):
        """Test the model on the test dataset."""

    # endregion

    # region: Private methods.
    def _train_step(self, model: nn.Module) -> float:
        model.train()
        train_loss = 0
        
        for batch in tqdm(self.train_dataloader, total = ceil(len(self.train_dataloader) / self.train_dataloader.batch_size) ):
            batch = batch.to(DEVICE)
            loss = self.loss_generator(batch)
            train_loss += loss.item()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        train_loss /= len(self.train_dataloader)
        return train_loss

    def _validation_step(self, model: nn.Module) -> float:
        model.eval()
        validation_loss = 0
        with torch.inference_mode():
            for batch in self.val_dataloader:
                batch = batch.to(DEVICE)
                loss = self.loss_generator(batch)
                validation_loss += loss.item()

        validation_loss /= len(self.val_dataloader)

        return validation_loss

    def _train(self, model: nn.Module):
        writer = SummaryWriter(Path(self.save_path) / "runs" / str(type(model).__name__))
        for epoch in tqdm(range(self.epochs), total = self.epochs, desc="Training"):
            train_loss = self._train_step(model)
            validation_loss = self._validation_step(model)
            print("SAVING CHECKPOINT...")
            save_checkpoint(
                model, epoch, self.optimizer, self.early_stop_counter, self.save_path
            )
            print("SAVING RUN FOR TENSORBOARD...")
            writer.add_scalars(
                main_tag=f"Loss_{str(type(model).__name__)}",
                tag_scalar_dict={
                    "train_loss": train_loss,
                    "validation_loss": validation_loss,
                },
                global_step=epoch,
            )

            if validation_loss < self.best_val_loss:
                self.best_val_loss = validation_loss
                self.early_stop_counter = 0
            else:
                self.early_stop_counter += 1

            if self.early_stop_counter >= self.early_stopping_range:
                print(f"Early stopping training at epoch: {epoch+1}")
                break
        writer.flush()
        writer.close()
        save_model(model, self.save_path)

    def _test(self, model):
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for batch in tqdm(self.test_dataloader, desc="Testing"):
                batch = batch.to(DEVICE)
                loss = self.loss_generator(batch)
                test_loss += loss.item()

        test_loss /= len(self.test_dataloader)
        print(f"Test Loss: {test_loss: .4f}")

        return test_loss
    # endregion

## Semantic Trainer class

In [None]:
class SemanticTrainer(Trainer):
    """Trainer class derived from `Trainer`."""

    def __init__(
        self,
        semantic_encoder: W2VHuBert,
        semantic_transformer: SemanticTransformer,
        train_dataloader: AudioDataLoader,
        val_dataloader: AudioDataLoader,
        test_dataloader: AudioDataLoader,
        loss: nn.Module,
        optimizer: torch.optim.Optimizer,
        intervals: int,
        save_path: Path,
        early_stop_counter: int,
        early_stopping_range: int,
        epochs: int,
    ):
        """
        Takes as input `semantic_encoder` and `semantic_transformer`.
        They determine the `semantic_modelling`.

        `semantic_encoder` must be trained ahead of time, this trainer only
        trains `semantic_transformer`.

        Args
        ----
            `semantic_encoder` (W2VHuBert)

            `semantic_transformer` (TransformerDecoderOnly)

            `train_dataloader` (AudioDataLoader)

            `val_dataloader` (AudioDataLoader)

            `test_dataloader` (AudioDataLoader)

            `loss` (nn.Module)

            `optimizer` (torch.optim.Optimizer)

            `intervals` (int)

            `save_path` (Path)

            `early_stop_counter` (int)

            `early_stopping_range` (int)

            `epochs` (int)
        """
        super().__init__(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            loss=loss,
            optimizer=optimizer,
            intervals=intervals,
            save_path=save_path,
            early_stop_counter=early_stop_counter,
            early_stopping_range=early_stopping_range,
            epochs=epochs,
        )
        semantic_encoder = semantic_encoder.to(DEVICE)
        semantic_transformer = semantic_transformer.to(DEVICE)

    def loss_generator(self, batch):
        semantic_encode = self.semantic_encoder(batch)

        output, target = self.semantic_transformer.fit(semantic_encode)

        loss = self.loss(output, target)
        return loss

    def train(self):
        return self._train(self.semantic_transformer)

    def test(self):
        return self._test(self.semantic_transformer)


In [None]:
TRAIN_SEMANTIC = False
"""Start training for the semantic model"""
RESUME_SEMANTIC_TRAINING = True
"""Start training from the an epoch"""

semantic_transformer = SemanticTransformer(num_heads=4, layers=2)
semantic_loss = nn.CrossEntropyLoss()
semantic_optimizer = torch.optim.Adam(
    semantic_transformer.parameters(), lr=0.001
)

semantic_transformer_root = (
    Path(SAVE_LOAD_PATH)
    / Path("models")
    / str(type(semantic_transformer).__name__)
)
semantic_transformer_path = get_model_path(semantic_transformer_root)
checkpoint_path = get_latest_checkpoint_path(semantic_transformer_root)
if not TRAIN_SEMANTIC and semantic_transformer_path:
    load_model(semantic_transformer, semantic_transformer_path)
else:
    if RESUME_SEMANTIC_TRAINING and checkpoint_path:
        print("Starting from the last epoch")
        semantic_transformer, _, semantic_optimizer, _ = load_checkpoint(
            semantic_transformer, semantic_transformer_root
        )
    # elif not RESUME_SEMANTIC_TRAINING:
    #     # Adapt to choose a given epoch
    #     semantic_transformer, _, semantic_optimizer, _ = ...
    semantic_trainer = SemanticTrainer(
        semantic_encoder=semantic_encoder,
        semantic_transformer=semantic_transformer,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        test_dataloader=test_dataloader,
        loss=semantic_loss,
        optimizer=semantic_optimizer,
        intervals=INTERVALS,
        save_path=SAVE_LOAD_PATH,
        early_stop_counter=EARLY_STOP_COUNTER,
        early_stopping_range=EARLY_STOPPING_RANGE,
        epochs=EPOCHS,
    )
    semantic_trainer.train()



In [None]:
%tensorboard --logdir=../data/runs/Semantic_Transformer/

The second stage of this hierachy is the coarse acoustic modelling ... informazioni sul coarse.
come funziona la generazione della loss i.e come avviene il condizionamento ecc.

## Coarse Acoustic Trainer

In [None]:
class CoarseAcousticTrainer(Trainer):
    """Trainer class derived from `Trainer`."""

    def __init__(
        self,
        semantic_encoder: W2VHuBert,
        semantic_transformer: SemanticTransformer,
        acoustic_encoder_decoder: Encodec,
        coarse_acoustic_transformer: CoarseAcousticTransformer,
        train_dataloader: AudioDataLoader,
        val_dataloader: AudioDataLoader,
        test_dataloader: AudioDataLoader,
        loss: nn.Module,
        optimizer: torch.optim.Optimizer,
        intervals: int,
        save_path: Path,
        early_stop_counter: int,
        early_stopping_range: int,
        generate_audio_len: int,
        epochs: int,
    ):
        """
        Takes as input `semantic_encoder` and `semantic_transformer`.
        They determine the `semantic_modelling`.

        `semantic_encoder` must be trained ahead of time, this trainer only
        trains `semantic_transformer`.

        Args
        ----
            `semantic_encoder` (W2VHuBert)

            `semantic_transformer` (TransformerDecoderOnly)

            `train_dataloader` (AudioDataLoader)

            `val_dataloader` (AudioDataLoader)

            `test_dataloader` (AudioDataLoader)

            `loss` (nn.Module)

            `optimizer` (torch.optim.Optimizer)

            `intervals` (int)

            `save_path` (Path)

            `early_stop_counter` (int)

            `early_stopping_range` (int)

            `epochs` (int)
        """

        super().__init__(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            acoustic_encoder_decoder=acoustic_encoder_decoder,
            coarse_acoustic_transformer=coarse_acoustic_transformer,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            loss=loss,
            optimizer=optimizer,
            intervals=intervals,
            save_path=save_path,
            early_stop_counter=early_stop_counter,
            early_stopping_range=early_stopping_range,
            epochs=epochs,
        )
        semantic_encoder = semantic_encoder.to(DEVICE)
        semantic_transformer = semantic_transformer.to(DEVICE)
        acoustic_encoder_decoder = acoustic_encoder_decoder.to(DEVICE)
        coarse_acoustic_transformer = coarse_acoustic_transformer.to(DEVICE)
        self.generate_audio_len = generate_audio_len

    def loss_generator(self, batch):

        semantic_encode = self.semantic_encoder(batch)
        semantic_token = self.semantic_transformer.generate(
            semantic_encode, self.generate_audio_len * 50
        )

        coarse_acoustic_tokens, _, _ = self.acoustic_encoder_decoder.encode(batch)
        conditioning = torch.cat((semantic_token, coarse_acoustic_tokens), dim=1)

        output, target = self.coarse_acoustic_transformer.fit(conditioning)
        loss = self.loss(output, target)
        return loss

    def train(self):
        return self._train(self.coarse_acoustic_transformer)

    def test(self):
        return self._test(self.coarse_acoustic_transformer)


In [None]:
TRAIN_COARSE = False
"""Start training for the semantic model"""
RESUME_COARSE_TRAINING = True
"""Start training from the an epoch"""
GENERATE_AUDIO_LEN = 3
"""Len in seconds of the audio generated"""

semantic_transformer = SemanticTransformer(num_heads=4, layers=2)
semantic_transformer_root = (
    Path(SAVE_LOAD_PATH)
    / Path("models")
    / str(type(semantic_transformer).__name__)
)
semantic_transformer_path = get_model_path(semantic_transformer_root)

load_model(semantic_transformer, semantic_transformer_path)

coarse_acoustic_transformer = CoarseAcousticTransformer()
coarse_loss = nn.CrossEntropyLoss()
coarse_optimizer = torch.optim.Adam(
    coarse_acoustic_transformer.parameters(), lr=0.001
)

coarse_transformer_root = (
    Path(SAVE_LOAD_PATH)
    / Path("models")
    / str(type(coarse_acoustic_transformer).__name__)
)
coarse_transformer_path = get_model_path(coarse_transformer_root)
checkpoint_path = get_latest_checkpoint_path(coarse_transformer_root)

if not TRAIN_COARSE and coarse_transformer_path:
    load_model(coarse_acoustic_transformer, coarse_transformer_path)
else:
    if RESUME_COARSE_TRAINING and checkpoint_path:
        print("Starting from the last epoch")
        semantic_transformer, _, coarse_optimizer, _ = load_checkpoint(
            semantic_transformer, coarse_transformer_root
        )
    # elif not RESUME_SEMANTIC_TRAINING:
    #     # Adapt to choose a given epoch
    #     semantic_transformer, _, semantic_optimizer, _ = ...
    coarse_acoustic_trainer = CoarseAcousticTrainer(
        semantic_encoder=semantic_encoder,
        semantic_transformer=semantic_transformer,
        acoustic_encoder_decoder=acoustic_encoder_decoder,
        coarse_acoustic_transformer=coarse_acoustic_transformer,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        test_dataloader=test_dataloader,
        loss=coarse_loss,
        optimizer=coarse_optimizer,
        intervals=INTERVALS,
        save_path=SAVE_LOAD_PATH,
        early_stop_counter=EARLY_STOP_COUNTER,
        early_stopping_range=EARLY_STOPPING_RANGE,
        generate_audio_len=GENERATE_AUDIO_LEN,
        epochs=EPOCHS,
    )
    coarse_acoustic_trainer.train()

For the sake of completness we implement the Fine acoustic trainer, even though it won't be used.

In [None]:
class FineAcousticTrainer(Trainer):
    """Trainer class derived from `Trainer`."""

    def __init__(
        self,
        semantic_encoder: W2VHuBert,
        semantic_transformer: SemanticTransformer,
        acoustic_encoder_decoder: Encodec,
        coarse_acoustic_transformer: CoarseAcousticTransformer,
        fine_acoustic_transformer: FineAcousticTransformer,
        train_dataloader: AudioDataLoader,
        val_dataloader: AudioDataLoader,
        test_dataloader: AudioDataLoader,
        loss: nn.Module,
        optimizer: torch.optim.Optimizer,
        intervals: int,
        save_path: Path,
        early_stop_counter: int,
        early_stopping_range: int,
        generate_audio_len: int,
        epochs: int,
    ):
        """
        Takes as input `semantic_encoder` and `semantic_transformer`.
        They determine the `semantic_modelling`.

        `semantic_encoder` must be trained ahead of time, this trainer only
        trains `semantic_transformer`.

        Args
        ----
            `semantic_encoder` (W2VHuBert)

            `semantic_transformer` (TransformerDecoderOnly)

            `train_dataloader` (AudioDataLoader)

            `val_dataloader` (AudioDataLoader)

            `test_dataloader` (AudioDataLoader)

            `loss` (nn.Module)

            `optimizer` (torch.optim.Optimizer)

            `intervals` (int)

            `save_path` (Path)

            `early_stop_counter` (int)

            `early_stopping_range` (int)

            `epochs` (int)
        """
        super().__init__(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            acoustic_encoder_decoder=acoustic_encoder_decoder,
            coarse_acoustic_transformer=coarse_acoustic_transformer,
            fine_acoustic_transformer=fine_acoustic_transformer,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            loss=loss,
            optimizer=optimizer,
            intervals=intervals,
            save_path=save_path,
            early_stop_counter=early_stop_counter,
            early_stopping_range=early_stopping_range,
            epochs=epochs,
        )
        self.generate_audio_len = generate_audio_len

    def loss_generator(self, batch):
        semantic_encode = self.semantic_encoder(batch)
        semantic_token = self.semantic_transformer.generate(semantic_encode, self.generate_audio_len * 50)

        coarse_acoustic_tokens, fine_acoustic_tokens, _ = (
            self.acoustic_encoder_decoder.encode(batch)
        )
        coarse_conditioning = torch.cat((semantic_token, coarse_acoustic_tokens), dim=1)
        coarse_tokens = self.coarse_acoustic_transformer.generate(
            coarse_conditioning, self.generate_audio_len * 75
        )

        output, target = self.fine_acoustic_transformer(
            torch.cat((coarse_tokens, fine_acoustic_tokens), dim=1)
        )
        loss = self.loss(output, target)
        return loss

    def train(self):
        return self._train(self.fine_acoustic_transformer)

    def test(self):
        return self._test(self.fine_acoustic_transformer)

We can now define the full audiolm model

In [None]:
class AudioLM:
    def __init__(
        self,
        semantic_encoder: W2VHuBert,
        semantic_transformer: SemanticTransformer,
        acoustic_encoder_decoder: Encodec,
        coarse_acoustic_transformer: CoarseAcousticTransformer,
        fine_acoustic_transformer: FineAcousticTransformer,
        # https://stackoverflow.com/a/53797072
        *,
        audio_len=1,
        # We set Q' = 4 such that we predict the flattened tokens corresponding
        # to the coarse 4 layers in the second stage.
        n_coarse_quantizers=4,
        # Not specified, but num quantizers must be a power of 2
        # so this is the most reasonable combination.
        n_fine_quantizers=4,
    ) -> None:
        super().__init__()
        self.semantic_encoder = semantic_encoder
        for param in self.semantic_encoder.model.parameters():
            param.requires_grad = False
        self.acoustic_encoder_decoder = acoustic_encoder_decoder
        for param in self.acoustic_encoder_decoder.model.parameters():
            param.requires_grad = False
        self.semantic_transformer = semantic_transformer
        self.coarse_acoustic_transformer = coarse_acoustic_transformer
        self.fine_acoustic_transformer = fine_acoustic_transformer
        self.audio_len = audio_len
        self.n_coarse_quantizers = n_coarse_quantizers
        self.n_fine_quantizers = n_fine_quantizers

    def generate(self, x: torch.Tensor, audio_len: int = 3):
        
        semantic_encode = self.semantic_encoder(x)
        semantic_token = self.semantic_transformer.generate(
            semantic_encode, audio_len * 50
        )

        coarse_acoustic_tokens, fine_acoustic_tokens, audio_scales = (
            self.acoustic_encoder_decoder.encode(x)
        )

        coarse_conditioning = torch.cat((semantic_token, coarse_acoustic_tokens), dim=1)
        coarse_tokens = self.coarse_acoustic_transformer.generate(
            coarse_conditioning, audio_len * 75
        )
        
        output = self.acoustic_encoder_decoder.decode(
            coarse_tokens.unsqueeze(0).unsqueeze(0), [None]
        )
        return output["audio_values"]
    @staticmethod
    def from_pretrained(
        models_path: os.PathLike,
        semantic_encoder : W2VHuBert,
        acoustic_encoder_decoder : Encodec
        ):
        
        semantic_transformer = SemanticTransformer()
        state_dict = torch.load(
            models_path / "models" / f"{str(type(semantic_transformer).__name__)}.pth"
        )
        semantic_transformer.load_state_dict(state_dict)
        
        coarse_acoustic_transformer = CoarseAcousticTransformer
        state_dict = torch.load(
            models_path
            / "models"
            / f"{str(type(coarse_acoustic_transformer).__name__)}.pth"
        )
        coarse_acoustic_transformer.load_state_dict(state_dict)
        return AudioLM(
            semantic_encoder=semantic_encoder,
            semantic_transformer=semantic_transformer,
            acoustic_encoder_decoder=acoustic_encoder_decoder,
            coarse_acoustic_transformer=coarse_acoustic_transformer,
            fine_acoustic_transformer=None,
        )

We now create a function for the instantation of the model assuming that the transformers are trained ahead of time

In [None]:
audiolm = AudioLM.from_pretrained(MODEL_PATH, semantic_encoder, acoustic_encoder_decoder)
