# üéôÔ∏è Train Your Own TTS Voice Model
**Tacotron2 Text-to-Speech Training on Google Colab**

This notebook trains a voice model from scratch using the LJSpeech dataset.

## Steps:
1. Enable GPU runtime
2. Install dependencies
3. Upload/download dataset
4. Train the model
5. Download trained model

## 1Ô∏è‚É£ Enable GPU
Go to **Runtime ‚Üí Change runtime type ‚Üí GPU (T4)**

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

## 2Ô∏è‚É£ Install Dependencies

In [None]:
!pip install -q torch torchaudio librosa numpy scipy tqdm matplotlib soundfile

import os

# ============== DETECT PLATFORM (Kaggle first since it also has /content) ==============
IS_KAGGLE = os.path.exists('/kaggle/input')  # More specific Kaggle check
IS_COLAB = os.path.exists('/content') and not IS_KAGGLE  # Colab only if not Kaggle

print(f"üñ•Ô∏è Platform: {'Kaggle' if IS_KAGGLE else 'Colab' if IS_COLAB else 'Local'}")

# ============== SETUP STORAGE ==============
if IS_KAGGLE:
    # Kaggle: use working directory (persists during session)
    CHECKPOINT_BASE = '/kaggle/working/tts_checkpoints'
    
    # Check if user uploaded checkpoint as a dataset
    if os.path.exists('/kaggle/input/tts-checkpoint'):
        print("üìÇ Found uploaded checkpoint dataset!")
        os.makedirs(f'{CHECKPOINT_BASE}/output', exist_ok=True)
        import shutil
        for f in ['best_model.pt', 'latest_model.pt']:
            src = f'/kaggle/input/tts-checkpoint/{f}'
            if os.path.exists(src):
                shutil.copy(src, f'{CHECKPOINT_BASE}/output/{f}')
                print(f"  ‚úÖ Copied {f}")
    print("üíæ Checkpoints: /kaggle/working/tts_checkpoints")
    print("üì• Download from: Output tab (right side) after training")

elif IS_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    CHECKPOINT_BASE = '/content/drive/MyDrive/tts_checkpoints'
    print(f"üíæ Checkpoints: {CHECKPOINT_BASE}")

else:
    CHECKPOINT_BASE = './tts_checkpoints'
    print(f"üíæ Checkpoints: {CHECKPOINT_BASE}")

os.makedirs(CHECKPOINT_BASE, exist_ok=True)

## 3Ô∏è‚É£ Download LJSpeech Dataset

In [None]:
import os
import tarfile
import requests
from tqdm import tqdm
from pathlib import Path

# ============== DATASET SETUP ==============
if IS_KAGGLE and os.path.exists('/kaggle/input/ljspeech11'):
    # Use Kaggle's built-in LJSpeech dataset (no download needed!)
    print("‚úÖ Using Kaggle's LJSpeech dataset (instant!)")
    DATA_DIR = '/kaggle/input/ljspeech11/LJSpeech-1.1'
    wavs_dir = Path(f'{DATA_DIR}/wavs')
else:
    # Download on Colab
    os.makedirs('data', exist_ok=True)
    url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
    tar_path = "data/LJSpeech-1.1.tar.bz2"

    if not os.path.exists(tar_path):
        print("Downloading LJSpeech dataset (~2.6GB)...")
        response = requests.get(url, stream=True)
        total = int(response.headers.get('content-length', 0))
        with open(tar_path, 'wb') as f:
            with tqdm(total=total, unit='B', unit_scale=True) as pbar:
                for chunk in response.iter_content(8192):
                    f.write(chunk)
                    pbar.update(len(chunk))
        print("Download complete!")
    else:
        print("Dataset already downloaded")

    if not os.path.exists('data/LJSpeech-1.1'):
        print("Extracting...")
        with tarfile.open(tar_path, 'r:bz2') as tar:
            tar.extractall('data')
        print("Extraction complete!")
    else:
        print("Already extracted")
    
    DATA_DIR = 'data/LJSpeech-1.1'
    wavs_dir = Path(f'{DATA_DIR}/wavs')

