In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import librosa
import librosa.display
import numpy as np
import os
from glob import glob
import matplotlib.pyplot as plt

# Define constants
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE = 22050
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
BATCH_SIZE = 8
EPOCHS = 20
LEARNING_RATE = 2e-4

# Load GTZAN dataset
class GTZANDataset(Dataset):
    def __init__(self, root_dir):
        self.file_paths = glob(os.path.join(root_dir, "genres_original", "*", "*.wav"))
        self.labels = [os.path.basename(os.path.dirname(fp)) for fp in self.file_paths]
        self.label_dict = {genre: idx for idx, genre in enumerate(sorted(set(self.labels)))}
        self.transform = transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE, 
            n_mels=N_MELS, 
            n_fft=N_FFT, 
            hop_length=HOP_LENGTH
        )
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.label_dict[self.labels[idx]]
        waveform, _ = librosa.load(file_path, sr=SAMPLE_RATE)
        waveform = torch.tensor(waveform[:SAMPLE_RATE*3])  # 3 seconds clip
        mel_spec = self.transform(waveform)  # Shape: [1, n_mels, time]
        return mel_spec.squeeze(0), label  # Return [n_mels, time]

# Load dataset
train_dataset = GTZANDataset("gtzan_dataset")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Define Transformer-based Generator
class TransformerGenerator(nn.Module):
    def __init__(self, input_dim=N_MELS, num_heads=4, ff_dim=256):
        super(TransformerGenerator, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, dim_feedforward=ff_dim)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        self.fc = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        x = x.permute(2, 0, 1)  # (batch, n_mels, time) -> (time, batch, n_mels)
        x = self.transformer_encoder(x)
        x = self.fc(x)
        x = x.permute(1, 2, 0)  # Back to (batch, n_mels, time)
        return x

