In [None]:
# !pip install torch sentence-transformers python-dotenv google-cloud-bigquery -q

In [None]:
from sentence_transformers import SentenceTransformer
from google.cloud import bigquery
import os
from dotenv import load_dotenv
from typing import List

import torch
import time
import pandas as pd
import numpy as np

load_dotenv()

PUBLISHING_PROJECT_ID = os.getenv("PUBLISHING_PROJECT_ID")
FEEDBACK_PROJECT_ID = os.getenv("FEEDBACK_PROJECT_ID")
FEEDBACK_TABLE = os.getenv("FEEDBACK_TABLE")

os.environ[
    "TOKENIZERS_PARALLELISM"
] = "true"  # Mostly set this to supress warnings but also to speed up tokenization; might cause issues with multiprocessing. If so set to false

In [None]:
use_cuda = torch.cuda.is_available()
if use_cuda:
    print("__CUDNN VERSION:", torch.backends.cudnn.version())
    print("__Number CUDA Devices:", torch.cuda.device_count())
    print("__CUDA Device Name:", torch.cuda.get_device_name(0))
    print(
        "__CUDA Device Total Memory [GB]:",
        torch.cuda.get_device_properties(0).total_memory / 1e9,
    )
else:
    print("cpu")

In [None]:
def get_feedback_records(date: str) -> pd.DataFrame:
    """
    Extracts feedback records from BigQuery
    :param date: date from which to extract feedback records from
    :return: DataFrame containing feedback records
    """
    client = bigquery.Client(project=FEEDBACK_PROJECT_ID)
    query = """
      WITH CTE AS (
          SELECT
        feedback_record_id,
        response_value
      FROM
        @feedback_table
      WHERE
        DATE(created) >= '@date'
        AND response_type = 'text'
        AND response_value IS NOT NULL
        AND spam_classification = 'not spam'
        AND ( prompt_value LIKE '%Please do not include personal or financial information, eg your National Insurance number or credit card details.%'
          OR prompt_value LIKE 'what_doing'
          OR prompt_value LIKE 'what_wrong'
          OR prompt_value LIKE 'description'
          OR prompt_value LIKE 'details' )
        AND TYPE NOT IN ('ServiceFeedback',
          'AggregatedServiceFeedback')
    )
      SELECT feedback_record_id,
      REGEXP_REPLACE(STRING_AGG(response_value, '. '), r'\.\.', '.') AS response_value
    FROM
      CTE
    GROUP BY
      feedback_record_id
    ORDER BY created
  """
    query = query.replace("@feedback_table", str(FEEDBACK_TABLE))
    query = query.replace("@date", date)
    query_job = client.query(query=query)
    result = query_job.result()
    df = result.to_dataframe()
    filtered_df = df.loc[df["response_value"].str.strip().str.len() > 0]
    print(f"Num records in df: {len(filtered_df)} with cols {filtered_df.columns}")
    return filtered_df

In [None]:
def get_embeddings(
    data: pd.DataFrame,
    default_model="sentence-transformers/all-mpnet-base-v2",
) -> pd.DataFrame:
    """
    Get embeddings for a given text
    :param text: str
    :param default_model: str
    :return: np.array
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SentenceTransformer(default_model, device=device)
    response_value_list = data["response_value"].tolist()
    feedback_record_id_list = data["feedback_record_id"].to_list()
    print("Generating embeddings...")
    embeddings = model.encode(response_value_list)
    print(f"Embeddings len {len(embeddings)}")
    output_df = pd.DataFrame(
        list(zip(feedback_record_id_list, embeddings)),
        columns=["feedback_record_id", "embeddings"],
    )
    print("Generated embeddings")
    return output_df

In [None]:
def save_embeddings_to_bigquery(data: pd.DataFrame, output_prefix: str) -> None:
    """
    Save embedding DataFrame to BigQuery
    :param data: DataFrame to save - in this instance we want to save chunks of feedback text and their embeddings, hence the WRITE_APPEND disposition
    :return: None
    """
    table_id = f"{PUBLISHING_PROJECT_ID}.embeddings.{output_prefix}_feedback_embeddings"
    print(table_id)
    client = bigquery.Client(project=PUBLISHING_PROJECT_ID)
    job_config = bigquery.LoadJobConfig(
        schema=[
            bigquery.SchemaField("feedback_record_id", "STRING"),
            bigquery.SchemaField("response_value", "STRING"),
            bigquery.SchemaField("embeddings", "FLOAT", mode="REPEATED"),
        ],
        write_disposition="WRITE_APPEND"
    )
    job = client.load_table_from_dataframe(data, table_id, job_config=job_config)
    job.result()

In [None]:
def check_null_responses(df: pd.DataFrame) -> object:
    return df.isnull().sum()


def check_df_len(df: pd.DataFrame) -> int:
    return len(df)


def check_len_embeddings(df: pd.DataFrame) -> float:
    return df["embeddings"].apply(lambda x: x.shape[0]).sum()

In [None]:
def split_and_embed_docs(data: pd.DataFrame, N: int, text_column: str) -> None:
    """
    Split DataFrame into N chunks and apply function to each chunk
    :param data: DataFrame to split
    :param N: number of chunks to split DataFrame into
    :param text_column: column containing text to apply function to
    :return: None
    """
    split_dfs = np.array_split(data, N)
    for i, part in enumerate(split_dfs):
        try:
            start = time.time()
            chunk = pd.DataFrame(part)
            print("Embedding data...")
            result = get_embeddings(data=chunk)
            # print(f"Number of empty responses: {check_empty_responses(result)}")
            print(f"Number of null responses: {check_null_responses(result)}")
            print(f"Number of responses: {check_df_len(result)}")
            print(f"Number of embeddings: {check_len_embeddings(result)}")
            output_chunk = pd.merge(chunk, result, on="feedback_record_id", how="left")
            print(f"Created embeddings for chunk {i}")
            save_embeddings_to_bigquery(data=output_chunk, output_prefix="re_concat_mpnetv2")
            print("Embeddings saved to BQ")
            end = time.time()
            print(f"Time taken: {end - start}")
            print(output_chunk)
        except Exception as e:
            print(f"Error: {e}")
            print(f"Chunk: {part}")
            break

In [None]:
# Test access to BQ
client = bigquery.Client(project=PUBLISHING_PROJECT_ID)
dataset = client.get_dataset("embeddings")  # Make an API request.

full_dataset_id = "{}.{}".format(dataset.project, dataset.dataset_id)
friendly_name = dataset.friendly_name
print(
    "Got dataset '{}' with friendly_name '{}'.".format(
        full_dataset_id, friendly_name
    )
)

# View dataset properties.
print("Description: {}".format(dataset.description))
print("Labels:")
labels = dataset.labels
if labels:
    for label, value in labels.items():
        print("\t{}: {}".format(label, value))
else:
    print("\tDataset has no labels defined.")

# View tables in dataset.
print("Tables:")
tables = list(client.list_tables(dataset))  # Make an API request(s).
if tables:
    for table in tables:
        print("\t{}".format(table.table_id))
else:
    print("\tThis dataset does not contain any tables.")

In [None]:
! gcloud config list

In [None]:
df = get_feedback_records("2023-08-01")
split_and_embed_docs(df, N=20, text_column="response_value")