# Prepare metadata
os.makedirs('data/processed', exist_ok=True)
available_wavs = {f.stem for f in wavs_dir.glob('*.wav')}
print(f"Found {len(available_wavs)} audio files")

entries = []
with open(f'{DATA_DIR}/metadata.csv', 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.strip().split('|')
        if len(parts) >= 3 and parts[0] in available_wavs:
            wav_path = wavs_dir / f"{parts[0]}.wav"
            entries.append(f"{wav_path}|{parts[2]}")

with open('data/processed/metadata.txt', 'w') as f:
    f.write('\n'.join(entries))

print(f"‚úÖ Prepared {len(entries)} samples for training")

## 4Ô∏è‚É£ Define Model Architecture (Tacotron2)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass, field
from typing import Tuple, List, Optional, Dict
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

# ============== CONFIGURATION ==============
@dataclass
class AudioConfig:
    sample_rate: int = 22050
    n_fft: int = 1024
    hop_length: int = 256
    win_length: int = 1024
    n_mels: int = 80
    mel_fmin: float = 0.0
    mel_fmax: float = 8000.0

@dataclass
class ModelConfig:
    encoder_embedding_dim: int = 512
    encoder_n_convolutions: int = 3
    encoder_kernel_size: int = 5
    attention_rnn_dim: int = 1024
    attention_dim: int = 128
    attention_location_n_filters: int = 32
    attention_location_kernel_size: int = 31
    decoder_rnn_dim: int = 1024
    prenet_dim: int = 256
    max_decoder_steps: int = 1000
    gate_threshold: float = 0.5
    p_attention_dropout: float = 0.1
    p_decoder_dropout: float = 0.1
    postnet_embedding_dim: int = 512
    postnet_kernel_size: int = 5
    postnet_n_convolutions: int = 5

@dataclass
class TTSConfig:
    audio: AudioConfig = field(default_factory=AudioConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    characters: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 .,!?'-"
    pad_token: str = "_"
    batch_size: int = 16
    learning_rate: float = 1e-3
    weight_decay: float = 1e-6
    epochs: int = 500
    grad_clip_thresh: float = 1.0
    data_path: str = "data/processed/metadata.txt"
    checkpoint_dir: str = "checkpoints"
    output_dir: str = "output"

    @property
    def vocab_size(self) -> int:
        return len(self.characters) + 1

    @property
    def n_mels(self) -> int:
        return self.audio.n_mels

print("‚úÖ Configuration defined")

In [None]:
# ============== TEXT & AUDIO PROCESSING ==============
import soundfile as sf

class TextProcessor:
    def __init__(self, config: TTSConfig):
        self.char_to_idx = {config.pad_token: 0}
        for i, char in enumerate(config.characters):
            self.char_to_idx[char] = i + 1
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}

    def text_to_sequence(self, text: str) -> List[int]:
        return [self.char_to_idx[c] for c in text if c in self.char_to_idx]

class AudioProcessor:
    def __init__(self, config: AudioConfig):
        self.config = config
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=config.sample_rate,
            n_fft=config.n_fft,
            hop_length=config.hop_length,
            win_length=config.win_length,
            n_mels=config.n_mels,
            f_min=config.mel_fmin,
            f_max=config.mel_fmax,
        )

    def load_audio(self, path: str) -> torch.Tensor:
        # Use soundfile instead of torchaudio.load (avoids torchcodec issue)
        audio, sr = sf.read(path)
        waveform = torch.from_numpy(audio).float()
        
        # Handle stereo
        if waveform.dim() > 1:
            waveform = waveform.mean(dim=1)
        
        # Resample if needed
        if sr != self.config.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.config.sample_rate)
        
        return waveform

    def audio_to_mel(self, waveform: torch.Tensor) -> torch.Tensor:
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        mel = self.mel_transform(waveform)
        mel = torch.log(torch.clamp(mel, min=1e-5))
        mel = (mel - mel.mean()) / (mel.std() + 1e-8)
        return mel.squeeze(0)

