# Research with Gemini

We're going to be testing Google's Gemini API. 

Credentials are located in config.yaml in the ai-research root folder.

## Setup

In [1]:
# Imports
from chiefai.ai import make_gemini_request
from chiefai.db import query
import polars as pl

# Notebook formatting
from IPython.display import display, HTML, Markdown

In [2]:
data = query("SELECT * FROM web_visit_results LIMIT 5")
data.head(3)

id,core_client,project_id,user_id,session_id,unique_key,postal_code,region,dma,dma_code,city,country,browser,device,device_type,search_engine,medium,source,platform,platform_version,bounce,session_date_time,ip_address,pages,search_terms,language,latitude,longitude,organization,referrer,session_length,created_at,updated_at,core_product,last_access,content,session
i64,str,i64,str,str,str,str,str,str,i64,str,str,str,str,str,str,str,str,str,str,str,datetime[μs],str,list[struct[2]],str,str,f64,f64,str,str,str,datetime[μs],datetime[μs],str,datetime[μs],null,null
99735578,"""OPOS""",82884616,"""50182511673882""","""3288761085059530754""","""828846165018251167388232887610…",,,,0,,"""USA""","""Unknown""",,"""Desktop/laptop""",,"""paidsocial""","""snapchat""","""Unknown""","""Unknown""","""true""",2024-01-03 13:11:13,"""34.123.204.87""","[{""collections/vaginal-health"",0}]",,"""Unknown""",37.751,-97.822,"""Google""","""Direct""","""0""",2024-01-03 13:13:00.337214,2024-01-03 13:13:00.337784,"""""",2024-01-03 13:11:13,,
99735579,"""OPOS""",82884616,"""78135251450464""","""5120671839057608705""","""828846167813525145046451206718…","""32163""","""Florida""","""Orlando, FL""",534,"""The Villages""","""USA""","""Safari""","""Apple iPhone""","""Mobile""",,"""paidsocial""","""ig""","""iOS""","""iOS 17.1""","""true""",2024-01-03 13:11:13,"""68.205.39.5""","[{""collections/vaginal-health"",0}]",,"""English (United States)""",28.9265,-81.9928,"""Spectrum""","""instagram.com""","""0""",2024-01-03 13:13:00.337214,2024-01-03 13:13:00.337784,"""""",2024-01-03 13:11:13,,
99735580,"""OPOS""",82884616,"""103390995781948""","""6775832299565744285""","""828846161033909957819486775832…",,,,0,,"""USA""","""Chrome""","""Generic Android""","""Mobile""",,,"""Direct""","""Android""","""Android 9.0""","""true""",2024-01-03 13:11:13,"""147.160.184.123""","[{""tools/recurring/login"",0}]",,"""English (United States)""",37.751,-97.822,"""Unknown""","""Direct""","""0""",2024-01-03 13:13:00.337214,2024-01-03 13:13:00.337784,"""""",2024-01-03 13:11:13,,


## Assess AI Approach to Data Analysis

We're going to use the make_gemini_request function in chiefmedai.ai to send data. 

The approach we're going to try is to create summary statistics from the database and programmatically generate a text query to feed to the AI model. We'll use prompt injection of the data to do this. 

Firstly we'll try getting simple counts of orders by platform and ask the AI model to assess the results. 


In [3]:
query_str = """
WITH results AS (
    select 
        p.unique_key,
        COALESCE(s.station_name, p.station) AS station, 
        p.date,
        CASE 
            WHEN EXTRACT(ISODOW FROM p.date) = 1 THEN 'Monday'
            WHEN EXTRACT(ISODOW FROM p.date) = 2 THEN 'Tuesday'
            WHEN EXTRACT(ISODOW FROM p.date) = 3 THEN 'Wednesday'
            WHEN EXTRACT(ISODOW FROM p.date) = 4 THEN 'Thursday'
            WHEN EXTRACT(ISODOW FROM p.date) = 5 THEN 'Friday'
            WHEN EXTRACT(ISODOW FROM p.date) = 6 THEN 'Saturday'
            ELSE 'Sunday'
        END AS weekday,
        cbd.cdaypart AS daypart, 
        p.length, 
        p.buyrate AS spend, 
        lam.online_visits, 
        lam.online_orders, 
        lam.online_revenue, 
        lam.online_leads, 
        lam.impressions AS target_demo_impressions, 
        lam.impressions2 AS total_impressions
    from core_post_time p
    left outer join core_tape_details td ON td.tapecd = p.tape
        AND td.cmedia = p.media
        AND td.cclient = p.client
        AND td.cproduct = p.product
        AND td.startdate <= p.bcdate
        AND td.enddate >= p.bcdate
    left outer join core_tape_parent tpp ON tpp.ctpparent = td.ctpparent
        AND tpp.cclient = p.client
        AND tpp.cproduct = p.product
    left outer join linear_attribution_metrics lam ON lam.unique_key = p.unique_key
    join core_estimate est ON est.cmedia = p.media
        AND est.cclient = p.client
        AND est.cproduct = p.product
        AND est.cestimate = p.estimate
    join core_buy_detail cbd ON cbd.nbuydetid = p.buydetid
    join core_buy_table cbt ON cbt.nbuyid = p.buyid
    join stations s ON s.core_label = p.station
    where p.client = 'OPOS'
        AND p.media != 'OL'
        AND p.date >= NOW() - INTERVAL '4 weeks'
)

