In [None]:
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer, CrossEncoder, models 
import os, torch, re, html, tqdm
import pandas as pd 
from unidecode import unidecode
import unicodedata

def normalize_text(text):
    text = html.unescape(text)
    text = unicodedata.normalize('NFKD', text)
    text = unidecode(text)
    text = re.sub(r"\s+", " ", text)
    return text
def clean_text(s):
  s = normalize_text(s)
  s = re.sub(r"</i>|<NOTE>|NONLATINALPHABET|<i>"," ",s) # \d+\^PAGE[S]*\^MISSING"
  s = re.sub(r"\s+"," ",s)
  return s.strip(" ").lower()


model_checkpoint = 'emanjavacas/MacBERTh'
model_round = "ALL"
epoch = "1"
state_dict_path = f"EEPS_{model_round}_MacBERTh_Epoch{epoch}"

main_dir = '/Users/amycweng/SERMONS_APP/app'

In [None]:
word_embedding_model = models.Transformer(model_checkpoint, max_seq_length=128)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "mean")
bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
bi_encoder.load_state_dict(torch.load(f'{main_dir}/static/data/{state_dict_path}.pt',map_location=torch.device('cpu')))
cross_encoder = CrossEncoder(f"{main_dir}/static/data/EEPS_cross-encoder_emanjavacas_MacBERTh/checkpoint-1000")
basic_encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") 

In [9]:
persist_directory=f'{main_dir}/static/data/VECTOR_DB'
client_settings = Settings(is_persistent= True, persist_directory= persist_directory, anonymized_telemetry=False)
queryclient = chromadb.PersistentClient(path= persist_directory, settings= client_settings) 

In [58]:
data_folder = "/Users/amycweng/SERMONS_APP/db/data"
sermons = pd.read_csv(f"{data_folder}/sermons.csv",header=None)
tcpID_titles = {tcpID:title for tcpID, title in zip(sermons[0],sermons[3])}
sermons = pd.read_csv(f"{data_folder}/sermons_missing.csv",header=None)
tcpID_titles.update({tcpID:title for tcpID, title in zip(sermons[0],sermons[3])})
len(tcpID_titles)

5729

In [None]:
folder = '../../'
bible = {}
to_remove = {}
items = pd.read_csv(f"{folder}/EEPS/overly_vague.csv").to_dict(orient='records')
for entry in items:
  if entry['to_remove'] is True:
    to_remove[entry['verse_id']] = None

print('To Remove', len(to_remove))

b_versions = ['AKJV','ODRV','Geneva', 'Douay-Rheims', 'Tyndale', 'Wycliffe','Vulgate']
ODRV_books = pd.read_csv(f"{folder}/Bibles/ODRV.csv",header=None)
ODRV_books = set(ODRV_books[3])
for bname in b_versions:
    data = pd.read_csv(f"{folder}/Bibles/{bname}.csv",header=None)
    data = data.to_dict(orient="records")
    for entry in tqdm.tqdm(data):
        key = entry[0]
        if key in to_remove: continue
        v_id = key.split(" (")[0]
        text = entry[6]
        if re.search("Douay-Rheims",key):
            if entry[3] in ODRV_books: continue
        if len(text.split(" ")) < 200:
            bible[key] = normalize_text(f"{v_id}: {text}")

        parts = re.split(r'(?<=[\.\?]) (?=[A-Z])|(?<=[\!\:\;])', text)
        parts = [re.sub(r'\s+', ' ', p).strip() for p in parts if len(p.strip(" ")) > 0]
        if (len(parts[0].split(" ")) <= 5 or len(parts[-1].split(" ")) <= 5 or re.search(r"\&\w+\;",parts[0])):
            for pidx, p in enumerate(parts): continue
        elif len(parts) > 1:
            for pidx, p in enumerate(parts):
              p_id = f"{key} - {pidx}"
              if p_id in to_remove: continue
              if len(p) == 0: continue
              if re.search(r"\&\w+\;",p) or len(p.split(" ")) <= 5: continue
              bible[p_id] = normalize_text(f"Part {pidx+1} of {v_id}: {p}")

