In [None]:
import os
import numpy as np

# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from pydantic import BaseModel
from typing import Optional, List
from transformers import AutoTokenizer, AutoModel


class TextSpan(BaseModel):
    s: int
    e: int
    text: Optional[str] = None
    module_name: str


RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
model = AutoModel.from_pretrained(
    "/processing_data/search/zengziyang/models/infgrad/dewey_en_beta",
    trust_remote_code=True,
    attn_implementation="flash_attention_2"
).cuda().bfloat16()
model.tokenizer = AutoTokenizer.from_pretrained("/processing_data/search/zengziyang/models/infgrad/dewey_en_beta")
max_seq_length = 32 * 1024

q_list = ["why the sky is blue"]
p_list = [
    """
    I’ve been trying to understand why the sky changes colors, and I think I understand most of it, but something in the online explanations doesn’t make it clear for me:

I’ve read:

sky is blue because blue light gets scattered the most during the day.

in the evening it turns red because now even more of the blue light gets scattered

So a few questions:

The scattering of light during the day: does it mean that blue light gets reflected off air particles and reaches our eyes, while the rest of the frequencies pass through and reach the ground?

Surely some of the other frequencies also get scattered during the day, just in much smaller amounts?

So during the evening blue light gets scattered even more, to the point where even less of it reaches the eyes?

And so it gets red because now we can see the lower frequencies being scattered without blue overshadowing them?

Trying to word it myself: during the day only the highest frequencies get filtered, but during the evening also lower frequencies get filtered, because now the “light strainer” (air) is just catching more of it?

It gets darker in the evening without a good ability to see colors because there’s is no blue and so on light to reflect off of objects?

Is it ok to speak about light as a frequency? Or it’s only correct to say “wave length”?

Blue light is scattered in all directions by the tiny molecules of air in Earth's atmosphere. Blue is scattered more than other colors because it travels as shorter, smaller waves. 
This is why we see a blue sky most of the time. Closer to the horizon, the sky fades to a lighter blue or white.
    """
]

  from .autonotebook import tqdm as notebook_tqdm
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.




In [18]:
# query should be a single vector, so we set chunk_size as -1 to avoid chunk.
# If chunk size is -1, the model will return an array with shape of (2,2048) consisting of cls_vector and mean_vector(mean of all token embeddings).
query_vectors = model.encode(
    sentences=q_list,
    use_cuda=True,
    show_progress_bar=True,
    chunk_size=-1,
    chunk_overlap=32,
    convert_to_tensor=False,
    max_seq_length=max_seq_length,
    batch_size=8,
    normalize_embeddings=True,
    prompt=RETRIEVE_Q_PROMPT,
    fast_chunk=False
)[0]

encoding text...: 100%|██████████| 1/1 [00:00<00:00, 40.82it/s]


In [22]:
query_vectors = query_vectors[0]

In [24]:
query_vectors.shape

(2, 2048)

In [26]:

# spans_list contail each chunk's span, you can use span to get text
spans_list: List[List[TextSpan]]
passage_vectors_list: List[np.ndarray]
passage_vectors_list, spans_list = model.encode(
    sentences=p_list,
    use_cuda=True,
    show_progress_bar=True,
    chunk_size=64,
    chunk_overlap=8,
    convert_to_tensor=False,
    max_seq_length=max_seq_length,
    batch_size=8,
    normalize_embeddings=True,
    prompt=RETRIEVE_P_PROMPT,
    fast_chunk=True,  # if fast_chunk is true, directly chunk on input ids, else using RecursiveCharacterTextSplitter
)

# spans_list stores each passage's spans, passage_vectors_list stores each passage's vectors so len(spans_list) == len(p_list) and len(spans_list) == len(passage_vectors_list)
# for a passage's spans and vectors, each span corresponds to a vector (1*2048). So, len(spans_list[idx]) ==  len(passage_vectors_list[idx])

encoding text...: 100%|██████████| 1/1 [00:00<00:00, 22.33it/s]


In [32]:
passage_vectors_list[0].shape

(8, 2048)

In [28]:
print(f"query_vectors shape: {query_vectors[1].shape}")
print(f"passage_vectors_list shape: {passage_vectors_list[0].shape}")
print((query_vectors[1] @ passage_vectors_list[0].T).max())

query_vectors shape: (2048,)
passage_vectors_list shape: (8, 2048)
0.73425025


In [None]:
# output 0.7331543
# get each chunk's content
for spans, passage in zip(spans_list, p_list):
    text_ids = model.tokenizer.encode(RETRIEVE_P_PROMPT + passage)
    for span in spans:
        s, e = span.s, span.e
        chunk_text = model.tokenizer.decode(
            text_ids[s:e],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        ).strip()
        # print(chunk_text)