# Chroma

This notebook covers how to get started with the `Chroma` vector store.

>[Chroma](https://docs.trychroma.com/getting-started) is a AI-native open-source vector database focused on developer productivity and happiness. Chroma is licensed under Apache 2.0. View the full docs of `Chroma` at [this page](https://docs.trychroma.com/reference/py-collection), and find the API reference for the LangChain integration at [this page](https://python.langchain.com/api_reference/chroma/vectorstores/langchain_chroma.vectorstores.Chroma.html).

## Setup

To access `Chroma` vector stores you'll need to install the `langchain-chroma` integration package.

In [None]:
pip install -qU "langchain-chroma>=0.1.2"

### Credentials

You can use the `Chroma` vector store without any credentials, simply installing the package above is enough!

If you want to get best in-class automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:

In [None]:
# os.environ["LANGSMITH_API_KEY"] = getpass.getpass("Enter your LangSmith API key: ")
# os.environ["LANGSMITH_TRACING"] = "true"

## Initialization

### Basic Initialization

Below is a basic initialization, including the use of a directory to save the data locally.

import EmbeddingTabs from "@theme/EmbeddingTabs";

<EmbeddingTabs/>


In [None]:
# | output: false
# | echo: false
from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings(model="text-embedding-3-large")

In [None]:
from langchain_chroma import Chroma

vector_store = Chroma(
    collection_name="example_collection",
    embedding_function=embeddings,
    persist_directory="./chroma_langchain_db",  # Where to save data locally, remove if not necessary
)

### Initialization from client

You can also initialize from a `Chroma` client, which is particularly useful if you want easier access to the underlying database.

In [None]:
import chromadb

persistent_client = chromadb.PersistentClient()
collection = persistent_client.get_or_create_collection("collection_name")
collection.add(ids=["1", "2", "3"], documents=["a", "b", "c"])

vector_store_from_client = Chroma(
    client=persistent_client,
    collection_name="collection_name",
    embedding_function=embeddings,
)

## Manage vector store

Once you have created your vector store, we can interact with it by adding and deleting different items.

### Add items to vector store

We can add items to our vector store by using the `add_documents` function.

In [None]:
from uuid import uuid4

from langchain_core.documents import Document

document_1 = Document(
    page_content="I had chocolate chip pancakes and scrambled eggs for breakfast this morning.",
    metadata={"source": "tweet"},
    id=1,
)

document_2 = Document(
    page_content="The weather forecast for tomorrow is cloudy and overcast, with a high of 62 degrees.",
    metadata={"source": "news"},
    id=2,
)

document_3 = Document(
    page_content="Building an exciting new project with LangChain - come check it out!",
    metadata={"source": "tweet"},
    id=3,
)

document_4 = Document(
    page_content="Robbers broke into the city bank and stole $1 million in cash.",
    metadata={"source": "news"},
    id=4,
)

document_5 = Document(
    page_content="Wow! That was an amazing movie. I can't wait to see it again.",
    metadata={"source": "tweet"},
    id=5,
)

document_6 = Document(
    page_content="Is the new iPhone worth the price? Read this review to find out.",
    metadata={"source": "website"},
    id=6,
)

document_7 = Document(
    page_content="The top 10 soccer players in the world right now.",
    metadata={"source": "website"},
    id=7,
)

document_8 = Document(
    page_content="LangGraph is the best framework for building stateful, agentic applications!",
    metadata={"source": "tweet"},
    id=8,
)

document_9 = Document(
    page_content="The stock market is down 500 points today due to fears of a recession.",
    metadata={"source": "news"},
    id=9,
)

document_10 = Document(
    page_content="I have a bad feeling I am going to get deleted :(",
    metadata={"source": "tweet"},
    id=10,
)

documents = [
    document_1,
    document_2,
    document_3,
    document_4,
    document_5,
    document_6,
    document_7,
    document_8,
    document_9,
    document_10,
]
uuids = [str(uuid4()) for _ in range(len(documents))]

vector_store.add_documents(documents=documents, ids=uuids)

['f22ed484-6db3-4b76-adb1-18a777426cd6',
 'e0d5bab4-6453-4511-9a37-023d9d288faa',
 '877d76b8-3580-4d9e-a13f-eed0fa3d134a',
 '26eaccab-81ce-4c0a-8e76-bf542647df18',
 'bcaa8239-7986-4050-bf40-e14fb7dab997',
 'cdc44b38-a83f-4e49-b249-7765b334e09d',
 'a7a35354-2687-4bc2-8242-3849a4d18d34',
 '8780caf1-d946-4f27-a707-67d037e9e1d8',
 'dec6af2a-7326-408f-893d-7d7d717dfda9',
 '3b18e210-bb59-47a0-8e17-c8e51176ea5e']

### Update items in vector store

Now that we have added documents to our vector store, we can update existing documents by using the `update_documents` function.

In [None]:
updated_document_1 = Document(
    page_content="I had chocolate chip pancakes and fried eggs for breakfast this morning.",
    metadata={"source": "tweet"},
    id=1,
)

updated_document_2 = Document(
    page_content="The weather forecast for tomorrow is sunny and warm, with a high of 82 degrees.",
    metadata={"source": "news"},
    id=2,
)

vector_store.update_document(document_id=uuids[0], document=updated_document_1)
# You can also update multiple documents at once
vector_store.update_documents(
    ids=uuids[:2], documents=[updated_document_1, updated_document_2]
)

### Delete items from vector store

We can also delete items from our vector store as follows:

In [None]:
vector_store.delete(ids=uuids[-1])

## Query vector store

Once your vector store has been created and the relevant documents have been added you will most likely wish to query it during the running of your chain or agent.

### Query directly

#### Similarity search

Performing a simple similarity search can be done as follows:

In [None]:
results = vector_store.similarity_search(
    "LangChain provides abstractions to make working with LLMs easy",
    k=2,
    filter={"source": "tweet"},
)
for res in results:
    print(f"* {res.page_content} [{res.metadata}]")

* Building an exciting new project with LangChain - come check it out! [{'source': 'tweet'}]
* LangGraph is the best framework for building stateful, agentic applications! [{'source': 'tweet'}]


#### Similarity search with score

If you want to execute a similarity search and receive the corresponding scores you can run:

In [None]:
results = vector_store.similarity_search_with_score(
    "Will it be hot tomorrow?", k=1, filter={"source": "news"}
)
for res, score in results:
    print(f"* [SIM={score:3f}] {res.page_content} [{res.metadata}]")

* [SIM=1.726390] The stock market is down 500 points today due to fears of a recession. [{'source': 'news'}]


#### Search by vector

You can also search by vector:

In [None]:
results = vector_store.similarity_search_by_vector(
    embedding=embeddings.embed_query("I love green eggs and ham!"), k=1
)
for doc in results:
    print(f"* {doc.page_content} [{doc.metadata}]")

* I had chocalate chip pancakes and fried eggs for breakfast this morning. [{'source': 'tweet'}]


#### Other search methods

There are a variety of other search methods that are not covered in this notebook, such as MMR search or searching by vector. For a full list of the search abilities available for `AstraDBVectorStore` check out the [API reference](https://python.langchain.com/api_reference/astradb/vectorstores/langchain_astradb.vectorstores.AstraDBVectorStore.html).

### Query by turning into retriever

You can also transform the vector store into a retriever for easier usage in your chains. For more information on the different search types and kwargs you can pass, please visit the API reference [here](https://python.langchain.com/api_reference/chroma/vectorstores/langchain_chroma.vectorstores.Chroma.html#langchain_chroma.vectorstores.Chroma.as_retriever).

In [None]:
retriever = vector_store.as_retriever(
    search_type="mmr", search_kwargs={"k": 1, "fetch_k": 5}
)
retriever.invoke("Stealing from the bank is a crime", filter={"source": "news"})

[Document(metadata={'source': 'news'}, page_content='Robbers broke into the city bank and stole $1 million in cash.')]

## Usage for retrieval-augmented generation

For guides on how to use this vector store for retrieval-augmented generation (RAG), see the following sections:

- [Tutorials](/docs/tutorials/)
- [How-to: Question and answer with RAG](https://python.langchain.com/docs/how_to/#qa-with-rag)
- [Retrieval conceptual docs](https://python.langchain.com/docs/concepts/retrieval)

## API reference

For detailed documentation of all `Chroma` vector store features and configurations head to the API reference: https://python.langchain.com/api_reference/chroma/vectorstores/langchain_chroma.vectorstores.Chroma.html

In [44]:
import torch
import time
import psutil
import os
import tqdm
import numpy as np
from geoclip import LocationEncoder
from torch.cuda.amp import autocast, GradScaler

# Utility to measure memory usage
def get_memory_usage():
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss  # Resident Set Size in bytes
    return mem / (1024 ** 2)  # Convert to MB

# Generate random GPS locations
def generate_random_locations(num_locations=100000):
    latitudes = np.random.uniform(-90, 90, num_locations)
    longitudes = np.random.uniform(-180, 180, num_locations)
    return torch.tensor(np.column_stack((latitudes, longitudes)), dtype=torch.float32)

# Main embedding and evaluation function
def embed_and_evaluate(
    gps_encoder,
    gps_data,
    batch_size=4096,
    output_file="location_embeddings.pt",
    mixed_precision=True
):
    """
    Embed GPS locations and measure performance.
    Args:
        gps_encoder (LocationEncoder): The GPS encoder model.
        gps_data (torch.Tensor): GPS data to embed.
        batch_size (int): Number of locations per batch for embedding.
        output_file (str): Path to save embeddings.
        mixed_precision (bool): Whether to use mixed precision for faster computation.
    """
    num_locations = gps_data.shape[0]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gps_encoder.to(device)
    gps_data = gps_data.to(device)

    # Initialize storage for embeddings
    embeddings = torch.empty((num_locations, 512), device=device)  # Assuming 512 output dimensions

    scaler = GradScaler() if mixed_precision else None
    start_time = time.time()

    print(f"Embedding {num_locations} locations in batches of {batch_size}...")
    for i in tqdm.tqdm(range(0, num_locations, batch_size), desc="Embedding Progress", unit="batch"):
        batch = gps_data[i:i+batch_size]
        with torch.no_grad():
            if mixed_precision:
                with autocast():
                    embeddings[i:i+batch_size] = gps_encoder(batch)
            else:
                embeddings[i:i+batch_size] = gps_encoder(batch)

    total_time = time.time() - start_time
    total_memory = get_memory_usage()

    # Save embeddings directly from GPU
    torch.save(embeddings, output_file)

    # Metrics
    throughput = num_locations / total_time  # Locations per second
    print(f"\nEmbedding complete!")
    print(f"Total locations: {num_locations}")
    print(f"Total time taken: {total_time:.2f} seconds")
    print(f"Throughput: {throughput:.2f} locations/sec")
    print(f"Memory used: {total_memory:.2f} MB")
    print(f"Embeddings saved to: {output_file}")

    # Return metrics for further use
    return {
        "num_locations": num_locations,
        "total_time": total_time,
        "throughput": throughput,
        "memory_used": total_memory,
        "embedding_shape": embeddings.shape
    }

# Main execution
if __name__ == "__main__":
    # Initialize encoder
    gps_encoder = LocationEncoder()

    # Generate random GPS data
    gps_data = generate_random_locations(num_locations=10000)

    # Embed and evaluate
    metrics = embed_and_evaluate(
        gps_encoder,
        gps_data,
        batch_size=4096,
        output_file="location_embeddings.pt",
        mixed_precision=True  # Enable mixed precision
    )

    # Print detailed metrics
    print("\nPerformance Metrics:")
    for key, value in metrics.items():
        print(f"{key}: {value}")

  self.load_state_dict(torch.load(f"{file_dir}/weights/location_encoder_weights.pth"))
  scaler = GradScaler() if mixed_precision else None


Embedding 10000 locations in batches of 4096...


  with autocast():
Embedding Progress: 100%|██████████| 3/3 [00:20<00:00,  6.89s/batch]


Embedding complete!
Total locations: 10000
Total time taken: 20.68 seconds
Throughput: 483.56 locations/sec
Memory used: 2939.89 MB
Embeddings saved to: location_embeddings.pt

Performance Metrics:
num_locations: 10000
total_time: 20.679872751235962
throughput: 483.56196966455394
memory_used: 2939.890625
embedding_shape: torch.Size([10000, 512])





In [45]:
from chromadb import Documents, EmbeddingFunction, Embeddings
import torch
from geoclip import LocationEncoder
import numpy as np

class GeoClipEmbeddingFunction(EmbeddingFunction):
    def __init__(self):
        self.gps_encoder = LocationEncoder()

    def __call__(self, input: Documents) -> Embeddings:
        embeddings = []
        invalid_count = 0  # Keep track of invalid documents
        for document in input:
            try:
                lat, lon = map(float, document.strip().split(',')) # Added .strip()
                # Validate latitude and longitude ranges
                if not (-90 <= lat <= 90 and -180 <= lon <= 180):
                    raise ValueError("Latitude and longitude out of range")

                gps_data = torch.tensor([[lat, lon]], dtype=torch.float32) # Explicit dtype
                gps_embedding = self.gps_encoder(gps_data).tolist()[0]
                embeddings.append(gps_embedding)
            except (ValueError, IndexError) as e:
                print(f"Warning: Co  parse document as valid lat,lon: '{document}'. Error: {e}")
                embeddings.append(np.zeros(512).tolist()) # Append a zero vector
                invalid_count += 1

        if invalid_count > 0:
            print(f"Warning: {invalid_count} out of {len(input)} documents could not be parsed as valid lat,lon.")

        return embeddings


# Example usage:
embedding_function = GeoClipEmbeddingFunction()
documents = ["40.7128,-74.0060", "34.0522,-118.2437", "invalid input", "91,-181", "41.8781,-87.6298", "40.7128, -74.0060"] # Added invalid and spaced input
ak,ll in one
embeddings = embedding_function(documents)
print(f"Generated {len(embeddings)} embeddings.")
print(embeddings)

# Example Chroma usage (assuming you have a Chroma client):
import chromadb
client = chromadb.Client()

collection = client.create_collection(name="my_locations", embedding_function=embedding_function)

ids = ["id1", "id2", "id3", "id4", "id5", "id6"]
collection.add(documents=documents, ids=ids)

query_results = collection.query(query_texts=["40.7,-74.0"], n_results=3)
print("\nChroma Query Results:")
print(query_results)

query_results_2 = collection.query(query_texts=["90,0"], n_results=3) # Querying near a pole
print("\nChroma Query Results near North Pole:")
print(query_results_2)

client.delete_collection(name="my_locations") # Clean up

Generated 6 embeddings.
[array([-1.03931492e-02,  7.88416248e-03,  1.30271316e-02,  2.85092555e-03,
       -2.33397912e-03, -1.76571161e-02,  4.16676700e-03, -1.17765460e-03,
       -2.52922997e-03, -8.62812251e-03, -1.32290144e-02,  5.69932163e-05,
       -2.76782829e-03,  4.55948524e-04,  1.40813529e-03,  5.91676682e-04,
        4.80296649e-03, -1.39843225e-02,  1.87701173e-03,  1.86796859e-02,
        2.86985189e-03,  4.59994189e-03, -3.32547911e-03, -1.33848917e-02,
       -5.43905888e-03, -3.39362537e-03, -1.54661015e-03,  2.00032070e-03,
        1.46485260e-02, -6.98821060e-03,  5.97658567e-04, -8.34883936e-03,
        1.63335986e-02,  1.45360790e-02, -7.30384327e-03,  1.18450401e-03,
       -3.68343201e-04,  1.11040249e-02, -9.40784719e-03, -1.67028010e-02,
        7.43302051e-03,  4.18525189e-03, -2.39410438e-03,  6.81754574e-03,
        4.71056625e-03,  2.04296364e-03,  2.32487079e-03, -1.39349718e-02,
       -7.92150479e-03,  7.67707825e-03,  1.26732569e-02, -7.76223699e-03,


In [65]:
from chromadb.utils import embedding_functions
from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction
try:
    from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction
    print("GeoClipEmbeddingFunction imported successfully!")
except ModuleNotFoundError as e:
    print(f"Error importing GeoClipEmbeddingFunction: {e}")
# Instantiate the GeoCLIP embedding function
geoclip_ef = GeoClipEmbeddingFunction()

ModuleNotFoundError: No module named 'chromadb.utils.embedding_functions.geoclip_embedding_function'

In [66]:
from chromadb.utils import embedding_functions
from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction
try:
    from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction
    print("GeoClipEmbeddingFunction imported successfully!")
except ModuleNotFoundError as e:
    print(f"Error importing GeoClipEmbeddingFunction: {e}")
# Instantiate the GeoCLIP embedding function
geoclip_ef = GeoClipEmbeddingFunction()

ModuleNotFoundError: No module named 'chromadb.utils.embedding_functions.geoclip_embedding_function'

In [67]:
import os
import sys

# Create necessary directories if they don't exist
os.makedirs('/content/chroma-geoclip/chromadb/utils/embedding_functions', exist_ok=True)

# Add the parent directory to Python path
sys.path.append('/content/chroma-geoclip')

# Create __init__.py files to mark directories as Python packages
directories = [
    '/content/chroma-geoclip/chromadb',
    '/content/chroma-geoclip/chromadb/utils',
    '/content/chroma-geoclip/chromadb/utils/embedding_functions'
]

for directory in directories:
    init_path = os.path.join(directory, '__init__.py')
    if not os.path.exists(init_path):
        with open(init_path, 'w') as f:
            pass

# Install the package in development mode
os.chdir('/content/chroma-geoclip')
!pip install -e .

Obtaining file:///content/chroma-geoclip
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: chromadb
  Building editable for chromadb (pyproject.toml) ... [?25l[?25hdone
  Created wheel for chromadb: filename=chromadb-0.1.dev2018-0.editable-py3-none-any.whl size=10283 sha256=e478b2df84c317bd2ebd8e21d96cee5e0dd7649240e54eeac1cca669b6b34c06
  Stored in directory: /tmp/pip-ephem-wheel-cache-f424yvx0/wheels/2c/de/86/4269e693e61b70d7754c4ac4a754634b0de125f377ef78160f
Successfully built chromadb
Installing collected packages: chromadb
  Attempting uninstall: chromadb
    Found existing installation: chromadb 0.5.23
    Uninstalling chromadb-0.5.23:
      Successfully uninstalled chromadb-0.5.23
[31mERROR: pip's dependency resolver does not currentl

In [60]:
import sys

# Save the original sys.path
original_sys_path = sys.path.copy()

# Add your custom base path (adjust this to your base path)
custom_base_path = "/content/chroma-geoclip"
if custom_base_path not in sys.path:
    sys.path.append(custom_base_path)

print(f"Custom path added: {custom_base_path}")
print("Updated Python Path:", sys.path)

Custom path added: /content/chroma-geoclip
Updated Python Path: ['/content', '/env/python', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/usr/local/lib/python3.10/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.10/dist-packages/IPython/extensions', '/usr/local/lib/python3.10/dist-packages/setuptools/_vendor', '/root/.ipython', '/tmp/tmpup3qmamd', '/content/geo-clip', '/content/geo-clip', '/content/geo-clip', '/content/geo-clip/geoclip', '/content/geo-clip/geoclip/model', '/content/chroma-geoclip']


In [1]:
import os
import sys

# Set the correct working directory
os.chdir('/content/chroma-geoclip')

# Create the full directory structure
directory_path = 'chromadb/utils/embedding_functions'
os.makedirs(directory_path, exist_ok=True)

# Create __init__.py files in all parent directories
for path in ['chromadb', 'chromadb/utils', directory_path]:
    with open(f'{path}/__init__.py', 'w') as f:
        pass

# Install the package
!pip install -e .

# Add the directory to Python path
sys.path.insert(0, '/content/chroma-geoclip')

Obtaining file:///content/chroma-geoclip
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: chromadb
  Building editable for chromadb (pyproject.toml) ... [?25l[?25hdone
  Created wheel for chromadb: filename=chromadb-0.1.dev2018-0.editable-py3-none-any.whl size=10283 sha256=814400f1039ee974cfc0805fd6d70e87157c8797c2fcf357a9ca23ba5c37dec0
  Stored in directory: /tmp/pip-ephem-wheel-cache-cchsxyij/wheels/2c/de/86/4269e693e61b70d7754c4ac4a754634b0de125f377ef78160f
Successfully built chromadb
Installing collected packages: chromadb
  Attempting uninstall: chromadb
    Found existing installation: chromadb 0.1.dev2018
    Uninstalling chromadb-0.1.dev2018:
      Successfully uninstalled chromadb-0.1.dev2018
[31mERROR: pip's dependency resolver do

In [57]:
!git clone https://github.com/Latticeworks1/chroma-geoclip.git

Cloning into 'chroma-geoclip'...
remote: Enumerating objects: 22514, done.[K
remote: Counting objects: 100% (106/106), done.[K
remote: Compressing objects: 100% (86/86), done.[K
remote: Total 22514 (delta 48), reused 21 (delta 20), pack-reused 22408 (from 3)[K
Receiving objects: 100% (22514/22514), 45.18 MiB | 14.16 MiB/s, done.
Resolving deltas: 100% (14390/14390), done.


In [4]:
import importlib
import logging
from typing import Optional, List
import numpy as np
import torch
from geoclip import LocationEncoder

# Define top-level types for ChromaDB integration
Document = str
Documents = List[str]
Embedding = List[float]
Embeddings = List[List[float]]

logger = logging.getLogger(__name__)

class GeoClipEmbeddingFunction:
    """
    A custom embedding function for geographic coordinates using GeoCLIP.
    Compatible with ChromaDB's embedding interface.
    """

    def __init__(self, model_path: Optional[str] = None) -> None:
        try:
            self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self._encoder = LocationEncoder(model_path) if model_path else LocationEncoder()
            self._encoder.to(self._device)
            self._encoder.eval()
            logger.info(f"GeoCLIP model initialized on {self._device}")
        except Exception as e:
            raise ValueError(f"GeoCLIP initialization failed: {e}")

    def __call__(self, input: Documents) -> Embeddings:
        coordinates_batch = []
        valid_indices = []

        for idx, item in enumerate(input):
            try:
                lat, lon = map(float, item.strip().split(','))
                if -90 <= lat <= 90 and -180 <= lon <= 180:
                    coordinates_batch.append([lat, lon])
                    valid_indices.append(idx)
                else:
                    logger.warning(f"Invalid coordinates: {item}")
            except (ValueError, IndexError) as e:
                logger.warning(f"Parsing error for '{item}': {e}")

        embeddings: Embeddings = [np.zeros(512).tolist()] * len(input)

        if coordinates_batch:
            try:
                coordinates = torch.tensor(coordinates_batch, dtype=torch.float32).to(self._device)
                with torch.no_grad():
                    batch_embeddings = self._encoder(coordinates).cpu().numpy()

                for idx, embedding in zip(valid_indices, batch_embeddings):
                    embeddings[idx] = embedding.tolist()

            except Exception as e:
                logger.error(f"Embedding generation error: {e}")

        return embeddings

    @staticmethod
    def validate_coordinates(coord_str: str) -> bool:
        """Validate coordinate string format and ranges."""
        try:
            lat, lon = map(float, coord_str.strip().split(','))
            return -90 <= lat <= 90 and -180 <= lon <= 180
        except:
            return False

In [14]:
# my_embedding_script.py (Create this file)
from typing import List
from chroma-geoclip import EmbeddingFunction, Embeddings, Client
from sentence_transformers import SentenceTransformer
import numpy as np

class CustomEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model_name: str = "all-mpnet-base-v2"):
        self.model = SentenceTransformer(model_name)

    def __call__(self, input: List[str]) -> Embeddings:
        preprocessed_input = [text.lower() for text in input]
        embeddings = self.model.encode(preprocessed_input)

        normalized_embeddings = []
        for embedding in embeddings:
            norm = np.linalg.norm(embedding)
            if norm != 0:
                normalized_embeddings.append((embedding / norm).tolist())
            else:
                normalized_embeddings.append(embedding.tolist())

        return normalized_embeddings

# Example usage and Chroma integration:
if __name__ == "__main__": # This ensures the code below only runs when the script is executed directly
    embedding_function = CustomEmbeddingFunction()
    documents = ["This is the First document.", "this is the second DOCUMENT!"]

    client = Client() # No persistence for this example. For persistence use chromadb.PersistentClient
    try:
        client.delete_collection("my_collection")
    except:
        pass
    collection = client.create_collection("my_collection", embedding_function=embedding_function)
    collection.add(documents=documents, ids=["id1", "id2"])

    results = collection.query(query_texts=["what is the first document"], n_results=1)
    print("Chroma Results (lowercase query):", results)

    results = collection.query(query_texts=["WHAT IS THE FIRST DOCUMENT?"], n_results=1)
    print("Chroma Results (uppercase query):", results)

    results = collection.query(query_texts=["second"], n_results=1)
    print("Chroma Results (second query):", results)

SyntaxError: invalid syntax (<ipython-input-14-66cf58beee02>, line 3)

In [15]:
import importlib
import logging
from typing import Optional, List
import numpy as np
import torch
from geoclip import LocationEncoder
from chromadb import Client
from chromadb.api.types import Document, Documents, Embedding, Embeddings

# Define logging
logger = logging.getLogger(__name__)

# GeoCLIP Embedding Function
class GeoClipEmbeddingFunction:
    """
    A custom embedding function for geographic coordinates using GeoCLIP.
    Compatible with ChromaDB's embedding interface.
    """
    def __init__(self, model_path: Optional[str] = None) -> None:
        """
        Initialize the GeoCLIP embedding function.

        Args:
            model_path (Optional[str]): Path to custom model weights. If None, use default.
        """
        try:
            self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self._encoder = LocationEncoder(model_path) if model_path else LocationEncoder()
            self._encoder.to(self._device)
            self._encoder.eval()
            logger.info(f"GeoCLIP model initialized on {self._device}")
        except Exception as e:
            raise ValueError(f"GeoCLIP initialization failed: {e}")

    def __call__(self, input: Documents) -> Embeddings:
        """
        Generate embeddings for geographic coordinates.

        Args:
            input (Documents): List of latitude,longitude strings.

        Returns:
            Embeddings: List of generated embeddings.
        """
        coordinates_batch = []
        valid_indices = []

        # Parse and validate input
        for idx, item in enumerate(input):
            try:
                lat, lon = map(float, item.strip().split(','))
                if -90 <= lat <= 90 and -180 <= lon <= 180:
                    coordinates_batch.append([lat, lon])
                    valid_indices.append(idx)
                else:
                    logger.warning(f"Invalid coordinates: {item}")
            except (ValueError, IndexError) as e:
                logger.warning(f"Parsing error for '{item}': {e}")

        # Default embeddings for invalid inputs
        embeddings: Embeddings = [np.zeros(512).tolist()] * len(input)

        if coordinates_batch:
            try:
                coordinates = torch.tensor(coordinates_batch, dtype=torch.float32).to(self._device)
                with torch.no_grad():
                    batch_embeddings = self._encoder(coordinates).cpu().numpy()

                for idx, embedding in zip(valid_indices, batch_embeddings):
                    embeddings[idx] = embedding.tolist()

            except Exception as e:
                logger.error(f"Embedding generation error: {e}")

        return embeddings

    @staticmethod
    def validate_coordinates(coord_str: str) -> bool:
        """
        Validate coordinate string format and ranges.

        Args:
            coord_str (str): Coordinate string in "lat,lon" format.

        Returns:
            bool: True if valid, False otherwise.
        """
        try:
            lat, lon = map(float, coord_str.strip().split(','))
            return -90 <= lat <= 90 and -180 <= lon <= 180
        except:
            return False


# Initialize Chroma Client and GeoCLIP
client = Client()
geo_clip_embedder = GeoClipEmbeddingFunction()

# Create a ChromaDB Collection
collection = client.create_collection(
    name="geo_locations",
    embedding_function=geo_clip_embedder
)

# Example Data
locations = ["40.7128,-74.0060", "34.0522,-118.2437", "51.5074,-0.1278"]  # NYC, LA, London
metadata = [{"city": "NYC"}, {"city": "LA"}, {"city": "London"}]
ids = ["loc1", "loc2", "loc3"]

# Add Data to the Collection
collection.add(documents=locations, metadatas=metadata, ids=ids)

# Query Example
query_results = collection.query(query_texts=["40.7128,-74.0060"], n_results=1)
print("Query Results:", query_results)

ImportError: cannot import name 'Client' from 'chromadb' (/content/chroma-geoclip/chromadb/__init__.py)

In [17]:
%cd /content/chroma-geoclip

/content


In [18]:
import sys
sys.path.append("/content/chroma-geoclip")

In [9]:
try:
    from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction
    print("GeoClipEmbeddingFunction imported successfully!")
except ModuleNotFoundError as e:
    print(f"Error importing GeoClipEmbeddingFunction: {e}")

GeoClipEmbeddingFunction imported successfully!


In [19]:
from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction

In [21]:
import random
from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction
from chromadb import Client

# Initialize GeoCLIP embedding function and ChromaDB client
geo_clip_embedder = GeoClipEmbeddingFunction()
client = Client()

# Create a ChromaDB collection with the GeoCLIP embedding function
geo_collection = client.create_collection(
    name="geo_clip_collection",
    embedding_function=geo_clip_embedder
)

# Add example documents to the collection
example_coords = ["40.7128, -74.0060", "34.0522, -118.2437", "51.5074, -0.1278"]
geo_collection.add(
    documents=example_coords,
    ids=["id1", "id2", "id3"],
)

print("Documents added to the collection!")

# Query the collection
query_results = geo_collection.query(
    query_texts=["40.7128, -74.0060"],  # Query using NYC coordinates
    n_results=2
)

# Display the query results
print("\nQuery Results:")
for result in query_results["documents"]:
    print(result)

# Test with a large dataset of random coordinates
large_coords = [f"{random.uniform(-90, 90):.4f}, {random.uniform(-180, 180):.4f}" for _ in range(10000)]
large_embeddings = geo_clip_embedder(large_coords)

# Save embeddings and coordinates for later use
import torch
output_path = "geo_embeddings.pt"
torch.save({"coordinates": large_coords, "embeddings": large_embeddings}, output_path)

print(f"\nGenerated embeddings for {len(large_coords)} locations.")
print(f"Embeddings saved to {output_path}")

ImportError: cannot import name 'Client' from 'chromadb' (/content/chroma-geoclip/chromadb/__init__.py)

In [22]:
import random
import os
from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction

# Define a simple in-memory embedding store
class SimpleEmbeddingStore:
    def __init__(self):
        self.data = {}

    def add(self, documents, embeddings, ids):
        for doc, emb, id_ in zip(documents, embeddings, ids):
            self.data[id_] = {"document": doc, "embedding": emb}

    def query(self, query_embedding, top_k=5):
        # Perform a simple nearest neighbor search (Euclidean distance)
        results = []
        for id_, record in self.data.items():
            distance = sum((qe - re) ** 2 for qe, re in zip(query_embedding, record["embedding"])) ** 0.5
            results.append((id_, distance, record["document"]))
        results.sort(key=lambda x: x[1])  # Sort by distance
        return results[:top_k]

# Initialize GeoCLIP embedding function and embedding store
geo_clip_embedder = GeoClipEmbeddingFunction()
embedding_store = SimpleEmbeddingStore()

# Add example documents to the embedding store
example_coords = ["40.7128, -74.0060", "34.0522, -118.2437", "51.5074, -0.1278"]
example_ids = ["id1", "id2", "id3"]
embeddings = geo_clip_embedder(example_coords)
embedding_store.add(documents=example_coords, embeddings=embeddings, ids=example_ids)

print("Documents added to the embedding store!")

# Query the embedding store
query_coord = "40.7128, -74.0060"
query_embedding = geo_clip_embedder([query_coord])[0]
query_results = embedding_store.query(query_embedding, top_k=2)

# Display the query results
print("\nQuery Results:")
for result in query_results:
    print(f"ID: {result[0]}, Distance: {result[1]:.4f}, Document: {result[2]}")

# Test with a large dataset of random coordinates
large_coords = [f"{random.uniform(-90, 90):.4f}, {random.uniform(-180, 180):.4f}" for _ in range(10000)]
large_embeddings = geo_clip_embedder(large_coords)

# Save embeddings and coordinates for later use
import torch
output_path = "geo_embeddings.pt"
torch.save({"coordinates": large_coords, "embeddings": large_embeddings}, output_path)

print(f"\nGenerated embeddings for {len(large_coords)} locations.")
print(f"Embeddings saved to {output_path}")

  self.load_state_dict(torch.load(f"{file_dir}/weights/location_encoder_weights.pth"))


Documents added to the embedding store!

Query Results:
ID: id1, Distance: 0.0000, Document: 40.7128, -74.0060
ID: id3, Distance: 0.2311, Document: 51.5074, -0.1278


KeyboardInterrupt: 

In [24]:
import random
import os
from typing import List, Tuple, Dict, Optional, Any
import numpy as np
import torch
from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction

class GeoEmbeddingStore:
    """
    A sophisticated embedding store for geographic coordinates using GeoCLIP embeddings.
    Implements efficient nearest neighbor search and data persistence.
    """

    def __init__(self, batch_size: int = 100):
        """
        Initialize the embedding store.

        Args:
            batch_size: Size of batches for processing large datasets
        """
        self.data: Dict[str, Dict[str, Any]] = {}
        self.embedding_function = GeoClipEmbeddingFunction()
        self.batch_size = batch_size
        self._embeddings_matrix: Optional[np.ndarray] = None
        self._needs_update = True

    def add(self, documents: List[str], embeddings: List[List[float]], ids: List[str]) -> None:
        """
        Add documents and their embeddings to the store.

        Args:
            documents: List of coordinate strings
            embeddings: List of embedding vectors
            ids: List of unique identifiers
        """
        if not (len(documents) == len(embeddings) == len(ids)):
            raise ValueError("Length mismatch between documents, embeddings, and IDs")

        for doc, emb, id_ in zip(documents, embeddings, ids):
            if id_ in self.data:
                print(f"Warning: Overwriting existing ID: {id_}")
            self.data[id_] = {
                "document": doc,
                "embedding": np.array(emb, dtype=np.float32)
            }
        self._needs_update = True

    def _update_embeddings_matrix(self) -> None:
        """Update the internal embeddings matrix for efficient search."""
        if self._needs_update:
            self._embeddings_matrix = np.stack([
                self.data[id_]["embedding"] for id_ in self.data
            ])
            self._needs_update = False

    def query(self, query_embedding: List[float], top_k: int = 5) -> List[Tuple[str, float, str]]:
        """
        Query the store using cosine similarity.

        Args:
            query_embedding: Query vector
            top_k: Number of results to return

        Returns:
            List of (id, distance, document) tuples
        """
        self._update_embeddings_matrix()
        query_vec = np.array(query_embedding, dtype=np.float32)

        # Compute cosine similarity
        query_norm = np.linalg.norm(query_vec)
        if query_norm == 0:
            raise ValueError("Query embedding has zero norm")

        query_vec = query_vec / query_norm
        matrix_norms = np.linalg.norm(self._embeddings_matrix, axis=1)
        normalized_matrix = self._embeddings_matrix / matrix_norms[:, np.newaxis]
        similarities = np.dot(normalized_matrix, query_vec)

        # Convert to distances and get top k
        distances = 1 - similarities
        top_indices = np.argsort(distances)[:top_k]

        results = []
        ids = list(self.data.keys())
        for idx in top_indices:
            id_ = ids[idx]
            results.append((
                id_,
                float(distances[idx]),
                self.data[id_]["document"]
            ))

        return results

    def generate_random_coordinates(self, count: int) -> List[str]:
        """
        Generate random geographic coordinates.

        Args:
            count: Number of coordinates to generate

        Returns:
            List of coordinate strings
        """
        return [
            f"{random.uniform(-90, 90):.4f}, {random.uniform(-180, 180):.4f}"
            for _ in range(count)
        ]

    def process_large_dataset(self, coordinates: List[str], output_path: str) -> None:
        """
        Process and save embeddings for a large dataset in batches.

        Args:
            coordinates: List of coordinate strings
            output_path: Path to save the embeddings
        """
        all_embeddings = []

        for i in range(0, len(coordinates), self.batch_size):
            batch = coordinates[i:i + self.batch_size]
            batch_embeddings = self.embedding_function(batch)
            all_embeddings.extend(batch_embeddings)

            if (i + 1) % 1000 == 0:
                print(f"Processed {i + 1} coordinates...")

        torch.save({
            "coordinates": coordinates,
            "embeddings": all_embeddings
        }, output_path)

        print(f"Saved {len(coordinates)} embeddings to {output_path}")

    def load_embeddings(self, path: str) -> None:
        """
        Load previously saved embeddings.

        Args:
            path: Path to the saved embeddings file
        """
        if not os.path.exists(path):
            raise FileNotFoundError(f"No embeddings file found at {path}")

        data = torch.load(path)
        embeddings = data["embeddings"]
        coordinates = data["coordinates"]
        ids = [f"id_{i}" for i in range(len(coordinates))]

        self.add(coordinates, embeddings, ids)
        print(f"Loaded {len(coordinates)} embeddings from {path}")

# Usage example
if __name__ == "__main__":
    # Initialize the store
    store = GeoEmbeddingStore()

    # Add example data
    example_coords = [
        "40.7128, -74.0060",  # New York
        "34.0522, -118.2437", # Los Angeles
        "51.5074, -0.1278"    # London
    ]
    example_ids = ["nyc", "la", "london"]
    embeddings = store.embedding_function(example_coords)
    store.add(example_coords, embeddings, example_ids)

    # Query example
    query_coord = "40.7128, -74.0060"  # New York
    query_embedding = store.embedding_function([query_coord])[0]
    results = store.query(query_embedding, top_k=2)

    print("\nQuery Results:")
    for id_, distance, doc in results:
        print(f"ID: {id_}, Distance: {distance:.4f}, Coordinates: {doc}")

    # Generate and process large dataset
    large_coords = store.generate_random_coordinates(10000)
    store.process_large_dataset(large_coords, "geo_embeddings.pt")


Query Results:
ID: nyc, Distance: 0.0000, Coordinates: 40.7128, -74.0060
ID: london, Distance: 0.5433, Coordinates: 51.5074, -0.1278
Saved 10000 embeddings to geo_embeddings.pt


In [28]:
import logging
from typing import List, Tuple, Dict, Union
from chromadb.api.types import EmbeddingFunction, Embeddings
import torch
from geoclip import LocationEncoder

logger = logging.getLogger(__name__)

# Type alias for clarity
LocationType = Union[Tuple[float, float], Dict[str, float]]
Locations = List[LocationType]

class GeoCLIPEmbeddingFunction(EmbeddingFunction):
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        try:
            self.geoclip_model = LocationEncoder().to(self.device)
            logger.info(f"GeoCLIP model loaded on device: {self.device}")
        except Exception as e:
            logger.exception(f"Error initializing GeoCLIP model: {e}")
            raise  # Re-raise the exception to stop execution

    def __call__(self, locations: Locations) -> Embeddings:
        if not locations:
            logger.warning("Input locations list is empty. Returning empty embeddings.")
            return []

        try:
            gps_data = []
            for location in locations:
                if isinstance(location, tuple) and len(location) == 2:
                    lat, lon = location
                elif isinstance(location, dict) and "latitude" in location and "longitude" in location:
                    lat = location["latitude"]
                    lon = location["longitude"]
                else:
                    logger.error(f"Invalid location format: {location}. Expected tuple (lat, lon) or dict {{'latitude': ..., 'longitude': ...}}")
                    return [] #return empty if one is bad

                gps_data.append([lat, lon])

            gps_tensor = torch.tensor(gps_data, dtype=torch.float32).to(self.device)
            with torch.no_grad(): # Add no_grad for inference
                gps_embeddings = self.geoclip_model(gps_tensor).cpu().tolist() # Move back to CPU for Chroma

            return gps_embeddings

        except Exception as e:
            logger.exception(f"Error embedding locations: {e}")
            return []

# Example Usage:
try:
    embedding_function = GeoCLIPEmbeddingFunction()
    locations_tuples: Locations = [(40.7128, -74.0060), (34.0522, -118.2437)]
    embeddings_tuples = embedding_function(locations_tuples)
    print(embeddings_tuples)

    locations_dicts: List[Dict[str, float]] = [{"latitude": 40.7128, "longitude": -74.0060}, {"latitude": 34.0522, "longitude": -118.2437}]
    embeddings_dicts = embedding_function(locations_dicts)
    print(embeddings_dicts)

    locations_mixed: List[object] = [(40.7128, -74.0060), {"latitude": 34.0522, "longitude": -118.2437}, "not a location"]
    embeddings_mixed = embedding_function(locations_mixed)
    print(embeddings_mixed)

    locations_empty: Locations = []
    embeddings_empty = embedding_function(locations_empty)
    print(embeddings_empty)
except Exception as e:
    print(f"Error in example usage: {e}")

ERROR:__main__:Invalid location format: not a location. Expected tuple (lat, lon) or dict {'latitude': ..., 'longitude': ...}


[array([-1.03931390e-02,  7.88415316e-03,  1.30271409e-02,  2.85094045e-03,
       -2.33398378e-03, -1.76571198e-02,  4.16677073e-03, -1.17765041e-03,
       -2.52924487e-03, -8.62813089e-03, -1.32290125e-02,  5.69950789e-05,
       -2.76783574e-03,  4.55952249e-04,  1.40813971e-03,  5.91674820e-04,
        4.80296277e-03, -1.39843319e-02,  1.87700614e-03,  1.86796896e-02,
        2.86985375e-03,  4.59994189e-03, -3.32548842e-03, -1.33848898e-02,
       -5.43905143e-03, -3.39362584e-03, -1.54660083e-03,  2.00032815e-03,
        1.46485250e-02, -6.98821619e-03,  5.97655773e-04, -8.34884122e-03,
        1.63335986e-02,  1.45360734e-02, -7.30384886e-03,  1.18451007e-03,
       -3.68347624e-04,  1.11040147e-02, -9.40785278e-03, -1.67028010e-02,
        7.43301585e-03,  4.18525003e-03, -2.39410065e-03,  6.81754574e-03,
        4.71056206e-03,  2.04297062e-03,  2.32487638e-03, -1.39349690e-02,
       -7.92151038e-03,  7.67708477e-03,  1.26732457e-02, -7.76224770e-03,
        2.04075854e-02, 

In [25]:
import random
import os
from typing import List, Tuple, Dict, Optional, Any
import numpy as np
import requests
import torch
from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction

class GeoEmbeddingStore:
    """
    A sophisticated embedding store for geographic coordinates using GeoCLIP embeddings.
    Implements efficient nearest neighbor search and data persistence.
    """

    def __init__(self, batch_size: int = 100):
        """
        Initialize the embedding store.

        Args:
            batch_size: Size of batches for processing large datasets
        """
        self.data: Dict[str, Dict[str, Any]] = {}
        self.embedding_function = GeoClipEmbeddingFunction()
        self.batch_size = batch_size
        self._embeddings_matrix: Optional[np.ndarray] = None
        self._needs_update = True

    def add(self, documents: List[str], embeddings: List[List[float]], ids: List[str]) -> None:
        """
        Add documents and their embeddings to the store.

        Args:
            documents: List of coordinate strings
            embeddings: List of embedding vectors
            ids: List of unique identifiers
        """
        if not (len(documents) == len(embeddings) == len(ids)):
            raise ValueError("Length mismatch between documents, embeddings, and IDs")

        for doc, emb, id_ in zip(documents, embeddings, ids):
            if id_ in self.data:
                print(f"Warning: Overwriting existing ID: {id_}")
            self.data[id_] = {
                "document": doc,
                "embedding": np.array(emb, dtype=np.float32)
            }
        self._needs_update = True

    def _update_embeddings_matrix(self) -> None:
        """Update the internal embeddings matrix for efficient search."""
        if self._needs_update:
            self._embeddings_matrix = np.stack([
                self.data[id_]["embedding"] for id_ in self.data
            ])
            self._needs_update = False

    def query(self, query_embedding: List[float], top_k: int = 5) -> List[Tuple[str, float, str]]:
        """
        Query the store using cosine similarity.

        Args:
            query_embedding: Query vector
            top_k: Number of results to return

        Returns:
            List of (id, distance, document) tuples
        """
        self._update_embeddings_matrix()
        query_vec = np.array(query_embedding, dtype=np.float32)

        # Compute cosine similarity
        query_norm = np.linalg.norm(query_vec)
        if query_norm == 0:
            raise ValueError("Query embedding has zero norm")

        query_vec = query_vec / query_norm
        matrix_norms = np.linalg.norm(self._embeddings_matrix, axis=1)
        normalized_matrix = self._embeddings_matrix / matrix_norms[:, np.newaxis]
        similarities = np.dot(normalized_matrix, query_vec)

        # Convert to distances and get top k
        distances = 1 - similarities
        top_indices = np.argsort(distances)[:top_k]

        results = []
        ids = list(self.data.keys())
        for idx in top_indices:
            id_ = ids[idx]
            results.append((
                id_,
                float(distances[idx]),
                self.data[id_]["document"]
            ))

        return results

    def get_coordinates_from_api(self, city_name: str) -> Optional[str]:
        """
        Retrieve coordinates from the OpenCage Geocoder API.

        Args:
            city_name: Name of the city

        Returns:
            Coordinate string "latitude, longitude" or None if not found.
        """
        try:
            url = "https://nominatim.openstreetmap.org/search"
            params = {
                "q": city_name,
                "format": "json",
                "addressdetails": 0,
                "limit": 1,
            }
            response = requests.get(url, params=params)
            response.raise_for_status()
            data = response.json()
            if data:
                lat = data[0]["lat"]
                lon = data[0]["lon"]
                return f"{lat}, {lon}"
            else:
                print(f"Warning: Could not find coordinates for {city_name}.")
                return None
        except Exception as e:
            print(f"Error fetching coordinates for {city_name}: {e}")
            return None

    def process_large_dataset(self, coordinates: List[str], output_path: str) -> None:
        """
        Process and save embeddings for a large dataset in batches.

        Args:
            coordinates: List of coordinate strings
            output_path: Path to save the embeddings
        """
        all_embeddings = []

        for i in range(0, len(coordinates), self.batch_size):
            batch = coordinates[i:i + self.batch_size]
            batch_embeddings = self.embedding_function(batch)
            all_embeddings.extend(batch_embeddings)

            if (i + 1) % 1000 == 0:
                print(f"Processed {i + 1} coordinates...")

        torch.save({
            "coordinates": coordinates,
            "embeddings": all_embeddings
        }, output_path)

        print(f"Saved {len(coordinates)} embeddings to {output_path}")

# Usage example
if __name__ == "__main__":
    store = GeoEmbeddingStore()

    # Get coordinates for cities using API
    cities = ["New York", "Los Angeles", "London", "Tokyo", "Paris", "Mumbai", "Moscow", "Sydney", "Cape Town", "Dubai"]
    example_coords = []
    example_ids = []

    for city in cities:
        coords = store.get_coordinates_from_api(city)
        if coords:
            example_coords.append(coords)
            example_ids.append(city.lower().replace(" ", "_"))

    embeddings = store.embedding_function(example_coords)
    store.add(example_coords, embeddings, example_ids)
    print("City coordinates and embeddings added to the store.")

    # Query example
    query_coord = "40.7128, -74.0060"  # New York
    query_embedding = store.embedding_function([query_coord])[0]
    results = store.query(query_embedding, top_k=3)

    print("\nQuery Results:")
    for id_, distance, doc in results:
        print(f"ID: {id_}, Distance: {distance:.4f}, Coordinates: {doc}")

Error fetching coordinates for New York: 403 Client Error: Forbidden for url: https://nominatim.openstreetmap.org/search?q=New+York&format=json&addressdetails=0&limit=1
Error fetching coordinates for Los Angeles: 403 Client Error: Forbidden for url: https://nominatim.openstreetmap.org/search?q=Los+Angeles&format=json&addressdetails=0&limit=1
Error fetching coordinates for London: 403 Client Error: Forbidden for url: https://nominatim.openstreetmap.org/search?q=London&format=json&addressdetails=0&limit=1
Error fetching coordinates for Tokyo: 403 Client Error: Forbidden for url: https://nominatim.openstreetmap.org/search?q=Tokyo&format=json&addressdetails=0&limit=1
Error fetching coordinates for Paris: 403 Client Error: Forbidden for url: https://nominatim.openstreetmap.org/search?q=Paris&format=json&addressdetails=0&limit=1
Error fetching coordinates for Mumbai: 403 Client Error: Forbidden for url: https://nominatim.openstreetmap.org/search?q=Mumbai&format=json&addressdetails=0&limit=1


ValueError: Expected Embedings to be non-empty list or numpy array, got []

In [31]:
import logging
from typing import List, Tuple, Dict, Union, cast

from chromadb.api.types import EmbeddingFunction, Embeddings

import torch
from geoclip import LocationEncoder

logger = logging.getLogger(__name__)

# Type alias for clarity
LocationType = Union[Tuple[float, float], Dict[str, float]]
Locations = List[LocationType]

class GeoCLIPEmbeddingFunction(EmbeddingFunction[Locations]):
    def __init__(self):
        try:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.geoclip_model = LocationEncoder().to(self.device)
            logger.info(f"GeoCLIP model loaded on device: {self.device}")
        except ImportError:
            raise ValueError(
                "The geoclip python package is not installed. Please install it with `pip install geoclip`"
            )
        except Exception as e:
            logger.exception(f"Error initializing GeoCLIP model: {e}")
            raise  # Re-raise the exception to stop execution

    def __call__(self, input: Locations) -> Embeddings:
        if not input:
            logger.warning("Input locations list is empty. Returning empty embeddings.")
            return cast(Embeddings, [])

        try:
            gps_data = []
            for location in input:
                if isinstance(location, tuple) and len(location) == 2:
                    lat, lon = location
                elif isinstance(location, dict) and "latitude" in location and "longitude" in location:
                    lat = location["latitude"]
                    lon = location["longitude"]
                else:
                    logger.error(f"Invalid location format: {location}. Expected tuple (lat, lon) or dict {{'latitude': ..., 'longitude': ...}}")
                    return cast(Embeddings, [])

                gps_data.append([lat, lon])

            gps_tensor = torch.tensor(gps_data, dtype=torch.float32).to(self.device)
            with torch.no_grad():
                embeddings = self.geoclip_model(gps_tensor).cpu().tolist()

            return cast(Embeddings, embeddings)

        except Exception as e:
            logger.exception(f"Error embedding locations: {e}")
            return cast(Embeddings, [])

In [33]:
import logging
from typing import List, Tuple, Dict, Union, cast, Optional
from dataclasses import dataclass

import torch
from chromadb.api.types import EmbeddingFunction, Embeddings
from geoclip.model import LocationEncoder

logger = logging.getLogger(__name__)

LocationType = Union[Tuple[float, float], Dict[str, float]]
Locations = List[LocationType]

@dataclass
class GeoCLIPConfig:
    batch_size: int = 32
    validate_coordinates: bool = True
    model_path: Optional[str] = None

class GeoCLIPEmbeddingFunction(EmbeddingFunction[Locations]):
    def __init__(self, config: Optional[GeoCLIPConfig] = None) -> None:
        self.config = config or GeoCLIPConfig()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.geoclip_model = LocationEncoder(self.config.model_path)
        self.geoclip_model.to(self.device).eval()

    def _process_location(self, location: LocationType) -> Optional[List[float]]:
        if isinstance(location, tuple) and len(location) == 2:
            lat, lon = location
        elif isinstance(location, dict) and "latitude" in location and "longitude" in location:
            lat, lon = float(location["latitude"]), float(location["longitude"])
        else:
            return None

        if self.config.validate_coordinates and not (-90 <= lat <= 90 and -180 <= lon <= 180):
            return None

        return [lat, lon]

    def __call__(self, input: Locations) -> Embeddings:
        if not input:
            return cast(Embeddings, [])

        locations = []
        for location in input:
            coords = self._process_location(location)
            if coords:
                locations.append(coords)

        if not locations:
            return cast(Embeddings, [])

        try:
            locations_tensor = torch.tensor(locations, dtype=torch.float32).to(self.device)
            with torch.no_grad():
                embeddings = self.geoclip_model(locations_tensor).cpu().tolist()
            return cast(Embeddings, embeddings)
        except Exception as e:
            logger.error(f"Embedding generation failed: {str(e)}")
            return cast(Embeddings, [])

In [34]:
import chromadb
from chromadb.utils.embedding_functions import EmbeddingFunction
from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoCLIPEmbeddingFunction

client = chromadb.PersistentClient()
embedding_function = GeoCLIPEmbeddingFunction()

collection = client.get_or_create_collection(
    name="locations",
    embedding_function=embedding_function
)

collection.add(
    documents=["40.7128,-74.0060", "51.5074,-0.1278"],
    ids=["nyc", "london"]
)

results = collection.query(
    query_texts=["40.7128,-74.0060"],
    n_results=2
)

ImportError: cannot import name 'EmbeddingFunction' from 'chromadb.utils.embedding_functions' (/content/chroma-geoclip/chromadb/utils/embedding_functions/__init__.py)

In [32]:
import logging
from typing import List, Tuple, Dict, Union, cast, Optional
from dataclasses import dataclass

import torch
from chromadb.api.types import EmbeddingFunction, Embeddings
from geoclip.model import LocationEncoder

logger = logging.getLogger(__name__)

LocationType = Union[Tuple[float, float], Dict[str, float]]
Locations = List[LocationType]

@dataclass
class GeoCLIPConfig:
    batch_size: int = 32
    embedding_dimension: int = 512
    validate_coordinates: bool = True
    model_path: Optional[str] = None

class GeoCLIPEmbeddingFunction(EmbeddingFunction[Locations]):
    def __init__(self, config: Optional[GeoCLIPConfig] = None):
        self.config = config or GeoCLIPConfig()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.geoclip_model = LocationEncoder(self.config.model_path)
        self.geoclip_model.to(self.device).eval()

    def _validate_coordinates(self, lat: float, lon: float) -> bool:
        return -90 <= lat <= 90 and -180 <= lon <= 180

    def _extract_coordinates(self, location: LocationType) -> Optional[Tuple[float, float]]:
        if isinstance(location, tuple) and len(location) == 2:
            lat, lon = location
        elif isinstance(location, dict) and "latitude" in location and "longitude" in location:
            lat = float(location["latitude"])
            lon = float(location["longitude"])
        else:
            return None

        if self.config.validate_coordinates and not self._validate_coordinates(lat, lon):
            return None

        return (lat, lon)

    def _process_batch(self, batch: List[List[float]]) -> List[List[float]]:
        gps_tensor = torch.tensor(batch, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            embeddings = self.geoclip_model(gps_tensor).cpu().tolist()
        return embeddings

    def __call__(self, input: Locations) -> Embeddings:
        if not input:
            return cast(Embeddings, [])

        coordinates_batch = []
        valid_indices = []
        result_embeddings = [[0.0] * self.config.embedding_dimension] * len(input)

        for idx, location in enumerate(input):
            coords = self._extract_coordinates(location)
            if coords:
                coordinates_batch.append([coords[0], coords[1]])
                valid_indices.append(idx)

        if not coordinates_batch:
            return cast(Embeddings, result_embeddings)

        for batch_start in range(0, len(coordinates_batch), self.config.batch_size):
            batch_end = batch_start + self.config.batch_size
            batch = coordinates_batch[batch_start:batch_end]
            batch_indices = valid_indices[batch_start:batch_end]

            batch_embeddings = self._process_batch(batch)

            for idx, embedding in zip(batch_indices, batch_embeddings):
                result_embeddings[idx] = embedding

        return cast(Embeddings, result_embeddings)

    def dimension(self) -> int:
        return self.config.embedding_dimension

In [20]:
try:
    from chromadb.utils.embedding_functions.geoclip_embedding_function import GeoClipEmbeddingFunction
    print("GeoClipEmbeddingFunction imported successfully!")
except ImportError as e:
    print(f"Error importing GeoClipEmbeddingFunction: {e}")

GeoClipEmbeddingFunction imported successfully!


In [30]:
import logging
import sys

# Add the parent directory to the path so you can import the module correctly
sys.path.append("..")  # Adjust ".." if your directory structure is different
from geoclip_chroma.geoclip_embedding import GeoCLIPEmbeddingFunction

# Configure logging (optional but recommended)
logging.basicConfig(level=logging.INFO)

try:
    embedding_function = GeoCLIPEmbeddingFunction()

    locations_tuples = [(40.7128, -74.0060), (34.0522, -118.2437)]  # NYC, LA
    embeddings_tuples = embedding_function(locations_tuples)
    print("Tuple Embeddings:", embeddings_tuples)

    locations_dicts = [
        {"latitude": 40.7128, "longitude": -74.0060},
        {"latitude": 34.0522, "longitude": -118.2437},
    ]
    embeddings_dicts = embedding_function(locations_dicts)
    print("Dict Embeddings:", embeddings_dicts)

    locations_mixed = [(40.7128, -74.0060), {"latitude": 34.0522, "longitude": -118.2437}, "not a location"]
    embeddings_mixed = embedding_function(locations_mixed)
    print("Mixed Embeddings:", embeddings_mixed)

    locations_empty = []
    embeddings_empty = embedding_function(locations_empty)
    print("Empty Embeddings:", embeddings_empty)

except Exception as e:
    logging.exception(f"Error in example usage: {e}")

ModuleNotFoundError: No module named 'geoclip_chroma'

In [62]:
import sys

base_path = "/content/chroma-geoclip"  # Adjust this to your project directory
if base_path not in sys.path:
    sys.path.append(base_path)

In [46]:
import torch
import clip
from geoclip import LocationEncoder
from PIL import Image
import numpy as np
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

# Load GeoCLIP encoder
try:
    geoclip_encoder = LocationEncoder()
except Exception as e:
    logger.error(f"Failed to initialize GeoCLIP: {e}")
    raise

def get_image_embedding(image_path):
    """Gets the CLIP embedding for an image."""
    try:
        image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
        with torch.no_grad():
            image_embedding = clip_model.encode_image(image).cpu().numpy().flatten()
        return image_embedding
    except Exception as e:
        logger.error(f"Error processing image {image_path}: {e}")
        return None

def get_geographic_embedding(latitude, longitude):
    """Gets the GeoCLIP embedding for geographic coordinates."""
    try:
        gps_data = torch.tensor([[latitude, longitude]], dtype=torch.float32)
        geo_embedding = geoclip_encoder(gps_data).squeeze().tolist()
        return geo_embedding
    except Exception as e:
        logger.error(f"Error processing geographic coordinates: {e}")
        return None

# Example Usage
image_path = "path/to/your/image.jpg"  # Replace with your image path
latitude = 37.7749  # Example latitude
longitude = -122.4194  # Example longitude

image_embedding = get_image_embedding(image_path)
geo_embedding = get_geographic_embedding(latitude, longitude)

if image_embedding is not None and geo_embedding is not None:
    # Concatenate the embeddings (example fusion)
    combined_embedding = np.concatenate([image_embedding, geo_embedding])
    print("Image Embedding Shape:", image_embedding.shape)
    print("Geographic Embedding Shape:", np.array(geo_embedding).shape)
    print("Combined Embedding Shape:", combined_embedding.shape)

    # Now you can use combined_embedding for similarity search, etc.

elif image_embedding is None:
    print("Could not generate image embedding")
elif geo_embedding is None:
    print("Could not generate geographic embedding")

ModuleNotFoundError: No module named 'clip'

In [43]:
!pip install geoclip



In [47]:
!git clone https://github.com/VicenteVivan/geo-clip.git

fatal: destination path 'geo-clip' already exists and is not an empty directory.


In [48]:
import sys
sys.path.append('/content/geo-clip')

In [55]:
import os
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from image_encoder import ImageEncoder
from location_encoder import LocationEncoder
from misc import load_gps_data
from PIL import Image

class GeoCLIP(nn.Module):
    def __init__(self, from_pretrained=True, queue_size=4096):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.image_encoder = ImageEncoder()
        self.location_encoder = LocationEncoder()

        # Set up paths relative to the current file location
        current_dir = os.path.dirname(os.path.abspath(__file__))
        self.gps_gallery = load_gps_data(os.path.join(current_dir, "gps_gallery", "coordinates_100K.csv"))
        self._initialize_gps_queue(queue_size)

        if from_pretrained:
            self.weights_folder = os.path.join(current_dir, "weights")
            self._load_weights()

        # Initialize device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    def to(self, device):
        self.device = device
        self.image_encoder.to(device)
        self.location_encoder.to(device)
        self.logit_scale.data = self.logit_scale.data.to(device)
        if hasattr(self, 'gps_gallery'):
            self.gps_gallery = self.gps_gallery.to(device)
        return super().to(device)

    def _load_weights(self):
        try:
            self.image_encoder.mlp.load_state_dict(
                torch.load(os.path.join(self.weights_folder, "image_encoder_mlp_weights.pth"),
                          map_location=self.device)
            )
            self.location_encoder.load_state_dict(
                torch.load(os.path.join(self.weights_folder, "location_encoder_weights.pth"),
                          map_location=self.device)
            )
            self.logit_scale = nn.Parameter(
                torch.load(os.path.join(self.weights_folder, "logit_scale_weights.pth"),
                          map_location=self.device)
            )
        except Exception as e:
            print(f"Error loading weights: {e}")

    def _initialize_gps_queue(self, queue_size):
        self.queue_size = queue_size
        self.register_buffer("gps_queue", torch.randn(2, self.queue_size))
        self.gps_queue = nn.functional.normalize(self.gps_queue, dim=0)
        self.register_buffer("gps_queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def dequeue_and_enqueue(self, gps):
        gps_batch_size = gps.shape[0]
        gps_ptr = int(self.gps_queue_ptr)

        assert self.queue_size % gps_batch_size == 0

        self.gps_queue[:, gps_ptr:gps_ptr + gps_batch_size] = gps.t()
        gps_ptr = (gps_ptr + gps_batch_size) % self.queue_size
        self.gps_queue_ptr[0] = gps_ptr

    def get_gps_queue(self):
        return self.gps_queue.t()

    def forward(self, image, location):
        # Ensure inputs are on the correct device
        image = image.to(self.device)
        location = location.to(self.device)

        # Compute Features
        image_features = self.image_encoder(image)
        location_features = self.location_encoder(location)
        logit_scale = self.logit_scale.exp()

        # Normalize features
        image_features = F.normalize(image_features, dim=1)
        location_features = F.normalize(location_features, dim=1)

        # Compute similarity
        logits_per_image = logit_scale * (image_features @ location_features.t())

        return logits_per_image

    @torch.no_grad()
    def predict(self, image_path, top_k):
        try:
            # Load and preprocess image
            image = Image.open(image_path)
            image = self.image_encoder.preprocess_image(image)
            image = image.to(self.device)

            # Move GPS gallery to device
            gps_gallery = self.gps_gallery.to(self.device)

            # Get predictions
            logits_per_image = self.forward(image, gps_gallery)
            probs_per_image = logits_per_image.softmax(dim=-1).cpu()

            # Get top k predictions
            top_pred = torch.topk(probs_per_image, top_k, dim=1)
            top_pred_gps = self.gps_gallery[top_pred.indices[0]]
            top_pred_prob = top_pred.values[0]

            return top_pred_gps, top_pred_prob

        except Exception as e:
            print
            (f"Error during prediction: {e}")
            raise

ImportError: attempted relative import with no known parent package

In [50]:
!pip install -r /content/geo-clip/requirements.txt



In [56]:
import os
directories = ['model', 'model/rff', 'model/gps_gallery']
for directory in directories:
   init_path = os.path.join(directory, '__init__.py')
   if not os.path.exists(init_path):
       with open(init_path, 'w') as f:
           pass

with open('model/location_encoder.py', 'w') as f:
   f.write('''import torch
import torch.nn as nn
from model.rff import GaussianEncoding
from model.misc import file_dir''')

!pip install -e .

FileNotFoundError: [Errno 2] No such file or directory: 'model/__init__.py'