In [1]:
import os
os.environ["PHONEMIZER_ESPEAK_PATH"] = r"C:\Program Files\eSpeak NG\espeak-ng.exe"
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Audio
from phonemizer import phonemize

import torchaudio


In [3]:
class TTSConfig:
    d_model = 256
    num_latents = 128
    n_blocks = 6
    d_ff = 1024
    phoneme_vocab_size = 70
    mel_dim = 80
    mel_win = 1024
    mel_n_fft = 1024
    mel_hop = 256
    sample_rate = 22050
    batch_size = 32
    max_text_len = 80
    max_mel_len = 800
    lr = 1e-4
    epochs = 10
    num_heads = 8
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
CONFIG = TTSConfig()


In [4]:
#Perceiver Attention Block

class PerceiverAttention(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super(PerceiverAttention, self).__init__()

        self.self_attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)

        self.lin1 = nn.Linear(d_model, d_ff)
        self.lin2 = nn.Linear(d_ff, d_model)

        self.layer_norm = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def forward(self,x):
        attention_output, _ = self.self_attention(x, x, x)
        x = self.layer_norm(x + attention_output)
        ff = self.lin2(F.gelu(self.lin1(x)))
        x = self.layer_norm2(x + ff)
        return x

In [5]:
# import math
# class PerceiverTTS(nn.Module):
#     def __init__(self, config):
#         super(PerceiverTTS, self).__init__()
#         self.config = config
#
#         #Phoneme Embed
#         self.phoneme_embed = nn.Embedding(config.phoneme_vocab_size, config.d_model)
#
#         #Learnable Latents
#         self.latent = nn.Parameter(torch.randn(1, config.num_latents, config.d_model) * 0.02)
#         self.input_cross_attn = nn.MultiheadAttention(config.d_model, config.num_heads, batch_first=True)
#
#         self.latent_block = nn.ModuleList(
#             [
#                 PerceiverAttention(config.d_model, config.num_heads, config.d_ff)
#                 for _ in range(config.n_blocks)
#             ]
#         )
#         self.output_proj = nn.Linear(config.d_model, config.mel_dim)
#         self.upsampled = nn.Upsample(scale_factor= config.max_mel_len / config.num_latents, mode = 'linear')
#
#     def forward(self, phonemes):
#         b = phonemes.shape[0]
#         x = self.phoneme_embed(phonemes)
#
#         latent = self.latent.expand(b, -1, -1)
#         latent, _ = self.input_cross_attn(latent, x, x)
#
#         for block in self.latent_block:
#             latent = block(latent)
#         mel_len = self.config.max_mel_len
#         latent_len = latent.shape[1]
#         repeat_factor = math.ceil(mel_len / latent_len)
#
#         latent_expand = latent.repeat_interleave(repeat_factor, dim=1)
#         latent_expand = latent_expand[:, :mel_len, :]
#
#         mels = self.output_proj(latent_expand)
#         return mels


In [6]:
#Implementing a Pretrain Perceiver (Current Perceiver is working but frontier positional encoding wasnt use so that might be the major cause of model struggle to learn

In [7]:
from transformers import DistilBertTokenizer, DistilBertModel, Wav2Vec2Model, Wav2Vec2FeatureExtractor

from transformers import DistilBertTokenizer, DistilBertModel, Wav2Vec2Model, Wav2Vec2FeatureExtractor
import torch
import torch.nn as nn
import torch.nn.functional as F

