In [None]:
import sys
import os
sys.path.append(os.path.abspath(".."))

from elastic import elastic_session, ScrollingCorpus, ElasticDocument
from sentence_transformers import SentenceTransformer
from sentence import doc_to_sentences

In [None]:
session = elastic_session("arxiv-index", "../credentials.json", "../http_ca.crt")
doc = ElasticDocument(session, 0, text_path="article")
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
sentences = doc_to_sentences(doc, model)

In [None]:
from itertools import starmap, pairwise
from rich.table import Table
from rich.console import Console
from rich.markdown import Markdown

def print_pairs(sentences):
    console = Console()
    console.clear()

    table = Table()
    table.add_column("Sentence Pair")
    table.add_column("Similarity", vertical="top")

    for thing in starmap(lambda x, y: (x,y,x.sim(y)), pairwise(sentences)):
        mytext = f'''
- {thing[0].text.strip()}


- {thing[1].text.strip()}

---
        '''
        table.add_row(Markdown(mytext), "\n\n"+str(thing[2]))

    console.print(table)

sentences = sentences[:20]
print_pairs(sentences)







In [None]:
from sentence import SimilarityPair, SentenceChain, Sentence, SentenceLike
from itertools import chain

def iterative_merge(sentences: list[SentenceLike],*, threshold: float, round_limit: int | None = 1, pooling_method="average"):
    pairs = [SimilarityPair.from_sentences(s1, s2) for s1, s2 in pairwise(sentences)]

    #No more merging can happen
    if not any(filter(lambda x: x.sim > threshold, pairs)):
        return sentences

    chains = []

    for i, pair in enumerate(pairs):
        if pair.sim >= threshold: #Add to the chain
            if i == 0:
                chains.append([pair.s1, pair.s2])
            else:
                #We have already examined s1
                chains[-1] += [pair.s2]

        else: #Create new chain for this sentence
            if i == 0:
                chains.append([pair.s1, pair.s2])
            else:
                #We have already examined s1
                chains.append([pair.s2])

    result = [SentenceChain(c, pooling_method) for c in chains]
    
    if round_limit is None:
        return iterative_merge(result, threshold, None)
    elif round_limit > 1:
        return iterative_merge(result, threshold, round_limit-1)
    else:
        return result

In [None]:
print_pairs(merged)