class LJSpeechDataset(Dataset):
    def __init__(self, metadata_path: str, config: TTSConfig, max_samples: int = None):
        self.text_processor = TextProcessor(config)
        self.audio_processor = AudioProcessor(config.audio)
        self.samples = []
        with open(metadata_path, 'r') as f:
            for i, line in enumerate(f):
                if max_samples and i >= max_samples:
                    break
                parts = line.strip().split('|')
                if len(parts) >= 2 and Path(parts[0]).exists():
                    self.samples.append((parts[0], parts[1]))
        print(f"Loaded {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        audio_path, text = self.samples[idx]
        text_seq = self.text_processor.text_to_sequence(text)
        waveform = self.audio_processor.load_audio(audio_path)
        mel = self.audio_processor.audio_to_mel(waveform)
        return {
            'text': torch.LongTensor(text_seq),
            'text_length': len(text_seq),
            'mel': mel,
            'mel_length': mel.shape[1]
        }

def collate_fn(batch):
    max_text = max(b['text_length'] for b in batch)
    max_mel = max(b['mel_length'] for b in batch)
    n_mels = batch[0]['mel'].shape[0]
    B = len(batch)

    text_padded = torch.zeros(B, max_text, dtype=torch.long)
    mel_padded = torch.zeros(B, n_mels, max_mel)
    gate_padded = torch.zeros(B, max_mel)
    text_lengths = torch.LongTensor([b['text_length'] for b in batch])
    mel_lengths = torch.LongTensor([b['mel_length'] for b in batch])

    for i, b in enumerate(batch):
        text_padded[i, :b['text_length']] = b['text']
        mel_padded[i, :, :b['mel_length']] = b['mel']
        gate_padded[i, b['mel_length']-1:] = 1.0

    return {'text': text_padded, 'text_lengths': text_lengths,
            'mel': mel_padded, 'mel_lengths': mel_lengths, 'gate': gate_padded}

print("‚úÖ Data processing defined")

In [None]:
# ============== TACOTRON2 MODEL ==============
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel, dropout=0.5):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, kernel, padding=(kernel-1)//2)
        self.bn = nn.BatchNorm1d(out_ch)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(F.relu(self.bn(self.conv(x))))

class Encoder(nn.Module):
    def __init__(self, config: TTSConfig):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.model.encoder_embedding_dim)
        self.convs = nn.ModuleList([
            ConvBlock(config.model.encoder_embedding_dim, config.model.encoder_embedding_dim,
                     config.model.encoder_kernel_size)
            for _ in range(config.model.encoder_n_convolutions)
        ])
        self.lstm = nn.LSTM(config.model.encoder_embedding_dim,
                           config.model.encoder_embedding_dim // 2,
                           batch_first=True, bidirectional=True)

    def forward(self, text, lengths):
        x = self.embedding(text).transpose(1, 2)
        for conv in self.convs:
            x = conv(x)
        x = x.transpose(1, 2)
        x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        x, _ = self.lstm(x)
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        return x

class Attention(nn.Module):
    def __init__(self, config: ModelConfig, enc_dim):
        super().__init__()
        self.query = nn.Linear(config.attention_rnn_dim, config.attention_dim, bias=False)
        self.memory = nn.Linear(enc_dim, config.attention_dim, bias=False)
        self.v = nn.Linear(config.attention_dim, 1, bias=False)
        self.loc_conv = nn.Conv1d(2, config.attention_location_n_filters,
                                  config.attention_location_kernel_size,
                                  padding=(config.attention_location_kernel_size-1)//2)
        self.loc_dense = nn.Linear(config.attention_location_n_filters, config.attention_dim, bias=False)

    def forward(self, query, memory, attn_cat, mask=None):
        q = self.query(query.unsqueeze(1))
        k = self.memory(memory)
        loc = self.loc_dense(self.loc_conv(attn_cat).transpose(1, 2))
        e = self.v(torch.tanh(q + k + loc)).squeeze(-1)
        if mask is not None:
            e = e.masked_fill(mask, -float('inf'))
        attn = F.softmax(e, dim=1)
        ctx = torch.bmm(attn.unsqueeze(1), memory).squeeze(1)
        return ctx, attn

class Prenet(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim), nn.Linear(out_dim, out_dim)])

    def forward(self, x):
        for layer in self.layers:
            x = F.dropout(F.relu(layer(x)), 0.5, training=True)
        return x

