# ***AI-Compliance-Inspector: Automated Compliance Verification using RAG and LLaVA Author: Rawan Alahmadi***

# ***Section 1: Downloads - Libraries and Model***

In [None]:
# Install necessary libraries
!pip install torch transformers faiss-cpu pytesseract pdfplumber Pillow llama-index

# Import libraries
import torch
from transformers import pipeline
from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex, LLMPredictor, PromptHelper
import pytesseract
import pdfplumber
from PIL import Image
import os

# Download and load the LLaVA model
model_name = "LLaVA/llava-7b"
try:
    llava_pipeline = pipeline("image-to-text", model=model_name)
    print("LLaVA model loaded successfully.")
except Exception as e:
    print(f"Error loading LLaVA model: {e}")

print("Libraries and model downloaded successfully.")

# ***Section 2: Data Purpose - Preparing Data***

In [None]:

def extract_text_from_image(image_path):
    """
    Extract text from images using OCR.
    Args:
        image_path (str): Path to the image file.
    Returns:
        str: Extracted text from the image.
    """
    try:
        img = Image.open(image_path)
        text = pytesseract.image_to_string(img, lang="eng+ara")
        print(f"Text extracted from image: {image_path}")
        return text.strip()
    except Exception as e:
        print(f"Error extracting text from image: {e}")
        return ""

def extract_text_from_pdf(pdf_path):
    """
    Extract text from PDF files using PDFPlumber.
    Args:
        pdf_path (str): Path to the PDF file.
    Returns:
        str: Extracted text from the PDF.
    """
    try:
        text = ""
        with pdfplumber.open(pdf_path) as pdf:
            for page in pdf.pages:
                page_text = page.extract_text()
                if page_text:
                    text += page_text + "\n"
        print(f"Text extracted from PDF: {pdf_path}")
        return text.strip()
    except Exception as e:
        print(f"Error extracting text from PDF: {e}")
        return ""

def describe_image_with_llava(image_path):
    """
    Describe the content of an image using the LLaVA model.
    Args:
        image_path (str): Path to the image file.
    Returns:
        str: Description of the image.
    """
    try:
        img = Image.open(image_path)
        description = llava_pipeline(img)[0]['generated_text']
        print(f"Description generated for image: {image_path}")
        return description
    except Exception as e:
        print(f"Error describing image: {e}")
        return "Unable to describe the image."

def prepare_data(input_path):
    """
    Prepare data by extracting text from images and PDF files.
    Args:
        input_path (str): Path to the folder containing documents.
    Returns:
        dict: Dictionary with filenames as keys and extracted texts as values.
    """
    data = {}
    for file_name in os.listdir(input_path):
        file_path = os.path.join(input_path, file_name)
        if file_name.lower().endswith((".png", ".jpg", ".jpeg")):
            data[file_name] = extract_text_from_image(file_path)
        elif file_name.lower().endswith(".pdf"):
            data[file_name] = extract_text_from_pdf(file_path)
        else:
            print(f"Unsupported file format: {file_name}")
    return data

# ***Section 3: RAG Setup - Adding Compliance Files***

In [None]:
from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex, LLMPredictor, PromptHelper
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_compliance_files(compliance_dir):
    """
    Load compliance documents from a specified directory.
    Args:
        compliance_dir (str): Path to the compliance files directory.
    Returns:
        dict: Dictionary of compliance texts indexed by file name.
    """
    compliance_data = {}
    try:
        for file_name in os.listdir(compliance_dir):
            file_path = os.path.join(compliance_dir, file_name)
            if file_name.lower().endswith((".txt", ".pdf")):
                if file_name.lower().endswith(".pdf"):
                    compliance_text = extract_text_from_pdf(file_path)
                else:
                    with open(file_path, 'r', encoding='utf-8') as file:
                        compliance_text = file.read()
                compliance_data[file_name] = compliance_text
                print(f"Loaded compliance file: {file_name}")
            else:
                print(f"Skipping unsupported file format: {file_name}")
    except Exception as e:
        print(f"Error loading compliance files: {e}")
    return compliance_data

