In [7]:
from openai import OpenAI
from swarm import Swarm, Agent
import pandas as pd
import time
from langchain.schema import HumanMessage
from concurrent.futures import ThreadPoolExecutor
import logging
from langchain_ollama import ChatOllama
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.vectorstores import Chroma
from langchain_ollama import OllamaEmbeddings
from langchain_community.document_loaders import UnstructuredPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema import Document
import ollama
from langchain.retrievers.multi_query import MultiQueryRetriever
import re

# Configure logging
logging.basicConfig(level=logging.INFO)

# Constants
DOC_PATH = "../data/BA_AirlineReviews.csv"
MODEL_NAME = "llama3.1:8b"
EMBEDDING_MODEL = "nomic-embed-text"
VECTOR_STORE_NAME = "simple-rag"

def load_data(data_frame):
    data_summary = [
        f"ReviewHeader: {row['ReviewHeader']}, ReviewBody: {row['ReviewBody']}"
        for _, row in data_frame.iterrows()
    ]
    documents = [Document(page_content=entry) for entry in data_summary]
    return documents


def split_documents(documents):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    chunks = text_splitter.split_documents(documents)
    logging.info("Documents split into chunks.")
    return chunks


def create_vector_db(chunks):
    ollama.pull(EMBEDDING_MODEL)
    vector_db = Chroma.from_documents(
        documents=chunks,
        embedding=OllamaEmbeddings(model=EMBEDDING_MODEL),
        collection_name=VECTOR_STORE_NAME,
    )
    logging.info("Vector database created.")
    return vector_db


def extract_themes_from_feedback(df):
    documents = load_data(df)
    chunks = split_documents(documents)
    vector_db = create_vector_db(chunks)
    llm = ChatOllama(model=MODEL_NAME)

    QUERY_PROMPT = ChatPromptTemplate.from_template("""
    You are an AI language model assistant specializing in detailed text analysis. 
    Identify and retrieve all relevant themes from the text database based on the question:
    Original question: {question}
    Provide a comprehensive list of themes identified from the database.
    Return the themes in this format:
    1. Theme Name: Explanation
    2. Theme Name: Explanation
    """)

    retriever = MultiQueryRetriever.from_llm(vector_db.as_retriever(), llm, prompt=QUERY_PROMPT)

    template = """Answer the question based ONLY on the following context:
    {context}
    Question: {question}
    """
    prompt = ChatPromptTemplate.from_template(template)

    chain = (
        {"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser()
    )

    question = "What are the overall themes from the feedback?"
    response = chain.invoke(input=question)
    themes = response.strip()

    pattern = r'(\d+\.\s[^:]+:\s[^0-9]+)'
    matches = re.findall(pattern, themes)

    themes_column_name_pattern = r'\d+\.\s*([^:]+):'
    match_column_name = re.findall(themes_column_name_pattern, themes)

    logging.info("Themes extracted successfully.")
    print("Extracted Themes:", match_column_name)
    return matches, match_column_name


def classify_row_with_theme(row, theme, chat_model):
    text = row["ReviewBody"]
    prompt = f"Does the following text mention issues related to '{theme}'? Respond with 'Yes' or 'No'. Text: '{text}'"
    try:
        response = chat_model([HumanMessage(content=prompt)])
        response_text = response.content.strip()
    except Exception as e:
        logging.error(f"Error processing row {row['RowID']}: {e}")
        response_text = "Cant classify"
    return row['RowID'], response_text


def classify_all_rows_with_theme(df, theme, chat_model):
    theme_tags = []
    with ThreadPoolExecutor(max_workers=5) as executor:
        async_tasks = {
            executor.submit(classify_row_with_theme, row, theme, chat_model): row["RowID"]
            for _, row in df.iterrows()
        }

        for task in async_tasks:
            row_id = async_tasks[task]
            try:
                result = task.result()
                theme_tags.append(result[1])
            except Exception as e:
                logging.error(f"Cannot classify row {row_id}: {e}")
                theme_tags.append("Cant classify")

    df[f'{theme}'] = theme_tags
    return df


def run_demo_loop(agent_a, agent_b):
    print("Starting Ollama Swarm CLI 🐝")

    df = pd.read_csv(DOC_PATH)
    df = df.sample(n=20)
    df = df.rename(columns={"Unnamed: 0": "RowID"})

    themes, theme_columns = extract_themes_from_feedback(df)

    chat_model = ChatOllama(model=MODEL_NAME)

    for theme in theme_columns:
        df = classify_all_rows_with_theme(df, theme, chat_model)

    print("Classification complete. Data saved to CSV.")
    df.to_csv("../data/classified_feedback.csv", index=False)


if __name__ == "__main__":
    client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
    swarm_client = Swarm(client=client)

    agent_a = Agent(
        name="Agent A",
        model=MODEL_NAME,
        instructions="You are an AI agent that extracts themes from feedback data.",
        functions=[extract_themes_from_feedback]
    )

    agent_b = Agent(
        name="Agent B",
        model=MODEL_NAME,
        instructions="You are an AI agent that classifies rows based on extracted themes.",
        functions=[classify_all_rows_with_theme]
    )

    run_demo_loop(agent_a, agent_b)


INFO:root:Documents split into chunks.


Starting Ollama Swarm CLI 🐝


INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/pull "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed "HTTP/1.1 200 OK"
INFO:root:Vector database created.
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:langchain.retrievers.multi_query:Generated queries: ['After conducting an in-depth analysis of the text database, I have identified and retrieved the following relevant themes based on the original question "What are the overall themes from the feedback?":', "1. **Satisfaction with Product Quality**: Feedback indicates a general satisfaction with the product's performance, durability, and features, with many customers expressing their happiness with its ability to meet their needs.", '2. **User Experience (UX) Improvement Suggestions**: Several customers provided constructive feedback on how to improve the user experience, including suggestions for simplifying navigation, enhancing design aesthetics, and strea

Extracted Themes: ['**Disappointment in service decline**', '**Comparison to other airlines**', '**Poor cabin design and seating**', '**Cost-cutting measures negatively impacting customer experience**', '**Lack of attention to detail**']


INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/chat "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127

Classification complete. Data saved to CSV.
