In [27]:
import os, json
from tqdm import tqdm
from pymilvus import connections,Collection, utility, FieldSchema, CollectionSchema, DataType
from dotenv import load_dotenv
import torch
from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter

device = "cuda" if torch.cuda.is_available() else "cpu"
# for big model BAAI/bge-m3
model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
BATCH_SIZE = 100
load_dotenv()


True

In [28]:
print(device)

cuda


In [29]:
print(torch.version.cuda)

11.8


In [30]:

connections.connect(
    alias="default",
    host="127.0.0.1",
    port="19530",
    timeout=5,
    secure=False
)

print("Connected to Milvus:", utility.get_server_version())


Connected to Milvus: v2.3.21


In [31]:

# incase I need to delete it
collection_name = "publications"

# Check if collection exists
if utility.has_collection(collection_name):
    utility.drop_collection(collection_name)
    print(f"Collection '{collection_name}' deleted.")
else:
    print(f"Collection '{collection_name}' does not exist.")

Collection 'publications' deleted.


In [32]:
def load_docs(folder_path:str) -> list:
    """load json files from specified folder into list, not load if text field is empty"""
    docs = list()

    for file in os.listdir(folder_path):
        if file.endswith(".json"):
            with open(os.path.join(folder_path, file) , "r", encoding="utf-8") as doc:
                data = json.load(doc)
                if data["text"] != "":
                    docs.append(data)
                data["PMC_code"] = file.replace(".json", "").strip()
    return docs

In [33]:

folder_path = "./data/publications_raw/"
documents = load_docs(folder_path)


In [34]:
def clean_pub_name(publication:json):
    """clean publication name from trailing escape characters and spaces"""
    publication["name"] = publication["name"].replace(r"\n","").strip()
def clean_text(publication:json):
    """strip main text"""
    publication["text"] = publication["text"].strip()


In [35]:
for doc in documents:
    clean_pub_name(doc)
    clean_text(doc)
    

In [36]:



text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=120)
all_chunks = []

for doc in tqdm(documents,"Processing Docs"):
    chunks = text_splitter.create_documents([doc["text"]])
    for chunk in chunks:
        all_chunks.append({
            "PMC_code": doc["PMC_code"],
            "name": doc["name"],
            "text" : chunk.page_content,
            "authors":doc["authors"],
            "date":doc["date"],
            "doi": doc["doi"]
        })
    

Processing Docs: 100%|██████████| 559/559 [00:03<00:00, 176.93it/s]


In [17]:
len(all_chunks)



50552

In [18]:
all_chunks[0]

