In [None]:
import fitz
from langchain_experimental.text_splitter import SemanticChunker
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
import json
from docx import Document as Doc
import os
from langchain.prompts import PromptTemplate
from langchain_chroma import Chroma
from langchain_ollama import OllamaLLM
from langchain.chains import RetrievalQA
from langchain.schema import Document
import io
import zipfile
import base64
from PIL import Image
import numpy as np
import cv2
import pytesseract
import pandas as pd
import openpyxl
import datetime
from langchain_community.document_loaders import UnstructuredPowerPointLoader
import tempfile
from unstructured.partition.pptx import partition_pptx
import traceback
from PIL import Image as PILImage
from uuid import uuid4
import importnb

# Suppress tokenizers parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

with importnb.Notebook():
    import db_utils
db_utils.init_db()

os.environ["UNSTRUCTURED_POWERPOINT_ONLY_USE_PYTHON"] = "1"


In [2]:
def extract_docx_images(file):
    images = []
    doc_zip = zipfile.ZipFile(file)
    for name in doc_zip.namelist():
        if name.startswith("word/media"):
            image_data = doc_zip.read(name)
            image_bytes = io.BytesIO(image_data)
            images.append((name, Image.open(image_bytes)))
    return images

In [None]:
def process_image_and_ocr(images):
    final_image_data = []
    for image in images:
        img_array = np.array(image[1])
        
        # Handle different image formats - check channels before conversion
        if len(img_array.shape) == 3:
            if img_array.shape[2] == 4:  # RGBA
                img_bw = cv2.cvtColor(img_array, cv2.COLOR_RGBA2GRAY)
            elif img_array.shape[2] == 3:  # RGB
                img_bw = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
            else:
                img_bw = img_array  # Unknown format, use as-is
        else:
            # Already grayscale or single channel
            img_bw = img_array
        
        final_image = cv2.GaussianBlur(img_bw, (3,3), 0)
        extracted_image_data = pytesseract.image_to_string(final_image)
        final_image_data.append((image[0], extracted_image_data))
    
    return final_image_data

In [None]:
def get_and_read_file(file):
    content = file.read()
    text = ""
    if file.name.endswith(".docx"):
        byte_content = io.BytesIO(content)
        doc = Doc(byte_content)
        paras = [para.text for para in doc.paragraphs]
        tables = []
        image_text = None

        images = extract_docx_images(byte_content)
        if images:
            image_data = process_image_and_ocr(images)
            image_text = "\n".join(f"{image[0]}\n{image[1]}\n\n" for image in image_data)

        for table in doc.tables:
            for row in table.rows:
                row_text = [cell.text.strip() for cell in row.cells]
                tables.append("\t".join(row_text))
        if image_text is not None:
            text = "\n".join(tables + paras + [image_text])
        else:
            text = "\n".join(tables + paras)
        return text

    elif file.name.endswith(".pdf"):
        file.seek(0)
        content = file.read()
        doc = fitz.open(stream=content, filetype="pdf")
        text = "\n".join(page.get_text() for page in doc)
        ocr_texts = []
        for page_num in range(len(doc)):
            page = doc[page_num]
            images = page.get_images(full=True)
            for img_index, img in enumerate(images):
                xref = img[0]
                base_image = doc.extract_image(xref)
                image_bytes = base_image["image"]
                image = Image.open(io.BytesIO(image_bytes))
                ocr_result = process_image_and_ocr([(f"page{page_num}_img{img_index}", image)])
                for name, ocr_text in ocr_result:
                    ocr_texts.append(f"{name}\n{ocr_text}\n")
        if ocr_texts:
            text += "\n" + "\n".join(ocr_texts)
        return text
    
    elif file.name.endswith(".pptx"):
        file.seek(0)
        with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
            temp_file.write(file.read())
            temp_file.flush()
            temp_path = temp_file.name
        from pptx import Presentation
        prs = Presentation(temp_path)
        slide_texts = []
        for slide_idx, slide in enumerate(prs.slides):
            text = "\n".join([shape.text for shape in slide.shapes if hasattr(shape, "text") and shape.text])
            ocr_texts = []
            for shape in slide.shapes:
                if hasattr(shape, "image"):
                    img = shape.image
                    img_bytes = img.blob
                    pil_img = PILImage.open(io.BytesIO(img_bytes))
                    for _, ocr_text in process_image_and_ocr([(None, pil_img)]):
                        if ocr_text.strip():
                            ocr_texts.append(ocr_text.strip())
            if ocr_texts:
                text += "\n" + "\n".join(ocr_texts)
            slide_texts.append(text)
        return "\n".join(slide_texts)

