In [None]:
from utils import *
from google.cloud import bigquery
from openai import OpenAI
import os
import plotly.graph_objects as go

In [None]:
FEEDBACK_PROJECT = os.environ["FEEDBACK_PROJECT"]
GSDMM_TOPICS_TABLE = os.environ["GSDMM_TOPICS_TABLE"]
PROCESS_TABLE = os.environ["PROCESS_TABLE"]
API_KEY = os.getenv("API_KEY")

In [None]:
UUID = "dd526f3c-efbc-11ee-a5ac-4200a9fe0102"  # This UUID is the output of the FaaS subtopic pipeline on our 6mth FaaS topic pipeline outputs

In [None]:
def _format_text(text: str) -> str:
    """Remove newlines, multiple whitespace, and leading/trailing whitespace."""
    text = text.replace("\n", " ")
    text = " ".join(text.split())
    text = text.strip()
    return text


def read_sql_file(file_path: str) -> str:
    with open(file_path, "r") as file:
        query = file.read()
    return query


def replace_kwargs(query, **kwargs) -> str:
    for key, value in kwargs.items():
        placeholder = "@" + key
        query = query.replace(placeholder, value)
    return query


def load_data(query: str, project: str) -> list[dict] | None:
    try:
        client = bigquery.Client(project=project)
        query_job = client.query(query)
        result = [dict(row) for row in query_job.result()]
    except Exception as e:
        print(f"An error occurred: {e}")
        return None
    return result


def create_bq_client() -> bigquery.Client:
    """
    Create and return a BigQuery client.
    """
    return bigquery.Client()


def run_query(query: str, project: str):
    """
    Run the query and return the results.
    """
    client = create_bq_client()
    job_config = bigquery.QueryJobConfig()
    job_config.use_legacy_sql = False
    query_job = client.query(query, job_config=job_config, project=project)
    return query_job.result()


# Analyse GSDMM subtopics if they exist
def check_uuid_exists(uuid):
    query = read_sql_file("check_uuid.sql")
    query = replace_kwargs(
        query=query,
        project=FEEDBACK_PROJECT,
        topics_table=GSDMM_TOPICS_TABLE,
    )
    results = load_data(query, FEEDBACK_PROJECT)
    uuid_list = [result["uuid"] for result in results]
    if not uuid in uuid_list:
        return False
    else:
        return True


def get_subtopics(uuid: str):
    if check_uuid_exists(uuid):
        query = read_sql_file("get_subtopics.sql")
        query = replace_kwargs(
            query=query,
            project=FEEDBACK_PROJECT,
            topics_table=GSDMM_TOPICS_TABLE,
            UUID=uuid,
        )
        results = load_data(query, FEEDBACK_PROJECT)
        return results


def extract_bigrams(uuid):
    # For each topic, save one instance of the terms_list array output in the results dict
    query = read_sql_file("extract_bigrams.sql")
    query = replace_kwargs(
        query=query,
        project=FEEDBACK_PROJECT,
        topics_table=GSDMM_TOPICS_TABLE,
        UUID=uuid,
    )
    results = load_data(query, FEEDBACK_PROJECT)
    return results


# Format results to return a dictionary of the form {topic: {subtopic: [bigram1, bigram2, ...]}}
def format_bigrams(results):
    bigrams = {}

    # Loop over the results and create a dictionary of bigrams
    for result in results:
        bert_topic = result["BERT_topic"]
        gsdmm_topic = result["GSDMM_topic"]
        terms = result["unique_terms"]
        if bert_topic not in bigrams:
            bigrams[bert_topic] = {}
        bigrams[bert_topic][gsdmm_topic] = terms

    # Replace Other bigrams with Other because they are not useful
    for _, subtopics in bigrams.items():
        if "Other" in subtopics:
            subtopics["Other"] = ["Other"]

    return bigrams


# Extract results and visualise
def extract_top_docs(uuid):
    query = read_sql_file("extract_top_docs.sql")

    query = replace_kwargs(
        query=query,
        project=FEEDBACK_PROJECT,
        topics_table=GSDMM_TOPICS_TABLE,
        UUID=uuid,
    )
    results = load_data(query, FEEDBACK_PROJECT)
    return results


# Format results to return a dictionary of the form {BERT_topic: {GSDMM_topic: [sentence1, sentence2, ...]}}
def format_top_docs(results):
    top_docs = {}

    # Loop over the results and create a dictionary of top docs
    for result in results:
        bert_topic = result["BERT_topic"]
        gsdmm_topic = result["GSDMM_topic"]
        sentence = result["sentence"]
        if bert_topic not in top_docs:
            top_docs[bert_topic] = {}
        if gsdmm_topic not in top_docs[bert_topic]:
            top_docs[bert_topic][gsdmm_topic] = []
        top_docs[bert_topic][gsdmm_topic].append(sentence)

    return top_docs


def format_docs_and_grams(formatted_bigrams, formatted_docs):
    topics_dict = {}
    if formatted_bigrams.keys() == formatted_docs.keys():
        for bert_topic in formatted_bigrams.keys():
            topics_dict[bert_topic] = {}

            for topic in formatted_bigrams[bert_topic].keys():
                topics_dict[bert_topic][topic] = {}
                topics_dict[bert_topic][topic]["bigrams"] = formatted_bigrams[
                    bert_topic
                ][topic]
                topics_dict[bert_topic][topic]["docs"] = formatted_docs[bert_topic][
                    topic
                ]
    else:
        print("Topics are not equivalent in docs and bigrams")
    return topics_dict
    # This creates {bert_topic: {gibss_topic: {bigrams: [bigram_1, bigram_2, ...]}, docs: {gibbs_topic: [doc1, doc2, ...]}}}


