In [None]:
import torch
import torchaudio
import matplotlib.pyplot as plt
import librosa
import librosa.display
import IPython.display as ipd
from f5_tts.model.dataset import *
from f5_tts.model.modules import MelSpec
from tqdm import tqdm
import numpy as np
import io
import soundfile as sf
import json
from vocos import Vocos

In [None]:
def display_sample_with_audio(dataset, index, figsize=(8, 3)):
    """Display audio sample with playable widget, mel spectrogram and text in Jupyter notebook"""
    sample = dataset[index]
    mel_spec = sample['mel_spec']
    text = sample['text']
    
    # Initialize Vocos for mel spectrogram inversion
    vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
    
    # Convert mel spectrogram to audio
    # Add batch dimension and convert to float32
    mel_spec_batch = mel_spec.unsqueeze(0).float()
    
    # Generate audio from mel spectrogram
    with torch.no_grad():
        audio = vocos.decode(mel_spec_batch)
        
    # Remove batch dimension
    audio = audio.squeeze(0)

    # Display header
    print(f"\n{'='*80}")
    print(f"Sample {index + 1}")
    print(f"{'='*80}")
    
    # Display audio player
    print("Audio Player (Reconstructed from Mel Spectrogram):")
    display(ipd.Audio(audio.numpy(), rate=dataset.target_sample_rate))
    
    # Load and display lyrics segments
    json_path = dataset.json_files[index % len(dataset.json_files)]
    with open(json_path, 'r', encoding='utf-8') as f:
        segments = json.load(f)
    
    # Print text information
    print(f"\nText: {text}")
    '''
    print(f"Mel Spectrogram Shape: {mel_spec.shape}")
    print("\nDetailed segments:")
    for i, segment in enumerate(segments):
        print(f"{i+1}. [{segment['start']:.2f}s - {segment['end']:.2f}s]: {segment['text']}")
    print(f"{'='*80}")
    
    # Create figure for spectrogram
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    # Plot mel spectrogram
    img = librosa.display.specshow(
        mel_spec.numpy(),
        y_axis='mel',
        x_axis='time',
        sr=dataset.target_sample_rate,
        hop_length=dataset.hop_length,
        ax=ax
    )
    fig.colorbar(img, ax=ax, format='%+2.0f dB')
    ax.set_title('Mel Spectrogram')
    
    plt.tight_layout()
    plt.show()
    #'''
    
    print("\n" + "-"*80 + "\n")


In [None]:
dataset = RussianSingingDataset(
    target_sample_rate=24_000,
    hop_length=256,
    n_mel_channels=100,
    n_fft=1024,
    win_length=1024,
    mel_spec_type="vocos",
)

In [None]:
for i in range(5):
    display_sample_with_audio(dataset, i)