# common CACL utils with visual tests

In [1]:
%load_ext autoreload
%autoreload 2
## our utils
from utils.common_import import *

2.6.0+cu124


## audio 

In [None]:
def audio_wav_to_waveform(wav_path):
    waveform, sample_rate = torchaudio.load(wav_path)  # shape: [channels, time]
    # If stereo, select one channel, or average:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Resample if needed
    target_sample_rate=16000
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
        
    return waveform

## plots

In [None]:
def plot_wave_spec_calced(test_audio_file):
    # calc waveform and mel spectrogram
    waveform = audio_wav_to_waveform(test_audio_file)
    mel_spec = waveform_to_logmel(waveform)
    print(f"Waveform shape: {waveform.shape} , type: {type(waveform)}")
    ## draw mel spectrogram
    import matplotlib.pyplot as plt
    def plot_mel_spectrogram(mel_spec):
        plt.figure(figsize=(10, 4))
        plt.imshow(mel_spec.T, aspect='auto', origin='lower', cmap='viridis')
        plt.title('Mel Spectrogram')
        plt.xlabel('Time')
        plt.ylabel('Mel Frequency Bands')
        plt.colorbar(format='%+2.0f dB')
        plt.tight_layout()
        plt.show()
    # squeeze to 2D for imshow
    mel_spec_2d = mel_spec.squeeze()
    if isinstance(mel_spec_2d, torch.Tensor):
        mel_spec_2d = mel_spec_2d.cpu().numpy()
    # transpose if needed (time on x-axis)
    if mel_spec_2d.shape[0] == 60:
        mel_spec_2d = mel_spec_2d.T
    # plot mel spectrogram
    plot_mel_spectrogram(mel_spec_2d)


## MEL CALC

In [2]:
# ------------------------------------------------------------------
#  ReDimNet front-end settings (taken from the IDRnD repo defaults)
#    • 16 kHz audio
#    • pre-emphasis α = 0.97
#    • 25 ms window  (400 samples)
#    • 15 ms hop     (240 samples)  ➜ 134 frames for a 2-s clip
#    • 60 Mel bins, 20 Hz → 8 kHz
# ------------------------------------------------------------------
_PREEMPH  = 0.97
_SR       = 16_000
_N_FFT    = 512
_WIN_LEN  = 400
_HOP      = 240
_N_MELS   = 60
_F_MIN    = 20.0
_F_MAX    = 7600.0
_EPS      = 1e-6            # numerical stability


# Singleton MelSpectrogram so we build the kernel only once
_mel_layer = T.MelSpectrogram(
    sample_rate=_SR,
    n_fft=_N_FFT,
    win_length=_WIN_LEN,
    hop_length=_HOP,
    f_min=_F_MIN,
    f_max=_F_MAX,
    n_mels=_N_MELS,
    power=2.0,               # the original uses power-spec → log10 later
    center=True,
    pad_mode="reflect",
    window_fn=torch.hamming_window
)

def _pre_emphasis(wave: torch.Tensor, alpha: float = _PREEMPH) -> torch.Tensor:
    """y[n] = x[n] − α·x[n−1] (first sample unchanged)."""
    y = wave.clone()
    y[:, 1:] = y[:, 1:] - alpha * y[:, :-1]
    return y


def pad_or_crop_logmel(log_mel, target_frames=200):
    """
    Ensures log_mel is shaped [1, n_mels, target_frames] by:
    - Padding with zeros on the right if too short
    - Center-cropping if too long
    """
    B, M, T = log_mel.shape
    if T < target_frames:
        pad_amt = target_frames - T
        log_mel = F.pad(log_mel, (0, pad_amt))  # pad at end
        print(f"Padding log_mel from {T} to {target_frames} frames")
    elif T > target_frames:
        start = (T - target_frames) // 2
        log_mel = log_mel[:, :, start:start + target_frames]
        print(f"Cropping log_mel from {T} to {target_frames} frames")
    return log_mel


