In [1]:
import torch
import torchaudio
from torch import nn
from t2spec_converter import TextToSpecConverter
import soundfile as sf
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

In [3]:
device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
class PreprocessedLJSpeech(Dataset):
    def __init__(self, root="./"):
        self.dataset = torchaudio.datasets.LJSPEECH(root=root, download=True)
        self.sample_rate = 22050
        self.mel_transform = MelSpectrogram(
            sample_rate=self.sample_rate,
            n_fft=1024,
            win_length=1024,
            hop_length=256,
            n_mels=80,
            f_min=0.0,
            f_max=8000.0,
            power=1.5,
        )
        self.amplitude_to_db = AmplitudeToDB(stype="power")

    def __getitem__(self, idx):
        waveform, sr, transcript, *_ = self.dataset[idx]
        assert sr == self.sample_rate, "Sample rate mismatch"
        
        if waveform.shape[0] > 1:
            waveform = waveform[0:1]

        mel_spec = self.mel_transform(waveform)
        mel_spec_db = self.amplitude_to_db(mel_spec)
        mel_spec_db = mel_spec_db.squeeze(0)  # [n_mels, T]
        
        return mel_spec_db, waveform.squeeze(0), transcript

    def __len__(self):
        return len(self.dataset)


In [None]:
dataset = PreprocessedLJSpeech()

# Даталоадер
def collate_fn(batch):
    # batch: list of tuples (mel_spec, audio, transcript)
    mel_specs, audios, transcripts = zip(*batch)
    
    # паддинг мел-спектрограмм и аудио для батча (если нужно)
    mel_lengths = [m.shape[1] for m in mel_specs]
    max_mel_len = max(mel_lengths)
    
    padded_mels = []
    for m in mel_specs:
        pad = max_mel_len - m.shape[1]
        if pad > 0:
            m = torch.nn.functional.pad(m, (0, pad))
        padded_mels.append(m)
    mel_batch = torch.stack(padded_mels)
    
    # Аналогично с аудио (для удобства)
    audio_lengths = [a.shape[0] for a in audios]
    max_audio_len = max(audio_lengths)
    padded_audios = []
    for a in audios:
        pad = max_audio_len - a.shape[0]
        if pad > 0:
            a = torch.nn.functional.pad(a, (0, pad))
        padded_audios.append(a)
    audio_batch = torch.stack(padded_audios)
    
    return mel_batch, audio_batch, transcripts

# Пример DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Цикл проверки
for mel_batch, audio_batch, transcripts in dataloader:
    print("Mel shape:", mel_batch.shape)
    print("Audio shape:", audio_batch.shape)
    print("Text:", transcripts[:2])
    break


Mel shape: torch.Size([8, 80, 824])
Audio shape: torch.Size([8, 210845])
Text: ("Observing his blood-covered chest as he was pulled into his wife's lap, Governor Connally believed himself mortally wounded.", 'A simultaneous attack was made upon the captain and the first mate.')


In [14]:
class Generator(nn.Module):
    def __init__(self, upsample_factors=(4, 4, 4), n_mels=80):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(n_mels, 512, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.ConvTranspose1d(
                512,
                256,
                kernel_size=upsample_factors[0] * 2,
                stride=upsample_factors[0],
                padding=upsample_factors[0] // 2 + 1,
                output_padding=1,
            ),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(
                256,
                128,
                kernel_size=upsample_factors[1] * 2,
                stride=upsample_factors[1],
                padding=upsample_factors[1] // 2 + 1,
                output_padding=1,
            ),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(
                128,
                64,
                kernel_size=upsample_factors[2] * 2,
                stride=upsample_factors[2],
                padding=upsample_factors[2] // 2 + 1,
                output_padding=1,
            ),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 1, kernel_size=7, padding=3),
            nn.Tanh(),
        )

    def forward(self, mel):
        return self.model(mel)

In [18]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(1, 64, 15, stride=1, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, 15, stride=4, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 256, kernel_size=15, stride=4, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(256, 512, kernel_size=15, stride=4, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(512, 1, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, audio):
        out = self.model(audio)
        return out.mean(dim=[1, 2])

In [10]:
def generator_loss(d_fake, gen_audio, real_audio):
    l1 = F.l1_loss(gen_audio, real_audio)
    adv = F.mse_loss(d_fake, torch.ones_like(d_fake))
    return l1 + 0.001 * adv  

def discriminator_loss(d_real, d_fake):
    real_loss = F.mse_loss(d_real, torch.ones_like(d_real))
    fake_loss = F.mse_loss(d_fake, torch.zeros_like(d_fake))
    return real_loss + fake_loss

In [19]:
num_epochs = 5

generator = Generator()
generator.to(device)

discriminator = Discriminator()
discriminator.to(device)

g_opt = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.9))
d_opt = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.9))

In [20]:

for epoch in range(num_epochs):
    for mel, real_audio, _ in dataloader:
        mel = mel.to(device)
        real_audio = real_audio.to(device)

        # Generator
        gen_audio = generator(mel)
        d_fake = discriminator(gen_audio)
        g_loss = generator_loss(d_fake, gen_audio, real_audio)
        g_opt.zero_grad()
        g_loss.backward()
        g_opt.step()

        # Discriminator
        d_real = discriminator(real_audio)
        d_fake = discriminator(gen_audio.detach())
        d_loss = discriminator_loss(d_real, d_fake)
        d_opt.zero_grad()
        d_loss.backward()
        d_opt.step()
        
        print(f"Epoch {epoch} | G_loss: {g_loss.item():.4f} | D_loss: {d_loss.item():.4f}")

  l1 = F.l1_loss(gen_audio, real_audio)


RuntimeError: The size of tensor a (47467) must match the size of tensor b (189853) at non-singleton dimension 2

In [None]:
t2s = TextToSpecConverter()


generator.load_state_dict(torch.load("generator.pt"))
generator.eval()

with open("test_sentences.txt") as f:
    sentences = [line.strip() for line in f.readlines()]

for i, sent in enumerate(sentences[:5]):
    mel = torch.tensor(t2s.text2spec(sent)).unsqueeze(0).to(device)  # [1, 80, T]
    with torch.no_grad():
        audio = generator(mel).cpu().squeeze().numpy()
    sf.write(f"gen_{i}.wav", audio, samplerate=22050)