SELECT 
    station,
    weekday,
    daypart,
    COUNT(unique_key) AS spot_count,
    SUM(spend) AS total_spend,
    ROUND(AVG(spend)) AS average_spend,
    SUM(online_visits) AS total_visits,
    ROUND(AVG(online_visits)) AS average_visits,
    SUM(online_leads) AS total_leads,
    ROUND(AVG(online_leads)) AS average_leads,
    SUM(online_orders) AS total_orders,
    ROUND(AVG(online_orders)) AS average_orders,
    ROUND(SUM(online_revenue)::NUMERIC, 2) AS total_revenue,
    ROUND(AVG(online_revenue)::NUMERIC, 2) AS average_order_value,
    SUM(target_demo_impressions) AS total_target_demo_impressions,
    ROUND(AVG(target_demo_impressions)) AS average_target_demo_impressions,
    CASE 
        WHEN SUM(target_demo_impressions) = 0 THEN 0 
        ELSE ROUND(SUM(spend) / SUM(target_demo_impressions), 2)
    END AS cpm_target_demo,
    SUM(total_impressions) AS total_all_demo_impressions,
    ROUND(AVG(total_impressions)) AS average_all_demo_impressions,
    CASE 
        WHEN SUM(total_impressions) = 0 THEN 0 
        ELSE ROUND(SUM(spend) / SUM(total_impressions), 2)
    END AS cpm_all_demo
FROM results
GROUP BY     
    station,
    weekday,
    daypart