@torch.no_grad()
def waveform_to_logmel(wave: torch.Tensor) -> torch.Tensor:
    """
    Parameters
    ----------
    wave : Tensor [B', T] | [1, T]
        16-kHz mono waveform already trimmed / padded (32 000 samples for 2 s).

    Returns
    -------
    log_mel : Tensor [B', 1, 60, frames]
        Bit-exact front-end output expected by `model_no_mel`.
    """
    
    wave = wave[:,:32000]
    print(f"Input waveform shape: {wave.shape}")
    
    # Make sure we always have a batch dimension
    if wave.dim() == 1:      # (T,) → (1, T)
        wave = wave.unsqueeze(0)
    elif wave.dim() == 2 and wave.shape[0] > 1:
        raise ValueError("Input must be mono; got multi-channel tensor.")

    # pre-emphasis
    wave = _pre_emphasis(wave.float())

    # Mel power-spectrogram
    mel = _mel_layer(wave)
    mel = torch.log(mel + 1e-6)          # → [B, 60, frames]

    # log-scale (natural or log10 both work – log10 matches repo)
    mel = mel - mel.mean(dim=-1, keepdim=True)

    # pad/crop
    mel = pad_or_crop_logmel(mel, target_frames=134)  # Ensure 200 frames

    # add the dummy channel dim expected by Conv2d stem
    mel = mel.unsqueeze(1)                # → [B, 1, 60, frames]

    return mel

## cosine_similarity

In [None]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()

def cosine_similarity_numpys(emb1: np.ndarray, emb2: np.ndarray) -> float:
    """
    Compute cosine similarity between two vectors of shape (D,) or (1, D).
    """
    # If shape is (1, D), flatten to (D,)
    v1 = emb1.flatten()
    v2 = emb2.flatten()

    # dot product
    dot = np.dot(v1, v2)
    # norms
    norm1 = np.linalg.norm(v1)
    norm2 = np.linalg.norm(v2)

    # Add a small epsilon in case of very small norms
    sim = dot / (norm1 * norm2 + 1e-8)
    return sim


## ONNX Func

In [None]:
class NHWCWrapper(nn.Module):
    def __init__(self, model_nchw):
        super().__init__()
        self.model = model_nchw

    def forward(self, x):
        # x: NHWC => NCHW
        x = x.permute(0, 3, 1, 2).contiguous()
        return self.model(x)
    
    

def export_to_onnx(model, onnx_path="ReDimNet_no_mel.onnx"):
    model.eval()
    
    # Create a dummy input with shape matching [B=1, 1, n_mels=60, time_frames=134] (example)
    dummy_input = torch.randn(1, 1, 60, 134)

    #  fixed-length segments 
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["log_mel"],
        output_names=["embedding"],
        opset_version=17
    )
    
    # # dynamic axes for variable time frames
    # torch.onnx.export(
    #     model,
    #     dummy_input,
    #     onnx_path,
    #     input_names   = ["log_mel"],
    #     output_names  = ["embedding"],
    #     opset_version = 17,               # use a recent opset
    #     dynamic_axes = {
    #         # input  tensor : {axis_index : symbolic_name}
    #         "log_mel"  : {0: "batch",   2: "time"},   # B and T now flexible
    #         "embedding": {0: "batch"}                 # output length is fixed, batched
    #     }
    # ) 
    
    print("Exported to", onnx_path)
    



In [None]:
from onnxconverter_common.float16 import convert_float_to_float16

def restore_in_half_precision(onnx_path, output_path):
    """
    Convert an ONNX model to half precision (FP16).
    """
    model_fp32 = onnx.load(onnx_path)
    model_fp16 = convert_float_to_float16(model_fp32, keep_io_types=True)
    onnx.save(model_fp16, output_path)
    print(f"Converted {onnx_path} to half precision and saved as {output_path}")

