In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchaudio

from tqdm import tqdm
from evaluate import load
import matplotlib.pyplot as plt 

from model.Speech2Text import Speech2Text
from model.SpeechGenerator import SpeechGenerator
from utils.Config import ConfigSLP, ConfigNAC, ConfigDiTTO
from utils.MLS import MLSDataset
from utils.Processing import Processing

from torch.utils.data import DataLoader

In [None]:
ConfigSLP.display()
ConfigNAC.display()
ConfigDiTTO.display()

In [None]:
# Processing.remove_metadata_from_audio_folder(ConfigSLP.TRAIN_PATH+"/"+"audio", ConfigSLP.TRAIN_PATH+"/"+"audio_clean",)
# Processing.remove_metadata_from_audio_folder(ConfigSLP.TEST_PATH+"/"+"audio", ConfigSLP.TEST_PATH+"/"+"audio_clean",)
# Processing.remove_metadata_from_audio_folder(ConfigSLP.DEV_PATH+"/"+"audio", ConfigSLP.DEV_PATH+"/"+"audio_clean",)

## Speech Generation with DiTTO-TTs and Vocoder

In [None]:
train_set = MLSDataset(
    data_dir=ConfigDiTTO.TRAIN_PATH,
    max_text_token_length=ConfigDiTTO.MAX_TOKEN_LENGTH,
    sampling_rate=ConfigDiTTO.SAMPLE_RATE,
    nb_samples=ConfigDiTTO.NB_SAMPLES,
    tokenizer_model="gpt2"
)

test_set = MLSDataset(
    data_dir=ConfigDiTTO.TEST_PATH,
    max_text_token_length=ConfigDiTTO.MAX_TOKEN_LENGTH,
    sampling_rate=ConfigDiTTO.SAMPLE_RATE,
    nb_samples=ConfigDiTTO.NB_SAMPLES,
    tokenizer_model="gpt2"
)

train_loader = DataLoader(train_set, batch_size=ConfigNAC.BATCH_SIZE, shuffle=True, collate_fn=MLSDataset.collate_fn)
test_loader = DataLoader(test_set, batch_size=ConfigNAC.BATCH_SIZE, shuffle=True, collate_fn=MLSDataset.collate_fn)

In [None]:
ConfigDiTTO.DIFFUSION_STEPS = 1000

speech_generator = SpeechGenerator(
    nac_model_path="/tempory/M2-DAC/UE_DEEP/AMAL/DiTTO-TTS/src/params/NAC_epoch_20.pth",
    ditto_model_path="/tempory/M2-DAC/UE_DEEP/AMAL/DiTTO-TTS/src/params/DiTTO_epoch_20.pth",
    lambda_factor=ConfigNAC.LAMBDA_FACTOR,
    sample_rate=ConfigNAC.SAMPLE_RATE,
    device=ConfigDiTTO.DEVICE
)

In [None]:
def test_with_loader(loader, prompt):
    ConfigDiTTO.DIFFUSION_STEPS = 1000

    batch = next(iter(loader))
    batch["audio"] = batch["audio"].to(ConfigDiTTO.DEVICE)
    batch["text"]["input_ids"] = batch["text"]["input_ids"].to(ConfigDiTTO.DEVICE)
    batch["text"]["attention_mask"] = batch["text"]["attention_mask"].to(ConfigDiTTO.DEVICE)

    for audio_tensor, padding_mask_audio, text_input  in zip(batch["audio"], batch["padding_mask_audio"],  batch["text"]["input_ids"]):
        generated_waveform = speech_generator.generate_speech_from_audio_tensor(
            audio_tensor.to(ConfigDiTTO.DEVICE).unsqueeze(0), 
            padding_mask_audio.to(ConfigDiTTO.DEVICE).unsqueeze(0),
            text_input.unsqueeze(0),
            is_tokenized=True
        )
        output_path = "output.wav"
        torchaudio.save(output_path, generated_waveform.cpu(), ConfigDiTTO.SAMPLE_RATE)
        break

In [None]:
wave = test_with_loader(train_loader, "Bonjour, comment çava tout le monde ?")

In [None]:
ConfigDiTTO.DIFFUSION_STEPS = 1000

In [65]:
import torch
from tqdm import tqdm
from evaluate import load
from transformers import GPT2Tokenizer

cer_metric = load("cer")
wer_metric = load("wer")

model = Speech2Text(sampling_rate=16000)
model.eval()


predictions = []
references = []
max_batch = 5
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