def generate_completion(prompt: str):
    """
    Generate completion using OpenAI API.
    """
    client = OpenAI(api_key=API_KEY)
    return client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": prompt,
            }
        ],
        max_tokens=150,
        temperature=0.5,
        model="gpt-3.5-turbo",
    )


def create_subtopic_labels(topics_dict: dict, bert_topic: int):
    if bert_topic in topic_cache.keys():
        return

    labels = {}
    for key, value in topics_dict[bert_topic].items():
        topic_names = key

        topic_values = value
        for topic_name in [topic_names]:
            bigrams = topic_values["bigrams"]
            docs = topic_values["docs"]
            prompt_bigrams = ", ".join(x for x in bigrams)
            prompt_docs = "\n".join([f"- {_format_text(text)}" for text in docs])

            # # Create a prompt
            prompt = f"""
            This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title less than ten words long.
            Instances of PII in the texts have been replaced with the type of PII surrounded by square brackets e.g. [PERSON_NAME]; these do not have any influence on the topic description.
            ---
            - I tried to renew the vehicle tax online several times today around 3.30pm, but on each occasion, after entering the 16 digit reference number, the system goes back to "START" again and again
            - not accepting the reference number
            - taxing my motor bike
            - Would not accept 16 digit reference no either by phone or website

            Keywords: tax car, try tax, tax vehicle, car tax, vehicle tax, try pay, road tax, not accept, debit card, tax try
            Topic Name: Online vechile tax
            ---
            - My phone number no longer works and I can't go to my Universal Credit account
            - No text messages received, all other details are correct but I cannot update my journal and I won't get paid this month if it isn't done soon.
            - I am unable to sign in to my universal credit. The page is telling me that it is unavailable.
            - My name is [PERSON_NAME], I am trying to sign into my account.

            Keywords: universal credit, try sign, try log, account try, student finance, credit account, can not, sign universal, sign account, sign try
            Topic Name: Universal credit / student finance sign in
            ---
            Topic:
            Sample texts from this topic:
            {prompt_docs}
            Keywords: {prompt_bigrams}
            Topic name:
            """

            # Check if the topic is in the cache
            if topic_name not in topic_cache.keys():
                try:
                    completion = generate_completion(prompt)
                    topic_cache[topic_name] = completion.choices[0].message.content
                except Exception as e:
                    print(f"OpenAI request failed: {e}")

            # Add the topic and label to the dictionary
            labels[topic_name] = topic_cache[topic_name]
            # Add topic_name: labels to the topic_cache
            topic_cache[bert_topic] = labels


def count_subtopics(gibbs_results: list[dict], bert_topic: int, topic_cache: dict):
    filter_relevant_data = [x for x in gibbs_results if x["BERT_topic"] == bert_topic]
    gibbs_topics = set(x["GSDMM_topic"] for x in filter_relevant_data)
    gibbs_topic_counts = {topic: 0 for topic in gibbs_topics}
    for record in filter_relevant_data:
        gibbs_topic_counts[record["GSDMM_topic"]] += 1
    return gibbs_topic_counts


def create_gibbs_barchart(
    gibbs_results: list[dict], bert_topic: int, topic_cache: dict
):
    """
    For a given bert topic, create a bar chart of subtopic counts and subtopic labels so you can see what's going on.

    Args:
        gibbs_results: list[dict] = list of dictionaries containing the gibbs results queried from BQ
        bert_topic: int = bert topic of interest that we want to visualise
    """

    gibbs_topic_counts = count_subtopics(
        gibbs_results=gibbs_results, bert_topic=bert_topic, topic_cache=topic_cache
    )

    # Define ordered list: Will always be 10 unless we change the subtopic pipeline
    desired_order = [
        "Other",
        "Topic 0",
        "Topic 1",
        "Topic 2",
        "Topic 3",
        "Topic 4",
        "Topic 5",
        "Topic 6",
        "Topic 7",
        "Topic 8",
        "Topic 9",
    ]

    topics = [x for x in desired_order if x in gibbs_topic_counts]
    counts = [gibbs_topic_counts[x] for x in topics]
    try:
        topics = [topic_cache[key] for key in desired_order if key in topic_cache]
    except Exception as e:
        print(f"Error: {e}, most likely topic_cache doesn't exist")

    # Create the bar chart
    fig = go.Figure(data=go.Bar(x=topics, y=counts))

    # Set the chart title and axis labels
    fig.update_layout(
        title=f"Subtopics for BERT topic: {bert_topic}",
        xaxis_title="Subtopic",
        yaxis_title="Count",
        legend=dict(orientation="h", x=0.0, y=-0.15),
        font=dict(family="Helvetica", size=12, color="black"),
        plot_bgcolor="white",
        xaxis=dict(
            showline=True,
            showgrid=True,
            linecolor="black",
            tickmode="auto",
            nticks=15,
            tickformat="%b %d",
        ),
        yaxis=dict(showline=True, showgrid=True, linecolor="black"),
        margin=dict(l=20, r=20, t=40, b=80),
    )

    # Show the chart
    fig.show()

In [None]:
gibbs_results = get_subtopics(UUID)
raw_bigrams = extract_bigrams(UUID)
formatted_bigrams = format_bigrams(raw_bigrams)
raw_top_docs = extract_top_docs(UUID)
formatted_docs = format_top_docs(raw_top_docs)
topics_dict = format_docs_and_grams(formatted_bigrams, formatted_docs)
topic_cache = {"Other": "Other"}
create_subtopic_labels(
    topics_dict, 0
)  # Change number here to label a different BERT topic <- if already labelled it won't call OpenAI again though!
print(topic_cache)  # To see our labels

In [None]:
create_gibbs_barchart(
    gibbs_results, 0, topic_cache
)  # Visualise a subtopic, if it exists in topic cache it will use the labels instead of generic topic names