In [1]:
!pip install pytorch_lightning
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl

from librosa.core import load
from librosa.util import normalize

from pathlib import Path
import numpy as np
import random

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m52.5 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.11.0-py3-none-any.whl (25 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.13.0->pytorch_lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m58.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.13.0->pytorch_lightning)
  Downloading nvidia_cuda_runtime

In [2]:
class Generator(nn.Module):
    def __init__(self, mel_dim, audio_dim):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels=mel_dim, out_channels=128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(128)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv1d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(512)
        self.relu3 = nn.ReLU(inplace=True)

        self.deconv1 = nn.ConvTranspose1d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.relu3 = nn.ReLU(inplace=True)

        self.deconv1 = nn.ConvTranspose1d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm1d(128)
        self.relu3 = nn.ReLU(inplace=True)

        self.deconv2 = nn.ConvTranspose1d(in_channels=128, out_channels=audio_dim, kernel_size=3, stride=2, padding=1)

        self.output = nn.Linear(audio_dim, audio_dim)

    def forward(self, mel_spectrogram):
        x = self.conv1(mel_spectrogram)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = self.deconv1(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = self.deconv1(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.deconv2(x)
        output = self.output(x)

        return output


In [3]:
class Discriminator(nn.Module):
    def __init__(self, mel_dim):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels=mel_dim, out_channels=128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(128)
        self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.conv2 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(256)
        self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.conv2 = nn.Conv1d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(512)
        self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(256, 1)

    def forward(self, mel_spectrogram):
        x = self.conv1(mel_spectrogram)
        x = self.bn1(x)
        x = self.leaky_relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.leaky_relu2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.leaky_relu3(x)

        x = self.pool(x)
        x = x.view(x.size(0), -1)
        output = self.fc(x)
        return output

In [5]:
from audio_data import AudioDataset

class GANModule(pl.LightningModule):
    def __init__(self, mel_dim, audio_dim, learning_rate=1e-3, batch_size=32, training_files="train.txt", num_workers=4):
        super().__init__()

        self.generator = Generator(mel_dim, audio_dim)
        self.discriminator = Discriminator(mel_dim)

        self.learning_rate = learning_rate
        self.batch_size = batch_size

        self.train_data = AudioDataset(training_files, segment_length=16000, sampling_rate=22050)  # Adjust segment length as needed
        self.train_loader = DataLoader(self.train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    def forward(self, mel_spectrogram):
        generated_audio = self.generator(mel_spectrogram)
        return generated_audio

    def training_step(self, batch, batch_idx):
        real_mel_spectrogram, _ = batch

        with torch.no_grad():
            noise = torch.randn(real_mel_spectrogram.size(0), 100, 1).to(self.device)
            fake_mel_spectrogram = self.generator(noise)

        real_discrimination = self.discriminator(real_mel_spectrogram)
        fake_discrimination = self.discriminator(fake_mel_spectrogram.detach())
        d_loss = F.binary_cross_entropy_with_logits(real_discrimination, torch.ones_like(real_discrimination)) + \
                 F.binary_cross_entropy_with_logits(fake_discrimination, torch.zeros_like(fake_discrimination))

        g_loss = F.binary_cross_entropy_with_logits(self.discriminator(fake_mel_spectrogram), torch.ones_like(fake_discrimination))

        self.log("d_loss", d_loss, on_step=True, prog_bar=True, logger=True)
        self.log("g_loss", g_loss, on_step=True, prog_bar=True, logger=True)
        return d_loss + g_loss

    def configure_optimizers(self):
        optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate)
        optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate)
        return [optimizer_d, optimizer_g], []

