# Clone the nari-labs/Dia library locally

In [None]:
!git clone https://github.com/nari-labs/dia.git
!pip install -q ./dia
print("Dia library and dependencies installed.")

# Install required libraries

In [None]:
!pip install soundfile ffmpeg-python grpcio grpcio-tools gradio

# Import libraries

In [5]:
import gradio as gr
import soundfile as sf
from dia.model import Dia
import torch
import numpy as np
import subprocess
import os
import random
import re
import tempfile
import ffmpeg
from io import BytesIO

# Gradio App

In [7]:
DEFAULT_WPM = 150

def check_device():
    return "cuda" if torch.cuda.is_available() else "cpu"

def load_model(model_name):
    if not hasattr(load_model, "model"):
        load_model.model = Dia.from_pretrained(model_name)
    return load_model.model

def generate_audio(model, text, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    audio = model.generate(text)
    return audio

def save_audio(audio, file_name, sr=44100):
    sf.write(file_name, audio, sr)

def adjust_audio_for_unified_speaker(input_audio, speed_factor=1.0):
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as input_temp_file:
        input_temp_file.write(input_audio.read())
        input_temp_file_path = input_temp_file.name

    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as output_temp_file:
        output_temp_file_path = output_temp_file.name

    try:
        ffmpeg.input(input_temp_file_path).output(
            output_temp_file_path,
            filter_complex=(
                f"atempo={speed_factor},"
                f"equalizer=f=1000:t=q:w=1:g=-5,"
                f"equalizer=f=2000:t=q:w=1:g=-5,"
                f"equalizer=f=5000:t=q:w=1:g=-5,"
                f"dynaudnorm,"
                f"loudnorm=I=-16:TP=-1.5:LRA=11"
            )
        ).run(overwrite_output=True)

        with open(output_temp_file_path, 'rb') as f:
            adjusted_audio = BytesIO(f.read())
        adjusted_audio.seek(0)
        return adjusted_audio

    except ffmpeg.Error as e:
        raise gr.Error(f"An error occurred while processing the audio: {e}")

def split_sentences(text):
    sentences = re.split(r'(?<=[.!?])\s+', text.strip())
    return [s.strip() for s in sentences if s.strip()]

def batch_sentences(sentences, batch_size):
    return [sentences[i:i+batch_size] for i in range(0, len(sentences), batch_size)]

def generate_silence(duration, sr=44100):
    return np.zeros(int(duration * sr), dtype=np.float32)

def calculate_wpm_from_audio_duration(text, duration_seconds):
    word_count = len(text.split())
    return (word_count / duration_seconds) * 60 if duration_seconds > 0 else DEFAULT_WPM

def generate_narration(text_input, progress=gr.Progress()):
    if os.path.exists("audio_parts"):
        for file in os.listdir("audio_parts"):
            os.remove(os.path.join("audio_parts", file))
        os.rmdir("audio_parts")

    os.makedirs("audio_parts", exist_ok=True)

    device = check_device()
    model = load_model("nari-labs/Dia-1.6B")
    seed_value = 42

    sentences = split_sentences(text_input)
    sentence_batches = batch_sentences(sentences, 4)
    total_batches = len(sentence_batches)

    final_audio = np.array([], dtype=np.float32)

    for idx, batch in enumerate(sentence_batches):
        progress((idx+1, total_batches), desc="Generating narration")

        tagged_sentences = " ".join(f"[S1] {s}" for s in batch)
        audio = generate_audio(model, tagged_sentences, seed=seed_value)

        part_filename = f"audio_parts/batch_{idx}.wav"
        save_audio(audio, part_filename)

        y, sr = sf.read(part_filename)
        final_audio = np.concatenate((final_audio, y))

        pause = generate_silence(1, sr)
        final_audio = np.concatenate((final_audio, pause))

    combined_filename = "output_combined.wav"
    save_audio(final_audio, combined_filename)

    for file in os.listdir("audio_parts"):
        os.remove(os.path.join("audio_parts", file))
    os.rmdir("audio_parts")

    audio_info = sf.info(combined_filename)
    duration_seconds = audio_info.frames / audio_info.samplerate
    current_wpm = calculate_wpm_from_audio_duration(text_input, duration_seconds)

    return (
        combined_filename,
        text_input,
        gr.update(value=current_wpm, visible=True),
        gr.update(visible=True),
        f"Current WPM: {current_wpm:.2f}"
    )

def adjust_speed(text_input, new_wpm):
    audio_file_path = "output_combined.wav"
    if not os.path.exists(audio_file_path):
        raise gr.Error("No existing narration found")

    audio_info = sf.info(audio_file_path)
    duration_seconds = audio_info.frames / audio_info.samplerate
    current_wpm = calculate_wpm_from_audio_duration(text_input, duration_seconds)
    speed_factor = new_wpm / current_wpm

    adjusted_audio_path = adjust_audio_for_unified_speaker(audio_file_path, speed_factor=speed_factor)
    
    adjusted_audio_info = sf.info(adjusted_audio_path)
    adjusted_duration = adjusted_audio_info.frames / adjusted_audio_info.samplerate
    adjusted_wpm = calculate_wpm_from_audio_duration(text_input, adjusted_duration)

    return (
        adjusted_audio_path,
        f"Current WPM: {adjusted_wpm:.2f}",
        f"Speed Factor: {speed_factor:.2f}"
    )

with gr.Blocks(title="🎙️ EchoTales: The Voice of Stories") as app:
    gr.Markdown("# 🎙️ EchoTales: The Voice of Stories")

    with gr.Row():
        with gr.Column():
            text_input = gr.TextArea(label="Enter your story text here:", lines=10)
            generate_btn = gr.Button("Generate Narration")

            audio_output = gr.Audio(label="Generated Narration", interactive=False)

            generated_text_state = gr.State()

            with gr.Group(visible=False) as speed_adjust_group:
                wpm_slider = gr.Slider(minimum=50, maximum=300, value=DEFAULT_WPM, label="Adjust WPM")
                wpm_display = gr.Textbox(label="Current WPM")
                speed_factor_display = gr.Textbox(label="Speed Adjustment Factor")
                adjust_btn = gr.Button("Adjust Speed")

            device_info = gr.Textbox(label="Device", value=f"Running on: {check_device()}", interactive=False)

    generate_btn.click(
        fn=generate_narration,
        inputs=text_input,
        outputs=[audio_output, generated_text_state, wpm_slider, adjust_btn, wpm_display],
    ).then(
        lambda: gr.update(visible=False),
        inputs=None,
        outputs=speed_adjust_group
    )

    adjust_btn.click(
        fn=adjust_speed,
        inputs=[generated_text_state, wpm_slider],
        outputs=[audio_output, wpm_display, speed_factor_display]
    )

if __name__ == "__main__":
    app.launch(share=True)


* Running on local URL:  http://127.0.0.1:7864
* Running on public URL: https://215ba56c74c0a265ec.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
