In [None]:
!pip install transformers 
!pip install torch
!pip install speechbrain
!pip install torchaudio==2.7.0
!pip install datasets==3.6.0

### Speaker encoder training

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, Audio
import torchaudio
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    print("CUDA is available")
else:
    print("CUDA is not available")  

In [None]:
class CustomSpeakerEncoder(nn.Module):
    def __init__(self, input_dim=80, hidden_dim=256, output_dim=512, num_layers=4, num_heads=8, dropout=0.1):
        super(CustomSpeakerEncoder, self).__init__()
        
        self.conv1 = nn.Conv1d(input_dim, 128, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.conv3 = nn.Conv1d(256, hidden_dim, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm1d(hidden_dim)
        
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,  
                nhead=num_heads,    
                dropout=dropout      
            ),
            num_layers=num_layers  
        )
        
        self.attention = nn.Linear(hidden_dim, 1)
        
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, mel_spectrogram):
        x = F.relu(self.bn1(self.conv1(mel_spectrogram))) 
        x = F.relu(self.bn2(self.conv2(x)))                
        x = F.relu(self.bn3(self.conv3(x)))               
        
        x = x.transpose(1, 2)  # Shape: [batch_size, time_steps, hidden_dim]
        
        transformer_out = self.transformer_encoder(x) 
        
        attention_weights = F.softmax(self.attention(transformer_out), dim=1)  
        pooled = torch.sum(attention_weights * transformer_out, dim=1)  
        
        x = F.relu(self.fc1(pooled))  
        x = self.dropout(x)
        speaker_embedding = self.fc2(x)         
        speaker_embedding = F.normalize(speaker_embedding, p=2, dim=1)
        
        return speaker_embedding

In [None]:
class SpeakerClassifier(nn.Module):
    def __init__(self, embedding_dim=512, num_speakers=100):
        super(SpeakerClassifier, self).__init__()
        self.fc = nn.Linear(embedding_dim, num_speakers)
        
    def forward(self, embeddings):        
        return self.fc(embeddings)

In [None]:
def extract_mel_spectrogram(waveform, sample_rate=16000, n_mels=80):
    if not isinstance(waveform, torch.Tensor):
        waveform = torch.tensor(waveform, dtype=torch.float32)
    
    waveform = waveform.unsqueeze(0)
    
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=400,
        hop_length=160,
        n_mels=n_mels
    )
    
    mel_spec = mel_transform(waveform)
    mel_spec = torch.log1p(mel_spec)
    mel_spec = mel_spec.squeeze(0)
    
    return mel_spec

In [None]:
class SpeakerDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        mel_spec = extract_mel_spectrogram(item["audio"]["array"])
        speaker_id = item["speaker_id"]
        return mel_spec, speaker_id


def collate_fn(batch):
    # Pad mel spectrograms to same length in batch
    mel_specs, speaker_ids = zip(*batch)    
    max_len = max([mel.shape[1] for mel in mel_specs])
    max_len = min(max_len, 500)  # Limit max length
    
    # Pad or truncate
    padded_mels = []
    for mel in mel_specs:
        if mel.shape[1] > max_len:
            mel = mel[:, :max_len]  
        else:
            pad_len = max_len - mel.shape[1]
            if pad_len > 0:
                mel = F.pad(mel, (0, pad_len), value=0) 
        padded_mels.append(mel)
    
    mel_batch = torch.stack(padded_mels)
    speaker_ids = torch.tensor(speaker_ids, dtype=torch.long)
    
    return mel_batch, speaker_ids

In [None]:
dataset = load_dataset("facebook/voxpopuli", "nl", split="train", streaming=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

subset = list(dataset.take(1000))

unique_speakers = list(set([item["speaker_id"] for item in subset]))
speaker_to_id = {speaker: idx for idx, speaker in enumerate(unique_speakers)}
num_speakers = len(unique_speakers)
print(f"Number of unique speakers: {num_speakers}")

for item in subset:
    item["speaker_id"] = speaker_to_id[item["speaker_id"]]

speaker_dataset = SpeakerDataset(subset)
train_loader = DataLoader(
    speaker_dataset, 
    batch_size=4,  
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=0
)

speaker_encoder = CustomSpeakerEncoder(
    input_dim=80,
    hidden_dim=128,  
    output_dim=512,
    num_layers=2, 
    dropout=0.1
).to(device)

classifier = SpeakerClassifier(
    embedding_dim=512,
    num_speakers=num_speakers
).to(device)

optimizer = torch.optim.Adam(
    list(speaker_encoder.parameters()) + list(classifier.parameters()),
    lr=1e-5
)
criterion = nn.CrossEntropyLoss()

num_epochs = 10
speaker_encoder.train()
classifier.train()

for epoch in range(num_epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for mel_specs, speaker_ids in pbar:
        mel_specs = mel_specs.to(device)
        speaker_ids = speaker_ids.to(device)
        
        embeddings = speaker_encoder(mel_specs)
        logits = classifier(embeddings)
        
        loss = criterion(logits, speaker_ids)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = logits.max(1)
        correct += predicted.eq(speaker_ids).sum().item()
        total += speaker_ids.size(0)
        
        if total % 20 == 0:
            torch.cuda.empty_cache()
        
        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "acc": f"{100.*correct/total:.2f}%"
        })
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.2f}%")

output_file = "pretrained_speaker_encoder_6.pt"
torch.save(speaker_encoder.state_dict(), outputFile)
print("Saved to: ", output_file)

### Integration with text2speech

In [None]:
encoder_file_name = "pretrained_speaker_encoder_6.pt"
huggingface_token = "TOKEN"
model_name = "ACCOUNTNAME/MODELNAME"

