In [None]:
# install packages
%pip install -U langchain-ollama
%pip install langchain.prompts
%pip install langchain_community
%pip install langchain_community.llms
%pip install seaborn

# Import Packages

In [64]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain.prompts import ChatPromptTemplate
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pandas.api.types import is_datetime64_any_dtype, is_numeric_dtype, is_object_dtype
import os
from pydantic import BaseModel
import json
import numpy as np

## Preprocessing
- Return JSON's Headers and Description of Data

### Return JSON's Headers and Description of Data

- Tells LLM what files are available and what headers are in each file

In [None]:
import os
import json


def extract_headers(file_path):
    """
    Extract headers from a JSON file.

    Parameters:
        file_path (str): The path to the JSON file.

    Returns:
        list: List of headers from the JSON file.
    """
    with open(file_path, "r") as f:
        data = json.load(f)

    headers = []

    if isinstance(data, dict):
        # For a dictionary, get the keys
        headers = list(data.keys())
    elif isinstance(data, list) and data:
        first_item = data[0]
        if isinstance(first_item, dict):
            # For a list of dictionaries, get the keys from the first item
            headers = list(first_item.keys())

    return headers


def recursive_json_schema_extractor(directory):
    """
    Recursively walks through the given directory and extracts headers
    from all JSON files found.

    Parameters:
        directory (str): The root directory to start the recursive search.

    Returns:
        dict: A dictionary mapping JSON file paths to their header lists
    """
    schemas = {}
    for entry in os.listdir(directory):
        full_path = os.path.join(directory, entry)
        if os.path.isdir(full_path):
            # Recursively process subdirectories
            schemas.update(recursive_json_schema_extractor(full_path))
        elif entry.lower().endswith(".json"):
            try:
                headers = extract_headers(full_path)
                schemas[full_path] = headers
            except Exception as e:
                print(f"Error processing file {full_path}: {e}")
    return schemas


# Extract headers from all JSON files
data_dir = "./Data"
json_schemas = recursive_json_schema_extractor(data_dir)

# Print the results
for file_path, headers in json_schemas.items():
    print(f"\nFile: {file_path}")
    print("Headers:")
    for header in headers:
        print(f"  {header}")
    print("-" * 40)

## Visualisation Creator

### Select the file and columns to create the visualization

In [None]:
import os
from pydantic import BaseModel


# Pydantic model for chart information
class ChartInfo(BaseModel):
    file_name: str
    x_axis: str
    y_axis: str


# Format json_schemas into a readable string for the prompt (without type info)
def format_schemas_for_prompt(schemas):
    formatted_str = ""
    for file_path, headers in schemas.items():
        file_name = os.path.basename(file_path)  # Get just the filename without path
        # Assume headers is now a list; join just the header names.
        headers_str = ", ".join(headers)
        formatted_str += f"{file_name}: [{headers_str}]\n"
    return formatted_str


# Template that includes instructions for the LLM
template = """Given the following JSON file headers, determine the most appropriate file and columns to create the visualization.

Available JSON files and their headers:
{json_headers}

Chart request: {query}

Respond only with the following information in this exact format:
file_name: [selected json file name]
x_axis: [column name for x-axis]
y_axis: [column name for y-axis]"""

prompt = ChatPromptTemplate.from_template(template)
model = OllamaLLM(model="llama3.2")


# Function to parse LLM output into ChartInfo
def parse_llm_response(response: str) -> ChartInfo:
    lines = response.strip().split("\n")
    parsed = {}
    for line in lines:
        key, value = line.split(": ")
        parsed[key] = value.strip()

    return ChartInfo(**parsed)


# Create the chain with structured output
chain = prompt | model | parse_llm_response

# Example usage with json_schemas:
formatted_headers = format_schemas_for_prompt(json_schemas)
result = chain.invoke(
    {"json_headers": formatted_headers, "query": "Show me the revenue trends over time"}
)

print("Selected file:", result.file_name)
print("X-axis:", result.x_axis)
print("Y-axis:", result.y_axis)

### Dataframe for selected file

In [68]:
df = pd.read_json(f"./Data/{result.file_name}")

### Code for Chart Generator

1. Properly imports all necessary data type checking functions from pandas
2. Handles different data type combinations:
    - Time series data (both datetime and string-based time columns)
    - Numeric vs numeric (scatter plots)
    - Categorical vs numeric (box plots)
    - Categorical vs categorical (heatmaps)
3. Includes automatic handling of:
    - Label rotation for better readability
    - Layout adjustments to prevent cutoff
    - Proper sorting for time series data


In [69]:
def generate_graph(df: pd.DataFrame, x_col: str, y_col: str) -> None:
    """
    Generate an appropriate visualization based on the data types of input columns.

    Parameters:
        df (pd.DataFrame): Input dataframe
        x_col (str): Column name for x-axis
        y_col (str): Column name for y-axis
    """
    # Set figure size and style
    plt.figure(figsize=(12, 6))
    sns.set_style("whitegrid")

    # Determine data types
    x_is_datetime = is_datetime64_any_dtype(df[x_col])
    x_is_numeric = is_numeric_dtype(df[x_col])
    x_is_categorical = is_object_dtype(df[x_col])

    y_is_datetime = is_datetime64_any_dtype(df[y_col])
    y_is_numeric = is_numeric_dtype(df[y_col])
    y_is_categorical = is_object_dtype(df[y_col])

    # Time series plot
    if x_is_datetime and y_is_numeric:
        sns.lineplot(data=df, x=x_col, y=y_col)
        plt.xticks(rotation=90)

    # Scatter plot for numeric vs numeric
    elif x_is_numeric and y_is_numeric:
        sns.scatterplot(data=df, x=x_col, y=y_col)

    # Box plot for categorical vs numeric
    elif x_is_categorical and y_is_numeric:
        sns.boxplot(data=df, x=x_col, y=y_col)
        plt.xticks(rotation=90)

    # Bar plot for categorical vs numeric (alternative to box plot)
    elif y_is_categorical and x_is_numeric:
        sns.barplot(data=df, x=x_col, y=y_col)
        plt.xticks(rotation=90)

    # Heatmap for categorical vs categorical
    elif x_is_categorical and y_is_categorical:
        # Create contingency table
        contingency = pd.crosstab(df[x_col], df[y_col])
        sns.heatmap(contingency, annot=True, fmt="d", cmap="YlOrRd")
        plt.xticks(rotation=90)

    else:
        raise ValueError("Unsupported combination of data types")

    # Add labels and adjust layout
    plt.xlabel(x_col)
    plt.ylabel(y_col)
    plt.title(f"{y_col} vs {x_col}")
    plt.tight_layout()  # Prevent label cutoff

    # Show the plot
    plt.show()

### Plot Chart

In [None]:
generate_graph(df, result.x_axis, result.y_axis)