In [1]:
from transformers import AutoModel
from transformers import AutoTokenizer
from pathlib import Path
import torch

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3', trust_remote_code=True)
model = AutoModel.from_pretrained('BAAI/bge-m3', trust_remote_code=True)
model.eval()
model.to('cuda')
model.half()

book = '109 East Palace'
book_file_path = Path('..') / 'data' / 'bookcompanion'  / f'{book}.txt'
book_content = book_file_path.read_text(encoding='utf-8')
print(f"{book} has {len(book_content)} characters.")

109 East Palace has 932925 characters.


In [2]:
test_text = "The quick brown fox jumps over the lazy dog."
print(len(tokenizer(test_text, return_tensors='pt')['input_ids'][0]))
print(len(tokenizer(test_text, return_tensors='pt', add_special_tokens=False)['input_ids'][0]))



15
13


In [None]:
# import nltk
# nltk.download('punkt')
# nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Fergons\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Fergons\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt_tab.zip.


True

In [325]:
from nltk import sent_tokenize
           
def sentence_chunking(input_text: str, tokenizer: callable):
    sentences = sent_tokenize(input_text)
    last_offset_index = 0
    final_chunk = []
    spans = []
    for sentence in sentences:
        tokens = tokenizer(sentence, return_tensors='pt', add_special_tokens=False)
        input_ids = tokens['input_ids'][0]
        final_chunk.extend(input_ids)
        spans.append((last_offset_index, last_offset_index + len(input_ids)))
        last_offset_index += len(input_ids)
    
    return {'sentences': sentences, 'input_ids': final_chunk, 'spans': spans}
   



def late_chunking(
    model_output: 'BatchEncoding', span_annotation: list, max_length=None
):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go bejond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
      
        outputs.append(pooled_embeddings)

    return outputs


In [326]:
import numpy as np
input_text = book_content
# determine chunks
out = sentence_chunking(input_text, tokenizer)
sentences, tokens, spans = out['sentences'], out['input_ids'], out['spans']

In [327]:
for chunk, span in zip(sentences[68:75], spans[68:75]):
   print('---')
   print(chunk)
   x,y = span
   print(tokenizer.decode(tokens[x:y]))

---
As the gatekeeper to Los Alamos, she presented herself as a peculiarly compelling witness to history, registering the full scope of the momentous change and moral upheaval the scientists’ work unleashed.
As the gatekeeper to Los Alamos, she presented herself as a peculiarly compelling witness to history, registering the full scope of the momentous change and moral upheaval the scientists’ work unleashed.
---
She was not objective in any real sense, but for that matter, neither am I.
She was not objective in any real sense, but for that matter, neither am I.
---
She was smitten with Robert Oppenheimer from the moment they met and unreservedly embraced both him and his brilliant crew of scientists, including my grandfather, whom she liked and admired.
She was smitten with Robert Oppenheimer from the moment they met and unreservedly embraced both him and his brilliant crew of scientists, including my grandfather, whom she liked and admired.
---
But as an intelligent, articulate, and k

In [328]:
print(tokens[0:10])
print(spans[0:5])

[tensor(9804), tensor(13055), tensor(57766), tensor(6), tensor(56959), tensor(17777), tensor(14452), tensor(33996), tensor(1371), tensor(4878)]
[(0, 84), (84, 122), (122, 145), (145, 149), (149, 158)]


In [345]:
def split_doc_to_chunks(tokens: list[int], spans: list[(int, int)], max_length: int):
    """
        Assembles the max possible chunks (up to (max_length - 2) of tokens) from whole document tokens
    while respecting max_length of model's context and sentence boundries or other boundaries set by spans.
    (e.g. last sentence that fits whole into the context will be included, the next sentece is going to end up in the next chunk).
    tokens: whole document tokens
    spans: list of (start, end) tuples
    max_length: max number of tokens per chunk usually set it to model's context - 2 to leave room for [CLS] and [SEP] tokens
    """
    chunks = []
    chunk_spans = []
    last_chunk_span = (0,0)
    for i, span in enumerate(spans):
        start, end = span
        if (end - last_chunk_span[1] > max_length):
            new_chunk_span = (last_chunk_span[1], spans[i-1][1])
            chunk_spans.append(new_chunk_span)
            last_chunk_span = new_chunk_span
        if end >= len(tokens):
            new_chunk_span = (last_chunk_span[1], end)
            chunk_spans.append(new_chunk_span)
            
    return {'batched_input_ids': [tokens[start:end] for (start, end) in chunk_spans], 'spans':chunk_spans}  