class Decoder(nn.Module):
    def __init__(self, config: TTSConfig):
        super().__init__()
        self.n_mels = config.n_mels
        self.max_steps = config.model.max_decoder_steps
        self.gate_thresh = config.model.gate_threshold
        enc_dim = config.model.encoder_embedding_dim

        self.prenet = Prenet(config.n_mels, config.model.prenet_dim)
        self.attn_rnn = nn.LSTMCell(config.model.prenet_dim + enc_dim, config.model.attention_rnn_dim)
        self.attention = Attention(config.model, enc_dim)
        self.dec_rnn = nn.LSTMCell(config.model.attention_rnn_dim + enc_dim, config.model.decoder_rnn_dim)
        self.linear = nn.Linear(config.model.decoder_rnn_dim + enc_dim, config.n_mels)
        self.gate = nn.Linear(config.model.decoder_rnn_dim + enc_dim, 1)
        self.attn_drop = nn.Dropout(config.model.p_attention_dropout)
        self.dec_drop = nn.Dropout(config.model.p_decoder_dropout)

    def init_states(self, memory):
        B, T, D = memory.shape
        return (memory.new_zeros(B, self.attn_rnn.hidden_size),
                memory.new_zeros(B, self.attn_rnn.hidden_size),
                memory.new_zeros(B, self.dec_rnn.hidden_size),
                memory.new_zeros(B, self.dec_rnn.hidden_size),
                memory.new_zeros(B, T), memory.new_zeros(B, T),
                memory.new_zeros(B, D))

    def step(self, dec_in, states, memory, mask):
        ah, ac, dh, dc, aw, awc, ctx = states
        pre = self.prenet(dec_in)
        ah, ac = self.attn_rnn(torch.cat([pre, ctx], 1), (ah, ac))
        ah = self.attn_drop(ah)
        ctx, aw = self.attention(ah, memory, torch.stack([aw, awc], 1), mask)
        awc = awc + aw
        dh, dc = self.dec_rnn(torch.cat([ah, ctx], 1), (dh, dc))
        dh = self.dec_drop(dh)
        out = torch.cat([dh, ctx], 1)
        return self.linear(out), self.gate(out), (ah, ac, dh, dc, aw, awc, ctx)

    def forward(self, memory, mel, lengths):
        B, T = memory.shape[:2]
        mask = torch.arange(T, device=memory.device).expand(B, T) >= lengths.unsqueeze(1)
        go = memory.new_zeros(B, self.n_mels)
        inputs = torch.cat([go.unsqueeze(1), mel.transpose(1,2)[:,:-1]], 1)
        states = self.init_states(memory)
        mels, gates = [], []
        for t in range(inputs.size(1)):
            m, g, states = self.step(inputs[:,t], states, memory, mask)
            mels.append(m); gates.append(g)
        return torch.stack(mels, 2), torch.cat(gates, 1)

    def inference(self, memory, lengths=None):
        B, T = memory.shape[:2]
        mask = None
        if lengths is not None:
            mask = torch.arange(T, device=memory.device).expand(B, T) >= lengths.unsqueeze(1)
        dec_in = memory.new_zeros(B, self.n_mels)
        states = self.init_states(memory)
        mels, gates = [], []
        for _ in range(self.max_steps):
            m, g, states = self.step(dec_in, states, memory, mask)
            mels.append(m); gates.append(g)
            if torch.sigmoid(g).item() > self.gate_thresh:
                break
            dec_in = m
        return torch.stack(mels, 2), torch.cat(gates, 1)

