In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [257]:
# Tacotron 2 Model definition
class Tacotron2(nn.Module):
    def __init__(self, n_mel_channels=80, n_symbols=2000, n_speakers=1):
        super(Tacotron2, self).__init__()
        self.embedding = nn.Embedding(n_symbols, 512)
        self.encoder = nn.LSTM(512, 512, batch_first=True)
        self.decoder = nn.LSTM(512 + n_mel_channels, 512, batch_first=True)  # Adjusted to handle concatenated inputs
        self.mel_projection = nn.Linear(512, n_mel_channels)

    def forward(self, text, mel_input):
        embedded_text = self.embedding(text)
        encoder_outputs, (h_n, c_n) = self.encoder(embedded_text)
        
        # Repeat encoder hidden state across time steps
        repeated_h_n = h_n.repeat(mel_input.size(1), 1, 1).permute(1, 0, 2)
        repeated_c_n = c_n.repeat(mel_input.size(1), 1, 1).permute(1, 0, 2)

        # Concatenate encoder outputs with mel_input for the decoder
        decoder_input = torch.cat((repeated_h_n, mel_input), dim=-1)

        decoder_outputs, _ = self.decoder(decoder_input, (h_n, c_n))
        mel_outputs = self.mel_projection(decoder_outputs)
        return mel_outputs


In [None]:
import librosa
import numpy as np
import torch.utils.data as data

In [258]:
# Dataset class
class TTSDataset(data.Dataset):
    def __init__(self, text_files, mel_files):
        self.text_files = text_files
        self.mel_files = mel_files

    def __len__(self):
        return len(self.text_files)

    def __getitem__(self, idx):
        text = self.load_text(self.text_files[idx])
        mel = self.load_mel(self.mel_files[idx])
        return text, mel

    def load_text(self, text_file):
        with open(text_file, 'r', encoding='utf-8') as f:
            text = f.read().strip()
        # Convert text to Unicode code points
        text_indices = [ord(char) for char in text]
        return torch.tensor(text_indices, dtype=torch.long)

    def load_mel(self, mel_file):
        mel, _ = librosa.load(mel_file, sr=22050)
        mel = librosa.feature.melspectrogram(y=mel, sr=22050, n_mels=80)
        mel = librosa.power_to_db(mel, ref=np.max)
        return torch.tensor(mel, dtype=torch.float32).T  # Transpose for (time, mel)


In [259]:
# Padding function
def collate_fn(batch):
    texts, mels = zip(*batch)
    
    text_lengths = [len(t) for t in texts]
    mel_lengths = [len(m) for m in mels]
    
    max_text_len = max(text_lengths)
    max_mel_len = max(mel_lengths)
    
    padded_texts = torch.zeros(len(texts), max_text_len, dtype=torch.long)
    padded_mels = torch.zeros(len(mels), max_mel_len, mels[0].size(1))
    
    for i in range(len(texts)):
        padded_texts[i, :text_lengths[i]] = texts[i]
        padded_mels[i, :mel_lengths[i], :] = mels[i]
    
    return padded_texts, padded_mels


In [None]:
import pandas as pd
import arabic_reshaper
from bidi.algorithm import get_display
import os


In [None]:
metadata = pd.read_csv('archive/metadata.csv')

In [None]:
def find_file_name(row):
  # print(row.values)
  value = row.values[0]
  # print(value)
  return value[value.find('|')+1:]

def find_file_text(row):
  value = row.values[0]
  return value[:value.find('|')].replace('\u200c', '')

In [None]:
metadata['file'] = np.array(["archive/wavs/"+find_file_name(row) for row in metadata.iloc])
metadata['file_name'] = np.array([find_file_name(row) for row in metadata.iloc])

In [None]:
metadata['text'] = np.array([find_file_text(row) for row in metadata.iloc])
metadata['text_file'] = np.array(["archive/text/"+find_file_name(row)[:-3]+"txt" for row in metadata.iloc])

In [None]:
for row in metadata.iloc:
    file_path = os.path.join('archive/text', row['file_name'][:-3]+"txt")
    with open(file_path, 'w', encoding='utf-8') as file:
        file.write(row['text'])

In [261]:
# text_files = ['path/to/text1.txt', 'path/to/text2.txt', ...]
# mel_files = ['path/to/mel1.wav', 'path/to/mel2.wav', ...]

dataset = TTSDataset(metadata['text_file'].values, metadata['file'].values)
dataloader = data.DataLoader(dataset, batch_size=200, shuffle=True, collate_fn=collate_fn)


In [262]:
# Training loop
def train(model, dataloader, num_epochs=10, learning_rate=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    model.train()
    for epoch in range(num_epochs):
        for text, mel in dataloader:
            optimizer.zero_grad()
            mel_input = mel[:, :-1, :]
            mel_target = mel[:, 1:, :]
            # print(len(text[0]))

            mel_pred = model(text, mel_input)
            loss = criterion(mel_pred, mel_target)

            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

model = Tacotron2()
train(model, dataloader)

Epoch 1/10, Loss: 506.21783447265625
Epoch 2/10, Loss: 502.0771179199219


KeyboardInterrupt: 

In [None]:
def synthesize(model, text):
    model.eval()
    with torch.no_grad():
        text = torch.tensor([ord(c) for c in text]).unsqueeze(0)
        mel_input = torch.zeros((1, 1, 80))
        mel_output = model(text, mel_input)
    return mel_output.squeeze().cpu().numpy()

mel_output = synthesize(model, "Sample text")
