# Olfaction-Vision-Language-Classifier

This is a quick start on loading the olfaction-vision-language models and getting the probability/logits of the presence of observed chemical compounds in a visual scene given a set of aroma descriptors.

### Install Libraries

In [1]:
!pip install transformers
!pip install safetensors



### Import and Configure

In [2]:
import torch
import torch.nn as nn
from safetensors.torch import load_file
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from PIL import Image


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 512     # Embedding dims = 512 for classifiers
ENCODER_FILE_PATH = f"./olf_encoder_{EMBED_DIM}_c.pt"
GNN_FILE_PATH = f"./gnn_{EMBED_DIM}_c.pt"

### Embeddings Function

In [3]:
def get_embeddings(clip_model, olf_encoder, graph_model, image, olf_vec):
    """
    Gets joint olfaction-vision-language embeddings for a given image and olfaction vector.

    :param clip_model: vision-language model
    :param olf_encoder: olfactory encoder from aromas/molecules
    :param graph_model: cross-modal associator
    :param image: PIL image
    :param olf_vec: olfaction vector
    :return: joint olfaction-vision-language embeddings
    """
    clip_model.eval()
    olf_encoder.eval()
    graph_model.eval()

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    image_tensor = transform(image).unsqueeze(0).to(DEVICE)
    olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        vision_embeds = clip_model.get_image_features(pixel_values=image_tensor)
        if EMBED_DIM != 768 and EMBED_DIM != 512:
            projection = nn.Linear(vision_embeds.shape[-1], EMBED_DIM).to(DEVICE)
            vision_embeds = projection(vision_embeds).to(DEVICE)
        vision_embeds = vision_embeds.to(DEVICE)
        olf_embeds = olf_encoder(olf_tensor).to(DEVICE)
        ovl_logits = graph_model(vision_embeds, olf_embeds).squeeze()

    return ovl_logits

### Get Joint Embeddings from a Data Sample

In [4]:
# Load the models
olf_encoder = torch.jit.load(ENCODER_FILE_PATH)
graph_model = torch.jit.load(GNN_FILE_PATH)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)

# Build example vision-olfaction sample with dummy data
example_image = Image.new('RGB', (224, 224))
example_image.save(f"/tmp/image_example.jpg")
example_olf_vec = torch.randn(112)

# Run inference
logits = get_embeddings(
    clip_model,
    olf_encoder,
    graph_model,
    example_image,
    example_olf_vec
)
print("Logits", logits)

Access to the secret `HF_TOKEN` has not been granted on this notebook.
You will not be requested again.
Please restart the session if you want to be prompted again.
  olf_tensor = torch.tensor(olf_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE)


Logits tensor(3.8031e+23, device='cuda:0')