class PostNet(nn.Module):
    def __init__(self, config: TTSConfig):
        super().__init__()
        ch = config.model.postnet_embedding_dim
        k = config.model.postnet_kernel_size
        n = config.model.postnet_n_convolutions
        layers = [nn.Sequential(nn.Conv1d(config.n_mels, ch, k, padding=(k-1)//2),
                               nn.BatchNorm1d(ch), nn.Tanh(), nn.Dropout(0.5))]
        for _ in range(n - 2):
            layers.append(nn.Sequential(nn.Conv1d(ch, ch, k, padding=(k-1)//2),
                                       nn.BatchNorm1d(ch), nn.Tanh(), nn.Dropout(0.5)))
        layers.append(nn.Sequential(nn.Conv1d(ch, config.n_mels, k, padding=(k-1)//2),
                                   nn.BatchNorm1d(config.n_mels), nn.Dropout(0.5)))
        self.convs = nn.ModuleList(layers)

    def forward(self, x):
        for conv in self.convs:
            x = conv(x)
        return x

class Tacotron2(nn.Module):
    def __init__(self, config: TTSConfig):
        super().__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.postnet = PostNet(config)

    def forward(self, text, text_len, mel, mel_len):
        enc = self.encoder(text, text_len)
        mel_out, gate_out = self.decoder(enc, mel, text_len)
        mel_post = mel_out + self.postnet(mel_out)
        return mel_out, mel_post, gate_out

    def inference(self, text):
        text_len = torch.LongTensor([text.size(1)]).to(text.device)
        enc = self.encoder(text, text_len)
        mel_out, gate_out = self.decoder.inference(enc, text_len)
        mel_post = mel_out + self.postnet(mel_out)
        return mel_post

print("‚úÖ Tacotron2 model defined")

## 5Ô∏è‚É£ Train the Model

In [None]:
# ============== TRAINING ==============
# Configuration
config = TTSConfig()
config.batch_size = 16  # Reduce to 8 if running on CPU
config.learning_rate = 1e-3

# Choose training mode
QUICK_TEST = False  # Set to False for longer training
RESUME_TRAINING = True  # Set to True to continue from checkpoint

if QUICK_TEST:
    MAX_SAMPLES = 1000
    EPOCHS = 50
    print("üöÄ QUICK TEST MODE: 1000 samples, 50 epochs")
else:
    MAX_SAMPLES = 1000  # Keep same samples for consistency when resuming
    EPOCHS = 200  # Train to 200 epochs total
    print("üöÄ EXTENDED TRAINING: 1000 samples, 200 epochs")

# Setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cpu':
    print("‚ö†Ô∏è Training on CPU - will be slower!")
    config.batch_size = 8

# Use platform-aware checkpoint directories (set in cell 5)
config.checkpoint_dir = f'{CHECKPOINT_BASE}/checkpoints'
config.output_dir = f'{CHECKPOINT_BASE}/output'
print(f"üíæ Saving to: {config.output_dir}")

os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.output_dir, exist_ok=True)

# Create dataset and model
dataset = LJSpeechDataset(config.data_path, config, max_samples=MAX_SAMPLES)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True,
                        collate_fn=collate_fn, num_workers=0)

model = Tacotron2(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate,
                            weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

# Resume from checkpoint if available
start_epoch = 0
best_loss = float('inf')

latest_path = f"{config.output_dir}/latest_model.pt"
best_path = f"{config.output_dir}/best_model.pt"

if RESUME_TRAINING:
    checkpoint_path = None
    if os.path.exists(latest_path):
        checkpoint_path = latest_path
        print(f"\nüìÇ Found latest checkpoint")
    elif os.path.exists(best_path):
        checkpoint_path = best_path
        print(f"\nüìÇ Found best checkpoint")
    
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0) + 1
        best_loss = checkpoint.get('loss', float('inf'))
        print(f"‚úÖ Resumed from epoch {start_epoch}, loss: {best_loss:.4f}")
    else:
        print("\nüÜï No checkpoint found, starting fresh...")
else:
    print("\nüÜï Starting fresh training...")

mse_loss = nn.MSELoss()
bce_loss = nn.BCEWithLogitsLoss()

params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {params:,}")
print(f"Batches per epoch: {len(dataloader)}")
print(f"Training epochs: {start_epoch} ‚Üí {EPOCHS}")

In [None]:
# Training loop (continues from start_epoch)
losses = []

for epoch in range(start_epoch, EPOCHS):
    model.train()
    total_loss = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        text = batch['text'].to(device)
        text_len = batch['text_lengths'].to(device)
        mel = batch['mel'].to(device)
        mel_len = batch['mel_lengths'].to(device)
        gate = batch['gate'].to(device)

        optimizer.zero_grad()
        mel_out, mel_post, gate_out = model(text, text_len, mel, mel_len)

        loss = mse_loss(mel_out, mel) + mse_loss(mel_post, mel) + bce_loss(gate_out, gate)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_thresh)
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    avg_loss = total_loss / len(dataloader)
    losses.append(avg_loss)
    scheduler.step(avg_loss)

    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f} - LR: {optimizer.param_groups[0]['lr']:.2e}")

    # Save best model
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
            'config': config
        }, f"{config.output_dir}/best_model.pt")
        print(f"  ‚úÖ Saved best model (loss: {best_loss:.4f})")

    # ALWAYS save latest checkpoint (so we never lose progress!)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
        'config': config
    }, f"{config.output_dir}/latest_model.pt")

    # Also save numbered checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }, f"{config.checkpoint_dir}/checkpoint_epoch_{epoch+1}.pt")
        print(f"  üíæ Checkpoint saved: epoch {epoch+1}")

