# Sauti ya Kenya - TTS Training

This notebook trains the Kenyan Swahili TTS model using FastSpeech 2 architecture with memory optimizations.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repository and install dependencies
!git clone https://github.com/Msingi-AI/Sauti-Ya-Kenya.git
%cd Sauti-Ya-Kenya
!pip install -r requirements.txt
!pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118

In [None]:
# Copy data from Drive (if stored there)
!mkdir -p data/processed
!cp -r "/content/drive/MyDrive/Sauti-Ya-Kenya/data/processed/*" data/processed/
!cp -r "/content/drive/MyDrive/Sauti-Ya-Kenya/data/tokenizer" data/tokenizer/

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from src.preprocessor import SwahiliTokenizer
from src.model import FastSpeech2
from src.dataset import TTSDataset
from src.config import ModelConfig

# Memory optimizations
torch.backends.cudnn.benchmark = True
scaler = GradScaler()

# Load config
config = ModelConfig()

# Initialize tokenizer
tokenizer = SwahiliTokenizer(vocab_size=8000)
tokenizer.load('data/tokenizer/tokenizer.model')

# Create dataset and dataloader
dataset = TTSDataset(
    data_dir='data/processed',
    metadata_file='data/processed/metadata.csv',
    tokenizer=tokenizer
)

dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

# Initialize model
model = FastSpeech2(config).cuda()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

# Training loop with memory optimizations
def train_epoch(model, dataloader, optimizer, scaler, grad_accum_steps=4):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    for i, batch in enumerate(dataloader):
        # Move batch to GPU
        text_ids = batch['text_ids'].cuda()
        mel_target = batch['mel_target'].cuda()
        duration = batch['duration'].cuda()
        
        # Forward pass with mixed precision
        with autocast():
            mel_output, duration_pred = model(text_ids)
            loss = model.loss(mel_output, mel_target, duration_pred, duration)
            loss = loss / grad_accum_steps
        
        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        
        if (i + 1) % grad_accum_steps == 0:
            # Unscale gradients for clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Update weights
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += loss.item() * grad_accum_steps
        
        # Print progress
        if i % 10 == 0:
            print(f'Batch {i}, Loss: {loss.item():.4f}')
            
    return total_loss / len(dataloader)

In [None]:
# Training configuration
num_epochs = 100
save_interval = 10
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Training loop
for epoch in range(num_epochs):
    loss = train_epoch(model, dataloader, optimizer, scaler)
    print(f'Epoch {epoch}, Loss: {loss:.4f}')
    
    # Save checkpoint
    if (epoch + 1) % save_interval == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }
        torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_epoch_{epoch}.pt')
        
        # Save to Drive
        drive_path = '/content/drive/MyDrive/Sauti-Ya-Kenya/checkpoints'
        os.makedirs(drive_path, exist_ok=True)
        torch.save(checkpoint, f'{drive_path}/checkpoint_epoch_{epoch}.pt')