In [1]:
from transformers import AutoModel
from transformers import AutoTokenizer
from pathlib import Path
import torch

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3', trust_remote_code=True)
model = AutoModel.from_pretrained('BAAI/bge-m3', trust_remote_code=True)
model.eval()
model.to('cuda')
model.half()

book = '109 East Palace'
book_file_path = Path('..') / 'data' / 'bookcompanion'  / f'{book}.txt'
book_content = book_file_path.read_text(encoding='utf-8')
print(f"{book} has {len(book_content)} characters.")

109 East Palace has 932927 characters.


In [None]:
test_text = "The quick brown fox jumps over the lazy dog."
print(len(tokenizer(test_text, return_tensors='pt')['input_ids'][0]))
print(len(tokenizer(test_text, return_tensors='pt', add_special_tokens=False)['input_ids'][0]))



15
13


In [75]:

           
from nltk import sent_tokenize

def sentence_chunking(input_text: str, tokenizer: callable):
    sentences = sent_tokenize(input_text)
    last_offset_index = 0
    final_chunk = []
    spans = []
    for sentence in sentences:
        tokens = tokenizer(sentence, return_tensors='pt', add_special_tokens=False)
        input_ids = tokens['input_ids'][0]
        final_chunk.extend(input_ids)
        spans.append((last_offset_index, last_offset_index + len(input_ids)))
        last_offset_index += len(input_ids)
    return sentences, final_chunk, spans



def late_chunking(
    model_output: 'BatchEncoding', span_annotation: list, max_length=None
):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go bejond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs


In [77]:
import numpy as np
input_text = book_content
# determine chunks
chunks, tokenized, spans = sentence_chunking(input_text, tokenizer)

In [78]:
for chunk, span in zip(chunks[68:75], spans[68:75]):
   print('---')
   print(chunk)
   x,y = span
   print(tokenizer.decode(tokenized[x:y]))

---
As the gatekeeper to Los Alamos, she presented herself as a peculiarly compelling witness to history, registering the full scope of the momentous change and moral upheaval the scientists’ work unleashed.
As the gatekeeper to Los Alamos, she presented herself as a peculiarly compelling witness to history, registering the full scope of the momentous change and moral upheaval the scientists’ work unleashed.
---
She was not objective in any real sense, but for that matter, neither am I.
She was not objective in any real sense, but for that matter, neither am I.
---
She was smitten with Robert Oppenheimer from the moment they met and unreservedly embraced both him and his brilliant crew of scientists, including my grandfather, whom she liked and admired.
She was smitten with Robert Oppenheimer from the moment they met and unreservedly embraced both him and his brilliant crew of scientists, including my grandfather, whom she liked and admired.
---
But as an intelligent, articulate, and k

In [None]:
# chunk before
batch_size = 8
embeddings_traditional_chunking = []
with torch.no_grad():
    for i in range(0, 16, batch_size):
        tokenized_chunks = tokenizer(chunks[i:i+batch_size], return_tensors='pt', padding=True, truncation=True).to('cuda')
        outputs=model(**tokenized_chunks)[0].detach().cpu().numpy()
        embeddings_traditional_chunking.extend(outputs)
        print(outputs.shape)

In [44]:
# chunk afterwards (context-sensitive chunked pooling)
max_length = 8192
embeddings = []
with torch.no_grad():
    for i in range(0, len(chunks), batch_size):
        inputs = tokenizer(chunks[i:i+batch_size], return_tensors='pt', padding=True, truncation=True).to('cuda')
        model_output = model(**inputs)
        embeddings.extend(late_chunking(model_output, span_annotations[i:i+batch_size], max_length=max_length)[0].detach().cpu().numpy())

TypeError: cannot unpack non-iterable int object

In [36]:
embeddings

[array([-0.07544, -0.0696 , -0.6836 , ..., -0.311  , -0.702  , -0.3406 ],
       dtype=float16),
 array([-0.05313, -0.0354 , -0.00892, ...,  0.00399, -0.08466,  0.02725],
       dtype=float16),
 array([0., 0., 0., ..., 0., 0., 0.], dtype=float16),
 array([0., 0., 0., ..., 0., 0., 0.], dtype=float16),
 array([0., 0., 0., ..., 0., 0., 0.], dtype=float16),
 array([0., 0., 0., ..., 0., 0., 0.], dtype=float16),
 array([0., 0., 0., ..., 0., 0., 0.], dtype=float16),
 array([0., 0., 0., ..., 0., 0., 0.], dtype=float16),
 array([-0.0882 ,  0.00818, -0.5522 , ...,  0.7456 , -0.6885 ,  0.3816 ],
       dtype=float16),
 array([-0.0882 ,  0.00818, -0.5522 , ...,  0.7456 , -0.689  ,  0.3816 ],
       dtype=float16),
 array([-0.0882 ,  0.00818, -0.5522 , ...,  0.7456 , -0.6885 ,  0.3816 ],
       dtype=float16)]

In [None]:
import numpy as np

cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
q = 'Main protagonist character in the narrative'
tokenized_q = tokenizer(q, return_tensors='pt', padding=True, truncation=True).to('cuda')
q_embedding = model(**tokenized_q)[0].detach().cpu().numpy()

# for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):
#     print(f'similarity_new("{q}", "{chunk}"):', cos_sim(q_embedding, new_embedding))
#     print(f'similarity_trad("{q}", "{chunk}"):', cos_sim(q_embedding, trad_embeddings))


AttributeError: 'XLMRobertaModel' object has no attribute 'encode'

In [43]:
for c in embeddings:
    print(c.shape)



(1024,)
(1024,)
(1024,)
(1024,)
(1024,)
(1024,)
(1024,)
(1024,)
(1024,)
(1024,)
(1024,)
