## reference implementation


In [2]:
from transformers.utils import TRANSFORMERS_CACHE
print(f"Transformers cache directory: {TRANSFORMERS_CACHE}")

Transformers cache directory: /Users/jan/.cache/huggingface/hub


In [None]:
def late_chunking(
    model_output: 'BatchEncoding', span_annotation: list, max_length=None
):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go bejond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs

def chunk_by_sentences(input_text: str, tokenizer: callable):
    """
    Split the input text into sentences using the tokenizer
    :param input_text: The text snippet to split into sentences
    :param tokenizer: The tokenizer to use
    :return: A tuple containing the list of text chunks and their corresponding token spans
    """
    inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
    punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
    sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
    token_offsets = inputs['offset_mapping'][0]
    token_ids = inputs['input_ids'][0]
    chunk_positions = [
        (i, int(start + 1))
        for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
        if token_id == punctuation_mark_id
        and (
            token_offsets[i + 1][0] - token_offsets[i][1] > 0
            or token_ids[i + 1] == sep_id
        )
    ]
    chunks = [
        input_text[x[1] : y[1]]
        for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    span_annotations = [
        (x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    return chunks, span_annotations

In [None]:
input_text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."

# determine chunks
chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)
print('Chunks:\n- "' + '"\n- "'.join(chunks) + '"')

# chunk afterwards (context-sensitive chunked pooling)
inputs = tokenizer(input_text, return_tensors='pt')
model_output = model(**inputs)
embeddings = late_chunking(model_output, [span_annotations])[0]

## test implementation

In [None]:
import sys
import os

# Get the current working directory
current_dir = os.getcwd()

# Assuming your notebook is in a subdirectory of the main project folder,
# go up two levels to reach the project root
project_root = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
sys.path.append(project_root)


from app.doc_processing import process_doc, ProcessDocConfig
from app.vectorstore import get_chroma_store_as_retriever, add_docs_to_store
import os

pdf_path = '/Users/jan/Desktop/advanced_rag/dev_tests/test_data/el_nino.pdf'
config = ProcessDocConfig(
        tag="test_late_chunking",
        filepath=pdf_path,
        late_chunking=False
    )

    # Process the document
processed_docs = process_doc(config)

In [46]:
from app.doc_processing.late_chunking import apply_late_chunking
docs = apply_late_chunking(processed_docs)

In [None]:
from app.vectorstore.experimental import get_faiss_store_as_retriever, add_docs_to_faiss_store
retriever = get_faiss_store_as_retriever()
add_docs_to_faiss_store(retriever, docs)

In [56]:
test_query = "What is El Niño?"
results = retriever.get_relevant_documents(test_query)

  results = retriever.get_relevant_documents(test_query)
