In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.embeddings.sentence_transformer import (
    SentenceTransformerEmbeddings,
)
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.embeddings import JinaEmbeddings
from sentence_transformers import SentenceTransformer


In [None]:
import pickle as pkl


with open('splitDocuments.pkl','rb') as f: 
  all_splits = pkl.load(f)

In [None]:
groups = {"Academic": [0],
          "LTI": [7, 8, 9],
          "Calendar": [3],
          "Facts": [1, 2, 4, 5, 6, 10]}

In [None]:
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")

for i in groups:
    for j in groups[i]:
        vectorstore = Chroma.from_documents(documents=all_splits[j], embedding=embedding_function, persist_directory="{}DB/".format(i))


In [None]:
from langchain_core.retrievers import BaseRetriever, RetrieverLike, RetrieverOutputLike
from langchain_core.language_models import BaseLLM
from langchain_core.embeddings import Embeddings
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from typing import List

from transformers import AutoTokenizer, AutoModel
import torch
import transformers

class CustomRetriever(BaseRetriever):
  

    # vectorstore = Chroma(persist_directory="llm-embedder", embedding_function=embeddings)
    # model = SentenceTransformer('BAAI/bge-reranker-base')

    vectorstore : List[RetrieverLike]
    model : transformers.models.bert.modeling_bert.BertModel
    tokenizer : transformers.models.bert.tokenization_bert_fast.BertTokenizerFast

    def maxsim(self, query_embedding, document_embedding):
        expanded_query = query_embedding.unsqueeze(2)
        expanded_doc = document_embedding.unsqueeze(1)
    
        sim_matrix = torch.nn.functional.cosine_similarity(expanded_query, expanded_doc, dim=-1)
        max_sim_scores, _ = torch.max(sim_matrix, dim=2)    
        avg_max_sim = torch.mean(max_sim_scores, dim=1)
        return avg_max_sim
    
    def flatten_extend(self, matrix):
        flat_list = []
        for row in matrix:
            flat_list.extend(row)
        return flat_list

    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:

        all_docs = []
        for i in self.vectorstore:
            all_docs.append(i.get_relevant_documents(query, k=3))

        all_docs = self.flatten_extend(all_docs)

        query_encoding = self.tokenizer(query, return_tensors='pt')
        query_embedding = self.model(**query_encoding).last_hidden_state.mean(dim=1)

        scores = []
        for document in all_docs:
            document_encoding = self.tokenizer(document.page_content, return_tensors='pt', truncation=True, max_length=512)
            document_embedding = self.model(**document_encoding).last_hidden_state
        
            score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
            scores.append(score.item())

        print(scores)
        return [x for _, x in sorted(zip(scores, all_docs), reverse=True)]
        
        

        

In [None]:
tokenizer = AutoTokenizer.from_pretrained("colbert-ir/colbertv2.0")
model = AutoModel.from_pretrained("colbert-ir/colbertv2.0")
vectorstores = []
for i in groups:
    store = Chroma(persist_directory="{}DB".format(i), embedding_function=embedding_function)
    vectorstores.append(store.as_retriever())

In [None]:
customRetriever = CustomRetriever(vectorstore=vectorstores, model=model , tokenizer=tokenizer)

In [None]:
customRetriever.get_relevant_documents('When will the classes begin in the Fall 2024 semester')