# MQGAN Audio Reconstruction Demo

**Assumptions:**
- Audio files are in a folder specified by `AUDIO_FOLDER`.
- Checkpoints for the preencoder (VQGAN) and ISTFT models exist in directories specified by `PREENCODER_MODEL_DIR` and `ISTFT_MODEL_DIR`.
- Configuration files for spectrogram extraction (`spec_config.yaml`) and model parameters (`model_config.yaml`) are available.


In [None]:
import os
import glob
import torch
import torchaudio
import yaml
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display

# Import custom modules
from scripted_preencoder import ScriptedPreEncoder
from istftnetfe import ISTFTNetFE, TorchSTFT
from convert_spectrograms import TorchMelSpectrogramExtractor # Use the project's extractor

# --- Configuration ---
AUDIO_FOLDER = "path/to/your/audio/files"  # TODO: Change this path
PREENCODER_MODEL_DIR = "path/to/preencoder/checkpoint"  # TODO: Change this path
ISTFT_MODEL_DIR = "path/to/istft/checkpoint"  # TODO: Change this path
SPEC_CONFIG_PATH = "configs/spec_config_hifimusic.yaml" # Use the project's spec config
MODEL_CONFIG_PATH = "configs/model_config_hifimusic.yaml" # Use the project's model config
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_FILES_TO_PROCESS = 5  # Limit for demo

print(f"Using device: {DEVICE}")

# Load configurations
with open(SPEC_CONFIG_PATH, 'r') as f:
    spec_config = yaml.safe_load(f)
print(f"Loaded spectrogram config from {SPEC_CONFIG_PATH}")

with open(MODEL_CONFIG_PATH, 'r') as f:
    model_config = yaml.safe_load(f)
print(f"Loaded model config from {MODEL_CONFIG_PATH}")

In [None]:
# --- Load Models ---

try:
    preencoder = ScriptedPreEncoder(PREENCODER_MODEL_DIR, device=DEVICE)
    print(f"PreEncoder loaded. Mel channels: {preencoder.mel_channels}")
    # Verify config matches model
    assert preencoder.mel_channels == spec_config['spectrogram']['n_mel_channels'], \
        f"Mel channels mismatch: preencoder ({preencoder.mel_channels}) vs config ({spec_config['spectrogram']['n_mel_channels']})"
except Exception as e:
    print(f"Error loading PreEncoder: {e}")
    raise

try:
    # Initialize ISTFTNetFE with dummy components, then load the traced generator
    dummy_stft = TorchSTFT(filter_length=1024, hop_length=256, win_length=1024) # Placeholder
    dummy_gen = torch.nn.Identity() # Placeholder
    istft_model = ISTFTNetFE(dummy_gen, dummy_stft)
    istft_model.load_ts(ISTFT_MODEL_DIR, in_dev=DEVICE)
    print(f"ISTFT model loaded. Sampling rate: {istft_model.sampling_rate}")
    # Verify config matches model
    assert istft_model.sampling_rate == spec_config['spectrogram']['sampling_rate'], \
        f"Sampling rate mismatch: istft ({istft_model.sampling_rate}) vs config ({spec_config['spectrogram']['sampling_rate']})"
except Exception as e:
    print(f"Error loading ISTFT model: {e}")
    raise

# Initialize the spectrogram extractor using the loaded config
mel_extractor = TorchMelSpectrogramExtractor(spec_config['spectrogram'])
mel_extractor.transf = mel_extractor.transf.to(DEVICE) # Move extractor to device if needed

In [None]:
# --- Helper Functions ---

def load_and_preprocess_audio(filepath, target_sr=spec_config['spectrogram']['sampling_rate']):
    """Loads an audio file, resamples to target_sr, and converts to mono."""
    waveform, original_sr = torchaudio.load(filepath)
    
    # Convert to mono if necessary
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Resample if necessary
    if original_sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
        waveform = resampler(waveform)
    
    return waveform, target_sr

