# Chunked Pooling
This notebooks explains how the chunked pooling can be implemented. First you need to install the requirements: 

In [None]:
%pip uninstall -y torch torchvision torchaudio

In [None]:
%pip install -r requirements.txt

Then we load a model which we want to use for the embedding. We choose `jinaai/jina-embeddings-v2-base-en` but any other model which supports mean pooling is possible. However, models with a large maximum context-length are preferred.

In [None]:
from chunked_pooling import chunked_pooling, chunk_by_sentences
from transformers import AutoModel
from transformers import AutoTokenizer

In [None]:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True)
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True)

Now we define the text which we want to encode and split it into chunks. The `chunk_by_sentences` function also returns the span annotations. Those specify the number of tokens per chunk which is needed for the chunked pooling.

In [None]:
input_text = "Москва — столица России, город федерального значения, административный центр Центрального федерального округа и центр Московской области, в состав которой не входит. Мегаполис; крупнейший по численности населения город России и её субъект — 13 149 803 человека (2024), что делает Москву 22-й среди городов мира по численности населения. Центр Московской городской агломерации. Самый крупный город Европы по площади и населению."
#input_text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."

# determine chunks
chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)
print('Chunks:\n- "' + '"\n- "'.join(chunks) + '"')

Now we encode the chunks with the traditional and the context-sensitive chunked pooling method:

In [None]:
# chunk before
embeddings_traditional_chunking = model.encode(chunks)

# chunk afterwards (context-sensitive chunked pooling)
inputs = tokenizer(input_text, return_tensors='pt')
model_output = model(**inputs)
embeddings = chunked_pooling(model_output, [span_annotations])[0]

Finally, we compare the similarity of the word "Berlin" with the chunks. The similarity should be higher for the context-sensitive chunked pooling method:

In [None]:
import numpy as np

cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

test_text = 'Москва'#"Berlin"#
berlin_embedding = model.encode(test_text)

for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):
    print(f'similarity_new("{test_text}", "{chunk}"):', cos_sim(berlin_embedding, new_embedding))
    print(f'similarity_trad("{test_text}", "{chunk}"):', cos_sim(berlin_embedding, trad_embeddings))

In [None]:
for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):
    print(f'abs new("{chunk}"):', np.linalg.norm(new_embedding))
    print(f'abs trad("{chunk}"):', np.linalg.norm(trad_embeddings))

print(f'\nabs test_text("{test_text}"):', np.linalg.norm(berlin_embedding))

# Bench

In [1]:
from chunked_pooling import chunked_pooling, chunk_by_sentences
from transformers import AutoModel
from transformers import AutoTokenizer
import pandas as pd
import torch
import numpy as np
import os

In [2]:
basePath = os.path.abspath('')
queries = pd.read_json(basePath + "\\ai-forever-ria-news-retrieval\\queries.jsonl", lines=True)
corpus = pd.read_json(basePath + "\\ai-forever-ria-news-retrieval\\corpus.jsonl", lines=True)
test = pd.read_json(basePath + "\\ai-forever-ria-news-retrieval\\test.jsonl", lines=True)
del corpus['title']
del test['score']

### То, что мы слишком поздно заметили

In [5]:
corpus[corpus['text'] == ""]

Unnamed: 0,_id,text
3840,3840,


In [3]:
def do_ch_emb(calc_chunk):
    device = torch.device("cuda")
    # load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True, device_map = 'cuda')
    model = AutoModel.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True, device_map = 'cuda')
    result = pd.DataFrame({
        "chunk" : [],
        "trad_chunk_embedding" : [],
        "new_chunk_embedding" : [],
        "doc_id" : [],
    })
    result.astype('object')
    last_idx = 0
    for doc_id, doc in calc_chunk.iterrows():
        doc_chunks, doc_span_annotations = chunk_by_sentences(doc['text'], tokenizer)
        doc_trad_chunk_embeddings = model.encode(doc_chunks)
        doc_inputs = tokenizer(doc['text'], return_tensors='pt')
        doc_model_output = model(**(doc_inputs.to(device)))
        doc_new_chunk_embeddings = chunked_pooling(doc_model_output, [doc_span_annotations])[0]
        for i, (chunk, trad_chunk_embedding, new_chunk_embedding) in enumerate(zip(doc_chunks, doc_trad_chunk_embeddings, doc_new_chunk_embeddings)):
            result.loc[last_idx + i] = np.array([chunk, trad_chunk_embedding, new_chunk_embedding, doc['_id']], dtype=object)
        last_idx += len(doc_chunks)
    return result

