# ReDimNet TO ReDimNetNoMel 

In [1]:
import torch
print(torch.__version__)

import torchaudio
import torchaudio.transforms as T

import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import copy



2.6.0+cu124


In [2]:
model_name='B0'
# train_type='ft_lm'
train_type='ptn'
dataset='vox2'

torch.hub.set_dir('/data/deep/redimnet/models')

original_model = torch.hub.load('IDRnD/ReDimNet', 'ReDimNet', 
                       model_name=model_name, 
                       train_type=train_type, 
                       dataset=dataset)

Using cache found in /data/deep/redimnet/models/IDRnD_ReDimNet_master


/data/deep/redimnet/models/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


In [3]:
from torchinfo import summary
summary(original_model, input_size=(1, 32000))

  with torch.cuda.amp.autocast(enabled=False):


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetWrap                                                 [1, 192]                  --
├─MelBanks: 1-1                                              [1, 60, 134]              --
│    └─Sequential: 2-1                                       [1, 60, 134]              --
│    │    └─Identity: 3-1                                    [1, 32000]                --
│    │    └─PreEmphasis: 3-2                                 [1, 32000]                --
│    │    └─MelSpectrogram: 3-3                              [1, 60, 134]              --
├─ReDimNet: 1-2                                              [1, 600, 134]             --
│    └─Sequential: 2-2                                       [1, 600, 134]             --
│    │    └─Conv2d: 3-4                                      [1, 10, 60, 134]          100
│    │    └─LayerNorm: 3-5                                   [1, 10, 60, 134]          20
│   

we can see MelSpectrogram inside the model ; lets take it outside the model;


In [4]:
for name, module in original_model.named_children():
    print(name, "=>", module)

