In [3]:
import os
import librosa
import wandb
import numpy as np
import multiprocessing as mp

import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose
from torch.utils.data import random_split

# wandb.init(project='TransformerWaveNet')

In [34]:
def load_audio(audio_path, sample_rate=22050, duration=5):
    # Load audio file with librosa, automatically resampling to the given sample rate
    audio, sr = librosa.load(audio_path, sr=sample_rate, duration=duration)
    
    # Calculate target number of samples
    target_length = sample_rate * duration
    
    # Pad audio if it is shorter than the target length
    if len(audio) < target_length:
        padding = target_length - len(audio)
        audio = np.pad(audio, (0, padding), mode='constant')
    # Truncate audio if it is longer than the target length
    elif len(audio) > target_length:
        audio = audio[:target_length]
    
    return audio

def get_spectrogram(audio, n_fft=2048, hop_length=512, max_length=130):
    # Generate a spectrogram
    spectrogram = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    # Convert to magnitude (amplitude)
    spectrogram = np.abs(spectrogram)
    
    # Pad or truncate the spectrogram to ensure all are the same length
    if spectrogram.shape[1] < max_length:
        padding = max_length - spectrogram.shape[1]
        spectrogram = np.pad(spectrogram, ((0, 0), (0, padding)), mode='constant')
    else:
        spectrogram = spectrogram[:, :max_length]
    
    return spectrogram

class AudioDataset(Dataset):
    def __init__(self, root_dir, sample_rate=22050, n_fft=2048, hop_length=512, max_length=130):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.max_length = max_length
        self.files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(root_dir) for f in filenames if f.endswith('.mp3') or f.endswith('.wav')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        audio_path = self.files[idx]
        audio = load_audio(audio_path, self.sample_rate)
        spectrogram = get_spectrogram(audio, self.n_fft, self.hop_length, self.max_length)
        return audio, spectrogram

if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)

    dataset = AudioDataset(root_dir='DATA')
    loader = DataLoader(dataset, batch_size=10, shuffle=True)

In [35]:
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = int(total_size * val_ratio)
    test_size = total_size - train_size - val_size  # Ensure all data is used

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    return train_dataset, val_dataset, test_dataset

data_folder_path = 'DATA'
# dataset = AudioDataset(root_dir=data_folder_path)

# Assuming 'dataset' is an instance of AudioDataset
train_dataset, val_dataset, test_dataset = split_dataset(dataset)

# Create DataLoaders for each dataset split
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=0)

In [36]:
# Check if the dataset is correctly set up
print("Number of samples in dataset:", len(train_dataset))

# Create a DataLoader instance (make sure parameters like batch_size are set correctly)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

# Try to fetch a single batch to see if it works
try:
    data = next(iter(train_loader))
    print("Single batch loaded successfully:", data)
except Exception as e:
    print("Failed to load a batch:", e)