def cer_wer_on_loader(loader):
    with torch.no_grad():
        for i,batch in tqdm(enumerate(loader)):

            batch["audio"] = batch["audio"].to(ConfigDiTTO.DEVICE)
            batch["text"]["input_ids"] = batch["text"]["input_ids"].to(ConfigDiTTO.DEVICE)
            batch["text"]["attention_mask"] = batch["text"]["attention_mask"].to(ConfigDiTTO.DEVICE)

            for audio_tensor, padding_mask_audio, text_input  in zip(batch["audio"], batch["padding_mask_audio"],  batch["text"]["input_ids"]):
                generated_waveform = speech_generator.generate_speech_from_audio_tensor(
                    audio_tensor.to(ConfigDiTTO.DEVICE).unsqueeze(0), 
                    padding_mask_audio.to(ConfigDiTTO.DEVICE).unsqueeze(0),
                    text_input.unsqueeze(0),
                    is_tokenized=True
                )
                transcription = model(generated_waveform)
                predictions.extend(transcription)
                
            ref_texts = tokenizer.batch_decode(batch["text"]["input_ids"].to(ConfigDiTTO.DEVICE), skip_special_tokens=True)
            references.extend(ref_texts)
            if i > max_batch:
                break

    # Calcul des métriques
    cer_score = cer_metric.compute(predictions=predictions, references=references)
    wer_score = wer_metric.compute(predictions=predictions, references=references)

    print("CER score:", cer_score)
    print("WER score:", wer_score)

Downloading builder script: 100%|██████████| 5.60k/5.60k [00:00<00:00, 7.04MB/s]
Downloading builder script: 100%|██████████| 4.49k/4.49k [00:00<00:00, 4.70MB/s]
Some weights of Speech2TextForConditionalGeneration were not initialized from the model checkpoint at facebook/s2t-medium-mustc-multilingual-st and are newly initialized: ['model.decoder.embed_positions.weights', 'model.encoder.embed_positions.weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [66]:
cer_wer_on_loader(train_loader)
cer_wer_on_loader(test_loader)

1000it [00:12, 78.75it/s]




1000it [00:12, 77.91it/s]




1000it [00:12, 77.38it/s]




1000it [00:13, 76.81it/s]




1000it [00:13, 76.33it/s]




1000it [00:13, 76.06it/s]




1000it [00:13, 75.87it/s]




1000it [00:13, 75.78it/s]




1000it [00:13, 75.83it/s]




1000it [00:13, 75.60it/s]




1000it [00:13, 75.54it/s]




1000it [00:13, 75.51it/s]




1000it [00:13, 75.51it/s]




1000it [00:13, 75.59it/s]




1000it [00:13, 75.50it/s]




1000it [00:13, 75.64it/s]




1000it [00:13, 75.53it/s]




1000it [00:13, 75.62it/s]




1000it [00:13, 75.51it/s]




1000it [00:13, 75.36it/s]




1000it [00:13, 75.55it/s]




1000it [00:13, 75.36it/s]




1000it [00:13, 75.45it/s]




1000it [00:13, 75.56it/s]




1000it [00:13, 75.47it/s]




1000it [00:13, 75.43it/s]




1000it [00:13, 75.56it/s]




1000it [00:13, 75.54it/s]




6it [06:36, 66.12s/it]


CER score: 0.9305486490966351
WER score: 0.9981549815498155


1000it [00:13, 75.64it/s]




1000it [00:13, 75.60it/s]




1000it [00:13, 75.64it/s]




1000it [00:13, 75.55it/s]




1000it [00:13, 75.52it/s]




1000it [00:13, 75.60it/s]




1000it [00:13, 75.53it/s]




1000it [00:13, 75.59it/s]




1000it [00:13, 75.53it/s]




1000it [00:13, 75.62it/s]




1000it [00:13, 75.52it/s]




1000it [00:13, 75.57it/s]




1000it [00:13, 75.50it/s]




1000it [00:13, 75.59it/s]




1000it [00:13, 75.50it/s]




1000it [00:13, 75.50it/s]




1000it [00:13, 75.52it/s]




1000it [00:13, 75.76it/s]




1000it [00:13, 75.45it/s]




1000it [00:13, 75.53it/s]




1000it [00:13, 75.57it/s]




1000it [00:13, 75.53it/s]




1000it [00:13, 75.52it/s]




1000it [00:13, 75.51it/s]




1000it [00:13, 75.57it/s]




1000it [00:13, 75.53it/s]




1000it [00:13, 75.57it/s]




1000it [00:13, 75.51it/s]




6it [06:35, 65.98s/it]

CER score: 0.9305370442963544
WER score: 0.9973509933774835



