In [1]:
import torch
import torch.nn as nn
import pandas as pd
import pathlib
import os, sys

currentUrl = os.path.dirname("./notebooks")
parentUrl = os.path.abspath(os.path.join(currentUrl, os.pardir))
sys.path.append(parentUrl)

from src.models.MultiModalFusion import MultiModalFusion
from src.trainer.MultiModalFusionTrainer import MultiModalFusionTrainer
from src.utils.Retrieval import FetchSimilar
import yaml
from IPython.display import Audio, display, display_jpeg, Image
from IPython.core.display import HTML
from PIL import Image

with open('../configs/MultiModalFusion.yaml', 'r') as f:
    config = yaml.safe_load(f)

fetcher = FetchSimilar(
    chkpt_path="../logs/MultiModalFusion/yd2gaqhs/checkpoints/epoch=31-val_loss=2.23-val_mean_similarity=0.37.ckpt",
    image_path="../datasets/speech-handsign_commands_balanced2/handsign/",
    audio_path="../datasets/speech-handsign_commands_balanced2/speech/",
    device="cuda"
)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of Wav2Vec2ConformerModel were not initialized from the model checkpoint at facebook/wav2vec2-conformer-rope-large-960h-ft and are newly initialized: ['wav2vec2_conformer.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2_conformer.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return F.conv1d(input, weight, bias, self.stride,


In [2]:
def render_top_k(query_path:str|pathlib.PosixPath, top_k:dict, query_class:str|None=None) -> None:
    data = []
    if type(query_path) == str:
        query_path = pathlib.Path(query_path)
    if query_path.suffix == ".wav":
        query_html = f'<audio controls src="{query_path}" style="display:block; margin:0 auto;"></audio>'
    else:
        query_html = f'<img src="{query_path}" alt="query image" style="max-width:300px; height:auto; display:block; margin:0 auto;">'
        
    for k, v in top_k.items():
        path, cls, embed, score = v.values()
        data += [{ "path": path, "cls": cls, "embed": embed, "score": score, "modality": k.split("#")[1]}]
    # Define a function to render HTML for images and audio
    def render_table(idx, row):
        if row["modality"] == "image":
            display = f'<img src="{row["path"]}" alt="Image" style="width:100px;height:auto;">'
        else:
            display = f'<audio controls src="{row["path"]}" style="width:200px;"></audio>'
        return f'<tr><td>{idx+1}</td><td>{row["cls"]}</td><td>{display}</td><td>{row["score"]:.3f}</td></tr>'
    
    # Generate the table HTML
    table_html = """
    <table border="1" style="border-collapse:collapse; text-align:center; margin:auto;">
        <tr>
            <th>Rank</th>
            <th>Class</th>
            <th>Display</th>
            <th>Similarity Score</th>
        </tr>
    """
    for idx, row in enumerate(data):
        table_html += render_table(idx, row)
    table_html += "</table>"

    # Combine top media and the table
    full_html = f"""
    <div style="text-align:center; margin-bottom:20px;">
        <b>{query_class if query_class is not None else str(query_path)}</b>
        {query_html}
    </div>
    {table_html}
    """
    
    # Display the complete HTML
    display(HTML(full_html))

In [4]:
query_path = "../datasets/speech-handsign_commands_balanced2/speech/no/no_11.wav"
top_k, query_info = fetcher.top_k(
    path=query_path,
    modality="image",
    k=10
)
render_top_k(query_path, top_k, "no")

Rank,Class,Display,Similarity Score
1,no,,0.159
2,no,,0.154
3,no,,0.153
4,no,,0.151
5,no,,0.151
6,no,,0.15
7,no,,0.15
8,no,,0.15
9,no,,0.15
10,no,,0.15


In [5]:
query_path = "../datasets/test_stop2.jpeg"
top_k, query_info = fetcher.top_k(
    path=query_path,
    modality="audio",
    k=10
)
render_top_k(query_path, top_k, "stop")

Rank,Class,Display,Similarity Score
1,stop,,0.147
2,stop,,0.146
3,yes,,0.142
4,left,,0.142
5,stop,,0.135
6,no,,0.135
7,stop,,0.133
8,yes,,0.132
9,stop,,0.131
10,stop,,0.129


In [7]:
query_path = "../datasets/test_stop.wav"
top_k, query_info = fetcher.top_k(
    path=query_path,
    modality="image",
    k=10
)
render_top_k(query_path, top_k, "stop")

Rank,Class,Display,Similarity Score
1,stop,,0.313
2,stop,,0.312
3,stop,,0.311
4,stop,,0.307
5,stop,,0.304
6,stop,,0.303
7,stop,,0.302
8,stop,,0.298
9,stop,,0.294
10,stop,,0.292


In [40]:
# model = MultiModalFusion(**config['model_params'])
model = MultiModalFusionTrainer.load_from_checkpoint("./logs/MultiModalFusion/bmwnmh9z/checkpoints/epoch=35-val_loss=2.18-val_mean_similarity=0.38.ckpt")
model = model.model.to('cuda')
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
img_file = './datasets/speech-handsign_commands_balanced/handsign/left/hand1_g_bot_seg_2_cropped.jpeg'
wav_file = './datasets/speech-handsign_commands_balanced/speech/go/888a0c49_nohash_2.wav'
img_embed = model.encode_image(img_file)
audio_embed = model.encode_speech(wav_file)
sim(img_embed, audio_embed)

Some weights of Wav2Vec2ConformerModel were not initialized from the model checkpoint at facebook/wav2vec2-conformer-rope-large-960h-ft and are newly initialized: ['wav2vec2_conformer.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2_conformer.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tensor(-0.0472, device='cuda:0')