In [2]:
import polars as pl
import os
import time

In [3]:
# Get the current directory
current_dir = os.getcwd()

# Specify the cache directory as 'data' subdirectory within the current directory
cache_dir = os.path.join(current_dir, "data")
data_name = "big_1m_test_dataset_dbpedia.parquet"
# data_name = "small_100k_test_dataset_dbpedia.parquet"
width = 3072
data_path = os.path.join(cache_dir, data_name)

In [4]:
df = pl.read_parquet(data_path)


In [5]:
df.select(pl.col("embeddings").list.len().alias("len"))

len
u32
3072
3072
3072
3072
3072
3072
3072
3072
3072
3072


In [6]:
df =df.select(pl.col("combined_text"),pl.col("embeddings").list.to_array(width=width).alias("embeddings")).with_row_index(name="id")
df

id,combined_text,embeddings
u32,str,"array[f64, 3072]"
0,"""Parabolic refl…","[-0.022388, 0.028537, … -0.013797]"
1,"""John Baird (Ca…","[-0.009061, -0.011509, … 0.014185]"
2,"""The 80s: A Loo…","[-0.006117, -0.002699, … -0.000185]"
3,"""Shin Sang-ok S…","[0.033835, -0.006922, … -0.023406]"
4,"""Géza Anda Géza…","[-0.013956, 0.007346, … 0.002309]"
5,"""Marge vs. the …","[-0.016115, -0.016613, … 0.002089]"
6,"""Quebec general…","[0.029057, 0.005818, … 0.005066]"
7,"""D. P. Todd Sec…","[0.024967, 0.003702, … 0.02632]"
8,"""Dennis Gamsy D…","[0.024082, 0.017386, … 0.007041]"
9,"""Gourmet Night …","[-0.015987, 0.029655, … 0.010463]"


In [17]:
#break in two frames first 999 and last row
num_rows = 10000
vf = df.slice(0, num_rows-1)
query = df.slice(num_rows-1, 1)
query_embeddings = query["embeddings"]

In [18]:
try:
    knn = vf.select(
        pl.col("id"),
        pl.col("id").num.knn_ptwise(
            pl.col('embeddings'), # Columns used as the coordinates in n-d space
            k = 10, 
            dist = "l2", # actually this is squared l2
            leaf_size = 40,
            parallel = True
        ).alias("top_k_neighbors"),
    )
    print(knn)
except Exception as e:
    print(e)
    print("Error in knn")

the plugin failed with message: cannot cast list type
Error in knn


In [19]:
import polars_ds as pld


In [20]:
evf = vf.with_columns( pl.col("embeddings").arr.get(i).alias(f"embeddings_{i}") for i in range(width)).select(pl.all().exclude("embeddings"))
embedding_columns = [f"embeddings_{i}" for i in range(width)]
# embedding_columns = [f"embeddings_{i}" for i in range(1)]
print(embedding_columns)

['embeddings_0', 'embeddings_1', 'embeddings_2', 'embeddings_3', 'embeddings_4', 'embeddings_5', 'embeddings_6', 'embeddings_7', 'embeddings_8', 'embeddings_9', 'embeddings_10', 'embeddings_11', 'embeddings_12', 'embeddings_13', 'embeddings_14', 'embeddings_15', 'embeddings_16', 'embeddings_17', 'embeddings_18', 'embeddings_19', 'embeddings_20', 'embeddings_21', 'embeddings_22', 'embeddings_23', 'embeddings_24', 'embeddings_25', 'embeddings_26', 'embeddings_27', 'embeddings_28', 'embeddings_29', 'embeddings_30', 'embeddings_31', 'embeddings_32', 'embeddings_33', 'embeddings_34', 'embeddings_35', 'embeddings_36', 'embeddings_37', 'embeddings_38', 'embeddings_39', 'embeddings_40', 'embeddings_41', 'embeddings_42', 'embeddings_43', 'embeddings_44', 'embeddings_45', 'embeddings_46', 'embeddings_47', 'embeddings_48', 'embeddings_49', 'embeddings_50', 'embeddings_51', 'embeddings_52', 'embeddings_53', 'embeddings_54', 'embeddings_55', 'embeddings_56', 'embeddings_57', 'embeddings_58', 'embed

In [21]:
leaf_size = 1000
k = 10
start_time = time.time()
print("knn with polars_ds num_rows {}, num_embedding_columns {}, k {}, leaf_size {}".format(vf.shape[0], len(embedding_columns),k, leaf_size))

knn = evf.select(
    pl.col("id"),
    pl.col("id").num.knn_ptwise(
        *(pl.col(name) for name in embedding_columns), # Columns used as the coordinates in n-d space
        k = k, 
        dist = "l2", # actually this is squared l2
        leaf_size = leaf_size,
        parallel = True
    ).alias("top_k_neighbors"),
)
print("--- %s seconds ---" % (time.time() - start_time))
knn

knn with polars_ds num_rows 9999, num_embedding_columns 3072, k 10, leaf_size 1000
--- 27.41622495651245 seconds ---


id,top_k_neighbors
u32,list[u64]
0,"[0, 3726, … 6006]"
1,"[1, 2486, … 2622]"
2,"[2, 8217, … 5688]"
3,"[3, 9569, … 6742]"
4,"[4, 1044, … 3945]"
5,"[5, 9173, … 9775]"
6,"[6, 825, … 5645]"
7,"[7, 2519, … 7956]"
8,"[8, 6168, … 6354]"
9,"[9, 3565, … 3271]"


In [None]:
import polars_distance as plds

In [13]:
#let's try to go manual and see if it works
vf = vf.with_columns(pl.lit(query_embeddings.arr.to_list()).alias("query_embeddings"))
vf= vf.with_columns(pl.col("query_embeddings").list.to_array(width=width).alias("query_embeddings"))
vf

id,combined_text,embeddings,query_embeddings
u32,str,"array[f64, 3072]","array[f64, 3072]"
0,"""Parabolic refl…","[-0.022388, 0.028537, … -0.013797]","[-0.011691, 0.011382, … 0.005522]"
1,"""John Baird (Ca…","[-0.009061, -0.011509, … 0.014185]","[-0.011691, 0.011382, … 0.005522]"
2,"""The 80s: A Loo…","[-0.006117, -0.002699, … -0.000185]","[-0.011691, 0.011382, … 0.005522]"
3,"""Shin Sang-ok S…","[0.033835, -0.006922, … -0.023406]","[-0.011691, 0.011382, … 0.005522]"
4,"""Géza Anda Géza…","[-0.013956, 0.007346, … 0.002309]","[-0.011691, 0.011382, … 0.005522]"
5,"""Marge vs. the …","[-0.016115, -0.016613, … 0.002089]","[-0.011691, 0.011382, … 0.005522]"
6,"""Quebec general…","[0.029057, 0.005818, … 0.005066]","[-0.011691, 0.011382, … 0.005522]"
7,"""D. P. Todd Sec…","[0.024967, 0.003702, … 0.02632]","[-0.011691, 0.011382, … 0.005522]"
8,"""Dennis Gamsy D…","[0.024082, 0.017386, … 0.007041]","[-0.011691, 0.011382, … 0.005522]"
9,"""Gourmet Night …","[-0.015987, 0.029655, … 0.010463]","[-0.011691, 0.011382, … 0.005522]"


In [15]:
query_str = query["combined_text"].to_list()[0]
print(query_str)

WWMB WWMB is CW-affiliated television station for South Carolina's Pee Dee and Grand Strand regions that is licensed to Florence. It broadcasts a 720p high definition digital signal on UHF channel 21 from a transmitter on Pee Dee Church Road in Floydale. Owned by Howard Stirk Holdings, WWMB is operated through a local marketing agreement (LMA) by the Sinclair Broadcast Group.


In [16]:
size = 1000000
sf = vf.slice(0, size)
pre_time = time.time()
out = sf.select(pl.col("combined_text"),plds.col("embeddings").dist_arr.euclidean("query_embeddings").alias("l2")).sort("l2")
print("Time taken: ", time.time() - pre_time)
out

Time taken:  0.004846811294555664


combined_text,l2
str,f64
"""WWCR WWCR is a…",1.005102
"""WVVA WVVA, cha…",1.034929
"""KCPM (TV) KCPM…",1.047232
"""WDFX WDFX may …",1.102579
"""WLR FM WLR FM …",1.133165
"""DWBR DWBR (104…",1.135726
"""List of ABS-CB…",1.162832
"""Newsworld Inte…",1.183512
"""The Inter-Moun…",1.186985
"""All News Chann…",1.196639


In [17]:
out_0_str = out["combined_text"].to_list()[0]
print(out_0_str)

WWCR WWCR is a shortwave radio station located in Nashville, Tennessee in the United States. WWCR uses four 100 kW transmitters to broadcast on about a dozen frequencies.WWCR mainly leases out its four transmitters to religious organizations and speakers. However, it does air a few hours of original programming per week.F.W. Robbert Broadcasting also owns the AM (mediumwave) stations WNQM in Nashville, WMQM and WLRM in Memphis, WITA in Knoxville, and WVOG in New Orleans.
