In [30]:
from openai import embeddings, OpenAI
from dotenv import load_dotenv

load_dotenv()
nl_reqs = open("dataset/dataset.train.nl").read().splitlines()

print(len(nl_reqs))

6352


In [31]:
from tqdm import tqdm
import os

oai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE"))

batch_size = 150
res = []
for i in tqdm(range(0, len(nl_reqs), batch_size)):
    embeds = oai.embeddings.create(input=nl_reqs[i:i+batch_size], model="text-embedding-3-small")
    res.append(embeds)

100%|███████████████████████████████████████████████████████████████████| 43/43 [01:45<00:00,  2.45s/it]
IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [17]:
len(res[0].data)

150

In [32]:
embeds = [emb for batch in res for emb in batch.data]
len(embeds)

6352

In [33]:
len(embeds[0].embedding)

1536

In [34]:
import pickle

pickle.dump(embeds, open("embed_nl_backup.pkl", "wb"))

In [28]:
from pinecone import Pinecone, ServerlessSpec

pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
pc.create_index(
    name="osm-queries",
    dimension=1536, 
    spec=ServerlessSpec(
        cloud="aws",
        region="us-east-1"
    ) 
)

In [39]:
queries = open("dataset/dataset.train.query").read().splitlines()
assert len(queries) == len(embeds)
vectors = [{"id":str(i), "values": embeds[i].embedding, "metadata":{
    "query": queries[i], "nl": nl_reqs[i]
}} for i in range(len(queries))]
index = pc.Index("osm-queries")
upsert_batch_sz = 90
for i in tqdm(range(0,len(queries), upsert_batch_sz)):
    index.upsert(vectors[i:i+upsert_batch_sz])

100%|███████████████████████████████████████████████████████████████████| 71/71 [00:47<00:00,  1.51it/s]


In [52]:
nl  = input("describe your query in natural language:")
emb = oai.embeddings.create(input=nl, model="text-embedding-3-small")
sims = index.query(vector=emb.data[0].embedding, top_k=5, include_metadata=True)
for s in sims['matches']:
    print("----------------")
    print(s['metadata']['nl'])
    print(s['metadata']['query'])

describe your query in natural language: convenience stores next to railroads in germany


----------------
fuel stations directly at the motorway in Germany
[out:json][timeout:900];{{geocodeArea:"Deutschland"}}->.searchArea;(node["amenity"="fuel"]["atmotorway"="yes"](area.searchArea);way["amenity"="fuel"]["atmotorway"="yes"](area.searchArea);relation["amenity"="fuel"]["atmotorway"="yes"](area.searchArea););out;>;out skel qt;
----------------
service roads in Germany
[out:json][timeout:25];{{geocodeArea:"Deutschland"}}->.searchArea;(node["highway"="services"](area.searchArea);way["highway"="services"](area.searchArea);relation["highway"="services"](area.searchArea););out;>;out skel qt;
----------------
McDonald's around highways in Germany
[out:json][timeout:1000];{{geocodeArea:"germany"}}->.searchArea;(node["name"="McDonald's"](area.searchArea);way["name"="McDonald's"](area.searchArea);)->.mcd;way["highway"="motorway"](around.mcd:100)(area.searchArea);out;>;out skel qt;
----------------
convenience stores in the selected window
[out:json][timeout:25];(node["shop"="convenien

In [51]:
%time index.query(vector=emb.data[0].embedding, top_k=5, include_metadata=True)

CPU times: user 49.8 ms, sys: 16.9 ms, total: 66.7 ms
Wall time: 1.31 s


{'matches': [{'id': '3037',
              'metadata': {'nl': 'Subway entrances in San Francisco',
                           'query': '[out:json][timeout:500];(way["railway"="subway_entrance"](37.66507,-122.598983,37.861302,-122.211227);node["railway"="subway_entrance"](37.66507,-122.598983,37.861302,-122.211227););out;>;out '
                                    'skel qt;'},
              'score': 0.590370417,
              'values': []},
             {'id': '3665',
              'metadata': {'nl': 'coffee shops in current view',
                           'query': '[out:json][timeout:25];(node["shop"="coffee"]({{bbox}});way["shop"="coffee"]({{bbox}});relation["shop"="coffee"]({{bbox}}););out;>;out '
                                    'skel qt;'},
              'score': 0.569064796,
              'values': []},
             {'id': '4255',
              'metadata': {'nl': 'cafe with coffee shop in the current view',
                           'query': '(node["amenity"="cafe"]["cuisine"