In [3]:
import os
import pandas as pd
import pickle
from langchain_openai.embeddings.base import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from dotenv import load_dotenv
from tqdm import tqdm

# Load environment variables
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
MODEL = "gpt-3.5-turbo"

# Paths
DATA_PATH = "data/handcrafted"
CRITERIA_PATH = 'criteria.csv'
VECTOR_DB_NAME = "vector_db.pkl"
RESULTS_HEADER = ["Goal", "Compliance", "Explanation"]

# Load Criteria Data
criteria_df = pd.read_csv(CRITERIA_PATH)

# Function to Load Vector Database for Company
def load_vector_db(company_dir):
    db_path = os.path.join(company_dir, VECTOR_DB_NAME)
    if not os.path.exists(db_path):
        raise FileNotFoundError(f"Vector database not found for {company_dir}")
    with open(db_path, "rb") as f:
        db_bytes = pickle.load(f)
    return FAISS.deserialize_from_bytes(db_bytes, OpenAIEmbeddings(), allow_dangerous_deserialization=True)

# Function to Test Compliance for a Goal
def test_compliance(company, goal, criteria_row, llm):
    # Load company-specific data
    company_dir = os.path.join(DATA_PATH, company)
    try:
        db = load_vector_db(company_dir)
    except FileNotFoundError:
        print(f"Skipping {company}: vector database not found.")
        return None

    # Retrieve compliance criteria for the specified goal
    high_criteria = criteria_row['HIGH Compliance']
    medium_criteria = criteria_row['MEDIUM Compliance']
    low_criteria = criteria_row['LOW Compliance']

    # Formulate the question
    question = f"Based on the context, what is the company's level of compliance (HIGH, MEDIUM, or LOW) with respect to the goal '{goal}'? Provide a brief explanation."

    # Retrieve relevant documents
    docs = db.similarity_search(question)
    if not docs:
        return [goal, "LOW", "No relevant context found."]
    context = "\n".join([doc.page_content for doc in docs])

    # Define the system prompt
    system_prompt = """
You are an ESG compliance expert. Based on the provided context and compliance criteria, determine the company's level of compliance with respect to the specified goal. You must be very strict and harsh in your assessment. Use only the retrieved context to support your assessment. If the context lacks sufficient information, say "I don't know".

Question:
{query}

Compliance Criteria:
- HIGH Compliance: {high_criteria}
- MEDIUM Compliance: {medium_criteria}
- LOW Compliance: {low_criteria}

Context:
{context}
"""

    # Create the prompt template
    prompt_template = PromptTemplate(
        input_variables=["context", "high_criteria", "medium_criteria", "low_criteria", "query"],
        template=system_prompt
    )

    # Create the LLM chain
    chain = LLMChain(
        llm=llm,
        prompt=prompt_template
    )

    # Run the chain with all inputs
    result = chain.run(
        context=context,
        high_criteria=high_criteria,
        medium_criteria=medium_criteria,
        low_criteria=low_criteria,
        query=question
    )

    # Extract compliance level from the result and return
    compliance_level = "LOW"  # Default if parsing fails
    explanation = result
    for level in ["HIGH", "MEDIUM", "LOW"]:
        if level in result:
            compliance_level = level
            break
    return [goal, compliance_level, explanation]

# Main function to iterate over companies and goals
def generate_compliance_results():
    llm = ChatOpenAI(model_name=MODEL, temperature=0, openai_api_key=OPENAI_API_KEY)
    companies = [d for d in os.listdir(DATA_PATH) if os.path.isdir(os.path.join(DATA_PATH, d))]

    for company in tqdm(companies, desc="Processing Companies"):
        results = []
        result_path = os.path.join(DATA_PATH, company, "results.csv")
        
        for _, criteria_row in criteria_df.iterrows():
            goal = criteria_row['Goal']
            result = test_compliance(company, goal, criteria_row, llm)
            if result:
                results.append(result)

        # Save results to CSV
        results_df = pd.DataFrame(results, columns=RESULTS_HEADER)
        results_df.to_csv(result_path, index=False)
    

# Execute the function
generate_compliance_results()


Processing Companies: 100%|██████████| 8/8 [04:56<00:00, 37.07s/it]
