In [44]:
from llama_index.llms.ollama import Ollama

# Load a local model (e.g., Mistral)
llm = Ollama(model="llama3.3:latest", temperature=0.2)


In [15]:
import os
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv('OAK')


# client = OpenAI(
#     api_key=os.getenv('OAK'),  # This is the default and can be omitted
# )

# chat_completion = client.chat.completions.create(
#     messages=[
#         {
#             "role": "user",
#             "content": "Say this is a test",
#         }
#     ],
#     model="gpt-3.5-turbo",
# )
# print(chat_completion)

import openai
openai.api_key = os.getenv('OAK')
print(openai.Client().models.list())

SyncPage[Model](data=[Model(id='gpt-4.5-preview', created=1740623059, object='model', owned_by='system'), Model(id='omni-moderation-2024-09-26', created=1732734466, object='model', owned_by='system'), Model(id='gpt-4.5-preview-2025-02-27', created=1740623304, object='model', owned_by='system'), Model(id='gpt-4o-mini-audio-preview-2024-12-17', created=1734115920, object='model', owned_by='system'), Model(id='dall-e-3', created=1698785189, object='model', owned_by='system'), Model(id='dall-e-2', created=1698798177, object='model', owned_by='system'), Model(id='gpt-4o-audio-preview-2024-10-01', created=1727389042, object='model', owned_by='system'), Model(id='gpt-4o-audio-preview', created=1727460443, object='model', owned_by='system'), Model(id='gpt-4o-mini-realtime-preview-2024-12-17', created=1734112601, object='model', owned_by='system'), Model(id='gpt-4o-2024-11-20', created=1739331543, object='model', owned_by='system'), Model(id='gpt-4o-mini-realtime-preview', created=1734387380, o

In [10]:

from typing import List, Dict, Any
import pandas as pd
# from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext
# from llama_index.core.schema import Document
# from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.llms.openai import OpenAI
from llama_index.agent.openai import OpenAIAgent
from llama_index.core.tools import FunctionTool
# import matplotlib.pyplot as plt
# import seaborn as sns
import numpy as np

class CSVAgent:
    def __init__(self, csv_path: str = None, csv_data: pd.DataFrame = None):
        """Initialize with either a path to a CSV file or a pandas DataFrame."""
        if csv_data is not None:
            self.df = csv_data
        elif csv_path is not None:
            self.df = pd.read_csv(csv_path)
        else:
            raise ValueError("Either csv_path or csv_data must be provided")
        
        self.csv_path = csv_path
        self.column_info = self._get_column_info()
        
        # Initialize LLM using OpenAI model (e.g., gpt-3.5-turbo)
        self.llm = OpenAI(model="gpt-3.5-turbo")
        
        # Create tools
        self.tools = [
            FunctionTool.from_defaults(fn=self.get_dataframe_info),
            FunctionTool.from_defaults(fn=self.get_column_data),
            FunctionTool.from_defaults(fn=self.query_data),
            FunctionTool.from_defaults(fn=self.filter_data),
            FunctionTool.from_defaults(fn=self.group_by_data),
            FunctionTool.from_defaults(fn=self.sort_data),
            FunctionTool.from_defaults(fn=self.get_statistics),
            # FunctionTool.from_defaults(fn=self.generate_chart)
        ]
        
        # Create agent
        self.agent = OpenAIAgent.from_tools(
            self.tools,
            llm=self.llm,
            verbose=True,
            system_prompt=(
                "You are a helpful assistant that analyzes CSV data. "
                "Use the available tools to explore, analyze, and visualize the data. "
                "When using tools that return data, summarize the results in a clear and helpful way. "
                "For numerical data, consider providing statistical insights. "
                "For categorical data, consider providing distributions or patterns."
            )
        )
    
    def _get_column_info(self) -> Dict[str, Any]:
        """Get information about the columns in the DataFrame."""
        info = {}
        for col in self.df.columns:
            info[col] = {
                "dtype": str(self.df[col].dtype),
                "sample_values": self.df[col].head(3).tolist(),
                "unique_count": self.df[col].nunique(),
                "null_count": self.df[col].isna().sum()
            }
        return info
    
    def get_dataframe_info(self) -> str:
        """Get basic information about the dataframe."""
        shape = self.df.shape
        columns = self.df.columns.tolist()
        dtypes = self.df.dtypes.to_dict()
        dtypes = {k: str(v) for k, v in dtypes.items()}
        
        missing_data = self.df.isna().sum().to_dict()
        
        info = {
            "rows": shape[0],
            "columns": shape[1],
            "column_names": columns,
            "dtypes": dtypes,
            "missing_values": missing_data,
            "sample_data": self.df.head(5).to_dict(orient="records")
        }
        
        return str(info)
    
    def get_column_data(self, column_name: str) -> str:
        """Get data from a specific column."""
        if column_name not in self.df.columns:
            return f"Column '{column_name}' not found. Available columns: {', '.join(self.df.columns)}"
        
        values = self.df[column_name].tolist()
        return str(values)
    
    def query_data(self, query: str) -> str:
        """Run a pandas query on the dataframe."""
        try:
            result = self.df.query(query)
            if len(result) > 10:
                return f"Query returned {len(result)} rows. First 10 rows:\n{result.head(10).to_string()}"
            return result.to_string()
        except Exception as e:
            return f"Error executing query: {str(e)}"
    
    def filter_data(self, column: str, value: str, operator: str = "==") -> str:
        """Filter dataframe based on column value."""
        if column not in self.df.columns:
            return f"Column '{column}' not found. Available columns: {', '.join(self.df.columns)}"
        
        try:
            # Convert value to appropriate type if numeric
            if pd.api.types.is_numeric_dtype(self.df[column]):
                try:
                    value = float(value)
                except ValueError:
                    pass
            
            # Handle different operators
            if operator == "==":
                filtered_df = self.df[self.df[column] == value]
            elif operator == "!=":
                filtered_df = self.df[self.df[column] != value]
            elif operator == ">":
                filtered_df = self.df[self.df[column] > value]
            elif operator == ">=":
                filtered_df = self.df[self.df[column] >= value]
            elif operator == "<":
                filtered_df = self.df[self.df[column] < value]
            elif operator == "<=":
                filtered_df = self.df[self.df[column] <= value]
            elif operator == "contains":
                filtered_df = self.df[self.df[column].astype(str).str.contains(str(value))]
            else:
                return f"Unsupported operator: {operator}"
            
            if len(filtered_df) > 10:
                return f"Filter returned {len(filtered_df)} rows. First 10 rows:\n{filtered_df.head(10).to_string()}"
            return filtered_df.to_string()
        except Exception as e:
            return f"Error filtering data: {str(e)}"
    
    def group_by_data(self, group_cols: str, agg_dict: str) -> str:
        """Group data by columns and aggregate."""
        try:
            group_cols = [col.strip() for col in group_cols.split(",")]
            for col in group_cols:
                if col not in self.df.columns:
                    return f"Column '{col}' not found. Available columns: {', '.join(self.df.columns)}"
            
            # Parse the aggregation dictionary
            import ast
            agg_dict = ast.literal_eval(agg_dict)
            
            result = self.df.groupby(group_cols).agg(agg_dict).reset_index()
            if len(result) > 10:
                return f"Groupby returned {len(result)} rows. First 10 rows:\n{result.head(10).to_string()}"
            return result.to_string()
        except Exception as e:
            return f"Error in group by operation: {str(e)}"
    
    def sort_data(self, columns: str, ascending: bool = True) -> str:
        """Sort dataframe by specified columns."""
        try:
            sort_cols = [col.strip() for col in columns.split(",")]
            for col in sort_cols:
                if col not in self.df.columns:
                    return f"Column '{col}' not found. Available columns: {', '.join(self.df.columns)}"
            
            result = self.df.sort_values(by=sort_cols, ascending=ascending)
            if len(result) > 10:
                return f"Sorted data ({len(result)} rows). First 10 rows:\n{result.head(10).to_string()}"
            return result.to_string()
        except Exception as e:
            return f"Error sorting data: {str(e)}"
    
    def get_statistics(self, columns: str = "all") -> str:
        """Get descriptive statistics for numeric columns."""
        try:
            if columns.lower() == "all":
                numeric_df = self.df.select_dtypes(include=np.number)
                if numeric_df.empty:
                    return "No numeric columns found in the dataset."
                return numeric_df.describe().to_string()
            
            columns = [col.strip() for col in columns.split(",")]
            stats_df = pd.DataFrame()
            
            for col in columns:
                if col not in self.df.columns:
                    return f"Column '{col}' not found. Available columns: {', '.join(self.df.columns)}"
                
                if not pd.api.types.is_numeric_dtype(self.df[col]):
                    return f"Column '{col}' is not numeric. Can only get statistics for numeric columns."
                
                stats_df[col] = self.df[col]
            
            return stats_df.describe().to_string()
        except Exception as e:
            return f"Error getting statistics: {str(e)}"
    
    # def generate_chart(self, chart_type: str, x_col: str, y_col: str = None, title: str = "Chart") -> str:
    #     """Generate a chart based on the data."""
    #     try:
    #         # Validate columns
    #         if x_col not in self.df.columns:
    #             return f"Column '{x_col}' not found. Available columns: {', '.join(self.df.columns)}"
            
    #         if y_col and y_col not in self.df.columns:
    #             return f"Column '{y_col}' not found. Available columns: {', '.join(self.df.columns)}"
            
    #         # Create different chart types
    #         plt.figure(figsize=(10, 6))
            
    #         if chart_type.lower() == "bar":
    #             if y_col:
    #                 sns.barplot(x=self.df[x_col], y=self.df[y_col])
    #             else:
    #                 self.df[x_col].value_counts().plot(kind='bar')
                
    #         elif chart_type.lower() == "histogram":
    #             sns.histplot(self.df[x_col])
                
    #         elif chart_type.lower() == "scatter":
    #             if not y_col:
    #                 return "Scatter plot requires both x and y columns."
    #             sns.scatterplot(x=self.df[x_col], y=self.df[y_col])
                
    #         elif chart_type.lower() == "line":
    #             if y_col:
    #                 sns.lineplot(x=self.df[x_col], y=self.df[y_col])
    #             else:
    #                 self.df[x_col].plot(kind='line')
                
    #         elif chart_type.lower() == "box":
    #             sns.boxplot(x=self.df[x_col])
                
    #         elif chart_type.lower() == "violin":
    #             if y_col:
    #                 sns.violinplot(x=self.df[x_col], y=self.df[y_col])
    #             else:
    #                 sns.violinplot(y=self.df[x_col])
                
    #         else:
    #             return f"Chart type '{chart_type}' not supported. Supported types: bar, histogram, scatter, line, box, violin"
            
    #         plt.title(title)
    #         plt.xticks(rotation=45)
    #         plt.tight_layout()
            
    #         # Save the chart temporarily
    #         temp_file = "temp_chart.png"
    #         plt.savefig(temp_file)
    #         plt.close()
            
    #         return f"Chart generated and saved as {temp_file}"
    #     except Exception as e:
    #         return f"Error generating chart: {str(e)}"
    
    def chat(self, message: str) -> str:
        """Chat with the agent about the CSV data."""
        response = self.agent.chat(message)
        return response.response


In [11]:
agent = CSVAgent(csv_path='./csv/daily-for-last-7-days.csv')

In [12]:
def main():    
    # Example queries
    queries = [
        "Which Channel Name has the most Total Viewership Minutes",
        # "Which channel has the highest average viewership minutes per session",
        # "How does the average session count vary across different days of the week? Draw a chart",
        # "How does viewership change during different weeks of the month?",
        # "How does the average session duration per session compare across different channels? Draw a chart",
        # "What is the demographic breakdown of viewers for each channel?",
        # "Which channel had the highest total viewership minutes for the month?",
        # "Which day of the month had the highest unique viewership?",
        # "What is the average session duration per viewer for each channel?",
        # "Which channel had the highest average session count for the month?",
        # "Which week of the month saw the most significant increase in viewership?",
        # "How does viewership compare between weekdays and weekends?",
        # "Which channel had the highest number of unique viewers for the month?",
        # "For the top 3 channels with the most viewership minutes, how do their average session durations and unique viewer numbers compare? Are they successful for different reasons?",
        # "If we look at all channels, is there a relationship between the number of sessions and the total viewership minutes? Do channels with more sessions generally get more total watch time?",
        # "Thinking about efficiency, which channels are getting the most viewership minutes per session started? Which channels are the most efficient at converting sessions into watch time?",
        # "Which channels show the largest discrepancy between high Session Count but low Total Viewership Minutes, potentially indicating content that attracts clicks but fails to maintain viewer interest?"
    ]
    
    for query in queries:
        print(f"\nQuestion: {query}")
        response = agent.chat(query)
        print(f"Answer: {response}")

if __name__ == "__main__":
    main()


Question: Which Channel Name has the most Total Viewership Minutes
Added user message to memory: Which Channel Name has the most Total Viewership Minutes


Retrying llama_index.llms.openai.base.OpenAI._chat in 1.0 seconds as it raised APIConnectionError: Connection error..
Retrying llama_index.llms.openai.base.OpenAI._chat in 1.6127800327782784 seconds as it raised APIConnectionError: Connection error..


APIConnectionError: Connection error.