In [1]:
from chunking import *
from transformers import AutoModel
from transformers import AutoTokenizer
from pathlib import Path
import torch
import numpy as np
from pprint import pprint

In [2]:
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3', trust_remote_code=True)
model = AutoModel.from_pretrained('BAAI/bge-m3', trust_remote_code=True)
model.eval()
model.to('cuda')
model.half()

XLMRobertaModel(
  (embeddings): XLMRobertaEmbeddings(
    (word_embeddings): Embedding(250002, 1024, padding_idx=1)
    (position_embeddings): Embedding(8194, 1024, padding_idx=1)
    (token_type_embeddings): Embedding(1, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): XLMRobertaEncoder(
    (layer): ModuleList(
      (0-23): 24 x XLMRobertaLayer(
        (attention): XLMRobertaAttention(
          (self): XLMRobertaSdpaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): XLMRobertaSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-05, elem

In [3]:
doc_file_path= Path('..') / 'data' / 'bookcompanion'  / 'A Caribbean Mystery.txt'
doc_name = doc_file_path.name
doc_content = doc_file_path.read_text(encoding='utf-8')
print(f"{doc_name} has {len(doc_content)} characters.")

A Caribbean Mystery.txt has 302700 characters.


In [38]:
doc_out = sentence_chunking(doc_content, tokenizer)
doc_sentences = doc_out['sentences']
doc_sentence_spans = doc_out['spans']
doc_input_ids = doc_out['input_ids']

chunked_doc = split_doc_to_chunks(doc_input_ids, doc_sentence_spans, max_length=512, overlap=64)
chunks = chunked_doc['batched_input_ids']
chunk_spans = chunked_doc['spans']

print(f"""
Document {doc_name}:
    {len(doc_content)} characters
    {len(doc_input_ids)} tokens
    {len(doc_sentences)} sentences
    {len(chunked_doc['spans'])} chunks
""")


Document A Caribbean Mystery.txt:
    302700 characters
    81729 tokens
    3805 sentences
    166 chunks



In [5]:
def batch_sentence_token_offsets(chunk_spans, sentence_spans):
    last_end = chunk_spans[-1][1]
    filtered_sent_spans = [(start, end) for start, end in sentence_spans if end <= last_end]
    for chunk_span in chunk_spans:
        yield [(start-chunk_span[0], end-chunk_span[0]) for start, end in filtered_sent_spans if start >= chunk_span[0] and end <= chunk_span[1]]

    

In [18]:
batch_size = 8
chunks = chunked_doc['batched_input_ids']
chunk_spans = chunked_doc['spans']
late_embeddings = []
with torch.no_grad():
    for i in (range(0, len(chunks), batch_size)):
        text_chunks = tokenizer.batch_decode(chunks[i:i+batch_size], skip_special_tokens=True)
        inputs = tokenizer(text_chunks, return_tensors='pt',padding=True, truncation=True, max_length=tokenizer.model_max_length).to('cuda')
        model_output = model(**inputs)
        batch_sentence_spans = list(batch_sentence_token_offsets(chunk_spans[i:i+batch_size], doc_sentence_spans))
        # print(f'{model_output[0].shape=} vs {len(chunks[i])=}')
        for late_embed in late_chunking(model_output, batch_sentence_spans):
            late_embeddings.extend(late_embed)
    late_embeddings = torch.stack([torch.from_numpy(embd) for embd in late_embeddings])

print(f'{len(late_embeddings)=}')
print(f'{len(late_embeddings) == len(doc_sentences)=}')

KeyboardInterrupt: 

In [7]:
trad_embeddings = []
batch_size = 8
with torch.no_grad():
    for i in range(0, len(doc_sentences), batch_size):
        tokenized_chunks = tokenizer(doc_sentences[i:i+batch_size], return_tensors='pt', padding=True, truncation=True).to('cuda')
        cls_token_outputs=model(**tokenized_chunks).last_hidden_state[:, 0].detach().cpu().numpy()
        trad_embeddings.extend(cls_token_outputs.tolist())
    trad_embeddings = [np.array(embd) for embd in trad_embeddings]
trad_embeddings = torch.stack([torch.from_numpy(embd) for embd in trad_embeddings])

In [39]:
trad_chunk_embeddings = []
with torch.no_grad():
    for i in range(0, len(chunks), batch_size):
        tokenized_chunks = tokenizer(tokenizer.batch_decode(chunks[i:i+batch_size]), return_tensors='pt', padding=True, truncation=True).to('cuda')
        cls_token_outputs=model(**tokenized_chunks).last_hidden_state[:, 0].detach().cpu().numpy()
        trad_chunk_embeddings.extend(cls_token_outputs.tolist())
trad_chunk_embeddings = [np.array(embd) for embd in trad_chunk_embeddings]
trad_chunk_embeddings = torch.stack([torch.from_numpy(embd) for embd in trad_chunk_embeddings])

In [9]:
def get_sentences_for_chunks(chunk_spans, sentence_spans):
    for chunk_span in chunk_spans:
        yield [i for i, sent in enumerate(sentence_spans) if sent[0] >= chunk_span[0] and sent[1] <= chunk_span[1]]
        

In [None]:
q = """Query: Major Palgrave is an old soldier who tells endless stories about his past that no-one cares about and a few people doubt. The key story, in which he (or someone else) met a murderer is an important plot point and is true."""
tokenized_q = tokenizer(q, return_tensors='pt').to('cuda')
q_embedding = model(**tokenized_q).last_hidden_state[:, 0].squeeze(0).detach().cpu()

q_embedding_late_chunk = late_chunking(model(**tokenized_q), [[(0, len(q))]])[0][0]
q_embedding_late_chunk = torch.from_numpy(q_embedding_late_chunk)


cosine_sim = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
late_topk = torch.topk(cosine_sim(q_embedding_late_chunk, late_embeddings), 100)


for i,score in enumerate(late_topk.values):
    pprint(f'{score.item():.3f} {doc_sentences[late_topk.indices[i]]}')
    print(f'---')


TypeError: cosine_similarity(): argument 'x2' (position 2) must be Tensor, not list

In [41]:

trad_topk = torch.topk(cosine_sim(q_embedding, trad_embeddings), 30)

for i,score in enumerate(trad_topk.values):
    pprint(f'{score.item():.3f} {doc_sentences[trad_topk.indices[i]]}')
    print(f'---')


('0.711 Miss Marple did not resent it, because people seldom did resent Mr. '
 'Rafiel’s somewhat arbitrary method of doing things.')
---
('0.681 “I once came across a very curious case—not exactly personally.”\n'
 'Miss Marple smiled encouragingly.')
---
('0.671 It could be, of course, because Mr. Rafiel would not have liked it, '
 'but Miss Marple didn’t think Mr. Rafiel would really mind in the least.')
---
('0.670 Of course—he was quite old.”\n'
 '“He seemed quite well and cheerful yesterday,” said Miss Marple, slightly '
 'resenting this calm assumption that everyone of advanced years was liable to '
 'die at any minute.')
---
('0.668 He needn’t have died if he’d looked after himself properly.”\n'
 '“Oh come now, Mr. Rafiel,” said Mrs. Walters.')
---
('0.663 And if people decided the food was bad—and left—or told their '
 'friends—”\n'
 '“I really don’t think you need worry,” said Miss Marple kindly.')
---
('0.662 “I’m afraid I used to escape from him whenever I could.”\n'
 '“Miss

In [91]:
q = """Passage: Conviction by Counterfactual Clue:"""
tokenized_q = tokenizer(q, return_tensors='pt').to('cuda')
q_embedding = model(**tokenized_q).last_hidden_state[:, 0].squeeze(0).detach().cpu()
trad_chunk_topk = torch.topk(cosine_sim(q_embedding, trad_chunk_embeddings), 10)

for i,score in enumerate(trad_chunk_topk.values):
    print(f'{score.item():.3f}')
    pprint(f'{tokenizer.decode(chunks[trad_chunk_topk.indices[i]])}')
    print(f'---')

0.461
('He could even admit a likeness, he could say: ‘Yes, I do look rather like '
 'that fellow, don’t I! Ha, ha!’ Nobody’s going to take old Palgrave’s '
 'identification seriously. Don’t tell me so, because I won’t believe it. No, '
 'the chap, if it was the chap, had nothing to fear—nothing whatever. It’s the '
 'kind of accusation he can just laugh off. Why on earth should he proceed to '
 'murder old Palgrave? It’s absolutely unnecessary. You must see that.” “Oh I '
 'do see that,” said Miss Marple. “I couldn’t agree with you more. That’s what '
 'makes me uneasy. So very uneasy that I really couldn’t sleep last night.” '
 'Mr. Rafiel stared at her. “Let’s hear what’s on your mind,” he said quietly. '
 '“I may be entirely wrong,” said Miss Marple hesitantly. “Probably you are,” '
 'said Mr. Rafiel with his usual lack of courtesy, “but at any rate let’s hear '
 'what you’ve thought up in the small hours.” “There could be a very powerful '
 'motive if—” “If what?” “If there was go

In [23]:
chunk_sentence_map = list(get_sentences_for_chunks(chunk_spans, doc_sentence_spans))
print(f'{len(chunk_sentence_map)=}')
chunk_scores = np.zeros(len(chunk_sentence_map))
for i, chunk in enumerate(chunk_sentence_map):
    for late_topk_index, sent_id in enumerate(late_topk.indices):
        if sent_id in chunk:
            chunk_scores[i] += late_topk.values[late_topk_index]
top_chunks = np.argpartition(chunk_scores, -10)[-10:][::-1]
for i, chunk_index in enumerate(top_chunks):
    print(f'sim: {chunk_scores[chunk_index]}')
    pprint(f'{tokenizer.decode(chunks[chunk_index])}')
    print(f'---')

    
    


len(chunk_sentence_map)=344
sim: 4.53515625
('them a good deal.” “For how long?” “Oh, I don’t know. About—oh, I suppose a '
 'month—perhaps longer. She—we—thought they were just—well, nightmares, you '
 'know.” “Yes, yes, I quite understand. But what’s a much more serious sign is '
 'the fact that she seems to have felt afraid of someone. Did she complain '
 'about that to you?” “Well, yes. She said once or twice that—oh, people were '
 'following her.” “Ah! Spying on her?” “Yes, she did use that term once. She '
 'said they were her enemies and they’d followed her here.” “Did she have '
 'enemies, Mr. Kendal?—” “No. Of course she didn’t.” “No incident in England, '
 'anything you know about before you were married?” “Oh no, nothing of that '
 'kind. She didn’t get on with her family very well, that was all. Her mother '
 'was rather an eccentric woman, difficult to live with perhaps, but....” “Any '
 'signs of mental instability in her family?” Tim opened his mouth '
 'impulsively, th