# Fine-Tuning OuteTTS for French TTS on Google Colab

This notebook provides comprehensive steps for fine-tuning the OuteTTS model for French Text-to-Speech (TTS) using Google Colab. Follow the steps below to set up the environment, prepare the data, train the model, and evaluate the results.

## Step 1: Environment Setup

First, we need to set up the environment and install the necessary dependencies.

In [None]:
# Install necessary packages
!pip install outetts
!pip install torch torchaudio
!pip install transformers
!pip install polars
!pip install loguru
!pip install tqdm
!pip install soundfile
!pip install openai-whisper
!pip install mecab-python3 unidic-lite
!pip install uroman
!pip install pygame


## Step 2: Data Preparation

Prepare your dataset for fine-tuning. Ensure your raw data is formatted as Parquet files with a `"transcript"` field of type string and an `"audio"` field containing audio bytes in `"bytes"`.

In [None]:
# Example script to process and prepare data
import os
import polars as pl
import torch
from tqdm import tqdm
from outetts.wav_tokenizer.audio_codec import AudioCodec
from outetts.version.v2.prompt_processor import PromptProcessor
from outetts.version.v2.alignment import CTCForcedAlignment
import io
from loguru import logger

class DataCreation:
    def __init__(self, model_tokenizer_path: str, audio_files_path: str, save_dir: str, save_len: int = 5000):
        self.device = "cuda"
        self.audio_codec = AudioCodec(device=self.device, load_decoder=False)
        self.prompt_processor = PromptProcessor(model_tokenizer_path)
        self.files = self.get_files(audio_files_path, ".parquet")
        self.ctc = CTCForcedAlignment(self.device)
        self.save_dir = save_dir
        self.save_len = save_len
        self.save_id = 0
        self.data = []

    def get_files(self, folder_path, extension_filter=None):
        if isinstance(extension_filter, str):
            extension_filter = [extension_filter]
        matching_files = []
        for root, _, files in os.walk(folder_path):
            for file in files:
                if extension_filter:
                    if any(file.endswith(ext) for ext in extension_filter):
                        matching_files.append(os.path.join(root, file))
                else:
                    matching_files.append(os.path.join(root, file))
        return matching_files

    def create_speaker(self, audio, transcript: str):
        words = self.ctc.align(audio, transcript)
        full_codes = self.audio_codec.encode(
            self.audio_codec.convert_audio_tensor(
                audio=torch.cat([i["audio"] for i in words], dim=1),
                sr=self.ctc.sample_rate
            ).to(self.audio_codec.device)
        ).tolist()
        data = []
        start = 0
        for i in words:
            end = int(round((i["x1"] / self.ctc.sample_rate) * 75))
            word_tokens = full_codes[0][0][start:end]
            start = end
            if not word_tokens:
                word_tokens = [1]
            data.append({
                "word": i["word"],
                "duration": round(len(word_tokens) / 75, 2),
                "codes": word_tokens
            })
        return {
            "text": transcript,
            "words": data,
        }

    def save(self):
        os.makedirs(self.save_dir, exist_ok=True)
        path = os.path.join(self.save_dir, f"{self.save_id:06d}.parquet")
        logger.info(f"Saving data: {path}")
        pl.DataFrame(self.data).write_parquet(path)
        self.data = []
        self.save_id += 1

    def run(self):
        for i in self.files:
            df = pl.read_parquet(i)
            for data in tqdm(df.iter_rows(named=True), total=len(df)):
                try:
                    transcript = data['transcript']
                    audio = io.BytesIO(data['audio']['bytes'])
                    speaker = self.create_speaker(audio, transcript)
                    prompt = self.prompt_processor.get_training_prompt(speaker)
                    self.data.append({
                        'prompt': prompt,
                    })
                    if len(self.data) == self.save_len:
                        self.save()
                except Exception as e:
                    logger.error(f"{e}\n\nOccasional exceptions are expected due to potential inaccuracies or issues in the audio files. This is normal and may not indicate a critical problem, but a high frequency of errors might suggest an underlying issue that needs attention.")
        if self.data:
            self.save()

# Example usage
data_creator = DataCreation(
    model_tokenizer_path="path/to/model/tokenizer",
    audio_files_path="path/to/audio/files",
    save_dir="path/to/save/processed/data"
)
data_creator.run()


## Step 3: Model Training

Now, we will fine-tune the OuteTTS model using the prepared dataset.

In [None]:
from outetts import InterfaceHF, HFModelConfig

# Define model configuration
config = HFModelConfig(
    model_path="OuteAI/OuteTTS-0.3-500M",
    language="fr",
    tokenizer_path="path/to/tokenizer",
    max_seq_length=4096,
    device="cuda"
)

# Initialize the model interface
tts_interface = InterfaceHF(config)

# Example training loop
for epoch in range(num_epochs):
    for batch in data_loader:
        # Fine-tune the model with the batch data
        tts_interface.train(batch)
        # Save the model checkpoint
        tts_interface.save_checkpoint(f"checkpoint_{epoch}.pt")


## Step 4: Model Evaluation

Evaluate the fine-tuned model to ensure it performs well on French TTS.

In [None]:
# Example evaluation script
from outetts import InterfaceHF, HFModelConfig

# Define model configuration
config = HFModelConfig(
    model_path="path/to/fine-tuned/model",
    language="fr",
    tokenizer_path="path/to/tokenizer",
    max_seq_length=4096,
    device="cuda"
)

# Initialize the model interface
tts_interface = InterfaceHF(config)

# Example evaluation
text = "Bonjour, comment ça va?"
output = tts_interface.generate(text)
output.save("output.wav")

# Play the generated audio
import IPython.display as ipd
ipd.Audio("output.wav")