class PretrainedPerceiver(nn.Module):
    def __init__(self, config, id_to_phoneme=None):
        super().__init__()
        self.config = config
        # Robust: Always set self.id_to_phoneme (dummy if not provided)
        self.id_to_phoneme = id_to_phoneme or {i: f"PH{i}" for i in range(config.phoneme_vocab_size)}

        self.text_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        self.text_encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
        # Freeze
        for param in self.text_encoder.parameters():
            param.requires_grad = False
        self.text_proj = nn.Linear(768, config.d_model)

        self.audio_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base')
        self.audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
        for param in self.audio_encoder.parameters():
            param.requires_grad = False
        self.audio_proj = nn.Linear(self.audio_encoder.config.hidden_size, config.d_model)

        # Latent and Cross Attention
        self.latent = nn.Parameter(torch.rand(1, config.num_latents, config.d_model) * 0.02)
        self.text_to_latent_cross = nn.MultiheadAttention(config.d_model, config.num_heads, batch_first=True)

        self.latent_blocks = nn.ModuleList([  # Consistent name
            PerceiverAttention(config.d_model, config.num_heads, config.d_ff)
            for _ in range(config.n_blocks)
        ])

        self.latent_to_audio_cross = nn.MultiheadAttention(config.d_model, config.num_heads, batch_first=True)
        self.output_proj = nn.Linear(config.d_model, config.mel_dim)

    def unfreeze_layers(self, phase):
        if phase == 'frozen':
            pass
        elif phase == 'top':
            for name, param in self.text_encoder.named_parameters():
                if 'encoder.layer.5' in name or 'encoder.layer.4' in name:
                    param.requires_grad = True
            for name, param in self.audio_encoder.named_parameters():
                if 'encoder.layers.-2' in name or 'encoder.layers.-1' in name:
                    param.requires_grad = True
        elif phase == 'full':
            for param in self.text_encoder.parameters():
                param.requires_grad = True
            for param in self.audio_encoder.parameters():
                param.requires_grad = True
        print(f"Unfrozen phase: {phase}")

    def forward(self, phonemes, mels=None):
        b = phonemes.shape[0]
        device = phonemes.device

        phoneme_strings = []
        for row in phonemes.cpu().tolist():
            str_row = ' '.join([self.id_to_phoneme.get(id, '<UNK>') for id in row])
            phoneme_strings.append(str_row)

        text_inputs = self.text_tokenizer(
            phoneme_strings,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=self.config.max_text_len
        )
        text_inputs = {k: v.to(device) for k, v in text_inputs.items()}

        text_feats = self.text_encoder(**text_inputs).last_hidden_state
        text_feats = self.text_proj(text_feats)

        latent = self.latent.expand(b, -1, -1)
        latent = self.text_to_latent_cross(latent, text_feats, text_feats)[0]

        for block in self.latent_blocks:
            latent = block(latent)

        if self.training and mels is not None:
            mels_flat = mels.view(b, -1).cpu().numpy()
            audio_feats = self.audio_extractor(
                mels_flat,
                sampling_rate=16000,
                return_tensors='pt'
            ).input_values.to(device)

            audio_feats = self.audio_encoder(audio_feats).last_hidden_state.mean(1)
            audio_feats = self.audio_proj(audio_feats).unsqueeze(1)
            latent = self.latent_to_audio_cross(latent, audio_feats, audio_feats)[0]

        mel_len = self.config.max_mel_len
        latent_expand = F.interpolate(
            latent.transpose(1, 2),
            size=mel_len,
            mode='linear'
        ).transpose(1, 2)

        mels_out = self.output_proj(latent_expand)
        return mels_out


In [8]:
import io
import soundfile as sf



class LJSpeechDataset(Dataset):
    def __init__(self, config, split="full[:1%]"):
        super().__init__()
        self.config = config

        # Load dataset
        self.dataset = load_dataset("MikhailT/lj-speech", split=split)

        self.dataset = self.dataset.cast_column("audio", Audio(decode=False))

        # Mel transform
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=config.sample_rate,
            n_fft=config.mel_n_fft,
            hop_length=config.mel_hop,
            win_length=config.mel_win,
            n_mels=config.mel_dim,
        )

        # Build phoneme vocab
        self.phoneme_vocab = {"<PAD>": 0, "<UNK>": 1}
        self._build_vocab()
        self.phoneme_to_id = {p: i for i, p in self.phoneme_vocab.items()}

    def _build_vocab(self):
        all_phonemes = set()

        for item in self.dataset:
            text = item["normalized_text"]
            p_str = self.phonemize(text)

            for p in p_str.split():
                all_phonemes.add(p)

        self.phoneme_vocab = {
            "<PAD>": 0,
            "<UNK>": 1,
        }

        for p in sorted(all_phonemes):
            if p not in self.phoneme_vocab:
                self.phoneme_vocab[p] = len(self.phoneme_vocab)

        print("Final phoneme vocab size:", len(self.phoneme_vocab))

        assert "<PAD>" in self.phoneme_vocab
        assert "<UNK>" in self.phoneme_vocab


    def phonemize(self, text):
        return phonemize(
            [text],
            backend="espeak",
            language="en-us",
            with_stress=True,
            strip=True,
        )[0]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        audio_bytes = item["audio"]["bytes"]

        with sf.SoundFile(io.BytesIO(audio_bytes)) as f:
            waveform = f.read(dtype="float32")
            sr = f.samplerate

        waveform = torch.tensor(waveform).float()

        if waveform.ndim > 1:
            waveform = waveform.mean(1)

        if sr != self.config.sample_rate:
            waveform = torchaudio.functional.resample(
                waveform, sr, self.config.sample_rate
            )

        mel = self.mel_transform(waveform.unsqueeze(0)).squeeze(0).T

        mel = torch.clamp(mel, min=1e-5)
        mel = torch.log(mel)

        mel = (mel - mel.mean()) / (mel.std() + 1e-6)

        if mel.shape[0] > self.config.max_mel_len:
            mel = mel[: self.config.max_mel_len]
        else:
            mel = F.pad(
                mel, (0, 0, 0, self.config.max_mel_len - mel.shape[0])
            )


        text = item["normalized_text"]
        p_str = self.phonemize(text)

        unk_id = self.phoneme_to_id.get("<UNK>", 1)


        phoneme_ids = [
            self.phoneme_to_id[p] if p in self.phoneme_to_id else unk_id
            for p in p_str.split()
        ]

        phonemes = torch.tensor(
            phoneme_ids[: self.config.max_text_len], dtype=torch.long
        )

        if len(phonemes) < self.config.max_text_len:
            phonemes = F.pad(
                phonemes, (0, self.config.max_text_len - len(phonemes))
            )

        return phonemes, mel


