# Olfaction-Vision-Language-Embeddings

This is a quick start on loading the olfaction-vision-language models and getting the joint multimodal embeddings from an olfaction-vision data sample.

### Install Libraries

In [1]:
!pip install transformers
!pip install safetensors



### Import and Configure

In [3]:
import torch
import torch.nn as nn
from safetensors.torch import load_file
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from PIL import Image


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 512     # Embedding dims = 512 for small, 2048 for large
ENCODER_FILE_PATH = f"./olf_encoder_{EMBED_DIM}.pt"
GNN_FILE_PATH = f"./gnn_{EMBED_DIM}.pt"

### Embeddings Function

In [4]:
def get_embeddings(clip_model, olf_encoder, graph_model, image, olf_vec):
    """
    Gets joint olfaction-vision-language embeddings for a given image and olfaction vector.

    :param clip_model: vision-language model
    :param olf_encoder: olfactory encoder from aromas/molecules
    :param graph_model: cross-modal associator
    :param image: PIL image
    :param olf_vec: olfaction vector
    :return: joint olfaction-vision-language embeddings
    """
    clip_model.eval()
    olf_encoder.eval()
    graph_model.eval()

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    image_tensor = transform(image).unsqueeze(0).to(DEVICE)
    olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        vision_embeds = clip_model.get_image_features(pixel_values=image_tensor)
        if EMBED_DIM != 768 and EMBED_DIM != 512:
            projection = nn.Linear(vision_embeds.shape[-1], EMBED_DIM).to(DEVICE)
            vision_embeds = projection(vision_embeds).to(DEVICE)
        vision_embeds = vision_embeds.to(DEVICE)
        olf_embeds = olf_encoder(olf_tensor).to(DEVICE)
        ovl_embeds = graph_model(vision_embeds, olf_embeds).squeeze()

    return ovl_embeds

### Get Joint Embeddings from a Data Sample

In [5]:
# Load the models
olf_encoder = torch.jit.load(ENCODER_FILE_PATH)
graph_model = torch.jit.load(GNN_FILE_PATH)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)

# Build example vision-olfaction sample with dummy data
example_image = Image.new('RGB', (224, 224))
example_image.save(f"/tmp/image_example.jpg")
example_olf_vec = torch.randn(112)

# Run inference
embeddings = get_embeddings(
    clip_model,
    olf_encoder,
    graph_model,
    example_image,
    example_olf_vec
)
print("Embeddings", embeddings)

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

  olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE)


model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

Embeddings tensor([-1.2382e-02,  2.3777e-02, -9.0558e-03, -2.7979e-02, -5.9549e-03,
         1.4678e-02,  1.2424e-02,  2.2361e-02,  2.6167e-02,  4.0972e-02,
        -2.1756e-02,  3.2277e-02, -1.7444e-02,  5.4060e-03,  1.0332e-02,
         5.6603e-02, -1.3966e-02,  7.7477e-03, -1.8678e-03,  9.0651e-03,
         2.6837e-02,  9.8322e-03,  1.3258e-02, -3.7714e-02, -5.1809e-02,
         5.2755e-02,  1.2107e-02, -1.3018e-02,  1.8774e-02, -5.6916e-02,
        -2.9328e-02, -2.7885e-02,  1.4422e-03,  2.0761e-02, -2.9871e-02,
         2.7381e-02,  2.2001e-02, -2.1286e-02,  2.1925e-02, -2.4814e-03,
        -5.7491e-03,  9.8888e-03, -3.1911e-02, -3.9359e-02,  2.4043e-02,
         2.1701e-02,  6.7535e-04, -3.9431e-02,  3.2242e-02,  3.3601e-02,
         2.4195e-02,  1.9859e-02,  2.0833e-02, -4.2177e-03, -6.6240e-02,
         5.9470e-02,  1.4912e-02, -4.3088e-02, -4.2664e-02, -7.2853e-02,
        -4.0203e-02, -2.0882e-02,  2.9648e-02,  1.9767e-02,  4.8396e-02,
        -2.9807e-02,  3.1730e-02, -2.017