In [5]:
# splitting file into chunks
def split_pdf_to_chunks(file):
    text = get_and_read_file(file)
    splitter = SemanticChunker(embeddings=HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5"))
    chunks = splitter.split_text(text)
    data = []
    for i, chunk in enumerate(chunks):
        segment_data = {"chunk_number": i, "chunk_content": chunk}
        data.append(segment_data)
    return data

def split_docx_to_chunks(file):
    text = get_and_read_file(file)
    splitter = SemanticChunker(embeddings=HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5"))
    chunks = splitter.split_text(text)
    data = []
    for i, chunk in enumerate(chunks):
        segment_data = {"chunk_number": i, "chunk_content": chunk}
        data.append(segment_data)
    return data

def split_excel_to_chunks(file):
    if file.name.endswith(".xlsx") or file.name.endswith(".xls"):
        content = file.read()
        if file.name.endswith(".xlsx"):
            df = pd.read_excel(io.BytesIO(content), engine="openpyxl")
        elif file.name.endswith(".xls"):  # .xls
            df = pd.read_excel(io.BytesIO(content), engine="xlrd")
        elif file.name.endswith(".xlsb"):
            df = pd.read_excel(io.BytesIO(content), engine="pyxlsb")
        df = df.ffill(axis=0)
        headers = df.columns.tolist()
        rows = df.values.tolist()
        chunks = []
        for i in range(len(rows)):
            chunk = {
                "chunk_number": i,
                "chunk_content": []
            }
            for j in range(len(headers)):
                value = rows[i][j]
                if pd.isna(value):
                    value = ""
                elif isinstance(value, (pd.Timestamp, pd.NaT.__class__)):
                    value = str(value)
                elif isinstance(value, (datetime.datetime, datetime.date)):
                    value = str(value)
                chunk["chunk_content"].append({headers[j]: value})
            chunks.append(chunk)
        return chunks

def split_pptx_to_chunks(file):
    text = get_and_read_file(file)
    splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
    chunks = splitter.split_text(text)
    data = []
    for i, chunk in enumerate(chunks):
        segment_data = {"chunk_number": i, "chunk_content": chunk}
        data.append(segment_data)
    return data

In [6]:
def get_chunks(file, file_name):
    if file_name.endswith(".pdf"):
        data = split_pdf_to_chunks(file)
    elif file_name.endswith(".docx"):
        data = split_docx_to_chunks(file)
    elif file_name.endswith(".xlsx"):
        data = split_excel_to_chunks(file)
    elif file_name.endswith(".pptx"):
        data = split_pptx_to_chunks(file)
    else:
        print("Unsupported file format.")
        return []
    return data

In [None]:
# optional
def save_chunks_to_file(chunks, file_name):
    with open(f"../chunk_files/{file_name}_chunks.txt", "w") as write_file:
        json.dump(chunks, write_file, indent=2)

In [None]:
def init_db():
    embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
    vectordb = Chroma(
        collection_name="internship",
        persist_directory="./my_db",
        embedding_function=embedding_model
    )
    retriever = vectordb.as_retriever(search_kwargs={"k": 5})
    return vectordb, retriever

In [None]:
def store_to_vectordb(vectordb, chunks, file_name):
    if not chunks:
        return
    chunk_ids = []
    docs = []
    if file_name.endswith(".xlsx"):
        for chunk in chunks:
            chunk_id = str(uuid4())
            chunk_ids.append(chunk_id)
            content = ", ".join([
                f"{list(item.keys())[0]}: {list(item.values())[0]}" for item in chunk["chunk_content"]
            ])
            docs.append(Document(page_content=content, id=chunk_id, metadata={"source": file_name, "chunk_number": chunk["chunk_number"]}))
    else:
        for chunk in chunks:
            chunk_id = str(uuid4())
            chunk_ids.append(chunk_id)
            docs.append(Document(page_content=chunk["chunk_content"], id=chunk_id, metadata={"source": file_name, "chunk_number": chunk["chunk_number"]}))

    vectordb.add_documents(docs)
    return chunk_ids

In [None]:
# creating RAG chain 
def create_qa_chain(retriever):

    custom_prompt = PromptTemplate(
        input_variables=["context", "question"],
        template=(
            "You are a document assistant. Carefully analyze the provided context and answer the question using only the information found in that context.\n\n"
            "INSTRUCTIONS:\n"
            "1. Read through the entire context thoroughly to find relevant information\n"
            "2. Extract and present information from the context, organizing it clearly\n"
            "3. You may rephrase or reorganize the information for clarity, but stay true to the original meaning\n"
            "4. Include specific details, lists, requirements, and examples as they appear in the context\n"
            "5. If the context contains the information but it's scattered, bring the relevant pieces together\n"
            "6. If the answer is not available in the context, state: 'The provided documents do not contain this information'\n"
            "7. Base your response only on what is written in the context below\n\n"
            "CONTEXT:\n{context}\n\n"
            "QUESTION: {question}\n\n"
        )
    )
    try:
        llm = OllamaLLM(model="llama3:8B")
        qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            retriever=retriever,
            chain_type="stuff",
            chain_type_kwargs={"prompt": custom_prompt}
        )
        return qa_chain
    except Exception as e:
        print(f"LLM Error: {e}")

In [None]:
def handle_query(query, qa_chain, file_name):
    try:
        if file_name:
            return qa_chain.invoke({"query": query, "filter": {"source": file_name}})
        else:
            return qa_chain.invoke(query)
    except Exception as e:
        print("Error:", e)

In [None]:
def delete_chunks_from_vectordb(chunk_ids):
    try:
        embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
        vectordb = Chroma(
            collection_name="internship",
            persist_directory="./my_db",
            embedding_function=embedding_model
        )
        
        # Delete the specific chunks by their IDs
        vectordb.delete(ids=chunk_ids)
        print(f"Successfully deleted {len(chunk_ids)} chunks from vector database")
        
    except Exception as e:
        print(f"Error deleting chunks from vector database: {e}")
        raise e

In [None]:
def pipeline(vectordb, file):
    print("-"*100)
    try:
        print(f"Processing file: {file.name}")
        action_num = db_utils.is_uploaded(file)
        print(f"Action number: {action_num}")
        
        if action_num == 0:
            print("File needs processing...")
            chunks = get_chunks(file, file.name)
            print("Chunks generated successfully")
            
            if chunks:
                chunk_ids = store_to_vectordb(vectordb, chunks, file.name)
                print(f"Stored {len(chunk_ids)} chunks to vector DB")
                
                db_utils.update_insert(chunk_ids, file)
                save_chunks_to_file(chunks, file.name)
                print("Pipeline completed successfully")
                print("-"*100)
                return chunks
            else:
                print("No chunks generated from file")
                print("-"*100)
                return []
        else:
            print("File already processed, skipping...")
            print("-"*100)
            return []
    except Exception as e:
        print(f"Pipeline error: {e}")
        import traceback
        traceback.print_exc()
        print("-"*100)
        return []

In [None]:
def flush_db():
    try:
        # Clear vector database
        vectordb = Chroma(
            collection_name="internship",
            persist_directory="./my_db"
        )
        vectordb.delete_collection()
        print("Vector database cleared.")
        
        # Also clear metadata database to stay in sync
        db_utils.clear_metadata_db()
        print("Metadata database cleared.")
        
    except Exception as e:
        print(f"Flush DB Error: {e}")

In [None]:
def generate_enhanced_response(query, response, reason, qa_chain):
    enhanced_prompt = f"""
    Previous attempt at answering this question was not satisfactory.
    
    Original question: {query}
    Previous response: {response}
    Reason for why the previous attempt was bad: {reason}
    
    Please provide a better response by:
    1. Being more specific and detailed
    2. Checking if you have relevant information in the documents
    3. If information is not available, clearly state that
    4. Provide actionable insights where possible
    
    Question: {query}
    """
    try:
        result = qa_chain.invoke({"query": enhanced_prompt})
        return result["result"]
    except Exception as e:
        return f"Error: {e}"