# Define Discriminator
class CNNDiscriminator(nn.Module):
    def __init__(self, input_dim=N_MELS):
        super(CNNDiscriminator, self).__init__()
        self.conv = nn.Sequential(
            # Input: [batch, 1, n_mels, time]
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dim: [batch, 1, n_mels, time]
        return self.conv(x)

# Define Genre Classifier
class GenreClassifier(nn.Module):
    def __init__(self, input_dim=N_MELS, num_classes=10):
        super(GenreClassifier, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # Calculate the flattened size
        self.flattened_size = 128 * (N_MELS//8) * ((3*SAMPLE_RATE//HOP_LENGTH)//8)
        self.fc = nn.Sequential(
            nn.Linear(self.flattened_size, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Initialize models
generator = TransformerGenerator().to(device)
discriminator = CNNDiscriminator().to(device)
classifier = GenreClassifier().to(device)

# Train the classifier first
print("Training the genre classifier...")
criterion = nn.CrossEntropyLoss()
classifier_optimizer = optim.Adam(classifier.parameters(), lr=1e-3)

for epoch in range(10):  # Train classifier for 10 epochs
    for mel_specs, labels in train_loader:
        mel_specs, labels = mel_specs.to(device), labels.to(device)
        
        classifier_optimizer.zero_grad()
        outputs = classifier(mel_specs)
        loss = criterion(outputs, labels)
        loss.backward()
        classifier_optimizer.step()
    
    print(f"Classifier Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Save the trained classifier
torch.save(classifier.state_dict(), "genre_classifier.pth")
classifier.eval()

# Define losses for GAN
adversarial_loss = nn.BCELoss()
content_loss = nn.MSELoss()
cycle_loss = nn.L1Loss()

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# Training loop for GAN
print("Training the GAN...")
for epoch in range(EPOCHS):
    for batch_idx, (real_data, _) in enumerate(train_loader):
        real_data = real_data.to(device)
        batch_size = real_data.size(0)

        # Generate fake data
        fake_data = generator(real_data)

        # Train Discriminator
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        d_optimizer.zero_grad()
        
        # Real data
        real_output = discriminator(real_data)
        real_loss = adversarial_loss(real_output, real_labels)
        
        # Fake data
        fake_output = discriminator(fake_data.detach())
        fake_loss = adversarial_loss(fake_output, fake_labels)
        
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        g_loss_adv = adversarial_loss(discriminator(fake_data), real_labels)
        g_loss_content = content_loss(fake_data, real_data)
        g_loss_cycle = cycle_loss(generator(fake_data), real_data)
        g_loss = g_loss_adv + 0.5 * g_loss_content + 0.5 * g_loss_cycle
        g_loss.backward()
        g_optimizer.step()

        if batch_idx % 10 == 0:
            print(f"GAN Epoch [{epoch}/{EPOCHS}] Batch [{batch_idx}/{len(train_loader)}] "
                  f"D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")

# Genre conversion function
def convert_genre(audio_tensor, source_genre, target_genre):
    print(f"Converting from {source_genre} to {target_genre}...")
    audio_tensor = audio_tensor.to(device)
    converted_audio = generator(audio_tensor.unsqueeze(0)).squeeze(0)
    
    # Predict the genre of the converted audio
    predicted_genre_idx = predict_genre(converted_audio)
    predicted_genre = list(train_dataset.label_dict.keys())[predicted_genre_idx]
    
    print(f"Converted audio predicted as: {predicted_genre}")
    return converted_audio

# Genre prediction function
def predict_genre(audio_tensor):
    audio_tensor = audio_tensor.to(device)
    with torch.no_grad():
        prediction = classifier(audio_tensor.unsqueeze(0))
    return torch.argmax(prediction, dim=1).item()

print("Training Complete!")

Training the genre classifier...
Classifier Epoch 1, Loss: 2.4843
Classifier Epoch 2, Loss: 2.3151
Classifier Epoch 3, Loss: 2.2965
Classifier Epoch 4, Loss: 2.3064
Classifier Epoch 5, Loss: 2.2924
Classifier Epoch 6, Loss: 2.2930
Classifier Epoch 7, Loss: 1.9837
Classifier Epoch 8, Loss: 3.2562
Classifier Epoch 9, Loss: 2.2312
Classifier Epoch 10, Loss: 2.1923
Training the GAN...
GAN Epoch [0/20] Batch [0/125] D Loss: 1.3809 G Loss: 279197.5625
GAN Epoch [0/20] Batch [10/125] D Loss: 1.1439 G Loss: 1722914.6250
GAN Epoch [0/20] Batch [20/125] D Loss: 1.0305 G Loss: 2214226.7500
GAN Epoch [0/20] Batch [30/125] D Loss: 0.9264 G Loss: 128528.2812
GAN Epoch [0/20] Batch [40/125] D Loss: 0.9133 G Loss: 916655.1875
GAN Epoch [0/20] Batch [50/125] D Loss: 0.8592 G Loss: 503788.5312
GAN Epoch [0/20] Batch [60/125] D Loss: 0.8334 G Loss: 206348.8906
GAN Epoch [0/20] Batch [70/125] D Loss: 0.8021 G Loss: 1033042.3125
GAN Epoch [0/20] Batch [80/125] D Loss: 0.7790 G Loss: 1751631.6250
GAN Epoch 

In [31]:
def load_and_preprocess_audio(file_path, duration=3):
    """Load and preprocess an audio file for conversion"""
    # Load audio file
    waveform, sr = librosa.load(file_path, sr=SAMPLE_RATE)
    
    # Trim or pad to desired duration
    if len(waveform) > SAMPLE_RATE * duration:
        waveform = waveform[:SAMPLE_RATE * duration]
    else:
        padding = SAMPLE_RATE * duration - len(waveform)
        waveform = np.pad(waveform, (0, padding), mode='constant')
    
    # Convert to mel-spectrogram
    transform = transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_mels=N_MELS,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH
    )
    waveform_tensor = torch.tensor(waveform)
    mel_spec = transform(waveform_tensor).squeeze(0)  # [n_mels, time]
    
    return mel_spec

def convert_audio_file(input_path, output_path, source_genre, target_genre):
    """
    Convert an audio file from source_genre to target_genre
    and save the converted audio.
    
    Args:
        input_path: Path to input audio file
        output_path: Path to save converted audio
        source_genre: Name of source genre (for logging)
        target_genre: Name of target genre (for logging)
    """
    # Load and preprocess audio
    print(f"Loading and preprocessing {input_path}...")
    mel_spec = load_and_preprocess_audio(input_path)
    
    # Convert genre - ensure we're in evaluation mode
    generator.eval()
    with torch.no_grad():
        print(f"Converting from {source_genre} to {target_genre}...")
        converted_mel = convert_genre(mel_spec, source_genre, target_genre)
    
    # Convert mel-spectrogram back to audio (using Griffin-Lim)
    print("Converting back to waveform...")
    
    # Create inverse mel scale transform
    inv_mel = transforms.InverseMelScale(
        n_stft=N_FFT // 2 + 1,
        n_mels=N_MELS,
        sample_rate=SAMPLE_RATE
    )
    
    # Create Griffin-Lim transform
    griffin_lim = transforms.GriffinLim(
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        n_iter=32
    )
    
    # Process through inverse transforms
    # Add batch dimension if needed: [1, n_mels, time]
    # Detach and move to CPU
    converted_mel = converted_mel.unsqueeze(0).detach().cpu()
    
    # Convert mel to linear spectrogram
    spec_estimate = inv_mel(converted_mel)
    
    # Reconstruct waveform
    waveform = griffin_lim(spec_estimate)
    
    # Save the converted audio
    # Ensure waveform is detached numpy array
    torchaudio.save(output_path, waveform.detach(), SAMPLE_RATE)
    print(f"Converted audio saved to {output_path}")

# Example usage:
if __name__ == "__main__":
    # Example file paths - replace with your actual files
    input_audio = "gtzan_dataset/genres_original/metal/metal.00000.wav"  # Your input audio file
    output_audio = "output_rock.wav"  # Where to save converted audio
    
    # Available genres from GTZAN dataset
    available_genres = list(train_dataset.label_dict.keys())
    print("Available genres:", available_genres)
    
    # Convert jazz to rock (example)
    convert_audio_file(
        input_path=input_audio,
        output_path=output_audio,
        source_genre="metal",
        target_genre="rock"
    )

Available genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
Loading and preprocessing gtzan_dataset/genres_original/metal/metal.00000.wav...
Converting from metal to rock...
Converting from metal to rock...
Converted audio predicted as: metal
Converting back to waveform...
Converted audio saved to output_rock.wav
