# Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import time
from griffin_lim_algs import naive_griffin_lim, fast_griffin_lim, accelerated_griffin_lim
# from deep_agla import DeepAGLA
from deep_agla import DeepAGLA
from eval_metrics import evaluate_batch
from definitions import *

device = "cuda"
N_ITER = 64
times_algos = {}

print(f"Using device: {device}")

# Compare Model Vs. Original Algos

In [None]:
data = np.load(DATA_PATH)
print(f"Data Shape: {data.shape}")
test_sample = data[0]
print(f"Sample Shape: {test_sample.shape}")
Audio(test_sample, rate=22050)

## Quick Training Test

In [None]:
# Quick training test with small subset of data
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Subset
from dataset import AudioDataset
import os

# Create a small subset of the data for quick training
print("Setting up quick training...")

# Use only first 32 samples for quick training
subset_size = 32
full_dataset = AudioDataset(npy_path=DATA_PATH)
subset_indices = list(range(min(subset_size, len(full_dataset))))
train_subset = Subset(full_dataset, subset_indices)

# Create DataLoader
train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=0)

# Initialize model for training
model = DeepAGLA(
    n_layers=64,
    lr=1e-3,
    loss_weights={
        'time_l1': 0.9,
        'time_mse': 0.9, 
        'spec_l1': 0.1,
        'log_spec_l1': 0.1
    }
)

# Setup trainer for quick training
trainer = pl.Trainer(
    max_epochs=10,
    accelerator='gpu',
    devices=1,
    logger=False,  # Disable logging for quick test
    enable_checkpointing=True,
    default_root_dir='./quick_training_test',
    enable_progress_bar=True
)

print(f"Starting quick training with {len(train_subset)} samples for 10 epochs...")
trainer.fit(model, train_loader)

# Save the trained model
save_path = './quick_trained_model.ckpt'
trainer.save_checkpoint(save_path)
print(f"Quick training completed! Model saved to: {save_path}")
print(f"Final training loss: {trainer.callback_metrics}")

## Model Inference

In [13]:
CHECKPOINT_PATH = save_path
model = DeepAGLA.load_from_checkpoint(
    CHECKPOINT_PATH,
    map_location=device
)
model.set_inference_mode(use_all_layers=True)

In [None]:
model_test = DeepAGLA(
    n_layers=64,
    lr=1e-3,
    loss_weights={
        'time_l1': 0.9,
        'time_mse': 0.9, 
        'spec_l1': 0.1,
        'log_spec_l1': 0.1
    })

# Compute the diff in alpha/beta/gamma values for each layer, then plot the diff for each parameter over layers
alpha_diffs = []
beta_diffs = []
gamma_diffs = []
for i in range(model.n_layers):
    layer_trained = model.layers[i]
    layer_test = model_test.layers[i]
    alpha_diffs.append(layer_trained.alpha.item() - layer_test.alpha.item())
    beta_diffs.append(layer_trained.beta.item() - layer_test.beta.item())
    gamma_diffs.append(layer_trained.gamma.item() - layer_test.gamma.item())
    
plt.figure(figsize=(14, 4))
plt.subplot(1, 3, 1)
plt.plot(range(model.n_layers), alpha_diffs, marker='o', label='Alpha Diff')
plt.title('Alpha Differences Across Layers')
plt.xlabel('Layer Index')
plt.ylabel('Alpha Diff')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(range(model.n_layers), beta_diffs, marker='o', label='Beta Diff')
plt.title('Beta Differences Across Layers')
plt.xlabel('Layer Index')
plt.ylabel('Beta Diff')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(range(model.n_layers), gamma_diffs, marker='o', label='Gamma Diff')
plt.title('Gamma Differences Across Layers')
plt.xlabel('Layer Index')
plt.ylabel('Gamma Diff')
plt.legend()
plt.tight_layout()
plt.show()

In [15]:
start_time_model = time.time()
recon_model, losses = model.inference_forward(torch.tensor(test_sample).to(device).unsqueeze(0))
end_time_model = time.time()
times_algos['Model'] = end_time_model - start_time_model

## Test Original GLAs

In [None]:
start_time_naive = time.time()
recon_naive, metrics_naive, losses_naive = naive_griffin_lim(test_sample, n_iter=N_ITER)
end_time_naive = time.time()
times_algos['Naive'] = end_time_naive - start_time_naive

