In [None]:
from openai import OpenAI
from google.cloud import bigquery
import os
from dotenv import load_dotenv
import json
import pandas as pd
from pprint import pprint

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
FEEDBACK_PROJECT_ID = os.getenv("FEEDBACK_PROJECT_ID")
PUBLISHING_PROJECT_ID = os.getenv("PUBLISHING_PROJECT_ID")
FEEDBACK_TABLE = os.getenv("FEEDBACK_TABLE")
PUBLISHING_TABLE = os.getenv("PUBLISHING_TABLE")
LABELLED_FEEDBACK_TABLE = os.getenv("LABELLED_FEEDBACK_TABLE")
OPENAI_LABEL_FEEDBACK_TABLE = os.getenv("OPENAI_LABELLED_FEEDBACK_TABLE")

In [None]:
def get_feedback_by_record_id(N: int) -> list:
    """
    Extracts feedback records from BigQuery, grouped by
    feedback_record_id and concatenating response_value
    :return: Dictionary containing feedback records
    """
    client = bigquery.Client(project=PUBLISHING_PROJECT_ID, location="europe-west2")
    query = """
        SELECT
          feedback_record_id,
          STRING_AGG(response_value, ' '
          ORDER BY
            created) AS concatenated_response_value,
            rand() as r
        FROM
          @publishing_table
          WHERE DATE(created) >= "2024-01-01"
        GROUP BY
          feedback_record_id
        ORDER BY
          r
      LIMIT (@N)
    """
    query = query.replace("@publishing_table", str(PUBLISHING_TABLE))
    query = query.replace("@N", str(N))
    query_job = client.query(query=query)
    result = query_job.result()

    records = []
    for row in result:
        record = dict(row)
        records.append(record)

    return records

In [None]:
# Load unlabelled feedback records for labelling with OpenAI
records = get_feedback_by_record_id(5)
type(records)

In [None]:
records

### CHECK FOR PII IN feedback - regenerate sample if so.

In [None]:
def get_labelled_feedback_sample(N: int) -> list:
    """
    Extracts labelled feedback records from a BigQuery table
    :return: Dictionary containing feedback records
    """
    client = bigquery.Client(project=FEEDBACK_PROJECT_ID, location="europe-west2")
    query = """
      SELECT 
        *, 
        rand() AS r 
      FROM 
        @labelled_feedback_table 
      ORDER BY 
        r 
      LIMIT (@N)
    """
    query = query.replace("@labelled_feedback_table", str(LABELLED_FEEDBACK_TABLE))
    query = query.replace("@N", str(N))
    query_job = client.query(query=query)
    result = query_job.result()

    records = []
    for row in result:
        record = dict(row)
        records.append(record)

    return records

In [None]:
# Load labelled feedback records for few-shot prompting
labelled_records = get_labelled_feedback_sample(10)

In [None]:
labelled_records

### NOTE: labelled feedback has already been reviewed for PII and any offending records excluded.

In [None]:
# jsonify labelled feedback records for injection into the prompt
def jsonify_feedback(records: list, labelled=False):
    """
    Create json string from feedback
    :return: json string of feedback records
    """
    subs = []
    for i, item in enumerate(records):
        response_value = item["concatenated_response_value"]
        subs.append(
            {
                "id": item["feedback_record_id"],
                "feedback": response_value,
                "label": [item["labels"] if labelled else ""],
            }
        )

    return json.dumps(subs, indent=4)

In [None]:
labelled_subs_json = jsonify_feedback(labelled_records, labelled=True)
new_subs_json = jsonify_feedback(records)

In [None]:
# print(f"context length: {len(labelled_subs_json)}")
# print(f"context length: {len(new_subs_json)}")
print(new_subs_json)

In [None]:
label_prompt = f"""
    You are an expert at providing consistent categorisation of user feedback for the UK government left via the website www.gov.uk. 
    You are given user feedback, with an aribirtrary id number, and you must provide a label or set of labels for the feedback to categorise it.
    In the rare event there is no coherent theme within a feedback record, label it as "Unknown".
    If the feedback is clearly spam, label it as "Spam".
    Always return valid JSON.

    Short examples:
    {labelled_subs_json}

    Label the following data. Only return the id and the labels, do not return the feedback.
    
    {new_subs_json}
"""

In [None]:
print(label_prompt)

In [None]:
# Call OpenAI to generate labels for feedback records
client = OpenAI(api_key=OPENAI_API_KEY)

try:
    completion = client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": label_prompt,
            }  # type: ignore
        ],
        max_tokens=250,
        temperature=0.75,
        model="gpt-3.5-turbo-0125",
        response_format={ "type": "json_object" }
    )
    # synth_records = json.loads(completion.choices[0].message.content)
    open_labelled_records = completion.choices[0].message.content
except Exception as e:
    print(f"OpenAI request failed: {e}")

In [None]:
print(open_labelled_records)

### Write labels to BigQuery

In [None]:
openai_labelled___df = pd.DataFrame(json.loads(open_labelled_records)["results"]).rename(columns={"id":"feedback_record_id"})

In [None]:
def write_to_bigquery(table_id: str, df: pd.DataFrame):
    """
    Writes data to BigQuery
    """
    # Initialize a BigQuery client
    client = bigquery.Client(project=FEEDBACK_PROJECT_ID)

    # Define schema for the table
    schema = [
        bigquery.SchemaField("feedback_record_id", "STRING"),
        bigquery.SchemaField("label", "STRING", mode="REPEATED"),
    ]

    # Define job configuration
    job_config = bigquery.LoadJobConfig(
        schema=schema, write_disposition="WRITE_TRUNCATE"
    )

    # Write DataFrame to BigQuery
    job = client.load_table_from_dataframe(df, table_id, job_config=job_config)

    # Wait for the job to complete
    job.result()

    print(f"Table {table_id} created")

In [None]:
# Write feedback to BQ table
write_to_bigquery(table_id=OPENAI_LABEL_FEEDBACK_TABLE, df=openai_labelled___df)

### Optional - Generate synthetic data based on labelled feedback

In [None]:
synth_prompt = f"""
    You are an expert at generating synthetic records based on a few short examples.
    The following examples are feedback records from the UK government website www.gov.uk.
    You are given a few examples of feedback records and asked to generate a synthetic dataset of 10 feedback records that are similar in theme and tone to the examples.

    Short examples:
    {labelled_subs_json}

    Synthetic records:
    """

print(synth_prompt)

In [None]:
try:
    completion = client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": synth_prompt,
            }  # type: ignore
        ],
        max_tokens=150,
        temperature=0.75,
        model="gpt-3.5-turbo",
    )
    # synth_records = json.loads(completion.choices[0].message.content)
    open_labelled_records = completion.choices[0].message.content
except Exception as e:
    print(f"OpenAI request failed: {e}")