In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install -U sentence_transformers
!pip install --user annoy
!pip install -q findspark



In [3]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!cp /content/drive/MyDrive/spark/spark-3.1.2-bin-hadoop3.2.tgz . 
!tar xf spark-3.1.2-bin-hadoop3.2.tgz

In [4]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "spark-3.1.2-bin-hadoop3.2"

In [5]:
import findspark
findspark.init()
findspark.find()

'spark-3.1.2-bin-hadoop3.2'

In [6]:
import nltk
nltk.download('punkt')
from sentence_transformers import SentenceTransformer, util
# model = SentenceTransformer('all-MiniLM-L6-v2') 
model_asym = SentenceTransformer('msmarco-distilbert-base-tas-b')
model_sym = SentenceTransformer('all-MiniLM-L6-v2') 

print("Max Sequence Length:", model_asym.max_seq_length)
print("Max Sequence Length:", model_sym.max_seq_length)
symmetricSearch = False

if(symmetricSearch):
  model = model_sym
else:
  model = model_asym

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Max Sequence Length: 512
Max Sequence Length: 256


In [7]:
import pyspark
from pyspark.sql import SparkSession

from pyspark.sql.functions import monotonically_increasing_id 
from pyspark.sql.functions import col, pandas_udf
import pyspark.sql.functions as f
from pyspark.sql.types import *

import pandas as pd
from annoy import AnnoyIndex

In [8]:
spark = SparkSession.builder.master("local") \
                    .appName('sbert_encode') \
                    .getOrCreate()
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")
# spark.conf.set("spark.executor.memory", "5g")
sc = spark.sparkContext

print(sc._conf.get('spark.driver.memory'))


None


In [9]:
def map_encode(s):
     sentences = nltk.tokenize.sent_tokenize(s[1])
     sentences = [x.replace('\n', '') for x in sentences]
     sentences = [x.replace('"', '') for x in sentences]
     corpus_embeddings = model.encode(sentences)
     coll = []
     return [os.path.basename(s[0]),(sentences,corpus_embeddings)]

In [10]:
rdd = sc.wholeTextFiles("/content/drive/MyDrive/spark/data/set1/*.story")
rdd = rdd.map(map_encode)
rdd = rdd.flatMapValues(lambda x:zip(x[0],x[1]))
rdd = rdd.map( lambda  x: (x[0],x[1][0], x[1][1].tolist()))

In [11]:
news_df = spark.createDataFrame(rdd)
news_df = news_df.select("*").withColumn("id", monotonically_increasing_id())
news_df = (
    news_df
    .withColumnRenamed("_1","file")
    .withColumnRenamed("_2","text")
    .withColumnRenamed("_3","features")
)

lookupdf = news_df.drop("features")

news_df.show(10,truncate=True) 
news_df.printSchema()

