<a href="https://colab.research.google.com/github/BotsKnowBest/ImageBind/blob/main/BKB_ImageBind.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Clone ImageBind from GitHub, download test files, and install dependencies



In [1]:
!git clone https://github.com/facebookresearch/ImageBind.git
!ln -s ./ImageBind/bpe/ .

fatal: destination path 'ImageBind' already exists and is not an empty directory.
ln: failed to create symbolic link './bpe': File exists


In [2]:
!mkdir -p test_data
!wget -P test_data --quiet https://github.com/BotsKnowBest/ImageBind/raw/main/test_data/lotr_boro.wav
!wget -P test_data --quiet https://github.com/BotsKnowBest/ImageBind/raw/main/test_data/startup.wav

In [3]:
!cd ImageBind && pip install -r requirements.txt


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/, https://download.pytorch.org/whl/cu113
Collecting pytorchvideo@ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d (from -r requirements.txt (line 5))
  Cloning https://github.com/facebookresearch/pytorchvideo.git (to revision 28fe037d212663c6a24f373b94cc5d478c8c1a1d) to /tmp/pip-install-46pm72cb/pytorchvideo_73fc281c637c4a3ca1df28a447d08b93
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/pytorchvideo.git /tmp/pip-install-46pm72cb/pytorchvideo_73fc281c637c4a3ca1df28a447d08b93
  Running command git rev-parse -q --verify 'sha^28fe037d212663c6a24f373b94cc5d478c8c1a1d'
  Running command git fetch -q https://github.com/facebookresearch/pytorchvideo.git 28fe037d212663c6a24f373b94cc5d478c8c1a1d
  Running command git checkout -q 28fe037d212663c6a24f373b94cc5d478c8c1a1d
  Resolved https://github.com/facebook

###Import packages

In [4]:
%matplotlib inline

import numpy as np
import torch
import torch.nn.functional as F
import sys
import warnings
from pathlib import Path

import ipywidgets as widgets
import IPython
from IPython.display import Image
from tqdm.notebook import tqdm

In [5]:
sys.path.insert(0,'./ImageBind/')

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=UserWarning)
    import data as ib_data

from models import imagebind_model
from models.imagebind_model import ModalityType

### Instantiate ImageBind model

In [6]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

ib_model = imagebind_model.imagebind_huge(pretrained=True)
ib_model.eval()
ib_model.to(device)

print(f"ImageBind model loaded to {device}.")

Downloading imagebind weights to .checkpoints/imagebind_huge.pth ...


  0%|          | 0.00/4.47G [00:00<?, ?B/s]

ImageBind model loaded to cuda:0.


### Check that ImageBind inference works

In [7]:
test_text_list = ["A dog.", "A car", "A bird"]
test_image_paths = ["./ImageBind/.assets/dog_image.jpg", "./ImageBind/.assets/car_image.jpg", "./ImageBind/.assets/bird_image.jpg"]
test_audio_paths = ["./ImageBind/.assets/dog_audio.wav", "./ImageBind/.assets/car_audio.wav", "./ImageBind/.assets/bird_audio.wav"]

test_inputs = {
    ModalityType.TEXT: ib_data.load_and_transform_text(test_text_list, device),
    ModalityType.VISION: ib_data.load_and_transform_vision_data(test_image_paths, device),
    ModalityType.AUDIO: ib_data.load_and_transform_audio_data(test_audio_paths, device),
}

with torch.no_grad():
    test_embeddings = ib_model(test_inputs)

test_vt = torch.softmax(test_embeddings[ModalityType.VISION] @ test_embeddings[ModalityType.TEXT].T, dim=-1)
test_at = torch.softmax(test_embeddings[ModalityType.AUDIO] @ test_embeddings[ModalityType.TEXT].T, dim=-1)
test_va = torch.softmax(test_embeddings[ModalityType.VISION] @ test_embeddings[ModalityType.AUDIO].T, dim=-1)

print(f"Vision x Text: {test_vt}")
print(f"Audio x Text: {test_at}")
print(f"Vision x Audio: {test_va}")


Vision x Text: tensor([[9.9684e-01, 3.1310e-03, 2.5929e-05],
        [5.4494e-05, 9.9993e-01, 2.0353e-05],
        [4.4846e-05, 1.3246e-02, 9.8671e-01]], device='cuda:0')
Audio x Text: tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]], device='cuda:0')
Vision x Audio: tensor([[0.8064, 0.1051, 0.0885],
        [0.1284, 0.7205, 0.1511],
        [0.0016, 0.0022, 0.9962]], device='cuda:0')


### Download ImageNetV2Dataset

