In [1]:
import os
import re
import numpy as np

from typing import List, Tuple
from tqdm.notebook import tqdm
from fastembed import TextEmbedding
from qdrant_client.http import models
from qdrant_client import QdrantClient

In [2]:
def read_refined_content(file_path: str) -> str:
  with open(file_path, 'r', encoding='utf-8') as file:
    return file.read()

In [3]:
def extract_subjects(content: str) -> List[Tuple[str, str]]:
  # Split the content by "##" headers
  subjects = re.split(r'\n##\s', content)
  
  # The first element might be empty or contain text before the first "##"
  if not subjects[0].strip():
    subjects = subjects[1:]
  elif not subjects[0].startswith("##"):
    subjects[0] = "## " + subjects[0]
  
  # Pair each header with its content
  return [(subject.split('\n', 1)[0], subject) for subject in subjects]


In [4]:
def generate_embeddings(subjects: List[Tuple[str, str]]) -> List[Tuple[str, List[float]]]:
  embedding_model = TextEmbedding()
  embeddings = []
  
  for header, content in tqdm(subjects, desc="Generating embeddings"):
    embedding = next(embedding_model.embed([content]))
    embeddings.append((header, embedding.tolist()))
  
  return embeddings

In [5]:
def store_in_qdrant(embeddings: List[Tuple[str, List[float]]], collection_name: str, subjects: List[Tuple[str, str]]):
  api_key = os.getenv('QDRANT_API_KEY')
  endpoint = os.getenv('QDRANT_ENDPOINT')
  
  if not api_key or not endpoint:
    raise ValueError("Please set the QDRANT_API_KEY and QDRANT_ENDPOINT environment variables")
  
  client = QdrantClient(url=endpoint, api_key=api_key)
  
  collections = client.get_collections().collections
  if not any(collection.name == collection_name for collection in collections):
    client.create_collection(
      collection_name=collection_name,
      vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE),
    )
  
  points = [
    models.PointStruct(
      id=i,
      vector=embedding,
      payload={"header": header, "content": content}
    )
    for i, ((header, content), embedding) in enumerate(zip(subjects, [emb for _, emb in embeddings]))
  ]
  
  batch_size = 100
  for i in tqdm(range(0, len(points), batch_size), desc="Storing in Qdrant"):
    batch = points[i:i+batch_size]
    client.upsert(collection_name=collection_name, points=batch)

In [6]:
def main(input_file: str, collection_name: str):
  content = read_refined_content(input_file)
  
  subjects = extract_subjects(content)
  print(f"Extracted {len(subjects)} subjects")
  
  embeddings = generate_embeddings(subjects)
  
  store_in_qdrant(embeddings, collection_name, subjects)
  print("Finished storing embeddings in Qdrant")

In [7]:
input_file = "../markdowns/grading_doc-edit.md"
collection_name = "grading-doc"

main(input_file, collection_name)

Extracted 52 subjects


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Generating embeddings:   0%|          | 0/52 [00:00<?, ?it/s]

Storing in Qdrant:   0%|          | 0/1 [00:00<?, ?it/s]

Finished storing embeddings in Qdrant
