In [4]:
!pip install chromadb


Collecting chromadb
  Downloading chromadb-0.5.20-py3-none-any.whl.metadata (6.8 kB)
Collecting build>=1.0.3 (from chromadb)
  Downloading build-1.2.2.post1-py3-none-any.whl.metadata (6.5 kB)
Collecting chroma-hnswlib==0.7.6 (from chromadb)
  Downloading chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (252 bytes)
Collecting fastapi>=0.95.2 (from chromadb)
  Downloading fastapi-0.115.5-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn>=0.18.3 (from uvicorn[standard]>=0.18.3->chromadb)
  Downloading uvicorn-0.32.1-py3-none-any.whl.metadata (6.6 kB)
Collecting posthog>=2.4.0 (from chromadb)
  Downloading posthog-3.7.4-py2.py3-none-any.whl.metadata (2.0 kB)
Collecting onnxruntime>=1.14.1 (from chromadb)
  Downloading onnxruntime-1.20.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)
  Downloading opentelemetry_exporter_otlp_proto_grpc-1.28.2-py3

In [5]:
import chromadb
from transformers import AutoTokenizer, AutoModel
import torch

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [122]:
db_path_py = "/content/drive/MyDrive/Chroma DB/Desc_Falcon/Chroma_DB_Py_Desc"

In [123]:
client_py = chromadb.PersistentClient(path=db_path_py)

In [124]:
collections = client_py.list_collections()


In [125]:
collections

[Collection(name=Java-VectorDB-Desc-Falcon)]

In [29]:
collection_py = client_py.get_collection(name="Python-VectorDB-Desc-Falcon")

In [42]:
tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/falcon-rw-1b")
model = AutoModel.from_pretrained("Rocketknight1/falcon-rw-1b")

In [48]:
def get_embeddings(texts,model, tokenizer):

    # Tokenize input
    inputs = tokenizer([texts], return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    # Mean pooling to get embeddings
    embeddings = outputs.hidden_states[-1].float().mean(dim=1).cpu().numpy()
    return embeddings

In [116]:
# Update the function to return results sorted by similarity
def get_similar_code_by_description(description, collection, model, tokenizer,db_path,collection_name ,top_n=5):
    try:

        db_path = db_path
        client = chromadb.PersistentClient(path=db_path)
        #collections = client.list_collections()
        collection = client.get_collection(name=collection_name)

        embedding = get_embeddings(description, model, tokenizer)

        embedding_list = embedding.flatten().tolist()

        # Query the collection
        results = collection.query(
            query_embeddings=[embedding_list],
            n_results=top_n
        )
        print("Query Results:", results)

        similar_code_snippets = []
        for item_id, document, distance in zip(
            results["ids"][0],
            results["documents"][0],
            results["distances"][0]
        ):
            similar_code_snippets.append({
                "id": item_id,
                "description": document,
                "distance": distance
            })

        similar_code_snippets.sort(key=lambda x: x["distance"])

        for snippet in similar_code_snippets:
            print(f"ID: {snippet['id']}")
            print(f"Description: {snippet['description']}")
            print(f"Distance: {snippet['distance']}")
            print("-" * 50)

        return similar_code_snippets

    except Exception as e:
        print(f"An error occurred: {e}")
        return []


In [112]:
# Example description to query
description = "A function that calculates the sum of two numbers"

In [126]:
# Call the function
similar_snippets = get_similar_code_by_description(description, collection_py, model, tokenizer,"/content/drive/MyDrive/Chroma DB/Desc_Falcon/Chroma_DB_Py_Desc","Python-VectorDB-Desc-Falcon",1038)

Query Results: {'ids': [['448', '299', '656', '335', '397', '491', '573', '952', '924', '875', '324', '63', '320', '402', '614', '970', '672', '133', '227', '248', '704', '796', '738', '925', '137', '295', '951', '144', '717', '870', '404', '309', '345', '77', '798', '308', '481', '398', '955', '596', '321', '466', '453', '524', '343', '212', '530', '566', '442', '147', '88', '259', '130', '547', '655', '325', '968', '615', '93', '10', '218', '232', '357', '486', '219', '609', '89', '422', '287', '946', '267', '514', '927', '754', '724', '974', '622', '686', '599', '270', '476', '658', '959', '386', '268', '540', '541', '589', '470', '463', '294', '688', '963', '371', '169', '867', '841', '780', '504', '97', '763', '832', '701', '9', '145', '389', '905', '953', '518', '385', '699', '123', '558', '67', '511', '452', '598', '506', '301', '782', '585', '340', '42', '512', '728', '126', '249', '410', '579', '623', '911', '289', '773', '753', '45', '416', '755', '292', '400', '8', '611', '4

In [128]:
similar_snippets

[{'id': '448',
  'description': 'Write a function to calculate the sum of perrin numbers.',
  'distance': 634.7796630859375},
 {'id': '299',
  'description': 'Write a function to calculate the maximum aggregate from the list of tuples.',
  'distance': 729.4241943359375},
 {'id': '656',
  'description': 'Write a python function to find the minimum sum of absolute differences of two arrays.',
  'distance': 749.9418334960938},
 {'id': '335',
  'description': 'Write a function to find the sum of arithmetic progression.',
  'distance': 799.8870239257812},
 {'id': '397',
  'description': 'Write a function to find the median of three specific numbers.',
  'distance': 804.1572265625},
 {'id': '491',
  'description': 'Write a function to find the sum of geometric progression series.',
  'distance': 808.5955200195312},
 {'id': '573',
  'description': 'Write a python function to calculate the product of the unique numbers of a given list.',
  'distance': 813.4615478515625},
 {'id': '952',
  'desc

In [127]:
similar_ids = [snippet['id'] for snippet in similar_snippets]

print("Similar IDs:", similar_ids)

Similar IDs: ['448', '299', '656', '335', '397', '491', '573', '952', '924', '875', '324', '63', '320', '402', '614', '970', '672', '133', '227', '248', '704', '796', '738', '925', '137', '295', '951', '144', '717', '870', '404', '309', '345', '77', '798', '308', '481', '398', '955', '596', '321', '466', '453', '524', '343', '212', '530', '566', '442', '147', '88', '259', '130', '547', '655', '325', '968', '615', '93', '10', '218', '232', '357', '486', '219', '609', '89', '422', '287', '946', '267', '514', '927', '754', '724', '974', '622', '686', '599', '270', '476', '658', '959', '386', '268', '540', '541', '589', '470', '463', '294', '688', '963', '371', '169', '867', '841', '780', '504', '97', '763', '832', '701', '9', '145', '389', '905', '953', '518', '385', '699', '123', '558', '67', '511', '452', '598', '506', '301', '782', '585', '340', '42', '512', '728', '126', '249', '410', '579', '623', '911', '289', '773', '753', '45', '416', '755', '292', '400', '8', '611', '499', '121',