In [14]:
import json
import torch
import torch.nn.functional as F
from torch import Tensor
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, AutoModel

from config import RAW_DATA_DIR, PROCESSED_DATA_DIR

In [12]:
uid = 'CHE152876230'
with open(RAW_DATA_DIR / f'{uid}.json') as f:
    data = [content['markdown'] for url, content in json.load(f)[uid].items()]

In [17]:
headers_to_split_on = [
    ("#", "Header 1"),
    ("##", "Header 2"),
    ("###", "Header 3"),
]
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, strip_headers=False)
recursive_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=24)

chunks = []
for markdown in data:
    md_header_splits = markdown_splitter.split_text(markdown)
    chunks.extend(recursive_splitter.split_documents(md_header_splits))

In [18]:
chunks

[Document(metadata={}, page_content='[![Chiron Logo](https://cdn.prod.website-files.com/643483a340aff3191dff2edc/6435d5a6afbf3c4ecdef1f79_Logo_CHIRON-2023-01%201.png)](https://www.chiron-services.ch/</>)[Über uns](https://www.chiron-services.ch/</ueber-uns>)[Services](https://www.chiron-services.ch/</services>)[Produkte](https://www.chiron-services.ch/</produkte>)[NEWS](https://www.chiron-services.ch/</news>)\n[DE](https://www.chiron-services.ch/</>)[|](https://www.chiron-services.ch/<#>)[EN](https://www.chiron-services.ch/</en/home>)'),
 Document(metadata={'Header 1': 'Unlocking the Value of Data'}, page_content='# Unlocking the Value of Data\n[Services](https://www.chiron-services.ch/</services>)[Produkte](https://www.chiron-services.ch/</produkte>)\n![Lock](https://cdn.prod.website-files.com/643483a340aff3191dff2edc/6435383994b6c658b6bd27d9_chiron_lock_open2%201.webp)\nDigital solutions by digital natives'),
 Document(metadata={'Header 1': 'Unlocking the Value of Data', 'Header 2': 

In [None]:
def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
    get_detailed_instruct(task, 'how much protein should a female eat'),
    get_detailed_instruct(task, '南瓜的家常做法')
]
# No need to add instruction for retrieval documents
documents = [
    "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
    "1.清炒南瓜丝 原料:嫩南瓜半个 调料:葱、盐、白糖、鸡精 做法: 1、南瓜用刀薄薄的削去表面一层皮,用勺子刮去瓤 2、擦成细丝(没有擦菜板就用刀慢慢切成细丝) 3、锅烧热放油,入葱花煸出香味 4、入南瓜丝快速翻炒一分钟左右,放盐、一点白糖和鸡精调味出锅 2.香葱炒南瓜 原料:南瓜1只 调料:香葱、蒜末、橄榄油、盐 做法: 1、将南瓜去皮,切成片 2、油锅8成热后,将蒜末放入爆香 3、爆香后,将南瓜片放入,翻炒 4、在翻炒的同时,可以不时地往锅里加水,但不要太多 5、放入盐,炒匀 6、南瓜差不多软和绵了之后,就可以关火 7、撒入香葱,即可出锅"
]
input_texts = queries + documents

tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large-instruct')

# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')

outputs = model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())

In [None]:
class InstructEmbeddings:
    def __init__(self):
        pass

def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'


# Each query must come with a one-sentence instruction that describes the task
task = 'Given a question, retrieve Wikipedia passages that answer the question'
queries = [
    get_detailed_instruct(task, '최초의 원자력 발전소는 무엇인가?'),
    get_detailed_instruct(task, 'Who invented Hangul?')
]
# No need to add instruction for retrieval documents
passages = [
    "현재 사용되는 핵분열 방식을 이용한 전력생산은 1948년 9월 미국 테네시주 오크리지에 설치된 X-10 흑연원자로에서 전구의 불을 밝히는 데 사용되면서 시작되었다. 그리고 1954년 6월에 구소련의 오브닌스크에 건설된 흑연감속 비등경수 압력관형 원자로를 사용한 오브닌스크 원자력 발전소가 시험적으로 전력생산을 시작하였고, 최초의 상업용 원자력 엉더이로를 사용한 영국 셀라필드 원자력 단지에 위치한 콜더 홀(Calder Hall) 원자력 발전소로, 1956년 10월 17일 상업 운전을 시작하였다.",
    "Hangul was personally created and promulgated by the fourth king of the Joseon dynasty, Sejong the Great.[1][2] Sejong's scholarly institute, the Hall of Worthies, is often credited with the work, and at least one of its scholars was heavily involved in its creation, but it appears to have also been a personal project of Sejong."
]

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('Linq-AI-Research/Linq-Embed-Mistral')
model = AutoModel.from_pretrained('Linq-AI-Research/Linq-Embed-Mistral')

max_length = 4096
input_texts = [*queries, *passages]
# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors="pt")
outputs = model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

# Normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())