<td>
   <a target="_blank" href="https://labelbox.com" ><img src="https://labelbox.com/blog/content/images/2021/02/logo-v4.svg" width=256/></a>
</td>

<td>
<a href="https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/custom_embeddings.ipynb" target="_blank"><img
src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
</td>

<td>
<a href="https://github.com/Labelbox/labelbox-python/tree/master/examples/basics/custom_embeddings.ipynb" target="_blank"><img
src="https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white" alt="GitHub"></a>
</td>

# Custom Embeddings

You can improve your data exploration and similarity search experience by adding your own custom embeddings. Labelbox allows you to upload up to 100 different custom embeddings on any kind of data. You can experiment with different embeddings to power your data selection.

# Setup

In [None]:
!pip3 install -q "labelbox"

In [None]:
import labelbox as lb
import numpy as np
import json

In [None]:
API_KEY = ""
client = lb.Client(API_KEY)

# Select data rows in Labelbox for custom embeddings

In [None]:
client.enable_experimental = True

# get images from a Labelbox dataset
# Our systems start to process data after 1000 embeddings of each type, for this demo make sure your dataset is over 1000 data rows
dataset = client.get_dataset("<ADD YOUR DATASET ID>")

export_task = dataset.export()
export_task.wait_till_done()

In [None]:
data_rows = []

def json_stream_handler(output: lb.JsonConverterOutput):
  data_row = json.loads(output.json_str)
  data_rows.append(data_row)

if export_task.has_errors():
  export_task.get_stream(
  converter=lb.JsonConverter(),
  stream_type=lb.StreamType.ERRORS
  ).start(stream_handler=lambda error: print(error))

if export_task.has_result():
  export_json = export_task.get_stream(
    converter=lb.JsonConverter(),
    stream_type=lb.StreamType.RESULT
  ).start(stream_handler=json_stream_handler)

In [None]:
data_row_ids = [dr["data_row"]["id"] for dr in data_rows]

data_row_ids = data_row_ids[:1000] # keep the first 1000 examples for the sake of this demo

# Create the payload for custom embeddings
-- It should be a .ndjson file.   
-- Every line is a json file that finishes with a \n character.  
-- It does not have to be created through Python.  

In [None]:
nb_data_rows = len(data_row_ids)
print("Number of data rows: ", nb_data_rows)
# Generate random vectors, of dimension 2048 each
# Labelbox supports custom embedding vectors of dimension up to 2048
custom_embeddings = [list(np.random.random(2048)) for _ in range(nb_data_rows)]

In [None]:
# Create the payload for custom embeddings
payload = []
for data_row_id,custom_embedding in zip(data_row_ids,custom_embeddings):
  payload.append({"id": data_row_id, "vector": custom_embedding})

print('payload', len(payload),payload[:1])

In [None]:
# Delete any pre-existing file
import os
if os.path.exists("payload.ndjson"):
  os.remove("payload.ndjson")

# Convert the payload to a JSON file
with open('payload.ndjson', 'w') as f:
  for p in payload:
    f.write(json.dumps(p) + "\n")
    # sanity_check_payload = json.dump(payload, f)

In [None]:
# Sanity check that you can read/load the file and the payload is correct
with open('payload.ndjson') as f:
    sanity_check_payload = [json.loads(l) for l in f.readlines()]
print("Nb of custom embedding vectors in sanity_check_payload: ", len(sanity_check_payload))

In [None]:
# See all custom embeddings available in your Labelbox workspace
embeddings = client.get_embeddings()

In [None]:
# Create a new custom embedding, unless you want to re-use one
embedding = client.create_embedding("my_custom_embedding_2048_dimensions", 2048)

In [None]:
# Delete a custom embedding
embedding.delete()

# Upload the payload to Labelbox

In [None]:
# Replace the current id with the newly generated id from the previous step, or any existing custom embedding id
embedding.import_vectors_from_file("./payload.ndjson")

# Get the count of imported vectors for a custom embedding

In [None]:
# Count how many data rows have a specific custom embedding (this can take a couple of minutes)
count = embedding.get_imported_vector_count()