In [8]:
!pip install git+https://github.com/modestyachts/ImageNetV2_pytorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/modestyachts/ImageNetV2_pytorch
  Cloning https://github.com/modestyachts/ImageNetV2_pytorch to /tmp/pip-req-build-5k_egn01
  Running command git clone --filter=blob:none --quiet https://github.com/modestyachts/ImageNetV2_pytorch /tmp/pip-req-build-5k_egn01
  Resolved https://github.com/modestyachts/ImageNetV2_pytorch to commit 14d4456c39fe7f02a665544dd9fc37c1a5f8b635
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: imagenetv2-pytorch
  Building wheel for imagenetv2-pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for imagenetv2-pytorch: filename=imagenetv2_pytorch-0.1-py3-none-any.whl size=2659 sha256=67d4839408c73ac3fe4cdff1850a5bc603110cc7a197fb3ca0ad47770e0a829b
  Stored in directory: /tmp/pip-ephem-wheel-cache-uhtctu8r/wheels/ea/e3/2d/38c8d17086a0ea5890dc0d4796db505e41323d2d9800b56fa7
Successfully bu

In [9]:
from imagenetv2_pytorch import ImageNetV2Dataset

images = ImageNetV2Dataset()
print(f"{len(ImageNetV2Dataset().fnames)} images found.")

Dataset matched-frequency not found on disk, downloading....


100%|██████████| 1.26G/1.26G [00:22<00:00, 57.1MiB/s]


Extracting....
10000 images found.


### Compute ImageBind embeddings for ImageNetV2Dataset

In [10]:
def chunker(seq, size):
    return (seq[pos:pos+size] for pos in range(0, len(seq), size))

def get_load_and_transform_fn(modality):
    if modality in [ModalityType.VISION, ModalityType.THERMAL, ModalityType.DEPTH, ModalityType.IMU]:
        return ib_data.load_and_transform_vision_data
    elif modality == ModalityType.TEXT:
        return ib_data.load_and_transform_text
    elif modality == ModalityType.AUDIO:
        return ib_data.load_and_transform_audio_data
    elif modality == 'video':
        return ib_data.load_and_transform_video_data

def get_dataset_imagebin_embeddings(paths_batches, modality, batch_size=32):
    assert modality in vars(ModalityType).values()
    
    ib_embeds_fnames = []
    ib_embeds = []
    with torch.no_grad():
        fnames_batch = list(chunker(paths_batches, batch_size))
        for i, fnames_batch in enumerate(tqdm(fnames_batch)):
            inputs = {
                modality: get_load_and_transform_fn(modality)(fnames_batch, device)
            }
            ib_embed = ib_model(inputs)[modality]

            ib_embeds_fnames.extend(fnames_batch)
            ib_embeds.append(ib_embed)
            
    ib_embeds = torch.vstack(ib_embeds)
    return {
        'embeds': ib_embeds,
        'fnames': ib_embeds_fnames
    }
    

In [11]:
imagenetv2_images_paths = [str(n) for n in ImageNetV2Dataset().fnames]
ibembeds_dataset_vision = get_dataset_imagebin_embeddings(imagenetv2_images_paths, ModalityType.VISION)

print(ibembeds_dataset_vision['embeds'].shape)

  0%|          | 0/313 [00:00<?, ?it/s]

torch.Size([10000, 1024])


### Find 5 closest images

In [12]:
def get_closest_embeds(input_data, dataset_embeds):
    # Normalize dataset embeddings
    dataset_embeds_norm = dataset_embeds / dataset_embeds.norm(dim=-1, keepdim=True)
    dataset_embeds_norm = dataset_embeds_norm.cpu().numpy()
    
    # Pre-process inputs
    inputs = {}
    for modality in input_data:
        assert modality in vars(ModalityType).values()
        inputs[modality] = get_load_and_transform_fn(modality)([input_data[modality]], device)

    # Compute embeds
    with torch.no_grad():
        embeddings = ib_model(inputs)
    
    # Find closest embeddings
    closest_embeds = {}
    for modality in vars(ModalityType).values():
        if modality in input_data:
            vec_norm = embeddings[modality] / embeddings[modality].norm(dim=-1, keepdim=True)
            vec_norm = vec_norm.cpu().numpy()
            
            similarities = vec_norm @ dataset_embeds_norm.T
            closest_embeds[modality] = list(reversed(np.argsort(similarities, axis=1).tolist()[0]))

    return closest_embeds

In [13]:
def display_closest_images(input_data, dataset_vision, top_k=1):
    if dataset_vision:
        closest_embeds_vision = get_closest_embeds(input_data, dataset_vision['embeds'])
    
    assert top_k >= 1
    
    for modality in input_data:
        assert modality in vars(ModalityType).values()

        # Display image
        if dataset_vision:
            images_to_display = []
            for n in range(top_k):
                closest_image_idx = closest_embeds_vision[modality][n]
                closest_image_fname = dataset_vision['fnames'][closest_image_idx]

                
                img = open(closest_image_fname,'rb').read()
                images_to_display.append(widgets.Image(value=img, width=190))
                
            wid=widgets.HBox(images_to_display)
            display(wid)
    

In [14]:
input_data = {
    ModalityType.AUDIO: "test_data/startup.wav",
}

res = display_closest_images(
    input_data,
    dataset_vision=ibembeds_dataset_vision,
    top_k=5
)

HBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C…

In [15]:
input_data = {
    ModalityType.AUDIO: "test_data/lotr_boro.wav",
}

res = display_closest_images(
    input_data,
    dataset_vision=ibembeds_dataset_vision,
    top_k=5
)

HBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C…