In [None]:
from utils import *
from google.cloud import bigquery
from openai import OpenAI
import os
import plotly.graph_objects as go

In [None]:
FEEDBACK_PROJECT = os.environ["FEEDBACK_PROJECT"]
BERT_TOPICS_TABLE = os.environ["BERT_TOPICS_TABLE"]
BERT_TERMS_TABLE = os.environ["BERT_TERMS_TABLE"]
BERT_TOT_TABLE = os.environ["BERT_TOT_TABLE"]  # Where TOT = topics over time
PROCESS_TABLE = os.environ["PROCESS_TABLE"]
API_KEY = os.getenv("API_KEY")

In [None]:
UUID = "0db212fe-06fb-11ef-b6d9-acde48001122"  # Output of the FaaS bert pipeline for our 6 mth dataset

In [None]:
def _format_text(text: str) -> str:
    """Remove newlines, multiple whitespace, and leading/trailing whitespace."""
    text = text.replace("\n", " ")
    text = " ".join(text.split())
    text = text.strip()
    return text


def read_sql_file(file_path: str) -> str:
    with open(file_path, "r") as file:
        query = file.read()
    return query


def replace_kwargs(query, **kwargs) -> str:
    for key, value in kwargs.items():
        placeholder = "@" + key
        query = query.replace(placeholder, value)
    return query


def load_data(query: str, project: str) -> list[dict] | None:
    try:
        client = bigquery.Client(project=project)
        query_job = client.query(query)
        result = [dict(row) for row in query_job.result()]
    except Exception as e:
        print(f"An error occurred: {e}")
        return None
    return result


def create_bq_client() -> bigquery.Client:
    """
    Create and return a BigQuery client.
    """
    return bigquery.Client()


def run_query(query: str, project: str):
    """
    Run the query and return the results.
    """
    client = create_bq_client()
    job_config = bigquery.QueryJobConfig()
    job_config.use_legacy_sql = False
    query_job = client.query(query, job_config=job_config, project=project)
    return query_job.result()


def update_topic_dict(topic_dict: dict[str, dict[str, list]], row) -> None:
    """
    Update the topic dictionary with the row data.
    """
    if row.topics not in topic_dict:
        topic_dict[row.topics] = {"text_value": [], "keywords": []}

    topic_dict[row.topics]["text_value"].append(row.text_value)
    topic_dict[row.topics]["keywords"].append(row.term)


def format_text_values(text_values: list) -> str:
    """
    Format the text values for a topic.
    """
    return "\n".join([f"- {_format_text(text)}" for text in text_values])


def format_keywords(keywords: list) -> str:
    """
    Format the keywords for a topic.
    """
    return ", ".join(keywords)


def generate_completion(prompt: str):
    """
    Generate completion using OpenAI API.
    """
    client = OpenAI(api_key=API_KEY)
    return client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": prompt,
            }
        ],
        max_tokens=150,
        temperature=0.5,
        model="gpt-3.5-turbo",
    )


