<td>
   <a target="_blank" href="https://labelbox.com" ><img src="https://labelbox.com/blog/content/images/2021/02/logo-v4.svg" width=256/></a>
</td>

<td>
<a href="https://colab.research.google.com/github/Labelbox/labelbox-python/blob/develop/examples/integrations/huggingface/huggingface_custom_embeddings.ipynb" target="_blank"><img
src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
</td>

<td>
<a href="https://github.com/Labelbox/labelbox-python/tree/develop/examples/integrations/huggingface/huggingface_custom_embeddings.ipynb" target="_blank"><img
src="https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white" alt="GitHub"></a>
</td>

# Install required libraries

In [None]:
%pip install -q "labelbox[data]"
%pip install -q transformers

# Imports

In [None]:
import labelbox as lb
import transformers

transformers.logging.set_verbosity(50)
import torch
import torch.nn.functional as F
from PIL import Image
import requests
from tqdm import tqdm
import numpy as np

# Labelbox Credentials

In [None]:
# Add your API key
API_KEY = ""
client = lb.Client(API_KEY)

# Select data rows in Labelbox for custom embeddings

In [None]:
# Get images from a Labelbox dataset,
# Ensure the images are available by obtaining a token from your cloud provider if necessary
DATASET_ID = ""

In [None]:
dataset = client.get_dataset(DATASET_ID)

In [None]:
export_task = dataset.export_v2()

export_task.wait_till_done()
if export_task.errors:
    print(export_task.errors)
export_json = export_task.result

data_row_urls = [dr_url["data_row"]["row_data"] for dr_url in export_json]

# Get a HuggingFace Model to generate custom embeddings

In [None]:
# Get ResNet-50 from HuggingFace
image_processor = transformers.AutoImageProcessor.from_pretrained(
    "microsoft/resnet-50")
model = transformers.ResNetModel.from_pretrained("microsoft/resnet-50")

# Pick an existing custom embedding in Labelbox, or create a custom embedding

In [None]:
# Create a new embedding in your workspace, use the right dimensions to your use case, here we use 2048 for ResNet-50
new_custom_embedding_id = client.create_embedding(
    name="My new awesome embedding", dims=2048).id

# Or use an existing embedding from your workspace
# existing_embedding_id = client.get_embedding_by_name(name="ResNet img 2048").id

# Generate and upload custom embeddings

In [None]:
img_emb = []

for url in tqdm(data_row_urls):
    try:
        response = requests.get(url, stream=True)
        if response.status_code == 200:
            # Open the image, convert to RGB, and resize to 224x224
            image = Image.open(response.raw).convert("RGB").resize((224, 224))

            # Preprocess the image for model input
            img_hf = image_processor(image, return_tensors="pt")

            # Pass the image through the model to get embeddings
            with torch.no_grad():
                last_layer = model(**img_hf,
                                   output_hidden_states=True).last_hidden_state
                resnet_embeddings = F.adaptive_avg_pool2d(last_layer, (1, 1))
                resnet_embeddings = torch.flatten(resnet_embeddings,
                                                  start_dim=1,
                                                  end_dim=3)
                img_emb.append(resnet_embeddings.cpu().numpy())
        else:
            continue
    except Exception as e:
        print(f"Error processing URL: {url}. Exception: {e}")
        continue

data_rows = []

# Create data rows payload to send to a dataset
for url, embedding in tqdm(zip(data_row_urls, img_emb)):
    data_rows.append({
        "row_data":
            url,
        "embeddings": [{
            "embedding_id": new_custom_embedding_id,
            "vector": embedding[0].tolist(),
        }],
    })

In [None]:
# Upload to a new dataset
dataset = client.create_dataset(name="image_custom_embedding_resnet",
                                iam_integration=None)
task = dataset.create_data_rows(data_rows)
print(task.errors)