In [None]:
!pip install langchain_core==0.3.75 langchain_huggingface==0.3.1 langchain-community sentence-transformers
!pip install torch torchvision --upgrade

In [None]:
import os
# import toml

from pymilvus import connections
from pymilvus import Collection, utility, connections, FieldSchema, CollectionSchema, DataType

import ibm_boto3
from ibm_botocore.client import Config, ClientError

from langchain_huggingface import HuggingFaceEmbeddings  # Updated import

from dotenv import load_dotenv

## Load connections information from the project

In [None]:
# list your connections
wslib.list_connections()

In [None]:
# make sure you use the right connection name for presto
milvus_conn = wslib.get_connection('milvus_connection')
cos_conn = wslib.get_connection('cos_connection')

## Load env.txt file with configuration

In [None]:
with open('.env_all', 'wb') as env_file:
    env_file.write(wslib.load_data('env.txt').read())
# environmental variables store credentials and configuration
load_dotenv('.env_all')

## COS Resource

In [None]:
# Constants for IBM COS values
COS_ENDPOINT = f"https://{cos_conn['url']}"
COS_API_KEY_ID = cos_conn['api_key']
COS_INSTANCE_CRN = cos_conn['resource_instance_id']

### make sure that COS_ENDPOINT is the same as for buckets and the same for all buckets, if not -> replace

In [None]:
print(COS_ENDPOINT)
# COS_ENDPOINT = "enter https-prepended endpoint and uncomment"

In [None]:
# Create client
cos_resource = ibm_boto3.resource("s3",
    ibm_api_key_id=COS_API_KEY_ID,
    ibm_service_instance_id=COS_INSTANCE_CRN,
    config=Config(signature_version="oauth"),
    endpoint_url=COS_ENDPOINT
)


## For text splitting and embedding

In [None]:
embeddings_model_name = os.getenv("SENTENCE_TRANSFORMER")
chunk_size = int(os.getenv("TEXT_SPLITTER_CHUNK_SIZE"))
chunk_overlap = int(os.getenv("TEXT_SPLITTER_CHUNK_OVERLAP"))

# TEXT - SPLITTER RELATED
text_splitter_type = os.getenv("TEXT_SPLITTER_TYPE")
# to define text splitter incl. text splitter separators (defined in the environment variables)
# json string
text_splitter_separators = os.getenv("TEXT_SPLITTER_SEPARATORS")
text_replacements = os.getenv("TEXT_REPLACEMENTS")

similarity_metric = os.getenv("SIMILARITY_METRIC")

INPUT_BUCKET = os.environ["INPUT_BUCKET"]

In [None]:
# embeddings function - TO SELECT FUNCTION and PARAMETERS
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)

## For milvus connection

In [None]:
# Authentication for Milvus within wx.data
host = milvus_conn["host"]
port = milvus_conn["port"]
password = milvus_conn["password"]
user = milvus_conn["username"]

mv_collection = os.environ["MV_COLLECTION_NAME"]


In [None]:
connection_args={
               'host':host,
               'port':port,
               'user': user,
               'password': password,
                'secure': True
}

# 1. Load and split documents

In [None]:
import json
import os
import posixpath

from ibm_botocore.client import ClientError

from langchain.document_loaders import (
    PyPDFLoader,
    Docx2txtLoader,
    TextLoader,
    UnstructuredHTMLLoader,
)
from langchain.text_splitter import (
    CharacterTextSplitter,
    RecursiveCharacterTextSplitter,
)

# dictionary with text splitter classes
text_splitter_type_dict = {
    "CharacterTextSplitter": CharacterTextSplitter,
    "RecursiveCharacterTextSplitter": RecursiveCharacterTextSplitter,
}

# dictionary with available document loaders
_document_loaders_dict = {
    "docx": Docx2txtLoader,
    "html": UnstructuredHTMLLoader,
    "pdf": PyPDFLoader,
    "csv": TextLoader,
    "txt": TextLoader,
}

