In [None]:
from SlicingDataLoading import DataLoad
import torchaudio
# Testing Purpose
data_path = "./Data/genres_original"
data_loader = DataLoad(data_path)
X, y = data_loader.fetch_dataset()

In [None]:
import torch
from audio_resnet_encoder_wavenet_decoder import Encoder
# Testing Purpose
encoder = Encoder(latent_dim = 1024) 
sample_input = torch.randn(1, 1, 90000)  
mean, logvar, latent_encoding = encoder(sample_input)
print("Mean Shape:", mean.shape)
print("Log Variance Shape:", logvar.shape)
print("Latent Encoding Shape:", latent_encoding.shape)

In [None]:
import torch
from audio_resnet_encoder_wavenet_decoder import Decoder
# Testing Purpose
decoder = Decoder(latent_dim=1024, output_length=90000)
latent_vector = torch.randn(1, 1024)  
output_audio = decoder(latent_vector)
print(f"Decoder output shape: {output_audio.shape}")

In [None]:
class Trainer:
    def __init__(self, trainloader, testloader, model, optimizer, device, latent_dim):
        self.trainloader = trainloader
        self.testloader = testloader
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.latent_dim = latent_dim

        self.model.to(self.device)

    def loss_function(self, recon_x, x, mu, logvar):
        # Compute STFT (Short-Time Fourier Transform)
        x_squeezed = x.squeeze(1)          
        recon_x_squeezed = recon_x.squeeze(1)
        
        def stft_mag_loss(x, recon_x, n_fft):
            hop = n_fft // 4
            stft_x = torch.stft(x, n_fft=n_fft, hop_length=hop, return_complex=True)
            stft_recon = torch.stft(recon_x, n_fft=n_fft, hop_length=hop, return_complex=True)
            return F.l1_loss(torch.abs(stft_x), torch.abs(stft_recon))
        
        mag_loss = (
        stft_mag_loss(x_squeezed, recon_x_squeezed, 512) +
        stft_mag_loss(x_squeezed, recon_x_squeezed, 1024) +
        stft_mag_loss(x_squeezed, recon_x_squeezed, 2048)
        ) * 10

        mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=1024, hop_length=256, n_mels=80)
        mel_loss = F.l1_loss(mel_spec(x_squeezed), mel_spec(recon_x_squeezed)) * 5
        
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        # print(f"Loss Decomposition Information: Mag Loss -> {mag_loss}, KL Loss -> {kl_loss}, Mel Loss -> {mel_loss}")
        return mag_loss + kl_loss + mel_loss

    def train_one_epoch(self):
        self.model.train()
        running_loss = 0
        for batch in self.trainloader:
            batch = batch.to(self.device)

            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(batch)
            loss = self.loss_function(recon_batch, batch, mu, logvar)

            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()
        
        return running_loss / len(self.trainloader)

    def evaluate(self):
        self.model.eval()
        running_loss = 0
        with torch.no_grad():
            for batch in self.testloader:
                batch = batch.to(self.device)
                recon_batch, mu, logvar = self.model(batch)
                loss = self.loss_function(recon_batch, batch, mu, logvar)
                running_loss += loss.item()
        return running_loss / len(self.testloader)

    def train(self, num_epochs=30, factor=10):
        for epoch in range(1, num_epochs + 1):
            train_loss = self.train_one_epoch()
            test_loss = self.evaluate()
            print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
            timestamp = get_timestamp()
            encoder_name = f"epoch_{epoch}_audio_cvae_encoder_{timestamp}_with_wavenet.pth"
            decoder_name = f"epoch_{epoch}_audio_cvae_decoder_{timestamp}_with_wavenet.pth"
            torch.save(self.model.encoder.state_dict(), encoder_name)
            torch.save(self.model.decoder.state_dict(), decoder_name)
            print(f"Models saved as {encoder_name} and {decoder_name}")
            

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from SlicingDataLoading import DataLoad
from audio_resnet_encoder_wavenet_decoder import CVAE
from datetime import datetime
import torchaudio

def prepare_data():
    DATASET_PATH = "./Data/genres_original"
    batch_size = 8
    
    dataload = DataLoad(DATASET_PATH)
    all_audios, all_attrs = dataload.fetch_dataset()
    all_audios = torch.from_numpy(all_audios).float()
    all_attrs = torch.from_numpy(all_attrs).long()
    
    all_audios = all_audios.unsqueeze(1)  
    
    X_train, X_val = train_test_split(all_audios, test_size=0.15, random_state=365)
    train_loader = DataLoader(dataset=X_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=X_val, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

def get_timestamp():
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def main():
    LATENT_SPACE_SIZE = 1024
    
    train_loader, test_loader = prepare_data()
    
    model = CVAE(LATENT_SPACE_SIZE).to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0005)
    
    trainer = Trainer(
        trainloader=train_loader,
        testloader=test_loader,
        model=model,
        optimizer=optimizer,
        device=device,
        latent_dim=LATENT_SPACE_SIZE
    )
    
    trainer.train(num_epochs=6, factor=10)
    timestamp = get_timestamp()

    encoder_name = f"audio_cvae_encoder_{timestamp}_with_wavenet.pth"
    decoder_name = f"audio_cvae_decoder_{timestamp}_with_wavenet.pth"
    
    torch.save(model.encoder.state_dict(), encoder_name)
    torch.save(model.decoder.state_dict(), decoder_name)

    print(f"Models saved as {encoder_name} and {decoder_name}")

if __name__ == "__main__":
    main()