def create_rag_index(compliance_data):
    """
    Create a RAG index for compliance documents using Allam model from SDAIA.
    Args:
        compliance_data (dict): Dictionary of compliance texts.
    Returns:
        GPTVectorStoreIndex: Index object for RAG.
    """
    try:
        # Load Allam model from SDAIA
        model_name = "SDAIA/Allam-7B-Instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)

        llm_predictor = LLMPredictor(
            model=model,
            tokenizer=tokenizer,
            temperature=0.5,
            max_length=256
        )

        prompt_helper = PromptHelper(max_input_size=1024, num_output=256, max_chunk_overlap=20)

        # Create RAG index using Allam model
        documents = [SimpleDirectoryReader(compliance_data).load_data()]
        index = GPTVectorStoreIndex.from_documents(
            documents,
            llm_predictor=llm_predictor,
            prompt_helper=prompt_helper
        )
        print("RAG index created successfully using Allam model.")
        return index
    except Exception as e:
        print(f"Error creating RAG index with Allam model: {e}")
        return None

def search_compliance(query, index):
    """
    Search compliance documents using RAG.
    Args:
        query (str): User query for compliance check.
        index (GPTVectorStoreIndex): The RAG index object.
    Returns:
        str: The most relevant compliance text.
    """
    try:
        response = index.query(query)
        print("Compliance search completed.")
        return response
    except Exception as e:
        print(f"Error during compliance search: {e}")
        return "No relevant compliance found."

# Example usage
compliance_files_path = "./compliance_files"
compliance_data = load_compliance_files(compliance_files_path)
rag_index = create_rag_index(compliance_data)

# Test RAG search with Allam model
query = "What are the data privacy requirements?"
if rag_index:
    result = search_compliance(query, rag_index)
    print(f"Search Result: {result}")

# ***Section 4: Compliance Calculation using SDAIA Allam Model***


In [None]:
def calculate_compliance(query, index):
    """
    Calculate the compliance score using the RAG model with Allam.
    Args:
        query (str): The compliance check query.
        index (GPTVectorStoreIndex): The RAG index object.
    Returns:
        dict: Compliance score and relevant information.
    """
    try:
        # Perform RAG-based search using Allam model
        result = search_compliance(query, index)

        # Example keywords for compliance classification
        compliance_keywords = {
            "fully compliant": 100,
            "compliant": 90,
            "partially compliant": 70,
            "non-compliant": 30,
            "not compliant": 10,
            "unknown": 50
        }

        # Calculate compliance score
        compliance_score = 50  # Default for ambiguous cases
        for keyword, score in compliance_keywords.items():
            if keyword in result.lower():
                compliance_score = score
                break

        # Compile compliance result
        compliance_result = {
            "query": query,
            "result": result,
            "score": compliance_score,
            "status": "Compliant" if compliance_score >= 90 else "Non-Compliant" if compliance_score <= 30 else "Partially Compliant"
        }

        # Print and return the compliance result
        print(f"\nCompliance Check: {query}")
        print(f"Result: {result}")
        print(f"Compliance Score: {compliance_score}%")
        print(f"Status: {compliance_result['status']}\n")
        return compliance_result
    except Exception as e:
        print(f"Error calculating compliance: {e}")
        return {"query": query, "result": "Error", "score": 0, "status": "Unknown"}

# Example usage
user_query = "Does the document comply with data protection regulations?"
if rag_index:
    compliance_result = calculate_compliance(user_query, rag_index)
    print(f"Final Compliance Result: {compliance_result}")


# ***Section 5: User Interface for Compliance Check***

In [None]:
def main():
    """
    Main function to interact with the user for compliance checks.
    """
    print("\nWelcome to AI-Compliance-Inspector")
    print("Automated Compliance Verification using RAG and Allam Model\n")

    # Load compliance files and create RAG index
    compliance_files_path = "./compliance_files"
    compliance_data = load_compliance_files(compliance_files_path)
    rag_index = create_rag_index(compliance_data)

    if rag_index:
        while True:
            print("\nEnter your compliance query (or type 'exit' to quit):")
            user_query = input("> ")

            if user_query.lower() == 'exit':
                print("Exiting the compliance inspector. Goodbye!")
                break

            # Calculate compliance
            compliance_result = calculate_compliance(user_query, rag_index)

            # Display result
            print("\n----- Compliance Check Result -----")
            print(f"Query: {compliance_result['query']}")
            print(f"Result: {compliance_result['result']}")
            print(f"Compliance Score: {compliance_result['score']}%")
            print(f"Status: {compliance_result['status']}")
            print("-----------------------------------\n")
    else:
        print("Failed to create RAG index. Please check your files and try again.")

# Run the main function
if __name__ == "__main__":
    main()
