In [None]:
from typing import List, Iterable, Union

from langchain_core.embeddings import Embeddings

import torch
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig

def _last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

class MyEmbeddings(Embeddings):
    def __init__(
        self,
        model: str = "Qwen/Qwen3-Embedding-8B",
        *,
        max_length: int = 8192,
        device = None
    ):
        self.model_name = model
        self.max_length = max_length

        self._tokenizer = tokenizer = AutoTokenizer.from_pretrained(self.model_name, padding_side='left')

        bnb_config = BitsAndBytesConfig(load_in_4bit=True)
        self._model = AutoModel.from_pretrained(self.model_name,quantization_config=bnb_config,).eval()

        if device is None:
            if torch.cuda.is_available():
                device = "cuda"
            elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
                device = "mps"
            else:
                device = "cpu"
        self._device = torch.device(device)
        self._model.to(self._device)

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return self._embed_texts(texts)

    def embed_query(self, text: str) -> List[float]:
        return self._embed_texts([text])[0]

    @torch.inference_mode()
    def _embed_texts(self, texts: Union[List[str], Iterable[str]]) -> List[List[float]]:
        batch_dict = self._tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        batch_dict.to(self._device)
        outputs = self._model(**batch_dict)
        

        embeddings = _last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

        out = embeddings.tolist()
        return out

In [2]:
from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings

my_embeddings = MyEmbeddings()

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
import json
from langchain_core.documents import Document

output_file = "../../../data/chunked/markdown_header_text_splitter/markdown_header_text_splitter.jsonl"

docs = []
with open(output_file, "r", encoding="utf-8") as f:
    for line in f:
        d = json.loads(line) 
        docs.append(Document(**d))

In [4]:
from langchain_chroma import Chroma

DB_PATH = "../../../data/embedded/chroma_db"

chroma_db = Chroma(
    collection_name="samsung_quarterly_report",
    embedding_function=my_embeddings,
    persist_directory=DB_PATH,
)

batch_size = 1
for i in range(0, len(docs), batch_size):
    batch = docs[i : i + batch_size]
    chroma_db.add_documents(batch)
    print(f"Batch {i // batch_size + 1} ({len(batch)} docs) 추가 완료")

print("모든 문서 임베딩 완료 및 저장 완료")

Batch 1 (1 docs) 추가 완료
Batch 2 (1 docs) 추가 완료
Batch 3 (1 docs) 추가 완료
Batch 4 (1 docs) 추가 완료
Batch 5 (1 docs) 추가 완료
Batch 6 (1 docs) 추가 완료
Batch 7 (1 docs) 추가 완료
Batch 8 (1 docs) 추가 완료
Batch 9 (1 docs) 추가 완료
Batch 10 (1 docs) 추가 완료
Batch 11 (1 docs) 추가 완료
Batch 12 (1 docs) 추가 완료
Batch 13 (1 docs) 추가 완료
Batch 14 (1 docs) 추가 완료
Batch 15 (1 docs) 추가 완료
Batch 16 (1 docs) 추가 완료
Batch 17 (1 docs) 추가 완료
Batch 18 (1 docs) 추가 완료
Batch 19 (1 docs) 추가 완료
Batch 20 (1 docs) 추가 완료
Batch 21 (1 docs) 추가 완료
Batch 22 (1 docs) 추가 완료
Batch 23 (1 docs) 추가 완료
Batch 24 (1 docs) 추가 완료
Batch 25 (1 docs) 추가 완료
Batch 26 (1 docs) 추가 완료
Batch 27 (1 docs) 추가 완료
Batch 28 (1 docs) 추가 완료
Batch 29 (1 docs) 추가 완료
Batch 30 (1 docs) 추가 완료
Batch 31 (1 docs) 추가 완료
Batch 32 (1 docs) 추가 완료
Batch 33 (1 docs) 추가 완료
Batch 34 (1 docs) 추가 완료
Batch 35 (1 docs) 추가 완료
Batch 36 (1 docs) 추가 완료
Batch 37 (1 docs) 추가 완료
Batch 38 (1 docs) 추가 완료
Batch 39 (1 docs) 추가 완료
Batch 40 (1 docs) 추가 완료
Batch 41 (1 docs) 추가 완료
Batch 42 (1 docs) 추가 완료
B

In [3]:
from langchain_chroma import Chroma

DB_PATH = "../../../data/embedded/chroma_db"

persist_db = Chroma(
    persist_directory=DB_PATH,
    embedding_function=MyEmbeddings(),
    collection_name="samsung_quarterly_report",
)

print(len(persist_db.get()['ids']))

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

801


In [14]:
collection = persist_db._collection

# 벡터 꺼내기
results = collection.get(limit=5, include=["embeddings", "documents", "metadatas"])
print(results["embeddings"])

[[-2.45117188 -1.70996094 -2.4609375  ... -0.52050781 -0.51953125
  -0.1661377 ]
 [-3.34179688 -1.203125   -2.40429688 ...  0.31420898 -0.33349609
   0.11706543]
 [-1.08203125  3.65820312 -0.92578125 ... -2.5         0.90039062
   2.52734375]
 [ 3.52929688 -1.5625     -0.83203125 ...  0.36865234  2.390625
   1.54882812]
 [ 2.20703125 -1.40234375 -1.86328125 ...  0.39550781  1.234375
   3.45898438]]


In [6]:
embedded_documents = my_embeddings.embed_documents("안녕하세요")

In [8]:
print(embedded_documents)

[[4.85546875, 1.6015625, 0.93408203125, -7.62109375, -0.353515625, -1.626953125, -2.572265625, 1.556640625, -0.58935546875, 2.693359375, -1.7802734375, -0.90283203125, 6.90625, -1.7021484375, -0.55859375, 2.5, -4.8203125, -0.96875, 3.1953125, 4.421875, -2.984375, 1.4443359375, 3.669921875, -0.23193359375, 2.09765625, 6.9609375, -5.9140625, -3.185546875, 0.0129547119140625, 2.45703125, -5.45703125, 1.5869140625, 0.609375, 2.796875, -2.544921875, 1.736328125, -2.734375, 0.26806640625, 2.279296875, -1.423828125, -1.5634765625, 0.89794921875, 1.5634765625, 0.33740234375, -1.388671875, -2.623046875, -3.0859375, -1.2333984375, 0.630859375, 1.2041015625, -0.22119140625, -4.01953125, -3.765625, -1.8115234375, -0.72314453125, -0.348388671875, -0.6416015625, -2.318359375, -3.234375, -1.4189453125, -5.4921875, 3.525390625, -1.98046875, -1.6259765625, 1.8681640625, 1.0166015625, 1.4736328125, -2.1484375, 1.048828125, -1.7939453125, 0.5263671875, -1.1591796875, 2.939453125, 1.3212890625, -0.5747070