start_time_fast = time.time()
recon_fast, metrics_fast, losses_fast = fast_griffin_lim(test_sample, n_iter=N_ITER)
end_time_fast = time.time()
times_algos['Fast'] = end_time_fast - start_time_fast
start_time_accel = time.time()
recon_accel, metrics_accel, losses_accel = accelerated_griffin_lim(test_sample, n_iter=N_ITER)
end_time_accel = time.time()
times_algos['Accelerated'] = end_time_accel - start_time_accel

# Match the model losses dict to the Griffin-Lim losses, currently has different keys and no iteration key
# There is a total key (not needed), and each value is a list over iterations
model_losses_matched = []
for iter in range(N_ITER):
    model_losses_matched.append({
        'Iteration': iter+1,
        'L1 Waveform': losses['time_l1'][iter],
        'MSE Waveform': losses['time_mse'][iter],
        'L1 Spectral': losses['spec_l1'][iter],
        'Log L1 Spectral': losses['log_spec_l1'][iter],
        'Total Loss': losses['total'][iter]
    })

# Plot each loss type over iterations for each algorithm
# losses_<algo> is a list of dictionaries with keys 'Iteration', 'L1 Waveform', 'MSE Waveform', 'L1 Spectral', 'Log L1 Spectral'
plt.figure(figsize=(12, 6))
plt.plot([loss['Iteration'] for loss in losses_naive], [loss['L1 Waveform'] for loss in losses_naive], label='Naive L1 Waveform')
plt.plot([loss['Iteration'] for loss in losses_fast], [loss['L1 Waveform'] for loss in losses_fast], label='Fast L1 Waveform')
plt.plot([loss['Iteration'] for loss in losses_accel], [loss['L1 Waveform'] for loss in losses_accel], label='Accel L1 Waveform')
plt.plot([loss['Iteration'] for loss in model_losses_matched], [loss['L1 Waveform'] for loss in model_losses_matched], label='Model L1 Waveform', linestyle='--')
plt.xlabel('Iteration')
plt.ylabel('L1 Waveform Loss')
plt.title('L1 Waveform Loss Over Iterations')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(12, 6))
plt.plot([loss['Iteration'] for loss in losses_naive], [loss['MSE Waveform'] for loss in losses_naive], label='Naive MSE Waveform')
plt.plot([loss['Iteration'] for loss in losses_fast], [loss['MSE Waveform'] for loss in losses_fast], label='Fast MSE Waveform')
plt.plot([loss['Iteration'] for loss in losses_accel], [loss['MSE Waveform'] for loss in losses_accel], label='Accel MSE Waveform')
plt.plot([loss['Iteration'] for loss in model_losses_matched], [loss['MSE Waveform'] for loss in model_losses_matched], label='Model MSE Waveform', linestyle='--')
plt.xlabel('Iteration')
plt.ylabel('MSE Waveform Loss')
plt.title('MSE Waveform Loss Over Iterations')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(12, 6))
plt.plot([loss['Iteration'] for loss in losses_naive], [loss['L1 Spectral'] for loss in losses_naive], label='Naive L1 Spectral')
plt.plot([loss['Iteration'] for loss in losses_fast], [loss['L1 Spectral'] for loss in losses_fast], label='Fast L1 Spectral')
plt.plot([loss['Iteration'] for loss in losses_accel], [loss['L1 Spectral'] for loss in losses_accel], label='Accel L1 Spectral')
plt.plot([loss['Iteration'] for loss in model_losses_matched], [loss['L1 Spectral'] for loss in model_losses_matched], label='Model L1 Spectral', linestyle='--')
plt.xlabel('Iteration')
plt.ylabel('L1 Spectral Loss')
plt.title('L1 Spectral Loss Over Iterations')
plt.yscale('log')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(12, 6))
plt.plot([loss['Iteration'] for loss in losses_naive], [loss['Log L1 Spectral'] for loss in losses_naive], label='Naive Log L1 Spectral')
plt.plot([loss['Iteration'] for loss in losses_fast], [loss['Log L1 Spectral'] for loss in losses_fast], label='Fast Log L1 Spectral')
plt.plot([loss['Iteration'] for loss in losses_accel], [loss['Log L1 Spectral'] for loss in losses_accel], label='Accel Log L1 Spectral')
plt.plot([loss['Iteration'] for loss in model_losses_matched], [loss['Log L1 Spectral'] for loss in model_losses_matched], label='Model Log L1 Spectral', linestyle='--')
plt.xlabel('Iteration')
plt.ylabel('Log L1 Spectral Loss')
plt.title('Log L1 Spectral Loss Over Iterations')
plt.yscale('log')
plt.legend()
plt.grid()
plt.show()