bible_verses = list(bible.values())
bible_ids = list(bible.keys())

bible_vectors = torch.load(f"{folder}/EEPS/Bibles_{state_dict_path}.pt",map_location='cpu')
bible_vectors = bible_vectors[:-1]

full_ids = []
full_verses = []
full_vectors = []
for idx, v_id in enumerate(bible_ids):
    if " - " not in v_id: # a full verse
        full_ids.append(v_id)
        full_vectors.append(bible_vectors[idx])
        full_verses.append(bible_verses[idx])
print(len(full_ids),len(full_vectors),len(full_verses))
bible_ids, bible_verses = full_ids, full_verses
bible_vectors = full_vectors

To Remove 0


100%|██████████| 36702/36702 [00:02<00:00, 15783.42it/s]
100%|██████████| 14736/14736 [00:00<00:00, 17677.20it/s]
100%|██████████| 31090/31090 [00:01<00:00, 15859.67it/s]
100%|██████████| 35811/35811 [00:01<00:00, 24852.96it/s]
100%|██████████| 7954/7954 [00:00<00:00, 16134.86it/s]
100%|██████████| 9622/9622 [00:00<00:00, 18838.00it/s]
100%|██████████| 35809/35809 [00:02<00:00, 12922.86it/s]


In [6]:
import json 
save_full = {v_id: verse for v_id, verse in zip(bible_ids, bible_verses)}
with open('/Users/amycweng/SERMONS_APP/app/static/data/full_bibles.json','w+') as file: 
    json.dump(save_full,file)

In [10]:
bible_collection = queryclient.get_or_create_collection(name="Bible",metadata={"hnsw:space": "cosine"})
bible_collection.name

'Bible'

In [19]:
bible_batches = []
batch_size = 40000
for i in range(0, len(bible_ids), batch_size):
  bible_batches.append((bible_ids[i: i + batch_size],bible_vectors[i: i + batch_size],bible_verses[i:i+batch_size]))
for batchids, bvectors,batchtexts in tqdm.tqdm(bible_batches):
  bible_collection.upsert(
    embeddings=[b.tolist() for b in bvectors],
    ids=batchids,
    documents=batchtexts
  )

100%|██████████| 4/4 [05:10<00:00, 77.68s/it]


In [None]:
neg_sim_threshold = 0.7
pos_threshold = 0.6
pos_sim_threshold = 0.65
neg_threshold = 0.4
query = "the bond-woman and her son were cast out;"
q_embedding = bi_encoder.encode([query])
results = bible_collection.query(query_embeddings=q_embedding.tolist(), n_results=25,include=["distances"])
for hitlist, distances in zip(results['ids'],results['distances']): 
    scores = [1 - distances[vidx] for vidx in range(len(hitlist))]
    hitlist = [{'v_id':hit, 'score': scores[idx]} for idx, hit in enumerate(hitlist) if scores[idx] >= pos_sim_threshold]
    cross_inp = [[query, bible[hit['v_id']]] for hit in hitlist]
    if len(cross_inp) == 0: continue
    cross_scores = cross_encoder.predict(cross_inp)
    for i in range(len(cross_scores)):
        hitlist[i]['cross-score'] = cross_scores[i]
    hitlist = sorted(hitlist, key=lambda x: x['cross-score'], reverse=True)

    for hit in hitlist:
        v_id = hit['v_id']
        cross_score = hit['cross-score']
        sim_score = hit['score']
        if (cross_score >= neg_threshold and sim_score >= neg_sim_threshold) or (cross_score >= pos_threshold):
            cross_score = round(cross_score, 3)
            sim_score = round(sim_score,3)
            print(v_id, cross_score, sim_score, bible[v_id])

Galatians 4.30 (ODRV) 0.838 0.673 Galatians 4.30: But what saith the Scripture? Cast out the bond-woman and her sonne. For the sonne of the bond-woman shal not be heire with the sonne of the free-woman.


In [None]:
title_collection = queryclient.get_or_create_collection(name="Titles",metadata={"hnsw:space": "cosine"})
title_vectors = torch.load(f"{folder}/EEPS/titles_all-mpnet-base-v2.pt",map_location='cpu')
title_collection.upsert(
  embeddings=title_vectors.tolist(),
  ids=list(tcpID_titles.keys())
)