{'PMC_code': 'PMC10020673',
 'name': 'Microbial isolation and characterization from two flex lines from the urine processor assembly onboard the international space station',
 'text': 'Urine, humidity condensate, and other sources of non-potable water are processed onboard the International Space Station (ISS) by the Water Recovery System (WRS) yielding potable water. While some means of microbial control are in place, including a phosphoric acid/hexavalent chromium urine pretreatment solution, many areas within the WRS are not available for routine microbial monitoring. Due to refurbishment needs, two flex lines from the Urine Processor Assembly (UPA) within the WRS were removed and returned to Earth. The water from within these lines, as well as flush water, was',
 'authors': ['Brian Crucian',
  'Yo-Ann Velez Justiniano',
  'Hang Ngoc Nguyen',
  'Aubrie O’Rourke',
  'Chelsea McCool',
  'Sarah L Castro-Wallace',
  'Miten Jain',
  'Christian L Castro',
  'Michael D Lee',
  'Sarah Stahl

In [37]:
all_chunks_length_cleaned = []
dropped_chunks = 0
for chunk in all_chunks:
    if "List of" not in chunk['text']:
        if len(chunk['text'].strip()) > 100:
            all_chunks_length_cleaned.append(chunk)
            
        else:
            dropped_chunks +=1
    else:
        dropped_chunks+=1
print("dropped_chunks:" + str(dropped_chunks))


dropped_chunks:248


In [38]:
import pickle


def save_to_pickle(filename,object_to_save):
    if ".pkl" not in filename:
        filename = filename+".pkl"
    with open(filename, "wb") as f:
        pickle.dump(object_to_save,f)

def load_pickle(filename):
    if ".pkl" not in filename:
        filename = filename + ".pkl"
    with open(filename, "rb") as f:
        return pickle.load(f)
# save_to_pickle("semantic_chunks.pkl", all_chunks)

In [39]:
print(len(all_chunks))
all_chunks = all_chunks_length_cleaned

50552


In [40]:
import numpy as np 

BATCH_SIZE = 256
all_chunks = all_chunks_length_cleaned

def embed_chunks_in_batches(chunks, batch_size=BATCH_SIZE):
    vectors = []
    texts = [c["text"] for c in chunks]
    batch_vectors = model.encode(
        texts, 
        normalize_embeddings=True, 
        batch_size=batch_size,
        show_progress_bar=True
    ) 
    vectors.extend(batch_vectors)
    return np.array(vectors)

In [45]:
vectors = embed_chunks_in_batches(all_chunks)

Batches:   0%|          | 0/197 [00:00<?, ?it/s]

In [48]:
save_to_pickle("vectors_publications_v1.pkl", vectors)


In [41]:
vectors = load_pickle("vectors_publications_v1.pkl")

In [42]:
text_lengths = [len(c["text"]) for c in all_chunks]
max_text_length = max(text_lengths)
print(max_text_length)


600


In [43]:
bad_chunks = [c for c in all_chunks if len(c["text"]) > 600]
for bc in bad_chunks[:5]:
    print(len(bc["text"]), repr(bc["text"][:120]))


In [44]:


fields = [
    FieldSchema(
        name="id", 
        dtype=DataType.INT64, 
        is_primary=True, 
        auto_id=True
    ),
    FieldSchema(
        name="embedding",
        dtype=DataType.FLOAT_VECTOR,
        dim=384, 
        metric_type="COSINE"
    ),
    FieldSchema(
        name="PMC_code",
        dtype=DataType.VARCHAR,
        max_length=20
    ),
    FieldSchema(
        name="name",
        dtype=DataType.VARCHAR,
        max_length=300
    ),
    FieldSchema(
        name="content",
        dtype=DataType.VARCHAR,
        max_length=2000
    ),
    FieldSchema(
        name="authors", 
        dtype=DataType.VARCHAR, 
        max_length=2000  # store as comma-separated or JSON
    ),
    FieldSchema(
        name="date",
        dtype=DataType.VARCHAR,
        max_length=20  # e.g. "2024-09-30"
    ),
    FieldSchema(
        name="doi",
        dtype=DataType.VARCHAR,
        max_length=200
    ),
]

schema = CollectionSchema(fields, description="RAG collection with publication metadata")
collection = Collection("publications", schema)



In [45]:
from datetime import datetime

for i in range(0, len(all_chunks), BATCH_SIZE):
    batch_vectors = vectors[i:i+BATCH_SIZE]
    batch_pmc_codes = [c["PMC_code"] for c in all_chunks[i:i+BATCH_SIZE]]
    batch_names = [c["name"] for c in all_chunks[i:i+BATCH_SIZE]]
    batch_texts = [c["text"] for c in all_chunks[i:i+BATCH_SIZE]]
    batch_authors = [",".join(c["authors"]) for c in all_chunks[i:i+BATCH_SIZE]]
    batch_dates = []
    batch_doi = []
    for c in all_chunks[i:i+BATCH_SIZE]:
        date = c["date"]
        if date:
            dt = datetime.strptime(date, "%Y %b %d")
            formatted = dt.strftime("%Y-%m-%d")
        else:
            formatted = "None"
        
        batch_dates.append(formatted)

        doi = c["doi"]
        if doi:
            batch_doi.append(doi)
        else:
            doi = "None"
            batch_doi.append(doi)
        
        

    collection.insert([batch_vectors, batch_pmc_codes, batch_names, batch_texts,batch_authors,batch_dates,batch_doi])