In [None]:
from rich import print
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = (
    SparkSession.builder.appName("BM25_Query")
    .config("spark.jars", "./spark-cassandra-connector.jar")
    .config("spark.cassandra.connection.host", "127.0.0.1")
    .getOrCreate()
)

In [12]:
doc_stats = (
    spark.read.format("org.apache.spark.sql.cassandra")
    .options(table="doc_stats", keyspace="search")
    .load()
    .select(F.col("doc_id"), F.col("doc_title"), F.col("doc_length"))
)

In [13]:
doc_stats.show()

+--------+--------------------+----------+
|  doc_id|           doc_title|doc_length|
+--------+--------------------+----------+
|25604523|BBC Knowledge (ma...|       110|
|13114764|        BAO på turné|       396|
|40782306|       B. Deniswaran|       247|
| 2458143|              B road|        76|
| 6729513|        B. J. Porter|       276|
|62111958|   B.J. and the A.C.|       866|
|50537017|        BA25 (album)|       243|
|21654437|              BACPAC|       315|
|32833925|         B. B. Davis|       396|
|68461256|             BASIC-8|       993|
|73920812|     B. J. Callaghan|       411|
|11801169|BC Spartak Saint ...|       400|
| 3611959|                 BCY|        24|
|66827058|       B.L.E.S.S.E.D|        87|
|15065816|                BAP1|       545|
|14597703| B Sides and C Sides|        95|
|35627419|    B-class lifeboat|        45|
|10615081|        B. Kothakota|        72|
|16869388|    B8 road (Cyprus)|        72|
|  384414|            BBC UKTV|       547|
+--------+-

In [14]:
inv_index = (
    spark.read.format("org.apache.spark.sql.cassandra")
    .options(table="inverted_index", keyspace="search")
    .load()
    .select(F.col("term"), F.col("doc_id"), F.col("tf"))
)

In [15]:
inv_index.show()

+------------+--------+---+
|        term|  doc_id| tf|
+------------+--------+---+
|    rohinton|39297323|  1|
|    spelling| 3160315|  1|
|    spelling|17682449|  1|
|    spelling|18849559|  1|
|    spelling|19097751|  3|
|    spelling|30838130|  3|
|    spelling|62305198|  1|
|     arabian|13453391|  1|
|       aspel| 3226985|  1|
|       aspel|18700118|  1|
|       aspel|28444278|  1|
|     rounder|44918624|  1|
|     rounder|73587954|  1|
|      herrin|39297323|  1|
|       bulis| 3182741|  1|
|       madam| 1604479|  1|
|       madam|27475501|  3|
|preprocessed|10875123|  1|
|  palliative|  105391|  1|
|  palliative| 9385717|  1|
+------------+--------+---+
only showing top 20 rows



In [16]:
vocab = (
    spark.read.format("org.apache.spark.sql.cassandra")
    .options(table="vocabulary", keyspace="search")
    .load()
    .select(F.col("term"), F.col("df"))
)

In [17]:
vocab.show()

+-----------+---+
|       term| df|
+-----------+---+
|   karnatik|  1|
|   snapshot|  3|
|  reworking|  2|
|        gip|  1|
|      damir|  1|
|     werner|  2|
|    mashhad|  1|
|  prompting|  7|
| subbraayan|  1|
|hydrophobic|  3|
| nonenglish|  1|
|   derailed|  1|
|     molded|  1|
|    brewham|  1|
|     column| 17|
|     tokuma|  1|
|disaffected|  2|
|    impairs|  1|
|proposition|  2|
|  inbavalli|  1|
+-----------+---+
only showing top 20 rows



In [21]:
N = doc_stats.count()
avg_dl = doc_stats.agg(F.sum("doc_length")).first()[0] / N

print(f"{N = }\n{avg_dl = }")

In [30]:
query_text = "Armstrong moon landing"
query_terms = [term for term in query_text.lower().split() if term]

query_terms

['armstrong', 'moon', 'landing']

#### Filter the inverted index for query terms only.


In [40]:
query_index = inv_index.filter(F.col("term").isin(query_terms))
print(query_index.count())
query_index.show()



+---------+--------+---+
|     term|  doc_id| tf|
+---------+--------+---+
|armstrong|  322487| 16|
|armstrong| 3061257|  1|
|armstrong| 3160315|  1|
|armstrong| 9237209|  1|
|armstrong| 9278580|  1|
|armstrong|14597703|  1|
|armstrong|33244357|  1|
|armstrong|34832338|  1|
|armstrong|70547608|  2|
|  landing|  336647|  1|
|  landing|  451011|  1|
|  landing| 1467695|  1|
|  landing| 3395088|  5|
|  landing| 4082240|  1|
|  landing|13495191|  1|
|  landing|18226518|  1|
|  landing|67323043|  1|
|     moon|  171049|  1|
|     moon|  931899|  2|
|     moon| 1235674|  1|
+---------+--------+---+
only showing top 20 rows



#### Join with vocabulary table to get df for each term.


In [41]:
query_index = query_index.join(vocab, on="term", how="left")
print(query_index.count())
query_index.show()

+---------+--------+---+---+
|     term|  doc_id| tf| df|
+---------+--------+---+---+
|     moon|  171049|  1| 21|
|     moon|  931899|  2| 21|
|     moon| 1235674|  1| 21|
|     moon| 1604479|  1| 21|
|  landing|  336647|  1|  8|
|  landing|  451011|  1|  8|
|  landing| 1467695|  1|  8|
|  landing| 3395088|  5|  8|
|  landing| 4082240|  1|  8|
|  landing|13495191|  1|  8|
|  landing|18226518|  1|  8|
|  landing|67323043|  1|  8|
|armstrong|  322487| 16|  9|
|armstrong| 3061257|  1|  9|
|armstrong| 3160315|  1|  9|
|armstrong| 9237209|  1|  9|
|armstrong| 9278580|  1|  9|
|armstrong|14597703|  1|  9|
|armstrong|33244357|  1|  9|
|armstrong|34832338|  1|  9|
+---------+--------+---+---+
only showing top 20 rows



#### Join with doc_stats to get document lengths and titles.


In [42]:
query_index = query_index.join(doc_stats, on="doc_id", how="left")
print(query_index.count())
query_index.show()

+--------+---------+---+---+--------------------+----------+
|  doc_id|     term| tf| df|           doc_title|doc_length|
+--------+---------+---+---+--------------------+----------+
| 1467695|  landing|  1|  8|   BBC controversies|      7370|
| 4082240|  landing|  1|  8|BBC Allied Expedi...|       213|
| 9278580|armstrong|  1|  9|BC Junior A Lacro...|       622|
|13495191|  landing|  1|  8|         B. G. Henry|       563|
|67323043|  landing|  1|  8|       B&D Australia|       219|
| 9237209|armstrong|  1|  9|BAFTA Award for O...|      2604|
|70547608|armstrong|  2|  9|BBC's 100 Greates...|      1095|
|34832338|armstrong|  1|  9|BAFTA Award for B...|      1311|
|  171049|     moon|  1| 21|        B. J. Thomas|       855|
| 3160315|armstrong|  1|  9|         BBC Weather|       987|
|14597703|armstrong|  1|  9| B Sides and C Sides|        95|
|  451011|  landing|  1|  8|B-17, Queen of th...|       576|
| 1235674|     moon|  1| 21|BAFTA Award for B...|      4094|
| 1604479|     moon|  1|

## BM25


In [44]:
import math


def compute_bm25(
    tf: int,
    df: int,
    doc_length: int,
    avg_dl: float,
    N: int,
    k1:float = 1.0,
    b:float = 0.75
) -> float:
    # BM25 score for one term in one document:
    # BM25 = log((N - df + 0.5)/(df + 0.5)) * ((tf * (k1 + 1)) / (tf + k1 * ((1 - b) + b*(doc_length/avg_dl))))
    idf = math.log((N - df + 0.5) / (df + 0.5))
    numerator = tf * (k1 + 1)
    denominator = tf + k1 * ((1 - b) + b * (doc_length / avg_dl))
    return idf * (numerator / denominator)

In [45]:
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType


def bm25_udf(
    k1: float,
    b: float,
    avg_dl: float,
    N: int
):
    return udf(
        lambda tf, df, doc_length: (
            float(
                math.log((N - df + 0.5) / (df + 0.5))
                * ((tf * (k1 + 1)) / (tf + k1 * ((1 - b) + b * (doc_length / avg_dl))))
            )
            if tf is not None and df is not None and doc_length is not None
            else 0.0
        ),
        FloatType(),
    )

#### Register the BM25 UDF

In [46]:
k1 = 1.0
b = 0.75
bm25 = bm25_udf(k1, b, avg_dl, N)

In [47]:
query_index = query_index.withColumn(
    "bm25", bm25(F.col("tf"), F.col("df"), F.col("doc_length"))
)
print(query_index.count())
query_index.show()

+--------+---------+---+---+--------------------+----------+----------+
|  doc_id|     term| tf| df|           doc_title|doc_length|      bm25|
+--------+---------+---+---+--------------------+----------+----------+
| 1467695|  landing|  1|  8|   BBC controversies|      7370|0.64472777|
| 4082240|  landing|  1|  8|BBC Allied Expedi...|       213|  5.802819|
| 9278580|armstrong|  1|  9|BC Junior A Lacro...|       622| 3.8882837|
|13495191|  landing|  1|  8|         B. G. Henry|       563| 4.1709514|
|67323043|  landing|  1|  8|       B&D Australia|       219| 5.7641582|
| 9237209|armstrong|  1|  9|BAFTA Award for O...|      2604| 1.5427063|
|70547608|armstrong|  2|  9|BBC's 100 Greates...|      1095|  4.366089|
|34832338|armstrong|  1|  9|BAFTA Award for B...|      1311| 2.5437808|
|  171049|     moon|  1| 21|        B. J. Thomas|       855| 2.7103758|
| 3160315|armstrong|  1|  9|         BBC Weather|       987| 3.0377252|
|14597703|armstrong|  1|  9| B Sides and C Sides|        95| 6.5

                                                                                

#### Sum the BM25 score for each document (if more than one query term matches)

In [48]:
scores = (
    query_index.groupBy("doc_id", "doc_title")
    .agg({"bm25": "sum"})
    .withColumnRenamed("sum(bm25)", "bm25_score")
)
print(scores.count())
scores.show()

+--------+--------------------+------------------+
|  doc_id|           doc_title|        bm25_score|
+--------+--------------------+------------------+
|  171049|        B. J. Thomas|2.7103757858276367|
| 3061257|       BBC Breakfast|1.6011868715286255|
|62288674|         BDC (group)| 6.724920749664307|
| 1235674|BAFTA Award for B...| 0.872083842754364|
|  336647|             BC Rail|1.4255017042160034|
| 5301416|BBC Radio 1's Big...|0.7367796301841736|
| 7547396|              BAR 01| 3.915747880935669|
|14597703| B Sides and C Sides| 6.526942729949951|
| 1467695|   BBC controversies|0.6447277665138245|
| 9385828|B movies since th...| 1.172965168952942|
| 1604479|             BD Wong|1.8398765325546265|
| 9237209|BAFTA Award for O...|2.8102774620056152|
|13453391|BBC Studios Natur...|0.5210062861442566|
|33244357|         B InTune TV|2.6557371616363525|
|  322487|     B. J. Armstrong| 8.640856742858887|
| 1607443|           B. Kliban|3.9869720935821533|
|34832338|BAFTA Award for B...|

#### Retrieve top 10 documents by BM25 score

In [49]:
top_docs = scores.orderBy(F.col("bm25_score").desc()).limit(10)
print(top_docs.count())
top_docs.show()

+--------+--------------------+------------------+
|  doc_id|           doc_title|        bm25_score|
+--------+--------------------+------------------+
|67323043|       B&D Australia|10.388636112213135|
|  322487|     B. J. Armstrong| 8.640856742858887|
|62288674|         BDC (group)| 6.724920749664307|
|14597703| B Sides and C Sides| 6.526942729949951|
| 4082240|BBC Allied Expedi...| 5.802818775177002|
|70547608|BBC's 100 Greates...|   4.3660888671875|
|13495191|         B. G. Henry|4.1709513664245605|
|  451011|B-17, Queen of th...| 4.127835273742676|
| 1745369|B-Sides & Raritie...| 4.008989334106445|
| 1607443|           B. Kliban|3.9869720935821533|
+--------+--------------------+------------------+



In [58]:
print("Query: " + query_text)
top = top_docs.collect()
print("Top Documents (doc_id, title, BM25 score):")
for row in top:
    print(
        f"\t{row['doc_id']:<10}\t{row['doc_title'][:30]:<25}\t{row['bm25_score']:.2f}"
    )

In [59]:
spark.stop()