# Voice Generator Model Training

This notebook implements a voice generator model using Tacotron2 for text-to-speech synthesis and WaveNet for audio generation.

In [None]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import librosa
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# Load the trained rap lyrics generator model
model_path = './trained_model'
tokenizer = AutoTokenizer.from_pretrained(model_path)
lyrics_model = AutoModelForCausalLM.from_pretrained(model_path)

In [None]:
class RapVoiceDataset(Dataset):
    def __init__(self, audio_dir, lyrics_file, tokenizer, max_audio_length=16000*10):
        self.audio_dir = audio_dir
        self.tokenizer = tokenizer
        self.max_audio_length = max_audio_length
        
        # Load audio files and their corresponding lyrics
        self.audio_files = []
        self.lyrics = []
        
        # TODO: Implement loading of audio files and lyrics
        
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        lyrics = self.lyrics[idx]
        
        # Load and preprocess audio
        audio, sr = torchaudio.load(audio_path)
        if audio.shape[0] > 1:  # Convert stereo to mono
            audio = audio.mean(dim=0, keepdim=True)
        
        # Resample if necessary
        if sr != 22050:
            resampler = torchaudio.transforms.Resample(sr, 22050)
            audio = resampler(audio)
        
        # Trim or pad audio to max_audio_length
        if audio.shape[1] > self.max_audio_length:
            audio = audio[:, :self.max_audio_length]
        else:
            padding = self.max_audio_length - audio.shape[1]
            audio = F.pad(audio, (0, padding))
        
        # Tokenize lyrics
        lyrics_tokens = self.tokenizer.encode(lyrics, return_tensors='pt')
        
        return {
            'audio': audio,
            'lyrics': lyrics_tokens
        }

In [None]:
class Tacotron2(nn.Module):
    def __init__(self, vocab_size, embedding_dim=512, encoder_dim=512, decoder_dim=1024):
        super().__init__()
        
        # Text encoder
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = nn.Sequential(
            nn.Conv1d(embedding_dim, encoder_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(encoder_dim),
            nn.ReLU(),
            nn.Conv1d(encoder_dim, encoder_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(encoder_dim),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.GRU(
            input_size=encoder_dim,
            hidden_size=decoder_dim,
            num_layers=2,
            batch_first=True
        )
        
        # Mel spectrogram prediction
        self.mel_predictor = nn.Sequential(
            nn.Linear(decoder_dim, 80),  # 80 mel bands
            nn.Tanh()
        )
        
    def forward(self, x):
        # x: [batch_size, seq_len]
        embedded = self.embedding(x)  # [batch_size, seq_len, embedding_dim]
        embedded = embedded.transpose(1, 2)  # [batch_size, embedding_dim, seq_len]
        
        # Encode
        encoded = self.encoder(embedded)  # [batch_size, encoder_dim, seq_len]
        encoded = encoded.transpose(1, 2)  # [batch_size, seq_len, encoder_dim]
        
        # Decode
        decoded, _ = self.decoder(encoded)  # [batch_size, seq_len, decoder_dim]
        
        # Predict mel spectrogram
        mel_spec = self.mel_predictor(decoded)  # [batch_size, seq_len, 80]
        
        return mel_spec

In [None]:
class WaveNet(nn.Module):
    def __init__(self, in_channels=80, out_channels=1, layers=20, residual_channels=64, gate_channels=64, skip_channels=64):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.layers = layers
        self.residual_channels = residual_channels
        self.gate_channels = gate_channels
        self.skip_channels = skip_channels
        
        # Initial convolution
        self.start_conv = nn.Conv1d(in_channels, residual_channels, 1)
        
        # Dilated convolutions
        self.dilated_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        
        for layer in range(layers):
            dilation = 2 ** layer
            padding = (3 - 1) * dilation // 2
            
            self.dilated_convs.append(
                nn.Conv1d(residual_channels, gate_channels, 3, padding=padding, dilation=dilation)
            )
            self.gate_convs.append(
                nn.Conv1d(residual_channels, gate_channels, 3, padding=padding, dilation=dilation)
            )
            self.skip_convs.append(nn.Conv1d(gate_channels, skip_channels, 1))
            self.residual_convs.append(nn.Conv1d(gate_channels, residual_channels, 1))
        
        # Final convolutions
        self.final_conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(skip_channels, skip_channels, 1),
            nn.ReLU(),
            nn.Conv1d(skip_channels, out_channels, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # x: [batch_size, in_channels, time]
        x = self.start_conv(x)
        skip = 0
        
        for i in range(self.layers):
            residual = x
            
            # Dilated convolution
            x = self.dilated_convs[i](x)
            g = self.gate_convs[i](residual)
            
            # Gated activation
            x = torch.tanh(x) * torch.sigmoid(g)
            
            # Skip connection
            skip = skip + self.skip_convs[i](x)
            
            # Residual connection
            x = self.residual_convs[i](x) + residual
        
        # Final convolution
        x = self.final_conv(skip)
        
        return x

In [None]:
def train_voice_generator(model, train_loader, optimizer, device, epochs=10):
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        
        for batch in progress_bar:
            audio = batch['audio'].to(device)
            lyrics = batch['lyrics'].to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            mel_spec = model(lyrics)
            
            # Calculate loss (L1 loss between predicted and target mel spectrograms)
            loss = F.l1_loss(mel_spec, audio)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
        
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}')

In [None]:
# Initialize models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tacotron2 = Tacotron2(vocab_size=tokenizer.vocab_size).to(device)
wavenet = WaveNet().to(device)

# Initialize optimizer
optimizer = torch.optim.Adam(list(tacotron2.parameters()) + list(wavenet.parameters()), lr=0.001)

# Create dataset and dataloader
# TODO: Implement dataset creation with actual audio files
dataset = RapVoiceDataset(audio_dir='data/audio', lyrics_file='data/lyrics.txt', tokenizer=tokenizer)
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Train the model
train_voice_generator(tacotron2, train_loader, optimizer, device)

# Save the trained models
torch.save(tacotron2.state_dict(), 'checkpoints/tacotron2_model.pth')
torch.save(wavenet.state_dict(), 'checkpoints/wavenet_model.pth')