In [5]:
chunked_corpus  = np.array_split(corpus, 22) # Пусть у нас 5179 компьютер, сольём первые два из них для примера

  return bound(*args, **kwds)


In [5]:
%%time
chunks12 = do_ch_emb(chunked_corpus[12])
chunks12.shape

flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn i

CPU times: total: 2h 54min 26s
Wall time: 4h 53min 9s


(252887, 4)

In [6]:
chunks12.to_pickle(basePath + "\\ai-forever-ria-news-retrieval\\chunks_embedded_12.pkl")

In [8]:
chunked_corpus[0]['_id'][28346]

28346

In [6]:
%%time
chunks0 = do_ch_emb(chunked_corpus[0])
chunks0

flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn i

CPU times: total: 4.58 s
Wall time: 21.1 s


Unnamed: 0,chunk,trad_chunk_embedding,new_chunk_embedding,doc_id
0,"премьер-министр украины, кандидат в президенты...","[0.005603487, -0.027358998, -0.024784992, 0.03...","[-0.203125, -0.83203125, -0.80859375, 1.023437...",0
1,17 января в украине состоялся первый тур выб...,"[0.067110054, -0.111722454, -0.07669922, 0.078...","[0.21484375, -0.8984375, -1.171875, 1.140625, ...",0
2,второй тур выборов президента украины состои...,"[0.0342565, -0.074376695, -0.039508007, 0.0665...","[0.41210938, -0.9140625, -0.68359375, 1.21875,...",0
3,парламент украины по инициативе партии регион...,"[-0.011512435, -0.041095216, -0.08615855, 0.05...","[-0.013244629, -0.84765625, -1.109375, 1.07031...",0
4,министра.,"[0.030810941, -0.1456895, 0.027225856, 0.01233...","[-0.11425781, -1.21875, -1.3125, 0.7109375, -1...",0
...,...,...,...,...
964,налбандян из-за операции на бедре в мае прошл...,"[0.036205754, -0.12215991, -0.17237176, -0.035...","[0.45117188, -1.4921875, -2.390625, 0.58203125...",135
965,аргентинец собирался вернуться на корт в нача...,"[0.124536484, -0.12044175, -0.15277605, 0.1076...","[0.85546875, -1.5546875, -2.484375, 0.8203125,...",135
966,однако на турнире в новозеландском окленде на...,"[0.08003434, -0.11424159, -0.1379746, -0.04670...","[0.60546875, -1.203125, -2.421875, 0.453125, 1...",135
967,на данный момент последним официальным матче...,"[0.11851942, -0.14924, -0.15777057, 0.07895663...","[0.8125, -1.3984375, -2.390625, 0.6328125, 1.0...",135


In [7]:
%%time
chunks1 = do_ch_emb(chunked_corpus[1])
chunks1

flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn i

CPU times: total: 3.7 s
Wall time: 17 s