def create_topic_labels(
    query: str,
    UUID: str,
    FEEDBACK_PROJECT: str,
    TOPICS_TABLE: str,
    N_GRAM_TABLE: str,
) -> dict[str, str]:
    """
    Function which takes a BQ table as input containing a "topic", "probability" and "text column,
    effectively the output of a BERTopic topic model. For each topic, it orders the text column by
    probability and returns the top 10 rows for each topic. These ten text records are then passed
    to a gpt3.5-turbo prompt to generate a label for the topic. The label is then written to a new
    BQ table.
    """

    # Define the query
    query = replace_kwargs(
        query=query,
        UUID=UUID,
        FEEDBACK_PROJECT=FEEDBACK_PROJECT,
        TOPICS_TABLE=TOPICS_TABLE,
        N_GRAM_TABLE=N_GRAM_TABLE,
    )

    # Run the query
    results = run_query(query=query, project=FEEDBACK_PROJECT)

    # Create a dictionary to store the top 10 rows for each topic
    topic_dict = {}

    # Iterate over the results
    for row in results:
        update_topic_dict(topic_dict, row)

    # Iterate over the topics
    for topic in topic_dict:
        text_values = format_text_values(topic_dict[topic]["text_value"])
        keywords = format_keywords(topic_dict[topic]["keywords"])

        # Create a prompt
        prompt = f"""
        This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title less than ten words long.
        Instances of PII in the texts have been replaced with the type of PII surrounded by square brackets e.g. [PERSON_NAME]; these do not have any influence on the topic description.
        ---
        - I tried to renew the vehicle tax online several times today around 3.30pm, but on each occasion, after entering the 16 digit reference number, the system goes back to "START" again and again
        - not accepting the reference number
        - taxing my motor bike
        - Would not accept 16 digit reference no either by phone or website

        Keywords: tax car, try tax, tax vehicle, car tax, vehicle tax, try pay, road tax, not accept, debit card, tax try
        Topic Name: Online vechile tax
        ---
        - My phone number no longer works and I can't go to my Universal Credit account
        - No text messages received, all other details are correct but I cannot update my journal and I won't get paid this month if it isn't done soon.
        - I am unable to sign in to my universal credit. The page is telling me that it is unavailable.
        - My name is [PERSON_NAME], I am trying to sign into my account.

        Keywords: universal credit, try sign, try log, account try, student finance, credit account, can not, sign universal, sign account, sign try
        Topic Name: Universal credit / student finance sign in
        ---
        Topic:
        Sample texts from this topic:
        {text_values}
        Keywords: {keywords}
        Topic name:
        """

        try:
            completion = generate_completion(prompt)
        except Exception as e:
            print(f"OpenAI request failed: {e}")

        # Add the topic and label to the dictionary
        topic_dict[topic]["label"] = completion.choices[0].message.content

    return topic_dict

In [None]:
# Call the OpenAI API to label our BERT topics
query = read_sql_file("labels.sql")

labels = create_topic_labels(
    query=query,
    UUID=UUID,
    FEEDBACK_PROJECT=FEEDBACK_PROJECT,
    TOPICS_TABLE=BERT_TOPICS_TABLE,
    N_GRAM_TABLE=BERT_TERMS_TABLE,
)

print(labels)

In [None]:
# Load topics over time data
query = read_sql_file("generic_read_table.sql")
query = replace_kwargs(
    query=query,
    project=FEEDBACK_PROJECT,
    table=BERT_TOT_TABLE,
    uuid=UUID,
)

tot_data = load_data(query, FEEDBACK_PROJECT)

# Visualise topics over time with natural language labels as legend
topics = set(item["Topic"] for item in tot_data)
plot_data = {topic: {"date": [], "value": []} for topic in topics}

for item in tot_data:
    topic = item["Topic"]
    timestamp = item["Timestamp"]
    frequency = item["Frequency"]

    plot_data[topic]["date"].append(timestamp)
    plot_data[topic]["value"].append(frequency)

# Sort the date and value lists for each topic
for topic in plot_data:
    sorted_data = sorted(zip(plot_data[topic]["date"], plot_data[topic]["value"]))
    sorted_dates, sorted_values = zip(*sorted_data)
    plot_data[topic]["date"] = sorted_dates
    plot_data[topic]["value"] = sorted_values

# Create figure
fig = go.Figure()

# Adding traces
for i, (topic, vals) in enumerate(plot_data.items()):
    fig.add_trace(
        go.Scatter(
            x=vals["date"],
            y=vals["value"],
            mode="lines+markers",
            name=labels[topic]["label"],
            marker=dict(symbol=i),
            line_shape="spline",
        )
    )

# Layout settings
fig.update_layout(
    # title="Topics Over Time Analysis",
    yaxis_title="Frequency",
    legend_title="Topic",
    legend=dict(orientation="h", x=0.0, y=-0.15),
    font=dict(family="Helvetica", size=12, color="black"),
    plot_bgcolor="white",
    xaxis=dict(
        showline=True,
        showgrid=True,
        linecolor="black",
        tickmode="auto",
        nticks=15,
        tickformat="%b %d",
    ),
    yaxis=dict(showline=True, showgrid=True, linecolor="black"),
    margin=dict(l=20, r=20, t=40, b=80),
)

