In [None]:
import torch
import torchaudio
from encodec import EncodecModel
from typing import List, Tuple, Optional

# Define the type for EncodedFrame
EncodedFrame = Tuple[torch.Tensor, Optional[torch.Tensor]]

def load_audio(file_path):
    """Load audio file."""
    audio, sr = torchaudio.load(file_path)
    return audio, sr

def get_model_and_sr(channels, bandwidth):
    """Select the model based on the number of audio channels and bandwidth."""
    if channels == 1:
        if bandwidth not in [1.5, 3, 6, 12, 24]:  # Define acceptable bandwidths for mono
            raise ValueError("Unsupported bandwidth for mono audio.")
        model = EncodecModel.encodec_model_24khz(pretrained=True)
        target_sr = 24000
    elif channels == 2:
        if bandwidth not in [3, 6, 12, 24]:  # Define acceptable bandwidths for stereo
            raise ValueError("Unsupported bandwidth for stereo audio.")
        model = EncodecModel.encodec_model_48khz(pretrained=True)
        target_sr = 48000
    else:
        raise ValueError("Unsupported number of channels: Encodec supports only mono or stereo.")
    return model, target_sr

def resample_audio(audio, original_sr, target_sr):
    """Resample audio to the target sample rate if necessary."""
    if original_sr != target_sr:
        audio = torchaudio.transforms.Resample(original_sr, target_sr)(audio)
    return audio

def process_audio(input_file, bandwidth, embedding_file, output_file=None):
    """
    Process an audio file to encode it to embeddings and optionally decode it back.

    Args:
        input_file (str): Path to the input audio file.
        bandwidth (float): Bandwidth to be used for the model.
        embedding_file (str): Path where embeddings should be saved.
        output_file (str, optional): Path to save the decoded audio. If None, no audio is saved.

    """
    # Load the input audio file
    audio, sr = load_audio(input_file)
    
    # Determine the number of channels in the audio file
    channels = audio.shape[0]
    
    # Get the appropriate model based on the number of channels and specified bandwidth
    model, target_sr = get_model_and_sr(channels, bandwidth)
    
    # Resample the audio to the target sample rate if necessary
    audio = resample_audio(audio, sr, target_sr)
    
    # Configure the model to use the specified bandwidth
    model.set_target_bandwidth(bandwidth)
    
    # Convert the audio to embeddings without modifying the audio data
    with torch.no_grad():
        audio = audio.unsqueeze(0)  # Add batch dimension for processing
        encoded_frames = model.encode(audio)
        
        # Prepare to save embeddings to a file
        embedding_indices = []
        for frame, _ in encoded_frames:
            embedding_indices.append(frame.cpu().numpy())
        
        with open(embedding_file, 'w') as f:
            f.write(f"{channels}\n{bandwidth}\n")
            for indices in embedding_indices:
                for idx in range(indices.shape[1]):
                    f.write(' '.join(map(str, indices[0, idx])) + '\n')
    
    print(f"Embeddings saved to: {embedding_file}")
    
    # If an output file is provided, decode the embeddings back to audio and save it
    if output_file:
        output_audio = model.decode(encoded_frames)
        torchaudio.save(output_file, output_audio[0], target_sr)
        print(f"Output audio saved to: {output_file}")


def read_embeddings_from_txt(file_path: str, num_codebooks: int) -> List[EncodedFrame]:
    """
    Reads embeddings from a .txt file and converts them to a decodable format.

    Args:
        file_path (str): Path to the .txt file containing the embeddings.
        num_codebooks (int): Number of codebooks used in the model.

    Returns:
        List[EncodedFrame]: List of encoded frames suitable for decoding.
    """
    encoded_frames: List[EncodedFrame] = []
    
    with open(file_path, 'r') as file:
        lines = file.readlines()

    # Initialize a list to store the codebook tensors for each frame
    codebook_list = []
    
    # Skip the first two lines (channels and bandwidth)
    for line in lines[2:]:
        if line.strip() == '':
            continue  # Skip empty lines
        
        # Split the line into integers
        codebook_values = list(map(int, line.strip().split()))
        
        # Convert the list of integers into a tensor and append to the list
        codebook_tensor = torch.tensor(codebook_values, dtype=torch.long)
        codebook_list.append(codebook_tensor)
        
        # Once we have collected all codebooks for a frame, construct the frame tensor
        if len(codebook_list) == num_codebooks:
            # Stack the codebooks to create the frame tensor of shape [K, T_f]
            frame_tensor = torch.stack(codebook_list, dim=0).unsqueeze(0)  # Add batch dimension
            encoded_frames.append((frame_tensor, None))  # Append the frame with no scale
            codebook_list = []  # Reset the list for the next frame

    return encoded_frames

def decode_audio_from_embeddings(embedding_file, output_file):
    """
    Decode audio from embeddings saved in a text file.

    Args:
        embedding_file (str): Path to the .txt file containing the embeddings.
        output_file (str): Path to save the decoded audio.
    """
    # Read the channel count and bandwidth from the embedding file
    with open(embedding_file, 'r') as f:
        lines = f.readlines()
        channels = int(lines[0].strip())
        bandwidth = float(lines[1].strip())
    
    # Get the appropriate model based on the number of channels and bandwidth
    model, target_sr = get_model_and_sr(channels, bandwidth)
    
    # Set the target bandwidth for the model
    model.set_target_bandwidth(bandwidth)
    
    # Calculate the number of codebooks based on the model configuration
    num_codebooks = int(bandwidth*4/(3*channels)) #1.5 kbps- 2 codebooks, 3 kbps- 4 codebooks,... for mono. For stereo it's half of mono
    
    # Read the embeddings from the file and convert them into a format suitable for decoding
    encoded_frames = read_embeddings_from_txt(embedding_file, num_codebooks)
    
    # Decode the embeddings into audio
    with torch.no_grad():
        output_audio = model.decode(encoded_frames)
    
    # Save the decoded audio to the specified output file
    torchaudio.save(output_file, output_audio[0], target_sr)
    
    print(f"Decoded audio saved to: {output_file}")


In [3]:
import torch
from encodec import EncodecModel

def get_codebooks(audio_type):
    """
    Load the Encodec model based on the audio type and concatenate the codebooks.

    Parameters:
        audio_type (str): Type of audio model to load. 
                          Choose 'mono'/'encodec_24khz' for 24kHz model, 
                          or 'stereo'/'encodec_48khz' for 48kHz model.

    Returns:
        torch.Tensor: A 3D tensor containing concatenated codebooks with the shape 
                      [codebook_level, num_vectors_per_codebook, vector_dimension].
    """
    # Map the audio type to the corresponding model
    if audio_type in ('mono', 'encodec_24khz'):
        model = EncodecModel.encodec_model_24khz(pretrained=True)
    elif audio_type in ('stereo', 'encodec_48khz'):
        model = EncodecModel.encodec_model_48khz(pretrained=True)
    else:
        raise ValueError("Unsupported audio type. Choose 'mono', 'encodec_24khz', 'stereo', or 'encodec_48khz'.")

    # Access the quantizer
    quantizer = model.quantizer

    # Collect all codebook tensors from each level of the quantizer
    codebook_tensors = [layer._codebook.embed for layer in quantizer.vq.layers]

    # Stack all the tensors along a new dimension to create a 3D tensor
    # Resulting shape will be [codebook_level, num_vectors_per_codebook, vector_dimension]
    concatenated_codebooks = torch.stack(codebook_tensors, dim=0)

    return concatenated_codebooks
