In [None]:
from transformers import SpeechT5Processor
from transformers import SpeechT5ForTextToSpeech
from transformers import SpeechT5HifiGan
from datasets import load_dataset, Audio
import soundfile as sf

import os
import torch
from torch import manual_seed

import torchaudio
import torchaudio.transforms as T

from datasets import Dataset
from tqdm import tqdm
from textgrid import TextGrid
from joblib import Parallel, delayed

from speechbrain.inference import EncoderClassifier
import warnings
warnings.simplefilter("ignore")

In [2]:
manual_seed(32)
class Config:
    device='cuda'if torch.cuda.is_available() else 'cpu'
    sampling_rate=16000

In [3]:
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
model=model.to(Config.device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") 
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0)

def infer(text,emotion,file_name):
    combined_input = f"{text} [EMOTION] {emotion}"
    inputs = processor(text=combined_input, return_tensors="pt")
    speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
    sf.write(f"{file_name}.wav", speech.numpy(), samplerate=16000)

In [None]:
### One Speaker - One emotion
ds_root = "/home/alina/datasets/EMOV-DB/EMOV"
emo_abv = {'amused': '1', 'sleepiness': '1', 'anger': '2', 'neutral': '2'}

def process_file(emotion, file, path):
    try:
        if not file.endswith(".wav") or emotion not in file:
            return None
        
        textgrid_path = os.path.join(path, file.replace(".wav", ".TextGrid"))
        tg = TextGrid.fromFile(textgrid_path)
        words_tier = tg.getFirst("words")
        transcription = [interval.mark for interval in words_tier if interval.mark.strip()]
        text = " ".join(transcription)
        
        wav_path = os.path.join(path, file)
        waveform, sr = torchaudio.load(wav_path)

        if sr != 16000:
            waveform = T.Resample(orig_freq=sr, new_freq=16000)(waveform)

        input_text = f"{text} [EMOTION] {file.split('_')[0]}"
        audio_path = file
        return input_text, waveform, emotion, audio_path
    
    except Exception as e:
        print(f"Error processing {file}: {e}")
        return None

tasks = []
for emotion, subfolder in emo_abv.items():
    path = os.path.join(ds_root, subfolder)
    for file in os.listdir(path):
        tasks.append((emotion, file, path))

results = Parallel(n_jobs=1)(delayed(process_file)(emotion, file, path) for emotion, file, path in tqdm(tasks))

100%|██████████| 14296/14296 [00:43<00:00, 329.45it/s] 


In [None]:
texts, waveforms, emotions, audio_paths = zip(*[r for r in results if r is not None])

In [None]:
len(texts), len(waveforms), len(emotions), len(audio_paths)

In [None]:
dataset = Dataset.from_dict({
    'text': texts,
    'target': waveforms,
    'emotion': emotions,
    'audio_path': audio_paths
})

In [None]:
### Two speakers - one emotion
from tqdm import tqdm
from datasets import Dataset
import torchaudio
import torchaudio.transforms as T
import os
from textgrid import TextGrid
from joblib import Parallel, delayed

ds_root = "/home/alina/datasets/EMOV-DB/EMOV"
emo_abv = {'amused': '1', 'sleepiness': '1', 'anger': '2', 'neutral': '2'}

resampler = T.Resample(orig_freq=48000, new_freq=16000)

def process_file(file_path):
    try:
        if not file_path.endswith(".wav"):
            return None

        filename = os.path.basename(file_path)
        emotion = None
        for emo in emo_abv:
            if filename.startswith(emo):
                emotion = emo
                break
        if emotion is None:
            return None

        textgrid_path = file_path.replace(".wav", ".TextGrid")
        tg = TextGrid.fromFile(textgrid_path)
        words_tier = tg.getFirst("words")
        transcription = [interval.mark for interval in words_tier if interval.mark.strip()]
        text = " ".join(transcription)

        waveform, sr = torchaudio.load(file_path)
        if sr != 16000:
            waveform = T.Resample(orig_freq=sr, new_freq=16000)(waveform)

        input_text = f"{text} [EMOTION] {emotion}"
        audio_path = os.path.relpath(file_path, ds_root)

        return {
            "text": input_text,
            "target": waveform,
            "emotion": emotion,
            "audio_path": audio_path
        }

    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None

all_files = []
for subfolder in ['1', '2']:
    folder_path = os.path.join(ds_root, subfolder)
    for file in os.listdir(folder_path):
        if file.endswith(".wav"):
            all_files.append(os.path.join(folder_path, file))

results = Parallel(n_jobs=4)(
    delayed(process_file)(file_path) for file_path in tqdm(all_files)
)

processed_data = [r for r in results if r is not None]

dataset = Dataset.from_list(processed_data)


In [None]:
spk_model_name = "speechbrain/spkrec-xvect-voxceleb"
speaker_model = EncoderClassifier.from_hparams(
    source=spk_model_name, 
    run_opts={"device": Config.device}, 
    savedir=os.path.join("/tmp", spk_model_name)
)

def create_speaker_embedding(waveform):
    with torch.no_grad():
        speaker_embeddings = speaker_model.encode_batch(torch.tensor(waveform))
        speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2)
        speaker_embeddings = speaker_embeddings.squeeze().cpu().numpy()
    return speaker_embeddings

def prepare_dataset(single_data):
    example = processor(
        text=single_data["text"],
        audio_target=single_data['target'], 
        sampling_rate=Config.sampling_rate,
        return_attention_mask=False,
    )
    example["labels"] = example["labels"][0]
    example["emotion"] = single_data["emotion"]
    example["text"] = single_data["text"]
    example["audio_path"] = single_data["audio_path"]
    example["speaker_embeddings"] = create_speaker_embedding(single_data['target'])
    return example

dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names, desc="Processing dataset")

Processing dataset: 100%|██████████| 1209/1209 [03:48<00:00,  5.29 examples/s]


In [None]:
def is_not_too_long(input_ids):
    
    input_length = len(input_ids)
    return input_length < 200

dataset = dataset.filter(is_not_too_long, input_columns=["input_ids"],  desc="Filtering too long samples...")
dataset = dataset.train_test_split(test_size=0.1, seed=42, shuffle=True)


Filtering too long samples...: 100%|██████████| 1209/1209 [00:00<00:00, 34658.93 examples/s]


In [None]:
# For One Speaker - One Emotion
dataset["train"].to_json("/home/alina/datasets/EMOV-DB/EMOV/train_data.jsonl", lines=True)
dataset["test"].to_json("/home/alina/datasets/EMOV-DB/EMOV/test_data.jsonl", lines=True)

# For Two Speakers - One Emotion
dataset["train"].to_json("/home/alina/datasets/EMOV-DB/EMOV/train_data_double.jsonl", lines=True)
dataset["test"].to_json("/home/alina/datasets/EMOV-DB/EMOV/test_data_double.jsonl", lines=True)

In [None]:
base_path = "/home/alina/datasets/EMOV-DB/EMOV/1"

def add_full_path(example):
    example["audio_path"] = os.path.join(base_path, example["audio_path"])
    return example

dataset = dataset.map(add_full_path)

In [15]:
print(f"Total train samples: {len(dataset['train'])}\nTotal test samples: {len(dataset['test'])}")

Total train samples: 1088
Total test samples: 121


In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class TTSDataCollatorWithPadding:
    processor: Any
        
    def __call__(self, features):
        input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
        label_features = [{"input_values": feature["labels"]} for feature in features]
        speaker_features = [feature["speaker_embeddings"] for feature in features]
        emotions = [feature["emotion"] for feature in features]

        batch = processor.pad(
            input_ids=input_ids,
            labels=label_features,
            return_tensors="pt")        

        batch["labels"] = batch["labels"].masked_fill(batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100)
        del batch["decoder_attention_mask"]

        if model.config.reduction_factor > 1:
            target_lengths = torch.tensor([len(feature["input_values"]) for feature in label_features])
            target_lengths = target_lengths.new([length - length % model.config.reduction_factor for length in target_lengths])
            max_length = max(target_lengths)
            batch["labels"] = batch["labels"][:, :max_length]
        batch["speaker_embeddings"] = torch.tensor(speaker_features)
        batch["emotion"] = emotions
        
        return batch
    
data_collator = TTSDataCollatorWithPadding(processor=processor)

In [None]:
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="/home/alina/repos/SpeechT5/experiments/exp0045", 
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    eval_strategy="steps",
    save_steps=1000,
    eval_steps=100,
    logging_steps=5,
    lr_scheduler_type="cosine",
    report_to=["tensorboard"],
    greater_is_better=False,
    label_names=["labels"],
    load_best_model_at_end=True,
    save_total_limit=5,
    metric_for_best_model="eval_loss"
)


trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    tokenizer=processor.tokenizer)

trainer.train()

4.53.0.dev0
/home/ashapiro/miniconda3/envs/t5/lib/python3.10/site-packages/transformers/__init__.py
transformers.training_args_seq2seq


## Inference

In [None]:
ckpt='/home/alina/repos/SpeechT5/experiments/exp0045/checkpoint-2000'
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained(ckpt)
model=model.to(Config.device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") 
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0)

def infer(text,emotion,file_name):
    combined_input = f"{text} [EMOTION] {emotion}"
    inputs = processor(text=combined_input, return_tensors="pt")
    speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
    sf.write(f"{file_name}.wav", speech.numpy(), samplerate=16000)

In [None]:
import os
import torchaudio
import torch
from IPython.display import Audio

def gen(prompt, emo, model, speaker_embeddings, save_dir="outputs", save_name="output", device="cpu"):
    os.makedirs(save_dir, exist_ok=True)

    text = f"{prompt} [EMOTION] {emo}"
    inputs = processor(text=text, return_tensors="pt")
    model.to(device)
    vocoder.to(device)
    spectrogram = model.generate_speech(inputs["input_ids"].to(device), speaker_embeddings.to(device))
   
    with torch.no_grad():
        speech = vocoder(spectrogram)

    wav_path = os.path.join(save_dir, f"{os.path.basename(save_name)}")
    torchaudio.save(wav_path, speech.unsqueeze(0).cpu(), 16000)

    spec_path = os.path.join(save_dir, f"{save_name.replace('.wav', '_spec.pt')}")
    torch.save(spectrogram.cpu(), spec_path)

    print(f"Saved wav to {wav_path}")
    print(f"Saved spectrogram to {spec_path}") 

In [None]:
save_dir = "/home/alina/repos/SpeechT5/experiments/exp0045/evaluation"

device = "cuda" if torch.cuda.is_available() else "cpu"

for i, example in tqdm(enumerate(dataset["test"])):
    speaker_embeddings = torch.tensor(example["speaker_embeddings"]).unsqueeze(0)
    emotion = example["emotion"]
    if emotion != "sleepiness":
        continue
    text=example["text"]
    save_name=example["audio_path"].split("/")[-1]

    gen(prompt=text,
        emo=emotion,
        model=model,
        speaker_embeddings=speaker_embeddings,
        save_dir=save_dir,
        save_name=save_name,
        device=device)       