+--------------------+--------------------+--------------------+---+
|                file|                text|            features| id|
+--------------------+--------------------+--------------------+---+
|000beb8706bc2ad4c...|By Lucy Osborne a...|[0.05371571332216...|  0|
|000beb8706bc2ad4c...|Sir Irvine Patnic...|[0.06469271332025...|  1|
|000beb8706bc2ad4c...|He was criticised...|[0.03838087990880...|  2|
|000beb8706bc2ad4c...|Rest in peace: Si...|[0.11042502522468...|  3|
|000beb8706bc2ad4c...|A statement from ...|[0.09176587313413...|  4|
|000beb8706bc2ad4c...|He was a much lov...|[0.03039967268705...|  5|
|000beb8706bc2ad4c...|They had two chil...|[-0.0323729217052...|  6|
|000beb8706bc2ad4c...|He was awarded an...|[0.02839423529803...|  7|
|000beb8706bc2ad4c...|Sir Irvine was vi...|[0.07804574072360...|  8|
|000beb8706bc2ad4c...|@highlightSir Irv...|[0.02460888400673...|  9|
+--------------------+--------------------+--------------------+---+
only showing top 10 rows

root
 |-

In [12]:

lookupdf.show(10,truncate=True) 
lookupdf.printSchema()

+--------------------+--------------------+---+
|                file|                text| id|
+--------------------+--------------------+---+
|000beb8706bc2ad4c...|By Lucy Osborne a...|  0|
|000beb8706bc2ad4c...|Sir Irvine Patnic...|  1|
|000beb8706bc2ad4c...|He was criticised...|  2|
|000beb8706bc2ad4c...|Rest in peace: Si...|  3|
|000beb8706bc2ad4c...|A statement from ...|  4|
|000beb8706bc2ad4c...|He was a much lov...|  5|
|000beb8706bc2ad4c...|They had two chil...|  6|
|000beb8706bc2ad4c...|He was awarded an...|  7|
|000beb8706bc2ad4c...|Sir Irvine was vi...|  8|
|000beb8706bc2ad4c...|@highlightSir Irv...|  9|
+--------------------+--------------------+---+
only showing top 10 rows

root
 |-- file: string (nullable = true)
 |-- text: string (nullable = true)
 |-- id: long (nullable = false)



In [13]:
feature = news_df.take(3)[0][2]
len(feature)

384

In [14]:
# print(len(news_df.index)) 
print(news_df.count())
embeddingSize = 0
if(symmetricSearch):
  embeddingSize = 384
  ann = AnnoyIndex(embeddingSize, 'angular')
else:
  embeddingSize = 768
  ann = AnnoyIndex(embeddingSize, 'dot')

for row in news_df.select("id","features").collect():
  ann.add_item(row.id, row.features)

# news_df[news_df.columns[0]].count() 

617


In [None]:
ann.build(100)
ann.save("/content/drive/MyDrive/spark/news_tree.ann")

In [16]:
ann.save("/content/drive/MyDrive/spark/news_tree.ann")

True

In [24]:
query = 'Pamela Hobley went missing in which year'
query_embeddings = model.encode([query])

# print(query_embeddings)
closestitem = 10
response = ann.get_nns_by_vector(query_embeddings[0], closestitem, search_k=-1, include_distances=False)
df_output = lookupdf[lookupdf['id'].isin(response)]
df_output.show(truncate=True)

# for res in response:
#   lookupdf.where(news_df.id == res).collect()
# news_df.filter("id" == str(response[0])).show(false)
# print( news_df.loc[df['id'] == response[0]])


AttributeError: ignored

In [23]:
#cross encoder
from sentence_transformers import CrossEncoder

model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
if(symmetricSearch):
  model = model_sym
else:
  model = model_asym

pair = []
for rows in df_output.collect():
   pair.append((rows["text"],query))
scores = model.predict(pair)

print(pair)

print(scores)


('Patricia Spencer, 17, and her 15-year-old friend Pamela Hobley went missing on October 31, 1969 after going to a high school football game in Oscoda, Michigan.', 'Pamela Hobley went missing in which year')
('Scroll down for video\xa0Pamela Hobley, 15, (left) and Patricia Spencer, 17, (right) went missing on October 31, 1969 after going to a high school football game in Oscoda, MichiganThe girls, who did not usually hang out together, relatives said, left the school early on Halloween to go to a party but never arrived\xa0Oscoda High School was subject to a bomb threat that day and family members believed that the girls may have left school early that afternoon.', 'Pamela Hobley went missing in which year')
("Pamela's sister, Mary Buehrle, told\xa048 Hours' Crimesider: 'I can't move on.", 'Pamela Hobley went missing in which year')
('She said her mother, a single mom, was told by Pamela that she was going to a Halloween party with her friends and boyfriend, to whom she had recently be

In [None]:
cols = ['id','_3']
annoydata = df_index.select(*cols)

annoydata.show()

In [None]:
// version: 0.1.4
// spark.executor.instances = 100
// spark.executor.memory = 8g
// spark.driver.memory = 8g
fraction = 0.00086 // for about 100k samples
numTrees = 2
numPartitions = 100
annoyModel = new Annoy().setFraction(fraction).setNumTrees(numTrees).fit(dataset)
annoyModel.saveAsAnnoyBinary("/hdfs/path/to/index", numPartitions)