In [None]:

from encoder.utils import convert_audio
import torchaudio
import torch
from decoder.pretrained import WavTokenizer


device=torch.device('cpu')

config_path = "./configs/medium_matadata.yml"
model_path = "../../pretrained_models/wavtokenizer_medium_music_audio_320_24k_v2.ckpt"
audio_outpath = "xxx"

wavtokenizer = WavTokenizer.from_pretrained0802(config_path, model_path)
wavtokenizer = wavtokenizer.to(device)


wav, sr = torchaudio.load("../../../dataset/no8/0/audio0.mp3")
wav = convert_audio(wav, sr, 24000, 1) 
bandwidth_id = torch.tensor([0])
wav=wav.to(device)
features,discrete_code= wavtokenizer.encode_infer(wav, bandwidth_id=bandwidth_id)
'''
torch.Size([1, 512, 40643])
torch.Size([1, 1, 40643])
torch.Size([1, 13005504])'''

In [None]:
print(features.shape)
print(discrete_code.shape)
print(wav.shape)

In [None]:

audio_out = wavtokenizer.decode(discrete_code.float(), bandwidth_id=bandwidth_id) 
torchaudio.save("out.wav", audio_out, sample_rate=24000, encoding='PCM_S', bits_per_sample=16)

In [None]:
print(features)
audio_out = wavtokenizer.decode(features, bandwidth_id=bandwidth_id) 
torchaudio.save("out.wav", audio_out, sample_rate=24000, encoding='PCM_S', bits_per_sample=16)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
#%pip install seaborn
import seaborn as sns

def plot_codec(discrete_code, title="WavTokenizer Codec Visualization", figsize=(15, 5)):
    # Remove batch and channel dimensions, convert to numpy
    tokens = discrete_code.squeeze().cpu().numpy()
    
    # Create figure
    plt.figure(figsize=figsize)
    
    # Create heatmap
    # Using a discrete colormap since these are tokens
    plt.imshow(tokens.reshape(1, -1), 
              aspect='auto', 
              cmap='viridis',
              interpolation='nearest')
    
    # Customize plot
    plt.colorbar(label='Token Values')
    plt.title(title)
    plt.xlabel('Time Steps')
    plt.ylabel('Tokens')
    
    # Remove y-ticks since we only have one row
    plt.yticks([])
    
    # Add grid for better visibility
    plt.grid(False)
    
    # Adjust layout
    plt.tight_layout()
    
    return plt

def plot_codec_distribution(discrete_code, n_vocab=1024, figsize=(10, 5)):
    """
    Plot the distribution of codec tokens
    """
    # Remove batch and channel dimensions, convert to numpy
    tokens = discrete_code.squeeze().cpu().numpy()
    
    # Count token frequencies
    unique, counts = np.unique(tokens, return_counts=True)
    
    # Create figure
    plt.figure(figsize=figsize)
    
    # Plot histogram
    plt.bar(unique, counts, alpha=0.7)
    plt.title('Distribution of Codec Tokens')
    plt.xlabel('Token Value')
    plt.ylabel('Frequency')
    
    # Add grid
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    return plt

# Usage example
def visualize_codec(discrete_code):
    """
    Visualize both the codec sequence and token distribution
    """
    # Plot codec sequence
    plot_codec(discrete_code)
    plt.savefig('codec_sequence.png')
    plt.close()
    
    # Plot token distribution
    plot_codec_distribution(discrete_code)
    plt.savefig('codec_distribution.png')
    plt.close()

# Example usage with your data:
# assuming discrete_code has shape [1, 1, 40643]
visualize_codec(discrete_code)

# Optional: Plot a subset for better visibility of patterns
def plot_codec_subset(discrete_code, start_idx=0, duration_seconds=10, tokens_per_second=75):
    """
    Plot a subset of the codec for better visualization of local patterns
    """
    n_tokens = duration_seconds * tokens_per_second
    subset = discrete_code[..., start_idx:start_idx + n_tokens]
    
    plot_codec(subset, title=f"Codec Visualization ({duration_seconds}s segment)")
    plt.savefig(f'codec_subset_{start_idx}_{start_idx + n_tokens}.png')
    plt.close()

# Plot first 10 seconds
plot_codec_subset(discrete_code, start_idx=0, duration_seconds=10)