backbone => ReDimNet(
  (stem): Sequential(
    (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
    (2): to1d()
  )
  (stage0): Sequential(
    (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
    (1): to2d(f=60,c=10)
    (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
    (3): ConvBlock2d(
      (conv_block): ResBasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
        (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
        (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
        (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
        (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      

In [5]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        self.pool = original_wrap.pool
        self.bn = original_wrap.bn
        self.linear = original_wrap.linear

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(original_model)
model_no_mel.eval()



ReDimNetNoMel(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=60,c=10)
      (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_st

## Utility Function for WAV -> MelSpectrogram

In [6]:
def pad_or_crop_logmel(log_mel, target_frames=200):
    """
    Ensures log_mel is shaped [1, n_mels, target_frames] by:
    - Padding with zeros on the right if too short
    - Center-cropping if too long
    """
    B, M, T = log_mel.shape
    if T < target_frames:
        pad_amt = target_frames - T
        log_mel = F.pad(log_mel, (0, pad_amt))  # pad at end
        print(f"Padding log_mel from {T} to {target_frames} frames")
    elif T > target_frames:
        start = (T - target_frames) // 2
        log_mel = log_mel[:, :, start:start + target_frames]
        print(f"Cropping log_mel from {T} to {target_frames} frames")
    return log_mel


# my orig
# def waveform_to_logmel(
#     waveform: torch.Tensor,
#     sample_rate: int = 16000,
#     n_fft: int = 512,
#     hop_length: int = 160,
#     n_mels: int = 60,       ## 72 for vox2 ;  60 for B0
#     f_min: float = 20.0,
#     f_max: float = 8000.0,
#     preemphasis_alpha: float = 0.97,
#     target_frames=200
# ):
#     """
#     Reproduces the main logic of 'NormalizeAudio', 'PreEmphasis',
#     and 'MelSpectrogram' from the 'MelBanks' layer.
#     """

#     # 1) NormalizeAudio
#     waveform = waveform / (waveform.abs().max() + 1e-8)

#     # 2) PreEmphasis
#     shifted = torch.roll(waveform, shifts=1, dims=1)
#     waveform_preemph = waveform - preemphasis_alpha * shifted
#     # fix first sample
#     waveform_preemph[:, 0] = waveform[:, 0]

#     # 3) MelSpectrogram
#     ## todo: check if log not twice;
#     mel_transform = torchaudio.transforms.MelSpectrogram(
#         sample_rate=sample_rate,
#         n_fft=n_fft,
#         hop_length=hop_length,
#         n_mels=n_mels,
#         f_min=f_min,
#         f_max=f_max,
#         power=2.0,
#         center=False
#     )
#     mel_spec = mel_transform(waveform_preemph)  # shape: [channel=1, n_mels, time_frames]

#     # Log scale
#     log_mel = torch.log(torch.clamp(mel_spec, min=1e-8, max=1e8))

    
#     # 5) Pad/crop to fixed number of frames
#     log_mel = pad_or_crop_logmel(log_mel, target_frames=target_frames)
    
#     # 6) Standardize
#     # mean = log_mel.mean()
#     # std = log_mel.std()
#     # log_mel = (log_mel - mean) / (std + 1e-8)
#     # print(f"log_mel  mean={mean:.4f}  std={std:.4f}  min={log_mel.min():.4f}  max={log_mel.max():.4f}")
#     log_mel = (log_mel - 0.0) / 1.0

#     print("Log-mel shape:", log_mel.shape)
#     return log_mel


# chat1
# def waveform_to_logmel(
#         waveform: torch.Tensor,
#         sample_rate: int = 16000,
#         *,
#         # ---------- MelBanks hyper-parameters ----------
#         n_fft: int = 512,
#         win_length: int = 400,
#         hop_length: int = 240,
#         n_mels: int = 60,
#         f_min: float = 20.0,
#         f_max: float = 7600.0,
#         # ---------- Pre-processing switches ----------
#         do_preemph: bool = True,
#         preemph: float = 0.97,
#         norm_signal: bool = False,
#         # ---------- Spec-norm ----------
#         top_db: float = 80.0,
#         target_frames: int = 200,
#         eps: float = 1e-8):
#     """
#     Re-implements MelBanks exactly as used in the ReDimNet checkpoints whose
#     config string includes:

#         {'sample_rate': 16000, 'n_fft': 512, 'win_length': 400,
#          'hop_length': 240, 'f_min': 20, 'f_max': 7600,
#          'n_mels': 60, 'norm_signal': False, 'do_preemph': True}
#     The output shape is [1, n_mels, target_frames].
#     """

#     # 0) (Optional) signal-level normalisation
#     if norm_signal:
#         waveform = waveform / (waveform.abs().max() + eps)

#     # 1) Pre-emphasis
#     if do_preemph:
#         waveform[:, 1:] -= preemph * waveform[:, :-1]

#     # 2) Mel power spectrogram  →  dB
#     mel = torchaudio.transforms.MelSpectrogram(
#               sample_rate=sample_rate,
#               n_fft=n_fft,
#               win_length=win_length,
#               hop_length=hop_length,
#               n_mels=n_mels,
#               f_min=f_min,
#               f_max=f_max,
#               power=2.0,            # **power** spectrum!
#               center=True)(waveform)

#     logmel = torchaudio.transforms.AmplitudeToDB(top_db=top_db)(mel)

#     # 3) Pad / centre-crop to exactly `target_frames`
#     _, _, T = logmel.shape
#     if T < target_frames:
#         logmel = F.pad(logmel, (0, target_frames - T))
#     elif T > target_frames:
#         start = (T - target_frames) // 2
#         logmel = logmel[:, :, start:start + target_frames]

#     # 4) spec_norm – per-mel-band CMVN over the **time** axis
#     mean = logmel.mean(dim=-1, keepdim=True)
#     std  = logmel.std (dim=-1, keepdim=True)
#     logmel = (logmel - mean) / (std + eps)

#     return logmel        # shape [1, n_mels, target_frames]

# pas1
def waveform_to_logmel(waveform: torch.Tensor, sample_rate: int):
    """Convert [1, T] waveform to log-mel [1, 1, n_mels, frames] using PyTorch's MelBanks logic."""
    # Pre-emphasis
    waveform = waveform.float()
    waveform = torch.cat([waveform[:, :1], waveform[:, 1:] - 0.97 * waveform[:, :-1]], dim=1)

    # MelSpectrogram matches internal settings
    mel = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=512,
        win_length=400,
        hop_length=160,
        f_min=20.0,
        f_max=7600.0,
        n_mels=60,
        window_fn=torch.hamming_window,
        power=2.0,
        norm='slaney',
        mel_scale='slaney',
        center=True,
        pad_mode='reflect'
    )(waveform)  # shape: [1, n_mels, frames]

    # Add epsilon, log, mean-normalize
    mel = torch.log(mel + 1e-6)
    
    # 3) Pad / centre-crop to exactly `target_frames`
    target_frames=200
    _, _, T = mel.shape
    if T < target_frames:
        mel = F.pad(mel, (0, target_frames - T))
    elif T > target_frames:
        start = (T - target_frames) // 2
        mel = mel[:, :, start:start + target_frames]
    
    mel = mel - mel.mean(dim=-1, keepdim=True)

    return mel

In [7]:
import librosa
def waveform_to_logmel_librosa(
    waveform: torch.Tensor, sample_rate: int = 16000,
    n_fft: int = 512, win_length: int = 400, hop_length: int = 160,
    n_mels: int = 60, fmin: float = 20.0, fmax: float = 7600.0,
    preemph: float = 0.97
) -> torch.Tensor:
    """
    Convert waveform [1, T] to log-mel [1, 1, n_mels, frames] using Librosa,
    matching Slaney normalization and ReDimNet conventions.
    """
    # 1) Pre-emphasis
    wav = waveform.squeeze(0).cpu().numpy()
    wav = np.append(wav[0], wav[1:] - preemph * wav[:-1])

    # 2) Mel spectrogram with Slaney normalization
    S = librosa.feature.melspectrogram(
        y=wav,
        sr=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window='hann',
        center=True,
        pad_mode='reflect',
        power=2.0,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax,
        htk=False,
        norm='slaney'  # 🔑 area-normalized triangular filters :contentReference[oaicite:1]{index=1}
    )

    # 3) Add epsilon, log, mean-normalize
    logmel = np.log(S + 1e-6)
    logmel = logmel - np.mean(logmel, axis=1, keepdims=True)

    # 4) To PyTorch tensor with required shape
    logmel_t = torch.from_numpy(logmel).unsqueeze(0).unsqueeze(0)
    return logmel_t.float()


In [8]:
def example_inference(wav_path: str):
    # (a) Load audio
    waveform, sample_rate = torchaudio.load(wav_path)  # shape: [channels, time]
    # If stereo, select one channel, or average:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    target_sample_rate = 16000  # Force to 16kHz as per model requirements
    # Resample if needed
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)

    # (b) Convert to log-mel
    log_mel = waveform_to_logmel(waveform, sample_rate=target_sample_rate)
    # Now log_mel shape is [1, n_mels, time_frames].

    # (c) Model expects a batch, so add batch dim => [B=1, 1, n_mels, time_frames]
    log_mel = log_mel.unsqueeze(0)
    print('feeding logmel shape:', log_mel.shape)

    # (d) Forward pass
    with torch.no_grad():
        embedding = model_no_mel(log_mel)  # shape typically [1, 192] or so

    print("Embedding shape:", embedding.shape)
    #print("Embedding:", embedding)
    return embedding

In [9]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()

def cosine_similarity_numpys(emb1: np.ndarray, emb2: np.ndarray) -> float:
    """
    Compute cosine similarity between two vectors of shape (D,) or (1, D).
    """
    # If shape is (1, D), flatten to (D,)
    v1 = emb1.flatten()
    v2 = emb2.flatten()

    # dot product
    dot = np.dot(v1, v2)
    # norms
    norm1 = np.linalg.norm(v1)
    norm2 = np.linalg.norm(v2)

    # Add a small epsilon in case of very small norms
    sim = dot / (norm1 * norm2 + 1e-8)
    return sim


In [10]:
embed0 = example_inference("test000.wav")


# embed1 = example_inference("testRob1.wav")
# embed2 = example_inference("testRob2.wav")
embed1 = example_inference("pas_1.wav")
embed2 = example_inference("pas_2.wav")


embed3 = example_inference("testme1.wav")
embed4 = example_inference("testme2.wav")

feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: torch.Size([1, 192])
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: torch.Size([1, 192])
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: torch.Size([1, 192])
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: torch.Size([1, 192])
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: torch.Size([1, 192])


In [11]:
print(f"Similarity (robot to robot): {cosine_similarity(embed1, embed2)}")
print(f"Similarity (robot to webvoice): {cosine_similarity(embed1, embed0)}")
print(f"Similarity (robot to me1   ): {cosine_similarity(embed1, embed3)}")
print(f"Similarity (robot to me2   ): {cosine_similarity(embed1, embed4)}")
print(f"Similarity (me 1 to me 2  ): {cosine_similarity(embed3, embed4)}")

Similarity (robot to robot): 0.7475607395172119
Similarity (robot to webvoice): 0.27281489968299866
Similarity (robot to me1   ): 0.22384625673294067
Similarity (robot to me2   ): 0.3730255663394928
Similarity (me 1 to me 2  ): 0.33060920238494873


In [12]:
embed1[0][:10]

tensor([-2.7882, -4.7194, -0.9860, -3.0060, -1.1502, -2.5357,  2.7918, -1.2449,
        -1.1448,  2.1949])

In [13]:
#load pt embedding from file to compare 
refB0_embed0 = torch.load("/tmp/refB0_embed0.pt")
refB0_embed1 = torch.load("/tmp/refB0_embed1.pt")
refB0_embed2 = torch.load("/tmp/refB0_embed2.pt")
refB0_embed3 = torch.load("/tmp/refB0_embed3.pt")
refB0_embed4 = torch.load("/tmp/refB0_embed4.pt")

print(f"Similarity (web to REF): {cosine_similarity(refB0_embed0, embed0)}")
print(f"Similarity (robot1 to REF): {cosine_similarity(refB0_embed1, embed1)}")
print(f"Similarity (robot2 to REF): {cosine_similarity(refB0_embed2, embed2)}")
print(f"Similarity (me1 to REF): {cosine_similarity(refB0_embed3, embed3)}")
print(f"Similarity (me2 to REF): {cosine_similarity(refB0_embed4, embed4)}")


Similarity (web to REF): 0.3106271028518677
Similarity (robot1 to REF): 0.32703477144241333
Similarity (robot2 to REF): 0.22965891659259796
Similarity (me1 to REF): 0.3689720928668976
Similarity (me2 to REF): 0.4013173580169678


## store

In [14]:
class NHWCWrapper(nn.Module):
    def __init__(self, model_nchw):
        super().__init__()
        self.model = model_nchw

    def forward(self, x):
        # x: NHWC => NCHW
        x = x.permute(0, 3, 1, 2).contiguous()
        return self.model(x)

In [15]:
import onnx

def export_to_onnx(model, onnx_path="ReDimNet_no_mel.onnx"):
    model.eval()
    
    # Create a dummy input with shape matching [B=1, 1, n_mels=72, time_frames=200] (example)
    dummy_input = torch.randn(1, 1, 60, 200)
    
    #  fixed-length segments 
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["log_mel"],
        output_names=["embedding"],
        opset_version=11
    )
    print("Exported to", onnx_path)
    
    # ## Export with Dynamic Axis
    # torch.onnx.export(
    #     model_no_mel,                 # your model
    #     dummy_input,                  # e.g. shape [1, 1, 60, 200]
    #     onnx_path,
    #     input_names=["log_mel"],
    #     output_names=["embedding"],
    #     opset_version=11,
    #     dynamic_axes={
    #         "log_mel": {0: "batch_size", 3: "time_frames"},
    #         "embedding": {0: "batch_size"}
    #     }
    # )
    # print("Exported to", onnx_path)

# Example usage
export_to_onnx(model_no_mel)

Exported to ReDimNet_no_mel.onnx


In [16]:
export_to_onnx(model_no_mel,onnx_path = "ReDimNet_no_mel.onnx")
!ls -lah ReDimNet_no_mel.onnx

Exported to ReDimNet_no_mel.onnx
-rw-rw-r-- 1 vlad vlad 4.1M Jun 16 17:10 ReDimNet_no_mel.onnx


### store half

In [17]:
model_NHWC = NHWCWrapper(model_no_mel)
dummy_input_NHWC = torch.rand(1, 60, 200, 1)*16-8

fp16_net = copy.deepcopy(model_NHWC).half().eval()
fp16_dummy = dummy_input_NHWC.half()

#  fixed-length segments 
torch.onnx.export(
   fp16_net,
   fp16_dummy,
   "ReDimNet_no_mel_half.onnx",
   input_names=["log_mel"],
   output_names=["embedding"],
   opset_version=13
)

PyTorch's .half() is unreliable for full model precision control in ONNX. Instead:

In [18]:
from onnxconverter_common.float16 import convert_float_to_float16
import onnx

model_fp32 = onnx.load("ReDimNet_no_mel.onnx")
model_fp16 = convert_float_to_float16(model_fp32, keep_io_types=True)
onnx.save(model_fp16, "ReDimNet_no_mel_fp16.onnx")



## verify

In [19]:
onnx_model_path = "ReDimNet_no_mel.onnx"
# onnx_model_path = "ReDimNet_no_mel_fp16.onnx"

In [20]:
import onnx
onnx_model = onnx.load("ReDimNet_no_mel.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")


ONNX model is valid!


In [21]:
import onnxruntime as ort

def run_inference_onnx(onnx_path, wav_path):
    """
    Loads an audio file, converts to log-mel, and runs inference
    in an ONNX session. Returns the embedding as a NumPy array.
    """
    #######################################
    # 1) Load your ONNX model
    #######################################
    # (Optional) onnx.checker to confirm it’s valid
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print(f"Loaded and checked ONNX model from: {onnx_path}")

    # Create an inference session
    session = ort.InferenceSession(onnx_path)

    # Usually we retrieve the first input & output name
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    #######################################
    # 2) Load audio, get log-mel
    #######################################
    waveform, sample_rate = torchaudio.load(wav_path)
    # If multi-channel, downmix:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
        
    # Resample if needed
    target_sample_rate=16000
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)

    log_mel = waveform_to_logmel(waveform, sample_rate=target_sample_rate)
    # Insert a batch dimension => shape [B, 1, n_mels, frames]
    log_mel = log_mel.unsqueeze(0)  # => [1, 1, n_mels, time_frames]
    print('feeding logmel shape:', log_mel.shape)
    
    #######################################
    # 3) ONNX Inference
    #######################################
    # Convert to NumPy for ONNX runtime
    log_mel_np = log_mel.cpu().numpy()
    
    
    # Run inference
    outputs = session.run([output_name], {input_name: log_mel_np})
    # outputs is a list; typically we want the first item
    embedding = outputs[0]  # shape is [1, embedding_dim]

    print("Embedding shape:", embedding.shape)
    # print("Embedding data:\n", embedding)
    return embedding

In [22]:
embed0 = run_inference_onnx(onnx_model_path, "test000.wav")


# embed1 = run_inference_onnx(onnx_model_path, "testRob1.wav")
# embed2 = run_inference_onnx(onnx_model_path, "testRob2.wav")
embed1 = run_inference_onnx(onnx_model_path, "pas_1.wav")
embed2 = run_inference_onnx(onnx_model_path, "pas_2.wav")



embed3 = run_inference_onnx(onnx_model_path, "testme1.wav")
embed4 = run_inference_onnx(onnx_model_path, "testme2.wav")

Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
feeding logmel shape: torch.Size([1, 1, 60, 200])
Embedding shape: (1, 192)


In [23]:
print(f"Similarity (robot to robot): {cosine_similarity_numpys(embed1, embed2)}")
print(f"Similarity (robot to webvoice): {cosine_similarity_numpys(embed1, embed0)}")
print(f"Similarity (robot to me1   ): {cosine_similarity_numpys(embed1, embed3)}")
print(f"Similarity (robot to me2   ): {cosine_similarity_numpys(embed1, embed4)}")
print(f"Similarity (me 1 to me 2  ): {cosine_similarity_numpys(embed3, embed4)}")

Similarity (robot to robot): 0.747560441493988
Similarity (robot to webvoice): 0.2728148102760315
Similarity (robot to me1   ): 0.2238464206457138
Similarity (robot to me2   ): 0.37302565574645996
Similarity (me 1 to me 2  ): 0.33060914278030396


In [24]:
embed1[0][:10]

array([-2.7882264 , -4.7193723 , -0.98600155, -3.0060067 , -1.1502321 ,
       -2.5357099 ,  2.7918017 , -1.2448845 , -1.144834  ,  2.1948512 ],
      dtype=float32)