In [4]:
from layers.autoencoder.vae import AudioAutoencoder,OobleckDecoder,OobleckEncoder
import os
import torch



#model_state = torch.load("../models/Nov29.pth")
model_state = torch.load("../train_log/best_model.pth")
model = AudioAutoencoder(sample_rate=16000,downsampling_ratio=2048).to("cuda:0")
model.load_state_dict(model_state)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal Parameters: {total_params:,}")


In [None]:
import torchaudio
import torchaudio.transforms as T

def load_audio(file_path, target_sample_rate=16000):
    """
    Loads an audio file and resamples it to the target sample rate.

    Args:
        file_path (str): Path to the audio file.
        target_sample_rate (int): Desired sample rate. Defaults to 16kHz.

    Returns:
        torch.Tensor: Audio waveform as a 1D tensor.
    """
    # Load the audio file
    waveform, sample_rate = torchaudio.load(file_path)
    
    # If stereo, convert to mono by averaging the channels
    if waveform.size(0) > 1:
        waveform = waveform.mean(dim=0)
    
    # Resample if necessary
    if sample_rate != target_sample_rate:
        resample = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resample(waveform)
    
    return waveform
import torchaudio

def save_audio(file_path, waveform, sample_rate=16000):
    """
    Saves a waveform as an audio file.

    Args:
        file_path (str): Path to save the audio file (e.g., "output.mp3").
        waveform (torch.Tensor): Audio waveform as a 1D or 2D tensor.
        sample_rate (int): Sample rate of the audio. Defaults to 16kHz.

    Returns:
        None
    """
    # Ensure waveform is 2D (channel, time), even if mono
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    
    # Save the audio
    torchaudio.save(file_path, waveform, sample_rate)

import torch
from layers.tools.losses import MultiResolutionSTFTLoss

def process_audio_in_chunks(file_path, model, chunk_size=16384*2, target_sample_rate=16000):
    """
    Processes audio in non-overlapping chunks on the GPU to avoid memory overflow.

    Args:
        file_path (str): Path to the input audio file.
        model: Preloaded model with `encode` and `decode` methods.
        chunk_size (int): Size of each chunk in samples.
        target_sample_rate (int): Target sample rate for processing.

    Returns:
        torch.Tensor: The reconstructed audio.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Load the audio on CPU
    waveform = load_audio(file_path, target_sample_rate).squeeze(0)  # 1D waveform
    total_len = waveform.size(0)
    
    output_chunks = []
    start = 0

    while start < total_len:
        end = min(start + chunk_size, total_len)
        chunk = waveform[start:end].to(device).unsqueeze(0).unsqueeze(0)  # Shape (1, 1, chunk_size)

        if chunk.shape[-1] > 4000:  # Skip small chunks at the end
            with torch.no_grad():
                latents = model.encode(chunk)
                output_chunk = model.decode(latents)
            output_chunks.append(output_chunk.squeeze(0))  # Keep chunks on GPU
        start = end  # Move to next chunk

    # Concatenate chunks directly on GPU
    reconstructed_audio = torch.cat(output_chunks, dim=-1)
    return reconstructed_audio

# Main processing
input_path = "../demo_dataset/no14/0/audio0.mp3"
reconstructed_audio = process_audio_in_chunks(input_path, model)

# Save reconstructed audio
save_audio("output3.wav", reconstructed_audio.cpu().unsqueeze(0))

# Compute and print STFT loss
f = MultiResolutionSTFTLoss()
original_audio = load_audio(input_path).to(reconstructed_audio.device).unsqueeze(0).unsqueeze(0)
output_audio = reconstructed_audio.unsqueeze(0).unsqueeze(0)
error = f(original_audio, output_audio)
print(f"STFT Loss: {error}")