def collate_fn(batch):
    phonemes, mels = zip(*batch)
    return torch.stack(phonemes), torch.stack(mels)


In [None]:

train_ds = LJSpeechDataset(CONFIG, split="full[40%:]")
val_ds = LJSpeechDataset(CONFIG, split="full[40%:60%]")

train_loader = DataLoader(
    train_ds, batch_size=CONFIG.batch_size, shuffle=True, collate_fn=collate_fn,
)
val_loader = DataLoader(
    val_ds, batch_size=CONFIG.batch_size, shuffle=False, collate_fn=collate_fn,
)
print(f"Train batches: {len(train_loader)}, Val: {len(val_loader)}")

Final phoneme vocab size: 10095


In [None]:
ph, mel = train_ds[10]
print("Phonemes:", ph.shape)
print("Mel:", mel.shape)


In [None]:
from torch.amp import autocast, GradScaler
from tqdm import tqdm

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    scaler = GradScaler()
    total_loss = 0.0
    total_samples = 0

    pbar = tqdm(loader, desc="Train Epoch")

    for batch_idx, (phonemes, mels) in enumerate(pbar):
        phonemes = phonemes.to(device)
        mels = mels.to(device)

        optimizer.zero_grad()

        with autocast(device_type='cuda', dtype=torch.float16):
            preds = model(phonemes)
            loss = criterion(preds, mels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        batch_loss = loss.item()
        total_loss += batch_loss * phonemes.size(0)
        total_samples += phonemes.size(0)

        if batch_idx % 50 == 0:
            pbar.set_postfix({'Batch Loss': f'{batch_loss:.4f}'})

    avg_loss = total_loss / total_samples
    print(f"Train Epoch Avg Loss: {avg_loss:.4f}")
    return avg_loss

@torch.no_grad()
def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_samples = 0

    pbar = tqdm(loader, desc="Eval Epoch")

    for phonemes, mels in pbar:
        phonemes = phonemes.to(device)
        mels = mels.to(device)

        # AMP for eval (faster inference)
        with autocast(device_type='cuda', dtype=torch.float16):
            preds = model(phonemes)
            loss = criterion(preds, mels)

        total_loss += loss.item() * phonemes.size(0)
        total_samples += phonemes.size(0)

        pbar.set_postfix({'Batch Loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / total_samples
    print(f"Eval Avg Loss: {avg_loss:.4f}")
    return avg_loss

In [None]:

model = PretrainedPerceiver(
    CONFIG ).to(CONFIG.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
criterion = nn.MSELoss()

In [None]:
phonemes = torch.randint(0, CONFIG.phoneme_vocab_size, (32, 100)).to(device=CONFIG.device)

# Instantiate with dummy id_to_phoneme if not passed
model = PretrainedPerceiver(CONFIG, id_to_phoneme={i: f"PH{i}" for i in range(CONFIG.phoneme_vocab_size)}).to(CONFIG.device)

preds = model(phonemes)
print("Preds shape:", preds.shape)  # [32, 800, 80]

mels_dummy = torch.randn(32, 800, 80).to(CONFIG.device)
loss = criterion(preds, mels_dummy)
print("Loss computed OK:", loss.item())

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-5,betas = (0.9, 0.98), eps = 1e-9, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-6)

# Initial freeze
model.unfreeze_layers('frozen')

for epoch in range(30):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, CONFIG.device)
    val_loss = eval_one_epoch(model, val_loader, criterion, CONFIG.device)
    scheduler.step()

    print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

    # Progressive unfreeze
    if epoch + 1 == 10:
        model.unfreeze_layers('top')
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)
        scheduler = CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
    elif epoch + 1 == 20:
        model.unfreeze_layers('full')
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
        scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-7)

In [None]:
def tts_inference(model, text, dataset, device):
    model.eval()

    phoneme_str = dataset.phonemize(text)

    phoneme_ids = [
        dataset.phoneme_to_id.get(p, dataset.phoneme_to_id["<UNK>"])
        for p in phoneme_str.split()
    ]

    phonemes = torch.tensor(
        phoneme_ids[:dataset.config.max_text_len]
    ).unsqueeze(0).to(device)

    with torch.no_grad():
        mel = model(phonemes)

    return mel.squeeze(0).cpu()