In [109]:
title_collection.modify(metadata={
        "hnsw:M": 256,              # Max connectivity
        "hnsw:ef_construction": 2000,  # Thorough graph build
        "hnsw:ef": 10000,           # Search deep enough
    })

In [None]:
import numpy as np 
query = "marriage"

pos_sim_threshold = 0.35
q_embedding = basic_encoder.encode([query])
all_data = title_collection.get(include=["embeddings"])
all_embeddings = np.array(all_data["embeddings"])  # shape: [5729, dim]
all_ids = all_data["ids"]
query_embedding = np.array(q_embedding)
similarities = np.dot(all_embeddings, query_embedding.T).flatten()
sorted_indices = np.argsort(similarities)[::-1]
sorted_sim = similarities[sorted_indices]
sorted_ids = [all_ids[i] for i in sorted_indices]
hitlist = [{'v_id':hit, 'score': score} for hit,score in zip(sorted_ids, sorted_sim) if score >= pos_sim_threshold]
cross_inp = [[query, tcpID_titles[hit['v_id']]] for hit in hitlist]
if len(cross_inp) > 0:
    cross_scores = cross_encoder.predict(cross_inp)
    for i in range(len(cross_scores)):
        hitlist[i]['cross-score'] = cross_scores[i]
    hitlist = sorted(hitlist, key=lambda x: x['cross-score'], reverse=True)
    for hit in hitlist:
        v_id = hit['v_id']
        cross_score = round(hit['cross-score'],3)
        sim_score = round(hit['score'],3)
        print(v_id, cross_score, sim_score, tcpID_titles[hit['v_id']])

In [19]:
import re, json
from tqdm import tqdm 
import math, re
import torch
def add_to_db(era):
  corpus = {} 
  output = f"{era}_margin"
  folder = "/Users/amycweng/DH/Early-Modern-Sermons/assets"
  for fp in tqdm(os.listdir(f"{folder}/unique")):
      if re.search(era,fp):
        if "margin" in output:
          if not re.search(f'margin',fp): continue
        with open(f"{folder}/unique/{fp}","r") as file:
          r = json.load(file)
          for k, v in r.items():
            if k not in corpus:
              if len(v[0]) == 0: continue
              corpus[k] = (v[0],v[1],None)
            else:
              for loc in v[1]:
                corpus[k][1].append(loc)
  rel_batches = []
  idx_to_p = {}
  batch_size = 40000
  idx = 0
  for i in range(0, len(corpus), batch_size):
    batch = []
    for p in list(corpus.keys())[i: i + batch_size]:
      original = corpus[p][0][0]
      idx_to_p[idx] = original
      batch.append((idx,original,list(set([tuple(c) for c in corpus[p][1]])),corpus[p][2]))
      idx += 1
    rel_batches.append(batch)
  print(sum([len(v) for v in rel_batches]))

  chroma_batches = {}
  batch_size = math.ceil(len(corpus)/200000) + 1
  batch_num = 0
  for i in range(0, len(rel_batches), batch_size):
    print(f"{output}_{batch_num}")
    collection = queryclient.get_or_create_collection(name=f"{output}_{batch_num}",metadata={"hnsw:space": "cosine"})
    for j in range(i,i+batch_size):
      if j >= len(rel_batches): break
      chroma_batches[j] = batch_num
    batch_num += 1
  print(chroma_batches)

  for bidx, batch in enumerate(rel_batches):
    p_embedding = torch.load(f"{data_folder}/embeddings/{output}_{bidx}",map_location=torch.device('cpu'))
    print(len(p_embedding))
    cidx = chroma_batches[bidx]
    collection = queryclient.get_collection(name=f"{output}_{cidx}")
    print(collection)
    docs = [";".join(["_".join(key) for key in b[2]]) for b in batch]
    collection.upsert(
      embeddings=p_embedding.tolist(),
      ids=[str(b[0]) for b in batch],
      documents= docs
    )
    print(f"finished inserting to my Chroma collection")