Number of samples in dataset: 407
Single batch loaded successfully: [tensor([[ 1.1482e-02,  2.9690e-02,  3.3771e-02,  ..., -5.3938e-03,
         -3.2458e-04,  5.0673e-02],
        [-3.3641e-02, -6.0243e-02, -5.2029e-02,  ...,  1.2704e-02,
          1.7610e-02,  9.4385e-03],
        [-6.5417e-03, -8.3363e-03, -1.1754e-03,  ..., -6.6309e-02,
         -6.9432e-02, -5.7693e-02],
        ...,
        [-4.2689e-02, -7.9236e-02, -8.8070e-02,  ..., -2.9538e-01,
         -2.9291e-01, -3.2033e-01],
        [-2.9800e-03, -1.3442e-02, -7.7015e-03,  ..., -7.4696e-02,
         -8.0101e-02, -7.1932e-02],
        [-1.3502e-03, -5.8114e-05,  4.2739e-04,  ...,  1.0737e-01,
          9.4365e-02,  9.3138e-02]]), tensor([[[6.7074e-02, 4.5210e-03, 2.2363e-02,  ..., 3.4480e-02,
          1.6141e-02, 2.5023e-01],
         [1.1363e-01, 1.1417e-01, 1.4535e-01,  ..., 9.4468e-02,
          1.3799e-01, 1.4693e-01],
         [1.5949e-01, 1.8849e-01, 1.8715e-01,  ..., 1.2030e-01,
          8.3797e-02, 1.1366e-01],
 

In [37]:
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Training set size: 407
Validation set size: 87
Test set size: 88


In [62]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerWaveNet(nn.Module):
    def __init__(self, audio_channels=1, num_channels=64, kernel_size=2, num_blocks=4, num_layers=10, num_heads=8):
        super(TransformerWaveNet, self).__init__()
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.dilated_convs = nn.ModuleList()
        self.condition_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()

        # Initial convolution layer for raw audio
        self.audio_conv = nn.Conv1d(audio_channels, num_channels, 1)
        # self.audio_conv = nn.Conv1d(10, out_channels, kernel_size)

        # Initial convolution layer for spectrogram
        self.spectrogram_conv = nn.Conv1d(audio_channels, num_channels, 1)

        # Transformer block
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=num_channels, nhead=num_heads, dim_feedforward=num_channels * 4, batch_first=True),
            num_layers=3)

        # Dilated convolutions and condition convolutions
        for _ in range(num_blocks):
            for i in range(num_layers):
                dilation = 2 ** i
                self.dilated_convs.append(nn.Conv1d(num_channels, 2 * num_channels, kernel_size, dilation=dilation, padding=dilation))
                self.condition_convs.append(nn.Conv1d(num_channels, 2 * num_channels, 1))
                self.residual_convs.append(nn.Conv1d(num_channels, num_channels, 1))
                self.skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))

        # Output layers
        self.final_conv1 = nn.Conv1d(num_channels, num_channels, 1)
        self.final_conv2 = nn.Conv1d(num_channels, audio_channels, 1)

    def forward(self, audio, spectrogram):
        # Process audio and spectrogram
        audio = self.audio_conv(audio)
        spectrogram = self.spectrogram_conv(spectrogram)

        # Combine audio and spectrogram
        x = audio + spectrogram

        # Transformer processing
        x = self.transformer(x)
        
        skip_connections = []

        for b in range(self.num_blocks):
            for l in range(self.num_layers):
                # Dilated convolution
                dilated = self.dilated_convs[b * self.num_layers + l](x)
                # Split for gated activation
                filtered, gate = torch.split(dilated, dilated.size(1) // 2, dim=1)
                x = torch.tanh(filtered) * torch.sigmoid(gate)
                # Residual and skip connections
                x = self.residual_convs[b * self.num_layers + l](x)
                skip = self.skip_convs[b * self.num_layers + l](x)
                skip_connections.append(skip)

        # Sum all skip connections
        x = torch.sum(torch.stack(skip_connections), dim=0)

        # Final convolutions
        x = F.relu(self.final_conv1(x))
        x = self.final_conv2(x)

        return x

    def generate(self, audio, spectrogram):
        """
        Generate audio using the model in an autoregressive manner.
        Assumes the model is already trained and in eval mode.
        """
        self.eval()  # Ensure the model is in evaluation mode
        with torch.no_grad():  # No need to track gradients
            # Assuming the inputs are already on the correct device and preprocessed
            generated_audio = self.forward(audio, spectrogram)

            # Post-processing if necessary (e.g., applying a sigmoid to ensure output is in the correct range)
            generated_audio = torch.sigmoid(generated_audio)  # Example post-processing

        return generated_audio

In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Load a pre-trained VGGish model for audio feature extraction
vggish = torch.hub.load('harritaylor/torchvggish', 'vggish')

# Define the Perceptual Loss using VGGish as the feature extractor
class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.feature_extractor.eval()  # Set to evaluation mode

    def forward(self, generated_audio, target_audio):
        with torch.no_grad():
            real_features = self.feature_extractor(target_audio)
        generated_features = self.feature_extractor(generated_audio)
        loss = F.l1_loss(generated_features, real_features)
        return loss

perceptual_loss = PerceptualLoss(vggish)

Using cache found in C:\Users\rahat/.cache\torch\hub\harritaylor_torchvggish_master


In [64]:
class MultiScaleSpectrogramLoss(nn.Module):
    def __init__(self, scales=[1024, 2048, 4096]):
        super(MultiScaleSpectrogramLoss, self).__init__()
        self.scales = scales

    def forward(self, generated_audio, target_audio):
        loss = 0
        for scale in self.scales:
            gen_spec = torch.stft(generated_audio, n_fft=scale, return_complex=True)
            target_spec = torch.stft(target_audio, n_fft=scale, return_complex=True)
            loss += F.l1_loss(gen_spec.abs(), target_spec.abs())
        return loss / len(self.scales)

spectrogram_loss = MultiScaleSpectrogramLoss()

In [65]:
# For demonstration, let's assume we have a simple CNN as a discriminator
class SimpleAudioDiscriminator(nn.Module):
    def __init__(self):
        super(SimpleAudioDiscriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(16 * 16, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

    def intermediate_forward(self, x):
        x = F.relu(self.conv1(x))
        return x

discriminator = SimpleAudioDiscriminator()

class FeatureMatchingLoss(nn.Module):
    def __init__(self, discriminator):
        super(FeatureMatchingLoss, self).__init__()
        self.discriminator = discriminator
        self.discriminator.eval()

    def forward(self, generated_audio, target_audio):
        with torch.no_grad():
            real_features = self.discriminator.intermediate_forward(target_audio)
        generated_features = self.discriminator.intermediate_forward(generated_audio)
        loss = F.l1_loss(generated_features, real_features)
        return loss

feature_matching_loss = FeatureMatchingLoss(discriminator)

In [66]:
# Example of a composite loss
class CompositeLoss(nn.Module):
    def __init__(self, perceptual_loss, spectrogram_loss, feature_matching_loss):
        super(CompositeLoss, self).__init__()
        self.perceptual_loss = perceptual_loss
        self.spectrogram_loss = spectrogram_loss
        self.feature_matching_loss = feature_matching_loss

    def forward(self, generated_audio, target_audio):
        loss = (self.perceptual_loss(generated_audio, target_audio) +
                self.spectrogram_loss(generated_audio, target_audio) +
                self.feature_matching_loss(generated_audio, target_audio))
        return loss

In [67]:
import torch
from torch.utils.data import DataLoader
import os
import wandb  # Ensure wandb is imported if you're using it

def train(model, train_loader, val_loader, optimizer, criterion, epochs, device):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for i, (audio, spectrogram) in enumerate(train_loader):
            print("Here")
            audio, spectrogram = audio.to(device), spectrogram.to(device)
            optimizer.zero_grad()
            output = model(audio, spectrogram)
            loss = criterion(output, audio)  # Adjusted to use a proper loss function
            loss.backward()
            optimizer.step()

            # Log loss to wandb
            wandb.log({"train_loss": loss.item()})
            print(f"Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}")
            # Save model checkpoint
            torch.save(model.state_dict(), f'TW_Checkpoint/model_TW_{epoch}.pt')

            # Generate synthetic data and add to train_loader
            if i % 10 == 0:  # Every 10 iterations, generate synthetic data
                with torch.no_grad():
                    synthetic_audio = model.generate(audio, spectrogram)
                train_loader.dataset.append((synthetic_audio, spectrogram))

        epoch_loss /= len(train_loader)
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {epoch_loss}")
        wandb.log({"epoch_loss": epoch_loss})

        # Validation loop
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for audio, spectrogram in val_loader:
                audio, spectrogram = audio.to(device), spectrogram.to(device)
                output = model(audio, spectrogram)
                val_loss += criterion(output, audio).item()
            val_loss /= len(val_loader)

        # Log validation loss to wandb
        wandb.log({"val_loss": val_loss})

In [68]:
from torch.optim import Adam

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerWaveNet().to(device)
print(model)
optimizer = Adam(model.parameters(), lr=0.001)
# composite_loss = CompositeLoss(perceptual_loss, spectrogram_loss, feature_matching_loss)
composite_loss = CompositeLoss(perceptual_loss, spectrogram_loss, feature_matching_loss)

# train(model, train_loader, val_loader, optimizer, composite_loss, epochs=50, device=device)

TransformerWaveNet(
  (dilated_convs): ModuleList(
    (0): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(1,))
    (1): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(2,), dilation=(2,))
    (2): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(4,), dilation=(4,))
    (3): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(8,), dilation=(8,))
    (4): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(16,), dilation=(16,))
    (5): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(32,), dilation=(32,))
    (6): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(64,), dilation=(64,))
    (7): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(128,), dilation=(128,))
    (8): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(256,), dilation=(256,))
    (9): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(512,), dilation=(512,))
    (10): Conv1d(64, 128, kernel_size=(2,), stride=(1,), padding=(1,))
    (11): Conv1d(64, 128,

NameError: name 'audio' is not defined

In [69]:
train(model, train_loader, val_loader, optimizer, composite_loss, epochs=50, device=device)

Here


RuntimeError: Given groups=1, weight of size [64, 1, 1], expected input[1, 10, 110250] to have 1 channels, but got 10 channels instead