"""

In [4]:
data = query(query_str)
print(data.head(3))

shape: (3, 20)
┌────────────┬─────────┬─────────┬────────────┬───┬────────────┬───────────┬───────────┬───────────┐
│ station    ┆ weekday ┆ daypart ┆ spot_count ┆ … ┆ cpm_target ┆ total_all ┆ average_a ┆ cpm_all_d │
│ ---        ┆ ---     ┆ ---     ┆ ---        ┆   ┆ _demo      ┆ _demo_imp ┆ ll_demo_i ┆ emo       │
│ str        ┆ str     ┆ str     ┆ i64        ┆   ┆ ---        ┆ ressions  ┆ mpression ┆ ---       │
│            ┆         ┆         ┆            ┆   ┆ decimal[*, ┆ ---       ┆ s         ┆ decimal[* │
│            ┆         ┆         ┆            ┆   ┆ 2]         ┆ i64       ┆ ---       ┆ ,2]       │
│            ┆         ┆         ┆            ┆   ┆            ┆           ┆ decimal[* ┆           │
│            ┆         ┆         ┆            ┆   ┆            ┆           ┆ ,0]       ┆           │
╞════════════╪═════════╪═════════╪════════════╪═══╪════════════╪═══════════╪═══════════╪═══════════╡
│ A&E        ┆ Friday  ┆ DA      ┆ 1          ┆ … ┆ null       ┆ null      ┆

### Descriptive Prompt

In [5]:
def generate_descriptive_prompt(
    df, 
    group_cols, 
    currency_symbol="$", 
    add_comparisons=True,
    decimal_places=2
):
    """
    Generate a descriptive text prompt from a dataframe with grouped statistics.
    
    Parameters:
    -----------
    df : polars.DataFrame or pandas.DataFrame
        The dataframe containing the grouped statistics
    group_cols : list or str
        Column name(s) that represent categorical groupings
    currency_symbol : str, optional
        Symbol to use for currency formatting (default: "$")
    add_comparisons : bool, optional
        Whether to add comparisons between stats of the same measure (default: True)
    decimal_places : int, optional
        Number of decimal places to round numeric values (default: 2)
        
    Returns:
    --------
    str
        Formatted descriptive text suitable for an AI prompt
    """
    # Convert group_cols to list if it's a string
    if isinstance(group_cols, str):
        group_cols = [group_cols]
    
    # Validate inputs
    for col in group_cols:
        if col not in df.columns:
            raise ValueError(f"Group column '{col}' not found in dataframe")
    
    # Infer stat columns (all columns that aren't group columns)
    stat_cols = [col for col in df.columns if col not in group_cols]
    
    # Check if all stat columns are numeric
    for col in stat_cols:
        # Check column type for both polars and pandas
        try:
            # For polars
            if hasattr(df, 'dtypes'):
                col_type = df[col].dtype
                is_numeric = any(num_type in str(col_type) for num_type in ['Int', 'Float', 'int', 'float'])
            # For pandas
            else:
                is_numeric = df[col].dtype.kind in 'ifc'
                
            if not is_numeric:
                raise ValueError(f"Stat column '{col}' must be numeric, but got type {df[col].dtype}")
        except Exception as e:
            raise ValueError(f"Error checking if column '{col}' is numeric: {str(e)}")
    
    # Determine if we're using polars or pandas
    is_polars = hasattr(df, 'iter_rows')
    
    # Get rows for iteration
    if is_polars:
        rows = list(df.iter_rows(named=True))
    else:  # pandas
        rows = df.to_dict('records')
    
    # Group stat columns by their measure
    measures = {}
    for col in stat_cols:
        # Try to split the column name into statistic type and measure
        parts = col.split('_', 1)
        if len(parts) >= 2:
            stat_type, measure = parts[0], parts[1]
            if measure not in measures:
                measures[measure] = {}
            measures[measure][stat_type] = col
    
    # Start building the prompt
    prompt_text = ""
    
    # Format a value based on its measure
    def format_value(value, measure):
        if value is None:
            return "N/A"
        
        value = round(value, decimal_places) if isinstance(value, (int, float)) else value
        
        # Format based on measure name
        if any(term in measure for term in ['revenue', 'price', 'cost', 'income', 'expense']):
            return f"{currency_symbol}{value:.{decimal_places}f}" if isinstance(value, float) else f"{currency_symbol}{value}"
        elif any(term in measure for term in ['percent', 'rate', 'ratio']):
            return f"{value:.{decimal_places}f}%" if isinstance(value, float) else f"{value}%"
        else:
            return f"{value:.{decimal_places}f}" if isinstance(value, float) else f"{value}"
    
    for row in rows:
        # Create the group description
        group_desc = ""
        for col in group_cols:
            label = col.replace('_', ' ').title()
            group_desc += f"{label}: {row[col]}\n"
        
        # Add the statistics descriptions
        stats_desc = ""
        
        # Process each measure and its stats
        for measure, stats in measures.items():
            measure_label = measure.replace('_', ' ')
            
            # Process each statistic for this measure
            for stat_type, col in stats.items():
                value = row[col]
                formatted_value = format_value(value, measure)
                
                # Create description based on statistic type
                if stat_type == 'count':
                    stats_desc += f"There were {formatted_value} {measure_label}.\n"
                elif stat_type == 'mean':
                    stats_desc += f"The average (mean) {measure_label} was {formatted_value}.\n"
                elif stat_type == 'median':
                    stats_desc += f"The median {measure_label} was {formatted_value}.\n"
                elif stat_type == 'sum' or stat_type == 'total':
                    stats_desc += f"The total {measure_label} was {formatted_value}.\n"
                elif stat_type == 'min':
                    stats_desc += f"The minimum {measure_label} was {formatted_value}.\n"
                elif stat_type == 'max':
                    stats_desc += f"The maximum {measure_label} was {formatted_value}.\n"
                else:
                    stats_desc += f"The {stat_type} of {measure_label} was {formatted_value}.\n"
            
            # Add comparisons if requested and if we have both mean and median
            if add_comparisons and 'mean' in stats and 'median' in stats:
                mean_value = row[stats['mean']]
                median_value = row[stats['median']]
                
                if mean_value is not None and median_value is not None:
                    if mean_value > median_value:
                        difference = round(mean_value - median_value, decimal_places)
                        formatted_diff = format_value(difference, measure)
                        stats_desc += f"The mean is {formatted_diff} higher than the median, suggesting some high-value outlier {measure_label}.\n"
                    elif median_value > mean_value:
                        difference = round(median_value - mean_value, decimal_places)
                        formatted_diff = format_value(difference, measure)
                        stats_desc += f"The median is {formatted_diff} higher than the mean, suggesting some low-value outlier {measure_label}.\n"
                    else:
                        stats_desc += f"The mean and median are identical, suggesting a symmetrical distribution of {measure_label}.\n"
        
        # Combine group and stats descriptions
        section = group_desc + stats_desc + ("-" * 40 + "\n")
        prompt_text += section
    
    # Remove the final separator
    prompt_text = prompt_text.rstrip("-" * 40 + "\n")
    
    # Add a summary
    groups_count = len(rows)
    if len(group_cols) == 1:
        group_label = group_cols[0].replace('_', ' ').lower()
        prompt_text += f"\n\nSummary: Analyzed statistics for {groups_count} different {group_label}s."
    else:
        group_labels = [col.replace('_', ' ').lower() for col in group_cols]
        group_desc = ", ".join(group_labels)
        prompt_text += f"\n\nSummary: Analyzed statistics for {groups_count} different groups based on {group_desc}."
    
    return prompt_text

### Tabular Prompt

In [6]:
import polars as pl

def generate_table_prompt(
    df,
    title="Data Analysis",
    description=None,
    format_currency=True,
    currency_cols=None, 
    currency_symbol="$",
    decimal_places=2,
    analysis_request=None
):
    """
    Generate an AI prompt with a markdown table from a Polars dataframe.
    
    Parameters:
    -----------
    df : polars.DataFrame
        The Polars dataframe to convert to a table
    title : str, optional
        Title to use in the prompt (default: "Data Analysis")
    description : str, optional
        Custom description of the data (default: auto-generated based on columns)
    format_currency : bool, optional
        Whether to automatically format columns that appear to be currency (default: True)
    currency_cols : list, optional
        Specific columns to format as currency (default: None, auto-detect)
    currency_symbol : str, optional
        Symbol to use for currency formatting (default: "$")
    decimal_places : int, optional
        Number of decimal places for numeric values (default: 2)
    analysis_request : str or list, optional
        Custom analysis instructions (default: standard analysis request)
        
    Returns:
    --------
    str
        A prompt with the data formatted as a markdown table
    """
    # Create a copy of the dataframe to avoid modifying the original
    formatted_df = df.clone()
    
    # Format currency columns if requested
    if format_currency:
        # Auto-detect currency columns if not specified
        if currency_cols is None:
            currency_cols = []
            for col in formatted_df.columns:
                col_lower = str(col).lower()
                if any(term in col_lower for term in ['revenue', 'price', 'cost', 'value', 'amount', 'income', 'expense']):
                    currency_cols.append(col)
        
        # Format the currency columns
        for col in currency_cols:
            if col in formatted_df.columns:
                # Check if column is numeric
                col_type = formatted_df[col].dtype
                is_numeric = any(num_type in str(col_type) for num_type in ['Int', 'Float', 'int', 'float'])
                
                if is_numeric:
                    # Format currency values with specified symbol and decimal places
                    formatted_df = formatted_df.with_columns(
                        pl.col(col).map_elements(
                            lambda x: f"{currency_symbol}{x:.{decimal_places}f}" if x is not None else "N/A",
                            return_dtype=pl.Utf8
                        ).alias(col)
                    )
    
    # Format other numeric columns to have consistent decimal places
    for col in formatted_df.columns:
        if col not in currency_cols:
            col_type = formatted_df[col].dtype
            is_numeric = any(num_type in str(col_type) for num_type in ['Int', 'Float', 'int', 'float'])
            
            if is_numeric:
                formatted_df = formatted_df.with_columns(
                    pl.col(col).map_elements(
                        lambda x: f"{x:.{decimal_places}f}" if isinstance(x, float) else str(x),
                        return_dtype=pl.Utf8
                    ).alias(col)
                )
    
    # Generate the markdown table
    # Format the header
    header = "| " + " | ".join(formatted_df.columns) + " |"
    
    # Format the separator
    separator = "| " + " | ".join(["---" for _ in formatted_df.columns]) + " |"
    
    # Format each row
    rows = []
    for row in formatted_df.iter_rows(named=True):
        formatted_row = "| " + " | ".join([str(row[col]) for col in formatted_df.columns]) + " |"
        rows.append(formatted_row)
    
    # Combine all parts
    markdown_table = "\n".join([header, separator] + rows)
    
    # Auto-generate description if not provided
    if description is None:
        # Try to identify what the data is about
        col_list = list(formatted_df.columns)
        
        # Extract key column types
        group_cols = [col for col in col_list if col not in (currency_cols or [])]
        
        if group_cols:
            description = f"I'm sharing data grouped by {', '.join(group_cols)}."
            if currency_cols:
                description += f" The table includes {', '.join(currency_cols)} values."
        else:
            description = "I'm sharing the following data for analysis."
    
    # Add a paragraph with some "expert" context to make the returned insights more helpful. 
    expert_context = [
        "Target demo and all demo impressions should not be related to each other, as target demo is simply a subset of all demos",
        "While the overnight (ON) is cheap, it does not scale well. Be cautious in suggesting moving dollars to the overnight",
        "Be hyper-vigilant to sample size. We do not want to make optimizations using only a few instances of a data point",
        "Avoid providing generic advice such as 'move budget to stations with low CPMs.' Be descriptive in your suggestions and make refernce to specifics"
    ]
    expert_context = "Here is some expert context to help you:\n" + "\n".join([f"{i+1}. {item}" for i, item in enumerate(expert_context)])

    # Default analysis request if not provided
    if analysis_request is None:
        analysis_request = [
            "Key insights and patterns in the data",
            "Notable outliers or anomalies",
            "Interpretations of the trends observed",
            "Strategic recommendations based on this data"
        ]
    
    # Format the analysis request
    if isinstance(analysis_request, list):
        formatted_request = "Please provide:\n" + "\n".join([f"{i+1}. {item}" for i, item in enumerate(analysis_request)])
    else:
        formatted_request = analysis_request
    
    # Construct the full prompt
    prompt = f"""# {title}

{description}

{markdown_table}

{expert_context}

{formatted_request}
"""
    
    return prompt

In [7]:
prompt = generate_table_prompt(data)
#display(Markdown(prompt))

In [8]:
response = make_gemini_request(request_text=prompt)
display(Markdown(response))

Error calling Gemini API: 503 UNAVAILABLE. {'error': {'code': 503, 'message': 'The model is overloaded. Please try again later.', 'status': 'UNAVAILABLE'}}


ServerError: 503 UNAVAILABLE. {'error': {'code': 503, 'message': 'The model is overloaded. Please try again later.', 'status': 'UNAVAILABLE'}}

## Hybrid Prompt

In [None]:
from chiefai.ai import analyze_campaign_data
import polars as pl

# Run analysis
results = analyze_campaign_data(
    data,
    temperature=0.3,
    min_sample_size=5,
    confidence_level=0.95
)