In [1]:
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document as LCDocument
from langchain_milvus.utils.sparse import BM25SparseEmbedding
from typing import Iterator
import torch

In [2]:
from langchain.embeddings.base import Embeddings
from transformers import AutoTokenizer, AutoModel
import torch

class CustomDenseEmbedding(Embeddings):
    def __init__(self, model_name='sentence-transformers/stsb-xlm-r-multilingual'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def embed_documents(self, texts):
        inputs = self.tokenizer(texts, return_tensors='pt', truncation=True, padding=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
        return embeddings

    def embed_query(self, text):
        inputs = self.tokenizer([text], return_tensors='pt', truncation=True, padding=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
        return embedding.tolist()


  from .autonotebook import tqdm as notebook_tqdm


In [15]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from langchain_milvus.retrievers import MilvusCollectionHybridSearchRetriever
from pymilvus.client.abstract import BaseRanker
from typing import List


tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-m3')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-v2-m3')

class CustomReranker(BaseRanker):
    def __init__(self):
        super().__init__()

    def rerank(self, query_text: str, retrieved_texts: List[str]) -> List[str]:
        pairs = [[query_text, retrieved_text] for retrieved_text in retrieved_texts]
        results = []

        with torch.no_grad():
            for pair in pairs:
                inputs = tokenizer(pair, padding=True, truncation=True, return_tensors='pt', max_length=512)
                outputs = model(**inputs, return_dict=True)
                logits = outputs.logits.squeeze(0)  #
                if len(logits) == 1:
                    score = logits[0].item()  
                else:
                    score = abs(logits[0] - logits[1]).item()  
                results.append((pair[1], score))  

        sorted_results = sorted(results, key=lambda x: x[1], reverse=True)

        reranked_texts = [text for text, score in sorted_results]

        return reranked_texts

In [3]:
from datetime import date
import pandas as pd
from io import StringIO
import os
from dotenv import load_dotenv
import psycopg2

load_dotenv(dotenv_path="../env")

db_host = os.getenv("DB_HOST")
db_name = os.getenv("DB_NAME")
db_user = os.getenv("DB_USER")
db_password = os.getenv("DB_PASSWORD")

conn = psycopg2.connect(
    host=db_host,
    database=db_name,
    user=db_user,
    password=db_password
)

In [4]:
def load_wrong_questions(class_id, student_id,conn):
    cur = conn.cursor()

    query = """
                SELECT qa.*
                FROM quiz_answers qa
                JOIN mcq_questions mq ON qa.mcq_question_id = mq.id
                WHERE qa.user_id = %s
                AND mq.class_id = %s
                AND qa.answer != mq.answer
                ORDER BY qa.created_at DESC
                LIMIT 10;
            """


    cur.execute(query, (student_id, class_id))

    rows = cur.fetchall()

    col_names = [desc[0] for desc in cur.description]

    result = [dict(zip(col_names, row)) for row in rows]

    cur.close()
    return result


In [5]:
wrong_questions = load_wrong_questions(1, 4, conn)
print(wrong_questions)

[{'id': 4, 'user_id': 4, 'quiz_id': 2, 'mcq_question_id': 7, 'answer': 3, 'created_at': datetime.datetime(2024, 11, 28, 16, 20, 18, 132971)}, {'id': 8, 'user_id': 4, 'quiz_id': 3, 'mcq_question_id': 6, 'answer': 4, 'created_at': datetime.datetime(2024, 11, 28, 16, 20, 18, 132971)}, {'id': 10, 'user_id': 4, 'quiz_id': 2, 'mcq_question_id': 18, 'answer': 3, 'created_at': datetime.datetime(2024, 11, 28, 16, 20, 18, 132971)}, {'id': 12, 'user_id': 4, 'quiz_id': 1, 'mcq_question_id': 16, 'answer': 4, 'created_at': datetime.datetime(2024, 11, 28, 16, 20, 18, 132971)}, {'id': 19, 'user_id': 4, 'quiz_id': 2, 'mcq_question_id': 8, 'answer': 4, 'created_at': datetime.datetime(2024, 11, 28, 16, 20, 18, 132971)}, {'id': 20, 'user_id': 4, 'quiz_id': 3, 'mcq_question_id': 8, 'answer': 4, 'created_at': datetime.datetime(2024, 11, 28, 16, 20, 18, 132971)}, {'id': 22, 'user_id': 4, 'quiz_id': 1, 'mcq_question_id': 13, 'answer': 1, 'created_at': datetime.datetime(2024, 11, 28, 16, 20, 18, 132971)}]


In [6]:
wrong_mcq_question_ids = [q['mcq_question_id'] for q in wrong_questions]
print(wrong_mcq_question_ids)

[7, 6, 18, 16, 8, 8, 13]


In [7]:
def get_mcq_questions(wrong_mcq_question_ids, conn):
    query = """SELECT * FROM mcq_questions WHERE id IN %s"""

    cur = conn.cursor()
    cur.execute(query, (tuple(wrong_mcq_question_ids),))
    rows = cur.fetchall()

    col_names = [desc[0] for desc in cur.description]
    wrong_mcq_questions = [dict(zip(col_names, row)) for row in rows]

    return wrong_mcq_questions

wrong_mcq_questions = get_mcq_questions(wrong_mcq_question_ids, conn)
print(wrong_mcq_questions)

[{'id': 16, 'user_id': 2, 'question': [{'id': 'cea7c819-02ab-4074-9e2d-354832f38ed4', 'type': 'paragraph', 'props': {'textColor': 'default', 'textAlignment': 'left', 'backgroundColor': 'default'}, 'content': [{'text': 'Hi ', 'type': 'text', 'styles': {}}, {'text': 'Hello Yaluwane', 'type': 'text', 'styles': {'textColor': 'orange'}}, {'text': ' Oyalata ', 'type': 'text', 'styles': {}}, {'text': 'kohomada', 'type': 'text', 'styles': {'italic': True}}, {'text': ' ithin bath ', 'type': 'text', 'styles': {}}, {'text': 'kawada', 'type': 'text', 'styles': {'bold': True}}, {'text': '???', 'type': 'text', 'styles': {}}], 'children': []}], 'answer': 1, 'class_id': 1, 'topic': '6', 'sub_topic': '6', 'mcq_answers': [{'index': 0, 'value': '3'}, {'index': 1, 'value': '34'}, {'index': 2, 'value': '45'}, {'index': 3, 'value': 'hghg'}, {'index': 4, 'value': '12'}], 'activation': 0}, {'id': 6, 'user_id': 2, 'question': [{'id': 'a999eb34-f80f-4b07-9a75-4163fee5b388', 'type': 'paragraph', 'props': {'textC

In [8]:
texts = [q['question'][0]['content'][0]['text']for q in wrong_mcq_questions]

In [None]:
    
def embed_documents(wrong_mcq_questions, dense_embedding_instance, sparse_embedding_func):
    embeddings_with_metadata = []

    for doc in wrong_mcq_questions:
        question_id = doc.get('id') 
        content_text = " ".join([content.get('text', '') for content in doc.get('question', [{}])[0].get('content', [])])
        
        embeddings_with_metadata.append({
            "id": question_id,
            "text": content_text,
            "embedding": dense_embedding_instance.embed_documents(content_text),  # Assuming embedding_func is defined
            "sparse": sparse_embedding_func.embed_documents(content_text)[0],  # Assuming sparse_embedding_func is defined
            "metadata": {
                "topic": doc.get('topic'),
                "sub_topic": doc.get('sub_topic'),
                "class_id": doc.get('class_id')
            }
        })

    for doc in embeddings_with_metadata:
        doc["embedding"] = doc["embedding"].tolist()

    return embeddings_with_metadata


In [10]:
sparse_embedding_func = BM25SparseEmbedding(corpus=texts)
dense_embedding_instance = CustomDenseEmbedding()
embeddings_with_metadata = embed_documents(wrong_mcq_questions, dense_embedding_instance, sparse_embedding_func)

In [None]:
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility

connections.connect("default", host="localhost", port="19530")

collection_name = "questioncollection"

if utility.has_collection(collection_name):
    utility.drop_collection(collection_name)

: 

In [12]:
if not utility.has_collection(collection_name):
    field1 = FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False)
    field2 = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768)  
    field3 = FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR)  
    field4 = FieldSchema(name="text", dtype=DataType.VARCHAR, is_primary=False, max_length=5000)  # Question content
    field5 = FieldSchema(name="topic", dtype=DataType.VARCHAR, is_primary=False, max_length=256)  # Topic
    field6 = FieldSchema(name="sub_topic", dtype=DataType.VARCHAR, is_primary=False, max_length=256)  # Sub-topic
    field7 = FieldSchema(name="class_id", dtype=DataType.INT64, is_primary=False)  # Class ID

    # Create a collection schema
    schema = CollectionSchema(
        fields=[field1, field2, field3, field4, field5, field6, field7],
        description="Collection for storing MCQ question embeddings with metadata",
        enable_dynamic_field=True
    )

    # Create the collection
    collection = Collection(name=collection_name, schema=schema)

    # Define and create indexes
    embedding_index_params = {"metric_type": "IP", "index_type": "IVF_FLAT", "params": {"nlist": 128}}
    collection.create_index(field_name="embedding", index_params=embedding_index_params)

    sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}
    collection.create_index(field_name="sparse", index_params=sparse_index)

    # Flush the collection to persist changes
    collection.flush()

    print(f"Collection '{collection_name}' created successfully.")
else:
    # Load the existing collection
    collection = Collection(name=collection_name)
    print(f"Collection '{collection_name}' already exists.")

# Confirm the collection is loaded
print(f"Collection loaded: {collection.name}")


Collection 'questioncollection' created successfully.
Collection loaded: questioncollection


In [13]:
data = [
    {
        "id": embeddings_with_metadata[i]["id"],
        "embedding": embeddings_with_metadata[i]["embedding"],
        "sparse": sparse_embedding_func.embed_documents([embeddings_with_metadata[i]["text"]])[0],  # Generate sparse embedding
        "text": embeddings_with_metadata[i]["text"],
        "topic": embeddings_with_metadata[i]["metadata"]["topic"],
        "sub_topic": embeddings_with_metadata[i]["metadata"]["sub_topic"],
        "class_id": embeddings_with_metadata[i]["metadata"]["class_id"]
    }
    for i in range(len(embeddings_with_metadata))
]

In [14]:
collection.insert(data)
collection.load()

In [16]:
from langchain_milvus.retrievers import MilvusCollectionHybridSearchRetriever


sparse_search_params = {"metric_type": "IP"}
dense_search_params = {"metric_type": "IP", "params": {}}

reranker = CustomReranker()

retriever = MilvusCollectionHybridSearchRetriever(
    collection=collection,
    rerank=reranker,
    anns_fields=["embedding", "sparse"],  # Fields to search in Milvus
    field_embeddings=[dense_embedding_instance, sparse_embedding_func],  # Embedding functions
    field_search_params=[dense_search_params, sparse_search_params],  # Search parameters for embeddings
    top_k=5,  
    text_field="text",  
    filter_fields=["topic", "sub_topic", "class_id"],  
)

filter_criteria = {
    "topic": 7,  
    "sub_topic": 7 ,  
    "class_id": 1 
}

query = "OK OK OK OK"
retrieved_docs = retriever.invoke(query, filters=filter_criteria)

print(retrieved_docs)


[Document(metadata={'id': 7, 'topic': '7', 'sub_topic': '7', 'class_id': 1}, page_content='ok ok ok'), Document(metadata={'id': 6, 'topic': '7', 'sub_topic': '7', 'class_id': 1}, page_content='wefwefwe'), Document(metadata={'id': 18, 'topic': '7', 'sub_topic': '7', 'class_id': 1}, page_content='uefb w  aeifhpiauewp'), Document(metadata={'id': 13, 'topic': '6', 'sub_topic': '6', 'class_id': 1}, page_content='Hi Hello Yaluwane Oyalata kohomada ithin bath kawada???'), Document(metadata={'id': 16, 'topic': '6', 'sub_topic': '6', 'class_id': 1}, page_content='Hi  Hello Yaluwane  Oyalata  kohomada  ithin bath  kawada ???')]


In [17]:
extracted_ids = [doc.metadata['id'] for doc in retrieved_docs]
print(extracted_ids)

[7, 6, 18, 13, 16]