In [361]:
out = split_doc_to_chunks(tokens, spans, max_length=256)
chunks = out['batched_input_ids']
chunk_spans = out['spans']
assert all(map(lambda x: x <= 8192, map(lambda x: x[1]-x[0], chunk_spans))), "All chunks must respect max_length constraints"

print(f'{len(chunks)=}')
print(chunk_spans)
print(len(tokens))

len(chunks)=995
[(0, 250), (250, 260), (260, 554), (554, 792), (792, 1025), (1025, 1219), (1219, 1474), (1474, 1676), (1676, 1914), (1914, 2160), (2160, 2393), (2393, 2631), (2631, 2849), (2849, 3099), (3099, 3298), (3298, 3549), (3549, 3783), (3783, 4039), (4039, 4261), (4261, 4509), (4509, 4721), (4721, 4951), (4951, 5178), (5178, 5430), (5430, 5652), (5652, 5908), (5908, 6149), (6149, 6398), (6398, 6612), (6612, 6860), (6860, 7079), (7079, 7319), (7319, 7535), (7535, 7725), (7725, 7854), (7854, 8069), (8069, 8301), (8301, 8551), (8551, 8762), (8762, 8962), (8962, 9200), (9200, 9453), (9453, 9686), (9686, 9900), (9900, 10134), (10134, 10387), (10387, 10598), (10598, 10854), (10854, 11108), (11108, 11355), (11355, 11609), (11609, 11829), (11829, 12068), (12068, 12292), (12292, 12510), (12510, 12747), (12747, 12944), (12944, 13162), (13162, 13376), (13376, 13626), (13626, 13881), (13881, 14137), (14137, 14352), (14352, 14592), (14592, 14848), (14848, 15092), (15092, 15339), (15339, 155

In [362]:
sentence_index_window = slice(0,len(sentences)+1)
# chunk before
batch_size = 8
embeddings_traditional_chunking = []
with torch.no_grad():
    for i in range(0, len(sentences[sentence_index_window]), batch_size):
        tokenized_chunks = tokenizer(sentences[sentence_index_window][i:i+batch_size], return_tensors='pt', padding=True, truncation=True).to('cuda')
        cls_token_outputs=model(**tokenized_chunks).last_hidden_state[:, 0].detach().cpu().numpy()
        embeddings_traditional_chunking.extend(cls_token_outputs.tolist())
embeddings_traditional_chunking = [np.array(embd) for embd in embeddings_traditional_chunking]
        

In [363]:
# chunk afterwards (context-sensitive chunked pooling)
max_length = 8192
embeddings = []
doc_ids = {i for i, (doc_start, doc_end) in enumerate(chunk_spans) for (sentence_start, sentence_end) in spans[sentence_index_window] if sentence_end <= doc_end and sentence_start >= doc_start}
print(f'{doc_ids=}')
num_sent = []
with torch.no_grad():
    for i in doc_ids:
        inputs = tokenizer.prepare_for_model(chunks[i], return_tensors='pt', padding=True, truncation=True, max_length=tokenizer.model_max_length).to('cuda')
        inputs['input_ids'] = inputs['input_ids'].unsqueeze(0)
        inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(0)
        model_output = model(**inputs)
        print(chunk_spans[i])
        # +1 for [CLS] token offset in model outputs
        doc_sentence_spans = [(start-chunk_spans[i][0]+1, end-chunk_spans[i][0]+1) for start,end in spans[sentence_index_window] if start >= chunk_spans[i][0] and end <= chunk_spans[i][1]]
        # print(f'{model_output[0].shape=} vs {len(chunks[i])=}')
        num_sent.append(len(doc_sentence_spans))
        embeddings.extend(late_chunking(model_output, [doc_sentence_spans])[0])

print(f'Processed sentences: {sum(num_sent)}')
print(f'{len(embeddings)=}')

doc_ids={0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 2

In [365]:
print(len(embeddings_traditional_chunking))
print(len(embeddings))

print(f'{embeddings_traditional_chunking[0].shape=}')
print(f'{embeddings[0].shape=}')

5818
5818
embeddings_traditional_chunking[0].shape=(1024,)
embeddings[0].shape=(1024,)


In [374]:
import numpy as np

cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
q = 'What is the name of the person who was a in love with Oppenheimer?'
tokenized_q = tokenizer(q, return_tensors='pt', padding=True, truncation=True).to('cuda')
q_embedding = model(**tokenized_q).last_hidden_state[:, 0].squeeze(0).detach().cpu().numpy()
q_embedding_late_chunk = late_chunking(model(**tokenized_q), [[(0, len(q))]], max_length=8192)[0][0] #(shape: [num_batches, queries_in_batch, embedding_of_query])

print(f'{len(sentences[sentence_index_window])}')
print(f'{len(embeddings)}')
print(f'{len(embeddings_traditional_chunking)}')

assert len(sentences[sentence_index_window]) == len(embeddings) == len(embeddings_traditional_chunking), 'len(sentences[sentence_index_window]) != len(embeddings) != len(embeddings_traditional_chunking)'

# For small text chunks manual inspection uncomment code bellow
# for sent, new_embeddings, trad_embeddings in zip(sentences[sentence_index_window], embeddings, embeddings_traditional_chunking):
#     print(f'similarity_new("{q}", "{sent}"):', cos_sim(q_embedding_late_chunk, new_embeddings))
#     print(f'similarity_trad("{q}", "{sent}"):', cos_sim(q_embedding, trad_embeddings))


5818
5818
5818


In [375]:
late_embeddings = torch.stack([torch.from_numpy(embd) for embd in embeddings])
trad_embeddings = torch.stack([torch.from_numpy(embd) for embd in embeddings_traditional_chunking])
print(f'{late_embeddings.shape=}')
print(f'{trad_embeddings.shape=}')


late_q_embedding = torch.from_numpy(q_embedding_late_chunk)
trad_q_embedding = torch.from_numpy(q_embedding)


late_topk = torch.topk(torch.cosine_similarity(late_q_embedding, late_embeddings), 20)
trad_topk = torch.topk(torch.cosine_similarity(trad_q_embedding, trad_embeddings), 5)


print('--LATE TOPK--')
print(f'{q}:')
for i,score in enumerate(late_topk.values):
    print(f'late: {sentences[sentence_index_window][late_topk.indices[i]]}, sim: {score}')
print(f'---')

print('--TRAD TOPK--')
print(f'{q}:')
for i,score in enumerate(trad_topk.values):
    print(f'trad: {sentences[sentence_index_window][trad_topk.indices[i]]}, sim: {score}')
print(f'---')





late_embeddings.shape=torch.Size([5818, 1024])
trad_embeddings.shape=torch.Size([5818, 1024])
--LATE TOPK--
What is the name of the person who was a in love with Oppenheimer?:
late: Oppenheimer, who was besotted, called her “Golden.” His close-knit circle was less charitable, considering the poetic young Wunderkind—who was so bereft after his mother’s death in 1930 that he described himself to a friend as “the loneliest man in the world”—easy prey for a calculating woman., sim: 0.81005859375
late: “Oppenheimer stretched me,” recalled Bob Wilson., sim: 0.8076171875
late: In one way or another, everyone became caught up in the Oppenheimer charisma., sim: 0.80126953125
late: Oppenheimer had the powerful charisma of those who know from birth that they are especially gifted., sim: 0.79833984375
late: She had fallen for Oppenheimer almost as quickly as Dorothy McKibbin had., sim: 0.796875
late: The faculty wives who had doted on Oppie, who was known for bringing flowers to dinner, took an in