def get_bucket_contents(cur_conn, bucket_name):
    """
    function to get the list of files from the given bucket
    """
    items_list = []
    print("Retrieving bucket contents from:{0}".format(bucket_name))
    try:
        files = cur_conn.Bucket(bucket_name).objects.all()
        for file in files:
            items_list.append("{0}".format(file.key))
    except ClientError as be:
        print("CLIENT ERROR: {0}".format(be))
    except Exception as e:
        print("Unable to retrieve bucket contents: {0}".format(e))
    return items_list

def get_item(cur_conn, bucket_name, item_name):
    """
    get file contents of a particular file:
    - based on bucket_name and item_name

    """
    print(
        "Retrieving item from bucket: {0}, key: {1}".format(
            bucket_name, item_name
        )
    )
    try:
        file = cur_conn.Object(bucket_name, item_name).get()
        print("File Contents retrieved")
        return file
    except ClientError as be:
        print("CLIENT ERROR: {0}".format(be))
    except Exception as e:
        print("Unable to retrieve file contents: {0}".format(e))

def _save_doc_linux(cur_conn, bucket_name, doc_path, tmp_location):
    """Save file from COS to temp location locally"""
    bytes_obj = get_item(cur_conn, bucket_name, doc_path)["Body"].read()
    with open(tmp_location, "wb") as f:
        f.write(bytes_obj)


def _replace_characters(text_replacements, init_text: str) -> str:
    """
    Replaces characters in the init_text based on text_replacements that contains mapping {"old_char": "new_char"}.
    Returns text with the replaced character
    """
    replace_json_dict = json.loads(text_replacements)
    output_text = init_text
    for old_char, new_char in replace_json_dict.items():
        output_text = output_text.replace(old_char, new_char)
    return output_text


def load_split_cos_docs(
    cur_conn,
    bucket_name,
    chunk_size,
    chunk_overlap,
    text_splitter_type,
    text_splitter_separators,
    text_replacements,
):
    """
    Load and split txt, pdf and doc documents from bucket on COS, where:
    - cur_conn is connection established to COS
    - bucket_name is the name of the bucket with documents to load and split
    - chunk size is the size of text to split by
    - chunk overlap is the number of characters to overlap between chunks
    By default it will use original language, to translate to english you need to set translate to True
    """
    # list to save all docs into a list
    split_loaded_docs = list()
    temp_folder = os.path.join(".", "temp")

    if text_splitter_separators is not None and text_splitter_separators != "":
        text_splitter_separators = json.loads(text_splitter_separators, strict=False)

    # location for temp files
    bucket_contents = get_bucket_contents(cur_conn, bucket_name)

    # defining the splitter
    if text_splitter_type in text_splitter_type_dict:
        print(f"Using {text_splitter_type}")
        # free text
        if (
            isinstance(text_splitter_separators, list)
            and len(text_splitter_separators) > 0
        ):
            text_splitter = text_splitter_type_dict[text_splitter_type](
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                separators=text_splitter_separators,
            )
            print(
                f"Specified separators for the text splitter {text_splitter_separators}"
            )
        else:
            text_splitter = text_splitter_type_dict[text_splitter_type](
                chunk_size=chunk_size, chunk_overlap=chunk_overlap
            )
            print(
                "Separators for text splitters are either not specified or specified incorrectly"
            )
    else:
        print(f"Text splitter of type {text_splitter_type} is not supported")

    # going through files
    print("Looping through bucket contents")
    for doc_path_num, doc_path in enumerate(bucket_contents):
        # to check only files within project sources
        sources_folder_path = posixpath.join(os.getenv("COS_FOLDER"))
        if sources_folder_path not in doc_path:
            continue
        print(f"Processing {doc_path}")
        cur_doc_format = doc_path.split(".")[-1].lower()
        # to create temp folder if it doesn't exist
        if not os.path.exists(temp_folder):
            os.mkdir(temp_folder)
        tmp_location = os.path.join(temp_folder, doc_path.split("/")[-1])
        print(f"Attempting to load {doc_path}")
        # check for formatting
        if cur_doc_format in _document_loaders_dict:
            # to load document
            _save_doc_linux(cur_conn, bucket_name, doc_path, tmp_location)
            doc_loader = _document_loaders_dict[cur_doc_format](tmp_location)
            loaded_docs = doc_loader.load()
        else:
            logger.warning(
                f"Data in {doc_path} is not supported, accepted formats are {_document_loaders_dict.keys()}"
            )
            continue
        print(f"Successfully loaded {doc_path}")

        # to change metadata of loaded documents to include only filename for the source
        # replace strings based on the input json
        for loaded_doc in loaded_docs:
            # to change metadata of loaded documents to include only filename for the source
            loaded_doc.metadata["source"] = doc_path.split("/")[-1]
            # to replace characters in page_content according to dict stored in TEXT_REPLACEMENTS
            if text_replacements != "" and text_replacements is not None:
                loaded_doc.page_content = _replace_characters(
                    text_replacements, loaded_doc.page_content
                )
        # to split documents
        print(f"Splitting docs loaded from {doc_path}")
        splitted_docs = text_splitter.split_documents(loaded_docs)
        # to add current document to total list
        split_loaded_docs.extend(splitted_docs)
        if os.path.exists(tmp_location):
            os.remove(tmp_location)
            print(f"File {tmp_location} was removed")
    # after loop through COS documents ended
    print("Documents from Cloud Object Storage are loaded and splitted")
    return {"splitted_loaded_docs": split_loaded_docs}