# Total Loss Plot
plt.figure(figsize=(12, 6))
plt.plot([loss['Iteration'] for loss in losses_naive], [loss['Total Loss'] for loss in losses_naive], label='Naive Total Loss')
plt.plot([loss['Iteration'] for loss in losses_fast], [loss['Total Loss'] for loss in losses_fast], label='Fast Total Loss')
plt.plot([loss['Iteration'] for loss in losses_accel], [loss['Total Loss'] for loss in losses_accel], label='Accel Total Loss')
plt.plot([loss['Iteration'] for loss in model_losses_matched], [loss['Total Loss'] for loss in model_losses_matched], label='Model Total Loss', linestyle='--')
plt.xlabel('Iteration')
plt.ylabel('Total Loss')
plt.title('Total Loss Over Iterations')
plt.legend()
plt.grid()
plt.show()

# Bar plot of times for each algorithm
plt.figure(figsize=(8, 5))
plt.bar(times_algos.keys(), times_algos.values(), color=['blue', 'orange', 'green'])
plt.xlabel('Algorithm')
plt.ylabel('Time (seconds)')
plt.title('Time Taken by Each Griffin-Lim Algorithm')
plt.show()

# Metrics

In [17]:
# Metrics Evaluation (the same ones that came back from the algorithms, done on the model output)

from griffin_lim_algs import compute_all_metrics, match_signals

# Properly align the signals before computing metrics
model_recon = recon_model.squeeze().cpu().numpy()
aligned_original, aligned_model_recon = match_signals(test_sample, model_recon)
model_metrics = compute_all_metrics(aligned_original, aligned_model_recon)

In [None]:
# Create enhanced bar plots for the metrics comparing the algorithms and model

metrics = model_metrics.keys()

# Define which metrics are "higher is better" vs "lower is better"
higher_is_better = {'SNR (dB)', 'SSNR (dB)', 'PSNR (dB)', 'SER (dB)', 'SISDR (dB)', 'SISNR (dB)', 'STOI'}
lower_is_better = {'MSE', 'MAE', 'LSD (dB)', 'Spectral Convergence', 'THD (%)'}

algorithm_names = list(times_algos.keys())
colors = ['blue', 'orange', 'green', 'red']

for metric in metrics:
    plt.figure(figsize=(10, 6))
    
    # Get values for all algorithms
    values = [model_metrics[metric], metrics_naive[metric][-1], metrics_fast[metric][-1], metrics_accel[metric][-1]]
    
    # Find the best performing algorithm
    if metric in higher_is_better:
        best_idx = np.argmax(values)
    elif metric in lower_is_better:
        best_idx = np.argmin(values)
    else:
        # Default to higher is better for unknown metrics
        best_idx = np.argmax(values)
    
    # Create edge colors - highlight the best with yellow outline, others with black
    edge_colors = ['black'] * len(values)
    edge_colors[best_idx] = 'gold'
    edge_widths = [1] * len(values)
    edge_widths[best_idx] = 3
    
    # Create the bar plot
    bars = plt.bar(algorithm_names, values, color=colors, edgecolor=edge_colors, linewidth=edge_widths)
    
    # Add value labels on top of each bar
    for i, (bar, value) in enumerate(zip(bars, values)):
        height = bar.get_height()
        # Format the value based on the metric type
        if 'dB' in metric:
            value_str = f'{value:.2f}'
        elif metric in ['MSE', 'MAE', 'Spectral Convergence']:
            value_str = f'{value:.4f}'
        elif metric == 'STOI':
            value_str = f'{value:.3f}'
        elif metric == 'THD (%)':
            value_str = f'{value:.2f}%'
        else:
            value_str = f'{value:.2f}'
        
        plt.text(bar.get_x() + bar.get_width()/2., height + abs(height)*0.01,
                value_str, ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    plt.xlabel('Algorithm')
    plt.ylabel(metric)
    plt.title(f'{metric} Comparison')
    
    # Add a subtle grid
    plt.grid(True, alpha=0.3)
    
    # Adjust y-axis to accommodate the value labels
    y_min, y_max = plt.ylim()
    y_range = y_max - y_min
    plt.ylim(y_min, y_max + y_range * 0.1)
    
    plt.show()