# Show the figure
fig.show()

In [None]:
# Visualise the counts of topics
query = read_sql_file("generic_read_table.sql")
query = replace_kwargs(
    query=query,
    project=FEEDBACK_PROJECT,
    table=BERT_TOPICS_TABLE,
    uuid=UUID,
)

results = load_data(query, FEEDBACK_PROJECT)

# Get count of topics
topic_counts = {}
for row in results:
    if row["topics"] not in topic_counts:
        topic_counts[row["topics"]] = 0
    topic_counts[row["topics"]] += 1
topic_counts

# Create figure
fig = go.Figure()
import plotly.graph_objects as go

# Extract the topics and counts from the topic_counts dictionary
topics = list(topic_counts.keys())
topics = [labels[topic]["label"] for topic in topics]
counts = list(topic_counts.values())

# Create the bar chart
fig = go.Figure(
    data=go.Bar(
        y=topics,
        x=counts,
        name=labels[topic]["label"],
        orientation="h",
    )
)

# Set the chart title and axis labels
fig.update_layout(
    title="Topic Counts",
    xaxis_title="Topic",
    yaxis_title="Frequency",
    # legend_title="Topic",
    # legend=dict(orientation="h", x=0.0, y=-0.15),
    font=dict(family="Helvetica", size=12, color="black"),
    plot_bgcolor="white",
    xaxis=dict(
        showline=True,
        showgrid=True,
        linecolor="black",
        tickmode="auto",
        nticks=15,
        tickformat="%b %d",
    ),
    yaxis=dict(showline=True, showgrid=True, linecolor="black"),
    margin=dict(l=20, r=20, t=40, b=80),
    height=1000,
)

# Show the chart
fig.show()

# TODO: extend y axis so you can see all labels <- Note, added height=1000 to layout but not tested

In [None]:
# Visualise sentiment over time
query = read_sql_file("sentiment.sql")
query = replace_kwargs(
    query=query,
    project=FEEDBACK_PROJECT,
    topics_table=BERT_TOPICS_TABLE,
    process_table=PROCESS_TABLE,
    uuid=UUID,
)

results = load_data(query, FEEDBACK_PROJECT)

sentiment_counts = {}
for record in results:
    date_key = record["created"].date()
    sentiment = record["sentiment"]
    if date_key not in sentiment_counts:
        sentiment_counts[date_key] = {}
    if sentiment not in sentiment_counts[date_key]:
        sentiment_counts[date_key][sentiment] = 0
    sentiment_counts[date_key][sentiment] += 1

dates = sorted(list(sentiment_counts.keys()))
positive_counts = [sentiment_counts[date].get("POSITIVE", 0) for date in dates]
negative_counts = [sentiment_counts[date].get("NEGATIVE", 0) for date in dates]
neutral_counts = [sentiment_counts[date].get("NEUTRAL", 0) for date in dates]

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=dates,
        y=positive_counts,
        mode="lines+markers",
        name="Positive",
        line_shape="spline",
    )
)
fig.add_trace(
    go.Scatter(
        x=dates,
        y=negative_counts,
        mode="lines+markers",
        name="Negative",
        line_shape="spline",
    )
)
fig.add_trace(
    go.Scatter(
        x=dates,
        y=neutral_counts,
        mode="lines+markers",
        name="Neutral",
        line_shape="spline",
    )
)


fig.update_layout(
    title="Daily count of Positive, Neutral, and Negative comments",
    xaxis_title="Date",
    yaxis_title="Count",
    legend=dict(orientation="h", x=0.0, y=-0.15),
    font=dict(family="Helvetica", size=12, color="black"),
    plot_bgcolor="white",
    xaxis=dict(
        showline=True,
        showgrid=True,
        linecolor="black",
        tickmode="auto",
        nticks=15,
        tickformat="%b %d",
    ),
    yaxis=dict(showline=True, showgrid=True, linecolor="black"),
    margin=dict(l=20, r=20, t=40, b=80),
)

fig.show()