In [4]:
import chromadb
import json
import sys
import torch
sys.path.append('../')

In [5]:
client = chromadb.PersistentClient(path="/Users/amycweng/DH/db")

In [6]:
# only the texts 
eras = ["pre-Elizabethan","Elizabethan","Jacobean","Carolinian","CivilWar","Interregnum"]
extra_eras = ["Elizabethan2", "Jacobean2", "Carolinian2","Interregnum2"]
eras.extend(extra_eras)
collections = {}
for era in eras:
    # collections[era] = client.create_collection(name=era,metadata={"hnsw:space": "cosine"})
    collections[era] = client.get_collection(name=era)
    

In [None]:
# only the marginalia  
# pre1660marginalia = client.create_collection(name="pre1660marginalia",metadata={"hnsw:space": "cosine"})
pre1660marginalia = client.get_collection(name="pre1660marginalia")

In [None]:
for era, collection in collections.items(): 
    print(era, len(collection.get()['ids']))

In [None]:
with open(f'../assets/pre1660.json') as file:
    pre1660 = json.load(file)
preE,E,J,C,CW,IR = pre1660
eras = {"pre-Elizabethan":preE,"Elizabethan":E, "Jacobean":J, "Carolinian":C,"CivilWar":CW,"Interregnum":IR}
tcpID_era = {}
for era, era_dict in eras.items():
    for id_list in era_dict.values():
        for tcpID in id_list:
            tcpID_era[tcpID] = era

In [None]:
import re

def split_sentence(sentence):
    to_segment = [", but", ", while", ", let", ", they", ", NONLATINALPHABET",
                    ", then", ", yet", ", than", ', and yet', ', and though',
                    ', at least', ', and to', ', this be', ', for', ', therefore',
                    ', that', ', and we', ', and i ', ', when', ', and say', ', and this',
                    ', and then', ', and than', ', and they', ', i say', ', as the apostle',
                    ', otherwise', ', how', ', according', ', accordi^^', ', say',', and when',
                    ', and he', ', and she', ', he say', ', she say', ', lest', ', and where',
                    ', and how', ', and what', ', and there', ', and therefore', ', and thus',
                    ', and if', ', and because', ', and I ', ', he will', ', they will', ', she will']
    pattern = '|'.join(map(re.escape, to_segment))
    parts = re.split(pattern, sentence)

    matches = re.findall(pattern,sentence)
    if len(parts) == 1: return parts
    for idx, part in enumerate(parts):
        if idx == (len(parts) - 1): break
        conj = re.sub(", ", "",matches[idx])
        parts[idx] = part + " , "
        parts[idx + 1] = conj + parts[idx+1]
    return parts

In [None]:
class Sermons():
    def __init__(self,prefix):
      self.prefix = prefix

def get_docs(prefix):
  corpus = Sermons(prefix)
  with open(f'../assets/processed/{prefix}.json','r') as file:
      sent_id, lemmatized, chunks, fw_subchunks = json.load(file)
  corpus.sent_id = sent_id
  corpus.lemmatized = lemmatized
  corpus.sent_id_to_idx = {(tuple(x[0]),x[1]):idx for idx, x in enumerate(sent_id)}
  passages = []
  for id in corpus.sent_id:
      if prefix in id[0][0]:
          passage = corpus.lemmatized[corpus.sent_id_to_idx[(tuple(id[0]),id[1])]]
          passage = re.sub(r"[^A-Za-z\^\*,]"," ",passage)
          passage = re.sub(r"\s+"," ", passage).strip(" ")
          passage = passage.strip(" ")
          if len(passage.split(" ")) < 2: continue
          parts = split_sentence(passage)
          for part in parts:
              if len(part.split(" ")) < 3: continue
              passages.append(part)
  print("Passages:", len(passages))
  return passages

In [None]:
# collections['pre-Elizabethan'].get(ids=["A0_217167"])
collections['CivilWar'].get(where={"tcpID":"A67876"})

In [None]:
processed = {} 

In [None]:
def process(prefix,docs):
    vectors = torch.load(f"/Users/amycweng/DH/embeddings/{prefix}_corpus_embeddings_segmented.pth",map_location="cpu")
    with open(f'/Users/amycweng/DH/embeddings/{prefix}_ids.json') as file:
      ids = json.load(file)
    count = 0
    e, m, i,d = [],[],[],[]
    prev_tcpID = None
    s_count = 0
    for idx, label in enumerate(ids):
        tcpID, chunk_id, is_note = label[0]
        part_id = label[1]

        # check if the book is dated before 1660
        if tcpID not in tcpID_era: continue
        if int(tcpID[1:]) <= 67876: 
            count += 1 
            continue

        s_count += 1 
        if prev_tcpID is None:
            prev_tcpID = tcpID 

        if (tcpID != prev_tcpID): 
            if prev_tcpID not in processed:
                collection = collections[tcpID_era[prev_tcpID]] 
                collection.upsert(ids=i,embeddings=e,metadatas=m,documents=d)
                print('Processed',prev_tcpID,tcpID_era[prev_tcpID])
                processed[prev_tcpID] = True 
            prev_tcpID = tcpID
            e, m, i, d = [],[],[],[]
            s_count = 0
        elif (s_count > 0) and ((s_count % 20000) == 0):
            if len(e) > 0:
                collection = collections[tcpID_era[prev_tcpID]]
                collection.upsert(ids=i,embeddings=e,metadatas=m,documents=d)
                e, m, i, d = [],[],[],[]
                print('Processed part of',prev_tcpID,tcpID_era[prev_tcpID])

        count += 1
        e.append(vectors[idx].tolist())
        m.append({"tcpID": tcpID, 'chunk_id': chunk_id, 'is_note':is_note, 'part_id':part_id})
        i.append(f'{prefix}_{idx}')
        d.append(docs[idx])

    if len(i) > 0 and prev_tcpID not in processed:
        collection = collections[tcpID_era[prev_tcpID]]
        collection.add(ids=i,embeddings=e,metadatas=m,documents=d)
        print('Processed',prev_tcpID,tcpID_era[prev_tcpID])
    print(count)

In [None]:
prefix = "A9" 
data = get_docs(prefix)
process(prefix,data)
del data 

In [7]:
items =  collections['Elizabethan'].get(include=["metadatas"])

In [8]:
prefixes = []
for item in items["metadatas"]: 
    prefixes.append(item["tcpID"][:2])

In [9]:
from collections import Counter 
Counter(prefixes)

Counter({'A1': 431384,
         'A0': 173710,
         'B1': 27034,
         'A2': 10803,
         'B0': 10554,
         'A6': 10250})

In [10]:
collections["Elizabethan_A1"] = client.create_collection(name="Elizabethan_A1",metadata={"hnsw:space": "cosine"})

In [None]:
for idx, item in enumerate(items["metadatas"]): 
    if item["tcpID"][:2] == "A1": 
        embedding = collections["Elizabethan"].get(ids=[items["ids"][idx]])
        collections["Elizabethan_A1"].add