# Chunked Pooling
This notebooks explains how the chunked pooling can be implemented. First you need to install the requirements: 

In [None]:
%pip uninstall -y torch torchvision torchaudio

In [None]:
%pip install -r requirements.txt

Then we load a model which we want to use for the embedding. We choose `jinaai/jina-embeddings-v2-base-en` but any other model which supports mean pooling is possible. However, models with a large maximum context-length are preferred.

In [None]:
from chunked_pooling import chunked_pooling, chunk_by_sentences
from transformers import AutoModel
from transformers import AutoTokenizer

In [None]:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True)
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True)

Now we define the text which we want to encode and split it into chunks. The `chunk_by_sentences` function also returns the span annotations. Those specify the number of tokens per chunk which is needed for the chunked pooling.

In [None]:
input_text = "Москва — столица России, город федерального значения, административный центр Центрального федерального округа и центр Московской области, в состав которой не входит. Мегаполис; крупнейший по численности населения город России и её субъект — 13 149 803 человека (2024), что делает Москву 22-й среди городов мира по численности населения. Центр Московской городской агломерации. Самый крупный город Европы по площади и населению."
#input_text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."

# determine chunks
chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)
print('Chunks:\n- "' + '"\n- "'.join(chunks) + '"')

Now we encode the chunks with the traditional and the context-sensitive chunked pooling method:

In [None]:
# chunk before
embeddings_traditional_chunking = model.encode(chunks)

# chunk afterwards (context-sensitive chunked pooling)
inputs = tokenizer(input_text, return_tensors='pt')
model_output = model(**inputs)
embeddings = chunked_pooling(model_output, [span_annotations])[0]

Finally, we compare the similarity of the word "Berlin" with the chunks. The similarity should be higher for the context-sensitive chunked pooling method:

In [None]:
import numpy as np

cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

test_text = 'Москва'#"Berlin"#
berlin_embedding = model.encode(test_text)

for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):
    print(f'similarity_new("{test_text}", "{chunk}"):', cos_sim(berlin_embedding, new_embedding))
    print(f'similarity_trad("{test_text}", "{chunk}"):', cos_sim(berlin_embedding, trad_embeddings))

In [None]:
for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):
    print(f'abs new("{chunk}"):', np.linalg.norm(new_embedding))
    print(f'abs trad("{chunk}"):', np.linalg.norm(trad_embeddings))

print(f'\nabs test_text("{test_text}"):', np.linalg.norm(berlin_embedding))

# Bench

In [1]:
from chunked_pooling import chunked_pooling, chunk_by_sentences
from transformers import AutoModel
from transformers import AutoTokenizer
import pandas as pd
import torch
import numpy as np
import os

In [2]:
basePath = os.path.abspath('')
queries = pd.read_json(basePath + "\\ai-forever-ria-news-retrieval\\queries.jsonl", lines=True)
corpus = pd.read_json(basePath + "\\ai-forever-ria-news-retrieval\\corpus.jsonl", lines=True)
test = pd.read_json(basePath + "\\ai-forever-ria-news-retrieval\\test.jsonl", lines=True)
del corpus['title']
del test['score']

In [3]:
corpus

Unnamed: 0,_id,text
0,0,"премьер-министр украины, кандидат в президенты..."
1,1,группа вооруженных людей в ночь с субботы на ...
2,2,немецкий теннисист михаэль беррер стал победи...
3,3,генеральный секретарь оон пан ги мун заявил в...
4,4,"леверкузенский ""байер"" со счетом 3:1 на свое..."
...,...,...
704339,704339,главными стратегическими учениями для армии ро...
704340,704340,ракетные войска стратегического назначения (р...
704341,704341,сухопутные войска россии в 2015 году примут у...
704342,704342,полиция мексиканского города чилапа в штате ге...


In [4]:
def do_ch_emb(calc_chunk):
    device = torch.device("cuda")
    # load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True, device_map = 'cuda')
    model = AutoModel.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True, device_map = 'cuda')
    result = pd.DataFrame({
        "chunk" : [],
        "trad_chunk_embedding" : [],
        "new_chunk_embedding" : [],
        "doc_id" : [],
    })
    result.astype('object')
    for doc_id, doc in calc_chunk.iterrows():
        doc_chunks, doc_span_annotations = chunk_by_sentences(doc['text'], tokenizer)
        doc_trad_chunk_embeddings = model.encode(doc_chunks)
        doc_inputs = tokenizer(doc['text'], return_tensors='pt')
        doc_model_output = model(**(doc_inputs.to(device)))
        doc_new_chunk_embeddings = chunked_pooling(doc_model_output, [doc_span_annotations])[0]
        for i, (chunk, trad_chunk_embedding, new_chunk_embedding) in enumerate(zip(doc_chunks, doc_trad_chunk_embeddings, doc_new_chunk_embeddings)):
            result.loc[i] = np.array([chunk, trad_chunk_embedding, new_chunk_embedding, doc['_id']], dtype=object)
    return result

In [5]:
from multiprocessing import cpu_count
from multiprocessing.pool import ThreadPool as Pool
from threading import Lock

In [6]:
num_processes = cpu_count()
num_processes

20

In [7]:
num_processes = 2

In [8]:
corpus.shape[0]

704344

In [9]:
calc_chunk_size = int(corpus.shape[0]/num_processes)
calc_chunk_size

352172

In [10]:
last_calc_chunk_size = corpus.shape[0] - num_processes * calc_chunk_size + calc_chunk_size
last_calc_chunk_size

352172

In [11]:
# Initialize an empty list to hold the chunks
calc_chunks = []

# Create the regular chunks
for i in range(num_processes - 1):
    start_idx = i * calc_chunk_size
    end_idx = start_idx + calc_chunk_size
    calc_chunks.append(corpus.iloc[start_idx:end_idx])

# Handle the last chunk separately
start_idx = (num_processes - 1) * calc_chunk_size
end_idx = start_idx + last_calc_chunk_size
calc_chunks.append(corpus.iloc[start_idx:end_idx])

In [12]:
# create our pool with `num_processes` processes
pool = Pool(processes=num_processes)

flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn i

In [None]:
%%time
# apply our function to each chunk in the list
result = pool.map(do_ch_emb, calc_chunks)

In [None]:
chunks = pd.concat(result, ignore_index=True)

In [None]:
chunks.to_pickle(basePath + "\\ai-forever-ria-news-retrieval\\chunks_embedded.pkl")

In [None]:
%%time
queries['embedding'] = queries.apply(lambda q: model.encode(q['text']), axis=1)

In [None]:
queries.head()

In [None]:
queries.to_pickle(basePath + "\\ai-forever-ria-news-retrieval\\queries_embedded.pkl")