Unnamed: 0,chunk,trad_chunk_embedding,new_chunk_embedding,doc_id
0,хоккейный судья погиб в субботу во время мат...,"[0.063722946, -0.06450773, -0.03177504, 0.0524...","[0.87109375, -0.87109375, -0.546875, 0.6054687...",136
1,"как поясняет шведское телевидение, несчастныи...","[0.038227305, -0.1228597, -0.0351257, 0.002795...","[0.77734375, -0.7734375, -0.4453125, 0.6015625...",136
2,"судья, находившийся у борта, около скамьи дл...","[0.09723083, -0.09776237, -0.029085813, 0.1034...","[0.90234375, -0.796875, -0.18554688, 0.8476562...",136
3,несчастный случай произошел в первом период...,"[0.040227614, -0.15165919, -0.029836316, 0.088...","[1.0625, -0.78125, -0.22851562, 0.9609375, 2.0...",136
4,"пресс-служба местной полиции сообщает, что н...","[0.04868422, -0.08835515, 0.03624171, 0.085426...","[0.73828125, -0.63671875, 0.026489258, 0.97656...",136
...,...,...,...,...
1001,второй тур президентских выборов назначен на...,"[0.08194656, -0.09576612, -0.042479217, 0.0796...","[1.140625, -1.0703125, -0.734375, 1.0859375, -...",271
1002,"украинские эксперты не исключают, что проигра...","[0.1288577, -0.117042825, -0.07240841, 0.02360...","[1.3046875, -1.375, -0.859375, 0.5546875, -0.0...",271
1003,"""я не буду прогнозировать разрывы, я буду про...","[0.10990891, -0.0764635, -0.05408496, 0.023852...","[1.109375, -1.203125, -0.6953125, 0.6015625, -...",271
1004,"премьер украины в очередной раз заявила, что...","[0.00061494106, -0.08204302, -0.049697183, -0....","[0.84375, -1.1796875, -1.0234375, 0.32617188, ...",271


In [8]:
chunks0.to_pickle(basePath + "\\ai-forever-ria-news-retrieval\\chunks_embedded_0.pkl")

In [9]:
chunks1.to_pickle(basePath + "\\ai-forever-ria-news-retrieval\\chunks_embedded_1.pkl")

### Обратно загрузка с диска и сливание (части будто с разных компов, но сложенные в одну папку)

In [10]:
unpickled_chunks0 = pd.read_pickle(basePath + "\\ai-forever-ria-news-retrieval\\chunks_embedded_0.pkl")

In [11]:
unpickled_chunks1 = pd.read_pickle(basePath + "\\ai-forever-ria-news-retrieval\\chunks_embedded_1.pkl")

In [12]:
unpickled_chunks = pd.concat([unpickled_chunks0, unpickled_chunks1], ignore_index=True)

In [14]:
unpickled_chunks

Unnamed: 0,chunk,trad_chunk_embedding,new_chunk_embedding,doc_id
0,"премьер-министр украины, кандидат в президенты...","[0.005603487, -0.027358998, -0.024784992, 0.03...","[-0.203125, -0.83203125, -0.80859375, 1.023437...",0
1,17 января в украине состоялся первый тур выб...,"[0.067110054, -0.111722454, -0.07669922, 0.078...","[0.21484375, -0.8984375, -1.171875, 1.140625, ...",0
2,второй тур выборов президента украины состои...,"[0.0342565, -0.074376695, -0.039508007, 0.0665...","[0.41210938, -0.9140625, -0.68359375, 1.21875,...",0
3,парламент украины по инициативе партии регион...,"[-0.011512435, -0.041095216, -0.08615855, 0.05...","[-0.013244629, -0.84765625, -1.109375, 1.07031...",0
4,министра.,"[0.030810941, -0.1456895, 0.027225856, 0.01233...","[-0.11425781, -1.21875, -1.3125, 0.7109375, -1...",0
...,...,...,...,...
1970,второй тур президентских выборов назначен на...,"[0.08194656, -0.09576612, -0.042479217, 0.0796...","[1.140625, -1.0703125, -0.734375, 1.0859375, -...",271
1971,"украинские эксперты не исключают, что проигра...","[0.1288577, -0.117042825, -0.07240841, 0.02360...","[1.3046875, -1.375, -0.859375, 0.5546875, -0.0...",271
1972,"""я не буду прогнозировать разрывы, я буду про...","[0.10990891, -0.0764635, -0.05408496, 0.023852...","[1.109375, -1.203125, -0.6953125, 0.6015625, -...",271
1973,"премьер украины в очередной раз заявила, что...","[0.00061494106, -0.08204302, -0.049697183, -0....","[0.84375, -1.1796875, -1.0234375, 0.32617188, ...",271


# Трэш далее -- игнор

In [None]:
from multiprocessing import cpu_count
from multiprocessing.pool import ThreadPool as Pool
from threading import Lock

In [None]:
num_processes = cpu_count()
num_processes

In [None]:
num_processes = 2

In [None]:
corpus.shape[0]

In [None]:
calc_chunk_size = int(corpus.shape[0]/num_processes)
calc_chunk_size

In [None]:
last_calc_chunk_size = corpus.shape[0] - num_processes * calc_chunk_size + calc_chunk_size
last_calc_chunk_size

In [None]:
# Initialize an empty list to hold the chunks
calc_chunks = []

# Create the regular chunks
for i in range(num_processes - 1):
    start_idx = i * calc_chunk_size
    end_idx = start_idx + calc_chunk_size
    calc_chunks.append(corpus.iloc[start_idx:end_idx])

# Handle the last chunk separately
start_idx = (num_processes - 1) * calc_chunk_size
end_idx = start_idx + last_calc_chunk_size
calc_chunks.append(corpus.iloc[start_idx:end_idx])

In [None]:
# create our pool with `num_processes` processes
pool = Pool(processes=num_processes)

In [None]:
%%time
# apply our function to each chunk in the list
result = pool.map(do_ch_emb, calc_chunks)

In [None]:
chunks = pd.concat(result, ignore_index=True)

In [None]:
chunks.to_pickle(basePath + "\\ai-forever-ria-news-retrieval\\chunks_embedded.pkl")

In [None]:
%%time
queries['embedding'] = queries.apply(lambda q: model.encode(q['text']), axis=1)

In [None]:
queries.head()

In [None]:
queries.to_pickle(basePath + "\\ai-forever-ria-news-retrieval\\queries_embedded.pkl")