In [4]:
# Import required libraries
import import_ipynb
# Import required libraries
import json
import google.generativeai as genai
import os
import random
from Utils import extract_chunk  # Import the utility function

# Load secrets from secrets.json
with open("secrets.json", "r") as f:
    secrets = json.load(f)

# Configure Gemini API
GOOGLE_API_KEY = secrets["GOOGLE_API_KEY"]
genai.configure(api_key=GOOGLE_API_KEY)

# Initialize the GenerativeModel
model = genai.GenerativeModel('gemini-pro')

# Load JSON file
with open('GEMINI_best_model_parameters.json', 'r') as file:
    best_hyperparameters = json.load(file)

# Load the saved chunk settings
MAX_CHUNK_SIZE = best_hyperparameters["CHUNK_SIZE"]
ALIGN_SENTENCES = best_hyperparameters["SENTENCE_ALIGNMENT"]
# Load the saved prompt template
prompt_template = best_hyperparameters["PROMPT"]



# Define the classification function
def classify_document_text(document_text):
    """
    Classifies a document by extracting a chunk and using the Gemini model.
    
    Args:
        document_text (str): Full text of the document.
    
    Returns:
        str: Predicted category.
    """
    # Extract a chunk
    chunk = extract_chunk(
        document_text,
        chunk_length=MAX_CHUNK_SIZE,
        respect_sentence_boundaries=ALIGN_SENTENCES,
    )
    
    # Generate prompt
    prompt = prompt_template.format(chunk=chunk)
    
    # Call Gemini API
    try:
        response = model.generate_content(prompt)
        predicted_category = response.text.strip()
        
        # Extract only the category name (handle cases where the response includes extra text)
        predicted_category = predicted_category.split("\n")[0].strip()  # Take the first line
        predicted_category = predicted_category.replace("Category:", "").strip()  # Remove "Category:" if present
        
        return predicted_category
    except Exception as e:
        print(f"Error generating text: {e}")
        return None

# Example usage
if __name__ == "__main__":
    # Load a test document from the Documents folder
    DOCUMENTS_DIR = "Documents"  # Folder containing text documents
    doc_id = "000101"  # Replace with an actual DocID from your dataset
    doc_path = os.path.join(DOCUMENTS_DIR, f"{doc_id}.txt")

    if os.path.exists(doc_path):
        with open(doc_path, "r", encoding="utf-8") as file:
            document_text = file.read()
        
        # Classify the document
        predicted_category = classify_document_text(document_text)
        print(f"Predicted Category: {predicted_category}")
    else:
        print(f"Document {doc_id}.txt not found in {DOCUMENTS_DIR}.")

Predicted Category: Philosophy