In [None]:
custom_speaker_encoder = CustomSpeakerEncoder(
    input_dim=80,
    hidden_dim=128,
    output_dim=512,
    num_layers=2,
    dropout=0.1
).to(device)

custom_speaker_encoder.load_state_dict(torch.load(encoder_file_name, map_location=device))
custom_speaker_encoder.eval()  

In [None]:
from datasets import load_dataset, Audio, Dataset

from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, Seq2SeqTrainingArguments, Seq2SeqTrainer
from huggingface_hub import login
from dataclasses import dataclass
from typing import Any, Dict, List, Union

dataset = load_dataset("facebook/voxpopuli", "nl", split="train", streaming=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.take(4000)

checkpoint = "microsoft/speecht5_tts"
processor = SpeechT5Processor.from_pretrained(checkpoint)

def create_speaker_embedding(waveform):
    with torch.no_grad():
        mel_spec = extract_mel_spectrogram(waveform)
        mel_spec = mel_spec.unsqueeze(0).to(device)         
        speaker_embedding = custom_speaker_encoder(mel_spec)
        speaker_embedding = speaker_embedding.squeeze().cpu().numpy()
    
    return speaker_embedding

def prepare_dataset(example):
    audio = example["audio"]
    
    example = processor(
        text=example["normalized_text"],
        audio_target=audio["array"],
        sampling_rate=audio["sampling_rate"],
        return_attention_mask=False,
    )

    # strip off the batch dimension    
    example["labels"] = example["labels"][0]

    # use SpeechBrain to obtain x-vector    
    example["speaker_embeddings"] = create_speaker_embedding(audio["array"])
    
    return example

dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names)

def is_not_too_long(input_ids):
    input_length = len(input_ids)
    return input_length < 200

dataset = dataset.filter(is_not_too_long, input_columns=["input_ids"])
dataset = Dataset.from_list(list(dataset))
dataset = dataset.train_test_split(test_size=0.1)

@dataclass
class TTSDataCollatorWithPadding:
    processor: Any
    
    def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
        input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
        label_features = [{"input_values": feature["labels"]} for feature in features]
        speaker_features = [feature["speaker_embeddings"] for feature in features]
        
        # collate the inputs and targets into a batch
        batch = processor.pad(input_ids=input_ids, labels=label_features, return_tensors="pt")
        
        # replace padding with -100 to ignore loss correctly
        batch["labels"] = batch["labels"].masked_fill(batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100)
        
        # not used during fine-tuning
        del batch["decoder_attention_mask"]
        
        # round down target lengths to multiple of reduction factor
        if model.config.reduction_factor > 1:
            target_lengths = torch.tensor([len(feature["input_values"]) for feature in label_features])
            target_lengths = target_lengths.new(
                [length - length % model.config.reduction_factor for length in target_lengths]
            )
            max_length = max(target_lengths)
            batch["labels"] = batch["labels"][:, :max_length]
        
        # also add in the speaker embeddings
        batch["speaker_embeddings"] = torch.tensor(speaker_features)
        
        return batch

data_collator = TTSDataCollatorWithPadding(processor=processor)

model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
model.config.use_cache = False

training_args = Seq2SeqTrainingArguments(
    output_dir="speecht5_finetuned_voxpopuli_nl_custom6",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    #warmup_steps=500,    
    warmup_steps=50,
    #max_steps=4000,
    max_steps=2000,    
    #gradient_checkpointing=True,    
    gradient_checkpointing=False,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=2,
    #save_steps=1000,
    #eval_steps=1000,
    save_steps=500,
    eval_steps=500,    
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    greater_is_better=False,
    label_names=["labels"],
    push_to_hub=True,
)

login(token=huggingface_token)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    processing_class=processor,
)

trainer.train()

model.save_pretrained(model_name)
processor.save_pretrained(model_name)

trainer.push_to_hub()

### Inference

In [None]:
encoder_file_name = "pretrained_speaker_encoder_6.pt"
huggingface_token = "TOKEN"
model_name = "ACCOUNTNAME/MODELNAME"

In [None]:
custom_speaker_encoder = CustomSpeakerEncoder(
    input_dim=80,
    hidden_dim=128,
    output_dim=512,
    num_layers=2,
    dropout=0.1
).to(device)

custom_speaker_encoder.load_state_dict(torch.load(encoder_file_name, map_location=device))
custom_speaker_encoder.eval()  

In [None]:
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset, Audio as AudioFeature
import numpy as np

processor = SpeechT5Processor.from_pretrained(model_name)
model = SpeechT5ForTextToSpeech.from_pretrained(model_name).to(device)
model.eval()

# vocoder for converting to audio
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
vocoder.eval()

def synthesize_speech(text, reference_audio_array, sample_rate=16000):    
    mel_spec = extract_mel_spectrogram(reference_audio_array, sample_rate)
    mel_spec = mel_spec.unsqueeze(0).to(device)
    
    with torch.no_grad():
        speaker_embeddings = custom_speaker_encoder(mel_spec)
        
    inputs = processor(text=text, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    
    with torch.no_grad():
        spectrogram = model.generate_speech(input_ids, speaker_embeddings, vocoder=vocoder)
    
    audio = spectrogram.cpu().numpy()
    
    return audio

dataset = load_dataset("facebook/voxpopuli", "nl", split="test", streaming=True)
dataset = dataset.cast_column("audio", AudioFeature(sampling_rate=16000))

it = iter(dataset)
for _ in range(2):
    next(it)
reference_example = next(it)
reference_audio = reference_example['audio']['array']

text = "hallo allemaal, ik praat nederlands. groetjes aan iedereen!"

audio = synthesize_speech(text, reference_audio)

from IPython.display import Audio as IPythonAudio
IPythonAudio(audio, rate=16000)