In [1]:
from AudioDataLoading import DataLoad

# Remove './Data/genres_original/jazz/jazz.00054.wav' from the data folder as this file is broken and it will crash the loader
data_path = "./Data/genres_original"
data_loader = DataLoad(data_path)
X, y = data_loader.fetch_dataset()

In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from AudioDataLoading import DataLoad
from AudioCVAEEncoderArch import AudioCVAEEncoder  
from AudioCVAEDecoderArch import AudioCVAEDecoder  
from Loss import get_loss
from AudioCVAETrainer import TrainerCVAE
from datetime import datetime

def prepare_data():
    DATASET_PATH = "./Data/genres_original"
    batch_size = 32
    
    dataload = DataLoad(DATASET_PATH)
    all_audios, all_attrs = dataload.fetch_dataset()  
    # Convert to tensors.
    all_audios = torch.from_numpy(all_audios).float()
    all_attrs = torch.from_numpy(all_attrs).long()
    
    # Add a channel dimension: (N, 1, T)
    all_audios = all_audios.unsqueeze(1)
    
    # Split both audio and attributes
    X_train, X_val, y_train, y_val = train_test_split(all_audios, all_attrs, test_size=0.1, random_state=365)
    
    train_dataset = TensorDataset(X_train, y_train)
    test_dataset = TensorDataset(X_val, y_val)
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

def get_timestamp():
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def main():
    LATENT_SPACE_SIZE = 128
    CONDITION_DIM = 10
    INPUT_LENGTH = 661500
    OUTPUT_LENGTH = 661500
    
    train_loader, test_loader = prepare_data()
    
    # Instantiate CVAE models
    encoder = AudioCVAEEncoder(
        latent_dim=LATENT_SPACE_SIZE,
        input_length=INPUT_LENGTH,
        condition_dim=CONDITION_DIM,
        use_embedding=False
    )
    decoder = AudioCVAEDecoder(
        latent_dim=LATENT_SPACE_SIZE,
        condition_dim=CONDITION_DIM,
        output_length=OUTPUT_LENGTH
    )
    
    # Create a Trainer for CVAE. 
    # each batch is a tuple (audio, genre), and you convert the genre to a one-hot vector.
    trainer = TrainerCVAE(
        trainloader=train_loader,
        testloader=test_loader,
        Encoder=encoder,
        Decoder=decoder,
        latent_dim=LATENT_SPACE_SIZE,
        device="cuda"
    )
    
    trainer.train(num_epochs=50, factor=10)
    
    timestamp = get_timestamp()
    encoder_name = f"audio_cvae_encoder_{timestamp}.pth"
    decoder_name = f"audio_cvae_decoder_{timestamp}.pth"
    
    torch.save(encoder.state_dict(), encoder_name)
    torch.save(decoder.state_dict(), decoder_name)
    print(f"Models saved as {encoder_name} and {decoder_name}")

if __name__ == "__main__":
    main()


Epoch [1/50], Train Loss: 38762054376083.9609, Test Loss: 254000.5312
