<a href="https://colab.research.google.com/github/LEANHDUC2005/Neural-Network---ANN/blob/main/Fastspeech2%20Hifigan%20Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# FastSpeech 2 + HiFi-GAN: Training from scratch (LJSpeech)
# Compatible with Google Colab Free Tier (T4 GPU)

# Step 1: Install dependencies
!pip install torch torchaudio numpy matplotlib scipy tensorboard
!pip install phonemizer unidecode librosa


In [None]:
# Step 2: Clone HiFi-GAN for waveform decoder
!git clone https://github.com/jik876/hifi-gan.git
%cd hifi-gan
!pip install -r requirements.txt
%cd ..

In [None]:
# Step 3: Download and extract LJSpeech dataset
!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xjf LJSpeech-1.1.tar.bz2

In [None]:
# Step 4: Define FastSpeech 2 components
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

class VariancePredictor(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(input_dim, input_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.LayerNorm(input_dim),
            nn.Dropout(0.5),
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(input_dim, input_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.LayerNorm(input_dim),
            nn.Dropout(0.5),
        )
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.transpose(1, 2)
        return self.linear(x).squeeze(-1)

class VarianceAdaptor(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.duration_predictor = VariancePredictor(input_dim)
        self.pitch_predictor = VariancePredictor(input_dim)
        self.energy_predictor = VariancePredictor(input_dim)

    def forward(self, x):
        duration = self.duration_predictor(x)
        pitch = self.pitch_predictor(x)
        energy = self.energy_predictor(x)
        return x, duration, pitch, energy

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.5),
        )

    def forward(self, x):
        x = self.conv(x.transpose(1, 2)).transpose(1, 2)
        return x

class Decoder(nn.Module):
    def __init__(self, hidden_dim, mel_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.5),
        )
        self.linear = nn.Linear(hidden_dim, mel_dim)

    def forward(self, x):
        x = self.conv(x.transpose(1, 2)).transpose(1, 2)
        return self.linear(x)

class FastSpeech2(nn.Module):
    def __init__(self, vocab_size=300, d_model=256, mel_dim=80):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.encoder = Encoder(d_model, d_model)
        self.variance_adaptor = VarianceAdaptor(d_model)
        self.decoder = Decoder(d_model, mel_dim)

    def forward(self, phoneme_ids):
        x = self.embedding(phoneme_ids)
        x = self.pos_enc(x)
        x = self.encoder(x)
        x, dur, pitch, energy = self.variance_adaptor(x)
        mel = self.decoder(x)
        return mel, dur, pitch, energy

In [None]:
# Step 5: Preprocess LJSpeech to extract phonemes and mel-spectrograms
import os
import librosa
import numpy as np
from phonemizer import phonemize
from phonemizer.separator import Separator
from unidecode import unidecode
from tqdm import tqdm
import torchaudio

# Config
LJ_PATH = "LJSpeech-1.1"
SAMPLING_RATE = 22050
N_MELS = 80
HOP_LENGTH = 256

# Phoneme conversion
def text_to_phonemes(text):
    text = unidecode(text.strip())
    phones = phonemize(text, language='en-us', backend='espeak', separator=Separator(phone=' ', syllable=''))
    return phones

# Create mel-spectrogram
def wav_to_mel(wav_path):
    y, sr = librosa.load(wav_path, sr=SAMPLING_RATE)
    mel = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=1024, hop_length=HOP_LENGTH, n_mels=N_MELS)
    mel_db = librosa.power_to_db(mel, ref=np.max)
    return mel_db.T  # [T, 80]

# Prepare dataset (subset for quick demo)
metadata_path = os.path.join(LJ_PATH, "metadata.csv")
with open(metadata_path, "r", encoding="utf-8") as f:
    lines = f.readlines()

# Pick a small subset for demo
data = []
for line in tqdm(lines[:100]):
    parts = line.strip().split("|")
    wav_path = os.path.join(LJ_PATH, "wavs", parts[0] + ".wav")
    text = parts[2]
    phones = text_to_phonemes(text)
    mel = wav_to_mel(wav_path)
    data.append((phones, mel))