print("\nüéâ Training complete!")
print(f"Best loss: {best_loss:.4f}")
print(f"Model saved to: {config.output_dir}/best_model.pt")

In [None]:
# Plot loss curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.savefig(f"{config.output_dir}/loss_curve.png")
plt.show()

## 6Ô∏è‚É£ Test the Model

In [None]:
import scipy.signal
import numpy as np
from IPython.display import Audio

# Load best model (weights_only=False needed for custom config class)
checkpoint = torch.load(f"{config.output_dir}/best_model.pt", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Text to synthesize
text = "Hello, I am your AI tutor. How can I help you today?"

# Convert text to sequence
text_proc = TextProcessor(config)
seq = text_proc.text_to_sequence(text)
text_tensor = torch.LongTensor([seq]).to(device)

# Generate mel spectrogram
with torch.no_grad():
    mel = model.inference(text_tensor)

print(f"Generated mel shape: {mel.shape}")

# ============== USE HIFI-GAN VOCODER (Much better quality!) ==============
# Install and load pretrained HiFi-GAN
try:
    import librosa
    
    # Try to use pretrained HiFi-GAN from torchaudio
    hifigan_available = False
    try:
        bundle = torchaudio.pipelines.HIFIGAN_VOCODER
        hifigan = bundle.get_vocoder().to(device)
        hifigan_available = True
        print("‚úÖ Using HiFi-GAN vocoder (high quality)")
    except:
        print("‚ö†Ô∏è HiFi-GAN not available, using Griffin-Lim (lower quality)")
    
    if hifigan_available:
        # HiFi-GAN expects specific mel format - denormalize our mel
        mel_for_vocoder = mel.squeeze(0)  # Remove batch dim
        
        # Denormalize (reverse our normalization)
        mel_denorm = mel_for_vocoder * 2.5  # Scale up
        
        with torch.no_grad():
            audio_tensor = hifigan(mel_denorm)
        
        audio = audio_tensor.squeeze().cpu().numpy()
    else:
        # Fallback to improved Griffin-Lim
        def griffin_lim_improved(mel, config, n_iter=100):
            mel = mel.squeeze().cpu().numpy()
            
            # Denormalize
            mel = mel * 2.5 + 4  # Reverse normalization approximately
            mel = np.exp(mel)
            mel = np.clip(mel, 0, 1000)  # Prevent explosion
            
            # Create mel filterbank
            mel_basis = torchaudio.functional.melscale_fbanks(
                n_freqs=config.audio.n_fft // 2 + 1,
                f_min=config.audio.mel_fmin,
                f_max=config.audio.mel_fmax,
                n_mels=config.audio.n_mels,
                sample_rate=config.audio.sample_rate
            ).numpy().T
            
            # Inverse mel to linear
            linear = np.maximum(1e-10, np.dot(np.linalg.pinv(mel_basis), mel))
            
            # Griffin-Lim with more iterations
            angles = np.exp(2j * np.pi * np.random.rand(*linear.shape))
            for i in range(n_iter):
                full = linear * angles
                audio = scipy.signal.istft(
                    full, 
                    fs=config.audio.sample_rate,
                    nperseg=config.audio.win_length,
                    noverlap=config.audio.win_length - config.audio.hop_length
                )[1]
                
                if i < n_iter - 1:
                    _, _, new_spec = scipy.signal.stft(
                        audio, 
                        fs=config.audio.sample_rate,
                        nperseg=config.audio.win_length,
                        noverlap=config.audio.win_length - config.audio.hop_length
                    )
                    angles = np.exp(1j * np.angle(new_spec[:linear.shape[0], :]))
            
            return audio.astype(np.float32)
        
        audio = griffin_lim_improved(mel, config)

except Exception as e:
    print(f"Error: {e}")
    print("Using basic Griffin-Lim...")
    # Basic fallback
    mel_np = mel.squeeze().cpu().numpy()
    mel_np = np.exp(mel_np)
    audio = librosa.feature.inverse.mel_to_audio(
        mel_np,
        sr=config.audio.sample_rate,
        n_fft=config.audio.n_fft,
        hop_length=config.audio.hop_length,
        win_length=config.audio.win_length,
        n_iter=100
    )

# Normalize audio
audio = audio / (np.abs(audio).max() + 1e-8) * 0.9

print(f"Generated {len(audio) / config.audio.sample_rate:.2f} seconds of audio")
print("\n‚ö†Ô∏è NOTE: For clear speech, train for 200+ epochs with QUICK_TEST = False")

# Play audio
Audio(audio, rate=config.audio.sample_rate)

## 7Ô∏è‚É£ Download Trained Model

In [None]:
# Download the trained model
import shutil

if IS_KAGGLE:
    # On Kaggle: copy to /kaggle/working so it appears in Output tab
    output_file = f"{config.output_dir}/best_model.pt"
    kaggle_output = "/kaggle/working/best_model.pt"
    if os.path.exists(output_file):
        shutil.copy(output_file, kaggle_output)
        print(f"‚úÖ Model copied to: {kaggle_output}")
        print("üì• Download from the 'Output' tab on the right side of the notebook!")
    else:
        print("‚ùå No model found to download")
elif IS_COLAB:
    from google.colab import files
    files.download(f"{config.output_dir}/best_model.pt")
    print("‚úÖ Model downloaded!")
else:
    print(f"‚úÖ Model saved to: {config.output_dir}/best_model.pt")

## üìù Notes for Your Professor

This notebook implements:

1. **Tacotron2 Architecture** - State-of-the-art sequence-to-sequence TTS
   - Encoder: Character embedding + 3 conv layers + BiLSTM
   - Attention: Location-sensitive attention mechanism
   - Decoder: Autoregressive mel spectrogram generation
   - PostNet: 5 conv layers to refine output

2. **Training from Scratch**
   - Public LJSpeech dataset (24 hours, 13,100 samples)
   - Custom data pipeline for audio processing
   - Mel spectrogram features (80 mel bins)

3. **Complete ML Pipeline**
   - Data loading and preprocessing
   - Model definition
   - Training loop with checkpointing
   - Loss computation (MSE + BCE)
   - Inference and audio synthesis