In [None]:
splitted_docs = load_split_cos_docs(
        cos_resource, INPUT_BUCKET, chunk_size, chunk_overlap, text_splitter_type, text_splitter_separators, text_replacements
        )

In [None]:
sources = []
pages = []
docs = []
for splitted_doc in splitted_docs['splitted_loaded_docs']:
    sources.append(splitted_doc.metadata['source'])
    pages.append(splitted_doc.metadata['page'])
    docs.append(splitted_doc.page_content)

# 2. Milvus

## Connect

In [None]:
connections.connect(
alias='default',
**connection_args
)

## To drop collection if exists

In [None]:
if utility.has_collection(mv_collection): # check if collection exists
    utility.drop_collection(mv_collection)

## To create a new collection
> you can update schema naming to correspond to your requirements

In [None]:
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="text_embedding", dtype=DataType.FLOAT_VECTOR, dim=384),  # Assuming 384-dim vectors
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=500),
    FieldSchema(name="page", dtype=DataType.INT64),
    FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),  # Storing original content as well
]
schema = CollectionSchema(fields)
collection = Collection(name=mv_collection, schema=schema)

In [None]:
print(collection.schema.fields)

## To insert documents into collection, create index and load data into memory

In [None]:
# Convert documents into vector embeddings
doc_embeddings = embeddings.embed_documents(docs)

# Insert documents and embeddings into Milvus
collection.insert([doc_embeddings, sources, pages, docs]) 

In [None]:
# create index
index_params = {
    "index_type": "IVF_FLAT",  # Can also use "IVF_PQ", "HNSW", etc.
    "metric_type": similarity_metric,       # L2 for Euclidean distance, or use "IP" for Inner Product
    "params": {"nlist": 110},  # nlist is a hyperparameter for clustering / corresponds to cca 800 vectors
}

collection.create_index(field_name="text_embedding", index_params=index_params)

In [None]:
# Load data into memory (optional but recommended for larger datasets)
collection.load()

## To perform semantic search

In [None]:
cur_query="What is ETF?"
search_params = {"metric_type": similarity_metric, "params": {"nprobe": 10}}
query_result = collection.search(
    [embeddings.embed_query(cur_query)], 
    "text_embedding", 
    search_params, 
    limit=5, 
    output_fields=["title", "page", "text"])
# Print the query result
for entity in query_result[0]:
    print(entity)