#Save to disk (optional):
torch.save(data, "train_subset.pt")


In [None]:
# Step 8: Create Dataset and DataLoader
from torch.utils.data import Dataset, DataLoader

class LJSpeechDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

dataset = LJSpeechDataset(torch.load("train_subset.pt"))
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=lambda x: x)

# Step 9: Initialize Model and Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FastSpeech2(vocab_size=len(phoneme_vocab)+1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Step 10: Define Loss Functions
class FastSpeech2Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()

    def forward(self, mel_pred, mel_target, dur_pred, dur_target,
                pitch_pred, pitch_target, energy_pred, energy_target):
        mel_loss = self.l1_loss(mel_pred, mel_target)
        dur_loss = self.mse_loss(dur_pred, dur_target.float())
        pitch_loss = self.mse_loss(pitch_pred, pitch_target)
        energy_loss = self.mse_loss(energy_pred, energy_target)
        return mel_loss + dur_loss + pitch_loss + energy_loss

criterion = FastSpeech2Loss()

# Step 11: Training Loop
def train(model, dataloader, optimizer, criterion, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader):
            # Prepare batch
            phonemes = [item[0] for item in batch]
            mels = [item[1] for item in batch]

            # Pad sequences
            phonemes_padded = nn.utils.rnn.pad_sequence(phonemes, batch_first=True)
            mels_padded = nn.utils.rnn.pad_sequence(mels, batch_first=True)

            # Dummy targets for demo (in real training, extract these from data)
            dur_target = torch.ones(phonemes_padded.size(0), phonemes_padded.size(1))
            pitch_target = torch.rand(phonemes_padded.size(0), phonemes_padded.size(1))
            energy_target = torch.rand(phonemes_padded.size(0), phonemes_padded.size(1))

            # Move to device
            phonemes_padded = phonemes_padded.to(device)
            mels_padded = mels_padded.to(device)
            dur_target = dur_target.to(device)
            pitch_target = pitch_target.to(device)
            energy_target = energy_target.to(device)

            # Forward pass
            mel_pred, dur_pred, pitch_pred, energy_pred = model(phonemes_padded)

            # Calculate loss
            loss = criterion(
                mel_pred, mels_padded,
                dur_pred, dur_target,
                pitch_pred, pitch_target,
                energy_pred, energy_target
            )

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

# Run training
train(model, dataloader, optimizer, criterion, epochs=5)

# Step 12: Save trained model
torch.save(model.state_dict(), "fastspeech2_ljspeech.pth")

In [None]:
# Step 13: Load pretrained HiFi-GAN
from hifi_gan.models import Generator
from hifi_gan.env import AttrDict
import json

# Load HiFi-GAN config
with open("hifi-gan/config.json") as f:
    data = f.read()
hifi_config = AttrDict(json.loads(data))

# Initialize HiFi-GAN
hifi_gan = Generator(hifi_config).to(device)
hifi_gan.load_state_dict(torch.load("hifi-gan/generator_universal.pth", map_location=device))
hifi_gan.eval()
hifi_gan.remove_weight_norm()

# Step 14: Synthesis function
def synthesize(text):
    # Convert text to phonemes
    phonemes = text_to_phonemes(text)
    phoneme_ids = torch.tensor([phoneme_vocab[p] for p in phonemes.split() if p in phoneme_vocab]).unsqueeze(0).to(device)

    # Generate mel-spectrogram
    with torch.no_grad():
        mel, _, _, _ = model(phoneme_ids)

    # Generate waveform with HiFi-GAN
    with torch.no_grad():
        audio = hifi_gan(mel.transpose(1, 2)).squeeze().cpu().numpy()

    return audio, mel

# Step 15: Test synthesis
test_text = "Hello world, this is a test of speech synthesis."
audio, mel = synthesize(test_text)

# Save and play audio
import soundfile as sf
from IPython.display import Audio

sf.write("output.wav", audio, SAMPLING_RATE)
Audio("output.wav")