# Load

In [1]:
import panel as pn
from dotenv import load_dotenv
import param

import sys
sys.path.append('..')

load_dotenv()

pn.extension()

# Lance Playground

In [2]:
import lancedb
import pandas as pd
import pyarrow as pa

uri = "data/sample-lancedb"
db = lancedb.connect(uri)

# LanceDb offers both a synchronous and an asynchronous client.  There are still a
# few operations that are only supported by the synchronous client (e.g. embedding
# functions, full text search) but both APIs should soon be equivalent

# In this guide we will give examples of both clients.  In other guides we will
# typically only provide examples with one client or the other.
# uri = "data/sample-lancedb"
# async_db = await lancedb.connect_async(uri)

In [16]:
data = [
    {"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
    {"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
]

# Synchronous client
tbl = db.create_table("my_table_test", data=data, exist_ok=True)
# Asynchronous client
# async_tbl = await async_db.create_table("my_table2", data=data)

In [20]:
tbl.to_lance

[0;31mType:[0m        LanceTable
[0;31mString form:[0m LanceTable(connection=LanceDBConnection(/workspaces/pyllments/dev_sandbox/data/sample-lancedb), name="my_table_test")
[0;31mLength:[0m      2
[0;31mFile:[0m        ~/miniconda3/envs/pyllments/lib/python3.12/site-packages/lancedb/table.py
[0;31mDocstring:[0m  
A table in a LanceDB database.

This can be opened in two modes: standard and time-travel.

Standard mode is the default. In this mode, the table is mutable and tracks
the latest version of the table. The level of read consistency is controlled
by the `read_consistency_interval` parameter on the connection.

Time-travel mode is activated by specifying a version number. In this mode,
the table is immutable and fixed to a specific version. This is useful for
querying historical versions of the table.

In [5]:
async_db.op

AttributeError: 'AsyncConnection' object has no attribute 'url'

In [22]:
import duckdb
tbl = tbl.to_lance()

In [34]:
sample = tbl.sample(2)
# Main

In [47]:
sample.column('item').to_pylist()

['foo', 'bar']

In [49]:
sample.to_pydict()

{'vector': [[3.0999999046325684, 4.099999904632568],
  [5.900000095367432, 26.5]],
 'item': ['foo', 'bar'],
 'price': [10.0, 20.0]}

In [12]:
type(tbl)

lance.dataset.LanceDataset

In [11]:
duckdb.query("select item from tbl ORDER BY RANDOM() LIMIT 1")

┌─────────┐
│  item   │
│ varchar │
├─────────┤
│ bar     │
└─────────┘

# Main

In [26]:
pa.field('embedding', pa.list_(pa.float32(), 768))

pyarrow.Field<embedding: fixed_size_list<item: float>[768]>

In [62]:
class Collection(param.Parameterized):
    url = param.String(default="", doc="""
        The url of the data folder""")
    collection_name = param.String(default="default", doc="""
        The name of the collection""")
    db = param.Parameter(default=None)
    collection = param.Parameter(default=None)

    def load_db(self, url: str):
        """Loads a database from a url or creates a new one"""
        pass

    def load_collection(self, collection_name: str):
        """Loads a collection from the database"""
        pass

    def add_items(self, items: list[dict]):
        pass

import lancedb
import pyarrow as pa
import numpy as np

default_lance_db_schema = pa.schema([
    pa.field('text', pa.string()),
    pa.field('embedding', pa.list_(pa.float32(), 768)),
    pa.field('source_file', pa.string()),
    pa.field('start_idx', pa.int32()),
    pa.field('end_idx', pa.int32())
])

class LanceDBCollection(Collection):
    url = param.String(default="data/lancedb", doc="""
        The url of the database""")
    schema = param.Parameter(
        default=default_lance_db_schema,
        doc="""The pydantic schema of the collection""")
    metric = param.String(default="cosine", doc="""
        The metric used to search the collection""")
    n = param.Integer(default=5, doc="""
        The number of results to return""")
    
    def __init__(self, **params):
        super().__init__(**params)
        self.load_collection(self.collection_name)
    
    def load_collection(self, collection_name: str):
        """Loads a collection from the database"""
        self.db = lancedb.connect(self.url)
        self.collection = self.db.create_table(
            name=self.collection_name,
            schema=self.schema,
            exist_ok=True)
        
    def add_item(self, item: dict):
        """Adds an item to the collection"""
        # item['embedding'] = list(item['embedding'].astype(np.float32))
        self.collection.add([item])

    def add_items(self, items: list[dict]):
        """Adds items to the collection"""
        self.collection.add(items)

    def query(self, embedding: np.ndarray, n: int = None, metric: str = None):
        """Queries the collection. If n or metric are not provided, uses the class defaults"""
        if n is None:
            n = self.n
        if metric is None:
            metric = self.metric
        return self.collection.search(embedding) \
            .metric(metric) \
            .limit(n) \
            .to_list()

    def get_random_items(self, n: int, column_name: str = 'text', get_dict: bool = False):
        """
        Gets random items from the collection. If column_name provided, returns an
        n-length list of values. If get_dict is True, returns a dictionary.
        """
        lance_table = self.collection.to_lance()
        if get_dict:
            return lance_table.sample(n).to_pydict()
        else:
            return lance_table.sample(n).column(column_name).to_pylist()
    


In [72]:
etl = default_lance_db_schema.empty_table().column_names

In [73]:
etl

['text', 'embedding', 'source_file', 'start_idx', 'end_idx']

In [61]:
default_lance_db_schema.

TypeError: equals() takes at least 1 positional argument (0 given)

In [45]:
# Instantiate the LanceDBCollection with a specific collection name
lance_db_collection = LanceDBCollection(
    collection_name='test_collection',
    url='data/lancedb0',
    schema=pa.schema([
        pa.field('text', pa.string()),
        pa.field('embedding', pa.list_(pa.float32(), 3)),
        pa.field('source_file', pa.string())
    ])
)

# Create 5 items to add to the collection
items_to_add = [
    {'text': 'Item 1', 'embedding': np.random.rand(3), 'source_file': 'test'},
    {'text': 'Item 2', 'embedding': np.random.rand(3), 'source_file': 'test'},
    {'text': 'Item 3', 'embedding': np.random.rand(3), 'source_file': 'test'},
    {'text': 'Item 4', 'embedding': np.random.rand(3), 'source_file': 'test'},
    {'text': 'Item 5', 'embedding': np.random.rand(3), 'source_file': 'test'}
]

# Add the items to the LanceDBCollection
lance_db_collection.add_items(items_to_add)


In [79]:
lance_db_collection.query(np.random.rand(3)).metric("cosine").limit(2).to_list()


[{'text': 'Item 2',
  'embedding': [0.4641132652759552, 0.015456879511475563, 0.16316711902618408],
  'source_file': 'test',
  '_distance': 0.04401075839996338},
 {'text': 'Item 2',
  'embedding': [0.9206529259681702, 0.20832976698875427, 0.7067242860794067],
  'source_file': 'test',
  '_distance': 0.1752963662147522}]

In [83]:
lance_db_collection.query(np.random.rand(3))#.metric("cosine").limit(2).to_list()


<lancedb.query.LanceVectorQueryBuilder at 0x7fabc0f25a30>

In [82]:
lancedb.__version__

'0.12.0'

In [43]:
lance_db_collection.collection.search(np.random.rand(3)).to_list()

[{'text': 'Item 5',
  'embedding': [0.3201635777950287, 0.1651465892791748, 0.6056264638900757],
  'source_file': 'test',
  '_distance': 0.055993568152189255},
 {'text': 'Item 1',
  'embedding': [0.8767399787902832, 0.04871548339724541, 0.6977120041847229],
  'source_file': 'test',
  '_distance': 0.1152745857834816},
 {'text': 'Item 4',
  'embedding': [0.19384464621543884, 0.20555397868156433, 0.5176759958267212],
  'source_file': 'test',
  '_distance': 0.14991918206214905},
 {'text': 'Item 2',
  'embedding': [0.4641132652759552, 0.015456879511475563, 0.16316711902618408],
  'source_file': 'test',
  '_distance': 0.2482559084892273},
 {'text': 'Item 3',
  'embedding': [0.23947496712207794, 0.582857608795166, 0.2254280298948288],
  'source_file': 'test',
  '_distance': 0.4924866557121277},
 {'text': 'Item 1',
  'embedding': [0.06988758593797684, 0.736587643623352, 0.8949791789054871],
  'source_file': 'test',
  '_distance': 0.6807531714439392}]

In [25]:
from lancedb.pydantic import LanceModel, Vector

# model = create_model('test',__base__=LanceModel, vector=(Vector(768), ...))
# 
db.create_table('text_v', schema=model)


LanceTable(connection=LanceDBConnection(/workspaces/pyllments/dev_nbs/data/sample-lancedb), name="text_v")

In [33]:
import pyarrow as pa

schema = pa.schema([
    pa.field('text', pa.string()),
    pa.field('vector', pa.list_(pa.float32(), 768))
])

tbl = db.create_table('text_v', schema=schema, mode='overwrite')

In [40]:
import numpy as np

# Create a 768-dimensional array filled with zeros
array_768 = np.ones(758)


In [61]:
vec = np.random.rand(768)

In [None]:
# Uses default metric l2
tbl.search(vec).to_list()

# cosine metric forced
tbl.search(vec) \
    .metric("cosine") \
    .to_list()


In [66]:
tbl.create

[0;31mSignature:[0m
[0mtbl[0m[0;34m.[0m[0msearch[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mquery[0m[0;34m:[0m [0;34m"Optional[Union[VEC, str, 'PIL.Image.Image', Tuple]]"[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mvector_column_name[0m[0;34m:[0m [0;34m'Optional[str]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mquery_type[0m[0;34m:[0m [0;34m'str'[0m [0;34m=[0m [0;34m'auto'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mordering_field_name[0m[0;34m:[0m [0;34m'Optional[str]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m'LanceQueryBuilder'[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
and [full-text search][search].

Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> data = [
...    {"ori

In [41]:
tbl.add([{'text':'some test text', 'vector': array_768}])

ArrowTypeError: Size of FixedSizeList is not the same. input list: fixed_size_list<item: float>[758] output list: fixed_size_list<item: float>[768]

In [39]:
tbl.head()

pyarrow.Table
text: string
vector: fixed_size_list<item: float>[768]
  child 0, item: float
----
text: [["some test text"]]
vector: [[[1,1,1,1,1,...,1,1,1,1,1]]]

In [26]:
import lancedb

db = lancedb.connect("./.lancedb")

data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
        {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]

db.create_table("my_table", data)

db["my_table"].head()

pyarrow.Table
vector: fixed_size_list<item: float>[2]
  child 0, item: float
lat: double
long: double
----
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]

In [29]:
db["my_table"].to_pandas()

Unnamed: 0,vector,lat,long
0,"[1.1, 1.2]",45.5,-122.7
1,"[0.2, 1.8]",40.1,-74.1


In [19]:
from pyllments.base.model_base import Model
# from pyllments.common.tokenizers import get_token_len
from pyllments.payloads.chunk import ChunkPayload
from pyllments.payloads.message import MessagePayload

default_lance_db_schema = pa.schema([
    pa.field('text', pa.string()),
    pa.field('embedding', pa.list_(pa.float32(), 768)),
    pa.field('source_file', pa.string()),
    pa.field('start_idx', pa.int32()),
    pa.field('end_idx', pa.int32())
])

def pa_schema_to_col_list(schema: pa.Schema):
    return schema.empty_table().column_names

class RetrieverModel(Model):
    collection = param.ClassSelector(class_=Collection, doc="""
        The collection to retrieve from. Based on a a DB backend for storage""")
    collection_name = param.String(default="", doc="""
        The name of the collection""")
    url = param.String(default="", doc="""
        The url of the database""")
    embedding_dims = param.Integer(default=768, doc="""
        The dimension of the embedding""")
    schema = param.Parameter(default=default_lance_db_schema, doc="""
    The schema used with the collection. If the Collection is based on
    LanceDB, pyarrow schemas are preferred.
    """)
    schema_cols = param.List()
    metric = param.String(default="cosine", doc="""
        The metric used to search the collection""")
    retrieval_n = param.Integer(default=5, doc="""
        The number of results to return""")
    # TODO: Implement token limits for retrieval if necessary
    retrieval_token_limit = param.Integer(default=None, doc="""
        The token limit of the model""")
    retrieval_tokenizer_model = param.String(default="gpt-4o-mini", doc="""
        The model used to tokenize the text""")

    def __init__(self, retrieval_token_limit=None, **params):
        super().__init__(**params)
        if not self.collection_name:
            # Uses default param-generated RetrievelModel name if not set
            self.collection_name = self.name
        if not self.schema:
            schema = default_lance_db_schema
        self.schema_cols = pa_schema_to_col_list(self.schema)
        self.collection = LanceDBCollection(
            collection_name=self.collection_name,
            schema=self.schema
        )
    
    def add_item(self, chunk_payload: ChunkPayload):
        item = {col: getattr(chunk_payload.model, col) for col in self.schema_cols}
        self.collection.add_item(item)
    
    def retrieve(self, message_payload: MessagePayload):
        embedding = message_payload.model.embedding
        chunk_payloads = [
            ChunkPayload(**item)
            for item in self.collection.query(
                embedding,
                n=self.retrieval_n,
                metric=self.metric)
        ]
        return chunk_payloads
        

from typing import Union
from pyllments.base.element_base import Element
from pyllments.payloads.chunk import ChunkPayload


class RetrieverElement(Element):
    # Needs two col viz, one for the created chunks, and one for the retrieved chunks

    def __init__(self, **params):
        super().__init__(**params)
        if not self.collection_name:
            self.collection_name = self.name
        self.model = RetrieverModel()
        
    def _chunk_load_input_setup(self):
        """For the collection populating process"""
        def unpack(payload: Union[ChunkPayload, list[ChunkPayload]]):
            chunks = payload if isinstance(payload, list) else [payload]
            for chunk in chunks:
                self.model.add_item(chunk)
        
        self.ports.add_input('chunk_input', unpack)

    def _message_query_input_setup(self):
        """The input query used for retrieval"""
        def unpack(payload: MessagePayload):
            chunks = self.model.retrieve(payload)
            self.ports.output['chunk_output'].stage_emit(chunks)
        
        self.ports.add_input('message_input', unpack)

    def _chunk_result_output_setup(self):
        """The output of the retrieval process"""
        def pack(chunk_payload: list[ChunkPayload]):
            return chunk_payload
        
        self.ports.add_output('chunk_output', pack)



In [20]:
r_model = RetrieverModel()

In [None]:
r_model

In [11]:
retriever = RetrieverElement()


In [12]:
retriever.name

'RetrieverElement00117'

In [4]:
import param

class Grandparent(param.Parameterized):
    grandparent_name = param.String(default='grandparent')
    def __init__(self, **params):
        print(params)

class Parent(Grandparent):
    parent_name = param.String(default='parent')

class Test(Parent):
    name = param.String(default='test')

    def __init__(self, **params):

test = Test(name='asdfsfdsf', boogaloo='asdf')

TypeError: Test.__init__() got an unexpected keyword argument 'boogaloo'

In [2]:
from pyllments.payloads.file import FilePayload

file_payload = FilePayload(filename='test.txt')
file_payload

{'filename': 'test.txt'}


FilePayload(css_cache={}, id='0f0761e1-4dd6-49cd-9ed8-3a5dae972367', model=FileModel(b_file=None, filename='test.txt', local_path='', mime_type='', name='FileModel00117', remote_path=''), name='FilePayload', view_cache={})

In [3]:
from pyllments.elements.retriever.retriever_model import LanceDBCollection

collection = LanceDBCollection(
    collection_name='RetrieverModel00137',
    url='/workspaces/pyllments/dev_sandbox/retrieval step testing/data/lancedb',
)
collection

LanceDBCollection(collection=LanceTable(connection=LanceDBConnection(/workspaces/pyllments/dev_sandbox/retrieval step testing/data/lancedb), name="RetrieverModel00137"), collection_name='RetrieverModel00137', db=LanceDBConnection(/workspaces/pyllments/dev_sandbox/retrieval step testing/data/lancedb), metric='cosine', n=5, name='LanceDBCollection00118', schema=text: string
embedding: fixed_size_list<item: float>[768]
  child 0, item: float
source_filepath: string
start_idx: int32
end_idx: int32, url='/workspaces/pyllments/dev_sandbox/retrieval step testing/data/lancedb')

In [7]:
collection.query(np.random.rand(768), n=2, metric='cosine')

[{'text': 'l',
  'embedding': [0.21259000897407532,
   0.46376603841781616,
   0.057312097400426865,
   -0.4112528860569,
   0.42801469564437866,
   -1.1532549858093262,
   0.17037732899188995,
   0.2252465784549713,
   0.599388062953949,
   0.026341136544942856,
   -0.29634273052215576,
   -1.3212788105010986,
   0.5012282133102417,
   0.5871407985687256,
   0.8400241732597351,
   -0.23059634864330292,
   -0.7151937484741211,
   -1.4625110626220703,
   -0.5266277194023132,
   -0.7396470904350281,
   0.4119313061237335,
   1.001466155052185,
   0.6968907117843628,
   0.6948626041412354,
   0.009133252315223217,
   -0.15152376890182495,
   -0.198019877076149,
   1.4987369775772095,
   0.04665743559598923,
   0.7496033906936646,
   0.19130262732505798,
   0.30900681018829346,
   -0.11059275269508362,
   -0.028584390878677368,
   -0.5241007804870605,
   0.06153571233153343,
   0.23127877712249756,
   -0.3602466285228729,
   -0.386247843503952,
   -1.8052417039871216,
   0.3750823140144348

In [6]:
import numpy as np
collection.collection.search(np.random.rand(768)).metric('cosine').to_list()

[{'text': 'e',
  'embedding': [0.26205679774284363,
   0.6268020272254944,
   0.5425217747688293,
   0.10101582854986191,
   0.9048588275909424,
   -0.9096052646636963,
   0.19324102997779846,
   -0.43715330958366394,
   1.0208261013031006,
   -0.15846462547779083,
   -0.5061513185501099,
   -1.1818417310714722,
   0.3123464286327362,
   0.5497028231620789,
   0.19928452372550964,
   -0.408832848072052,
   -0.9135117530822754,
   -0.3973250687122345,
   -0.3507107198238373,
   0.5919840335845947,
   0.2946459650993347,
   0.1651860624551773,
   0.29724517464637756,
   0.09476714581251144,
   0.1395830363035202,
   0.3564080595970154,
   -0.28191283345222473,
   0.792989194393158,
   -0.3158460259437561,
   -0.09308665245771408,
   0.529658854007721,
   0.021622471511363983,
   0.799915611743927,
   0.6316210627555847,
   -0.2716239392757416,
   -0.03547493368387222,
   0.05121799185872078,
   -0.06390446424484253,
   -0.6362295150756836,
   -1.5501104593276978,
   -0.05256140977144241,