In [1]:
import chromadb
from pprint import pprint

In [2]:
chroma_client = chromadb.Client()

In [3]:
# collection = chroma_client.create_collection(name='collection1')
collection = chroma_client.get_or_create_collection(name='collection1') # use default embedding

collection.add(
    documents=[
        "this is doc1 about apple",
        "this is doc1 about mango"
    ],ids=['id1','id2']
)


results = collection.query(
    query_texts=['this is a query document about a fruit with red color'],
    n_results=2,
    where_document={'$contains':'apple'}
)

pprint(results)

{'data': None,
 'distances': [[1.3395774364471436]],
 'documents': [['this is doc1 about apple']],
 'embeddings': None,
 'ids': [['id1']],
 'included': [<IncludeEnum.distances: 'distances'>,
              <IncludeEnum.documents: 'documents'>,
              <IncludeEnum.metadatas: 'metadatas'>],
 'metadatas': [[None]],
 'uris': None}


# Dataset

In [4]:
import pandas as pd
articles = pd.read_csv('Articles.csv', encoding='ISO-8859-1', index_col=None)
articles['id'] = articles.index.astype(str)
articles = articles[:50]
articles.head()

Unnamed: 0,Article,Date,Heading,NewsType,id
0,KARACHI: The Sindh government has decided to b...,1/1/2015,sindh govt decides to cut public transport far...,business,0
1,HONG KONG: Asian markets started 2015 on an up...,1/2/2015,asia stocks up in new year trad,business,1
2,HONG KONG: Hong Kong shares opened 0.66 perce...,1/5/2015,hong kong stocks open 0.66 percent lower,business,2
3,HONG KONG: Asian markets tumbled Tuesday follo...,1/6/2015,asian stocks sink euro near nine year,business,3
4,NEW YORK: US oil prices Monday slipped below $...,1/6/2015,us oil prices slip below 50 a barr,business,4


# Ollama embedding

In [5]:
from langchain_ollama import OllamaEmbeddings

embedding_model = OllamaEmbeddings(model="llama2")

embedded_texts = embedding_model.embed_documents([
    "This is my first text to embed",
    "This is my second document"
])

print(embedded_texts)  # Check the output embeddings


[[-0.008455402, -0.0035663992, 0.01067029, 0.01913489, -0.005256403, -0.0054485956, -0.0016198481, -0.011148195, -0.012819734, 0.01290301, -0.00013121741, -0.000757624, 0.01680066, -0.0042573307, 0.022457005, 0.0009943492, -0.013622382, 0.007460264, 0.020176942, 0.0028079774, 0.008686881, -0.013345441, 0.0014763755, -0.011706626, -0.03652688, 0.0012336946, -0.005989819, -0.008728215, 0.0029214171, 0.0042052013, 0.002348707, 0.011738965, -0.016671944, -0.0019127182, -0.02028514, 0.004037694, 0.00086588284, 0.012585342, 0.0045988373, 0.01443806, 0.012112478, -0.015555765, 0.0053322366, -0.0064943978, -0.015778948, -0.01769562, 0.023536252, 0.0035874117, 0.01983626, 0.037085667, -0.0146520315, 0.012565124, 0.0037460704, 0.011661753, 0.025868794, 0.001146059, -0.0085400855, -0.007850899, 0.0016011017, 0.008660299, -0.010870487, 0.0067320107, -0.0040965574, -0.0065832203, -0.010745564, -0.011079591, 0.0065585603, 0.00898175, 0.024416313, -0.030575795, -0.0070669106, 0.02233246, -0.007307821

In [6]:
class ChromaDBEmbeddingFunction:    # Define a custom embedding function for ChromaDB using Ollama
    def __init__(self, langchain_embeddings):
        self.langchain_embeddings = langchain_embeddings

    def __call__(self, input):  # The __call__ method allows the object to be used like a function.
        if isinstance(input, str):  # Ensure the input is in a list format for processing
            input = [input]
        return self.langchain_embeddings.embed_documents(input)


In [7]:
ollama_embedding = OllamaEmbeddings(
        model='deepseek-r1:1.5b',
        base_url="http://localhost:11434"
    )

embedding_funct = ChromaDBEmbeddingFunction(ollama_embedding)  # initialize custom embedding function

In [8]:
# Create or get collection
collection = chroma_client.get_or_create_collection(
    name='collection2',
    embedding_function=embedding_funct)

collection.add(
    documents=list(articles['Article']),
    ids=list(articles['id'])
)

print("Documents stored successfully!")

Documents stored successfully!


In [9]:
query = 'public transport fares by 7 per cent'

out = collection.query(query_texts=query, n_results=1)
pprint(out)

{'data': None,
 'distances': [[0.9999998807907104]],
 'documents': [['ISLAMABAD:  Federal Minister for Finance Ishaq Dar on '
                'Saturday announced a five percent increase in the General '
                'Sales Tax (GST) on petroleum products.Dar said that the '
                'increment would enable a recovery of 12 billion rupees.The '
                'minister, however, went on to say that the ministry would '
                'still face a loss of 40 billion rupees.Earlier today, Prime '
                'Minister Nawaz Sharif announced a decrease in the price of '
                'petroleum products.Petrol has been decreased by Rs 7.99, '
                'Hi-Octane by Rs 11.82, Light Diesel by Rs 9.56 and kerosene '
                'oil by Rs 10.48 per litre. \r\n'
                '\r\n'
                '\r\n'
                '\r\n'
                '\r\n'
                '\r\n'
                '\r\n'
                '\r\n'
                '\r\n'
                '\r\n

# Persistent

In [10]:
chroma_client = chromadb.PersistentClient(path='./vectordb')

In [11]:
collection = chroma_client.get_or_create_collection(name='docs', embedding_function=embedding_funct)

In [12]:
collection.add(
    documents=list(articles['Article']),
    ids=list(articles['id'])
)

Add of existing embedding ID: 0
Add of existing embedding ID: 1
Add of existing embedding ID: 2
Add of existing embedding ID: 3
Add of existing embedding ID: 4
Add of existing embedding ID: 5
Add of existing embedding ID: 6
Add of existing embedding ID: 7
Add of existing embedding ID: 8
Add of existing embedding ID: 9
Add of existing embedding ID: 10
Add of existing embedding ID: 11
Add of existing embedding ID: 12
Add of existing embedding ID: 13
Add of existing embedding ID: 14
Add of existing embedding ID: 15
Add of existing embedding ID: 16
Add of existing embedding ID: 17
Add of existing embedding ID: 18
Add of existing embedding ID: 19
Add of existing embedding ID: 20
Add of existing embedding ID: 21
Add of existing embedding ID: 22
Add of existing embedding ID: 23
Add of existing embedding ID: 24
Add of existing embedding ID: 25
Add of existing embedding ID: 26
Add of existing embedding ID: 27
Add of existing embedding ID: 28
Add of existing embedding ID: 29
Add of existing embe