def plot_spectrogram(spectrogram, title="Spectrogram", ax=None):
    """Plots a spectrogram."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 4))
    
    # Assuming spectrogram is (batch, time, channels) or (time, channels)
    if spectrogram.ndim == 3:
        spec_to_plot = spectrogram[0].cpu().numpy()
    else:
        spec_to_plot = spectrogram.cpu().numpy()
        
    im = ax.imshow(spec_to_plot.T, aspect='auto', origin='lower', interpolation='none')
    ax.set_title(title)
    ax.set_xlabel("Time Frames")
    ax.set_ylabel("Frequency Bins")
    plt.colorbar(im, ax=ax, format="%+2.0f dB")
    if ax is None:
        plt.show()

def process_audio_file(filepath):
    """Full processing pipeline for a single audio file."""
    print(f"\n--- Processing {os.path.basename(filepath)} ---")
    
    # 1. Load and preprocess audio
    waveform, sr = load_and_preprocess_audio(filepath, spec_config['spectrogram']['sampling_rate'])
    print(f"Loaded waveform shape: {waveform.shape}, Sample rate: {sr}")
    
    # --- Original Audio Display ---
    display(Audio(waveform.numpy(), rate=sr))
    
    # 2. Create Spectrogram using the project's extractor
    # The extractor expects (1, T) and outputs (T, n_mels)
    with torch.no_grad():
        mel_spec = mel_extractor.get_mel_from_wav(waveform.to(DEVICE)) # (T, n_mels)
    # Add batch dimension for preencoder: (1, T, n_mels)
    mel_spec_input = mel_spec.unsqueeze(0) # (1, T, n_mels)
    print(f"Created log-mel-spectrogram shape: {mel_spec_input.shape}")
    
    # Plot original spectrogram
    plot_spectrogram(mel_spec_input.squeeze(0), title="Original Log-Mel Spectrogram")
    
    # 3. Encode to tokens
    with torch.no_grad():
        indices = preencoder.encode(mel_spec_input)
    print(f"Encoded indices shape: {indices.shape}")
    
    # 4. Decode tokens back to spectrogram
    with torch.no_grad():
        reconstructed_spec = preencoder.decode(indices)
    print(f"Reconstructed spectrogram shape: {reconstructed_spec.shape}")
    
    # Plot reconstructed spectrogram
    plot_spectrogram(reconstructed_spec.squeeze(0), title="Reconstructed Log-Mel Spectrogram")
    
    # 5. Reconstruct audio from spectrogram using ISTFT model
    # The ISTFT model expects input like (B, C, T)
    # Transpose the reconstructed spec: (1, T, n_mels) -> (1, n_mels, T)
    reconstructed_spec_input = reconstructed_spec.permute(0, 2, 1).to(DEVICE) # Shape: (1, n_mels, T)
    
    with torch.no_grad():
        # Use the generator part of the ISTFT model to get spec and phase
        spec_out, phase_out = istft_model.gen(reconstructed_spec_input)
        # Then use the full ISTFT model to get waveform
        reconstructed_waveform = istft_model(spec_out) # Uses internal stft.inverse
    
    reconstructed_waveform = reconstructed_waveform.cpu()
    print(f"Reconstructed waveform shape: {reconstructed_waveform.shape}")
    
    # --- Reconstructed Audio Display ---
    display(Audio(reconstructed_waveform.numpy(), rate=istft_model.sampling_rate))

    return waveform, mel_spec_input, indices, reconstructed_spec, reconstructed_waveform

In [None]:
# --- Main Execution ---

# Find audio files
audio_extensions = spec_config['io']['audio_extensions']
audio_files = []
for ext in audio_extensions:
    # Handle extensions with or without leading dot
    pattern_ext = ext if ext.startswith('.') else f'.{ext}'
    audio_files.extend(glob.glob(os.path.join(AUDIO_FOLDER, f'*{pattern_ext}')))

if not audio_files:
    print(f"No audio files found in {AUDIO_FOLDER}")
else:
    print(f"Found {len(audio_files)} audio files. Processing up to {MAX_FILES_TO_PROCESS}.")
    for i, filepath in enumerate(audio_files[:MAX_FILES_TO_PROCESS]):
        try:
            process_audio_file(filepath)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")