## Setup

In [1]:
%%capture
%load_ext kedro.ipython

In [2]:
import os

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from alive_progress import alive_bar
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from transformers.models.bert import BertTokenizerFast

In [3]:
# Parameters
MODEL_NAME: str = "multi-qa-mpnet-base-cos-v1"
OWNER: str = "sentence-transformers"
POOLING_STRATEGY: str = "mean"
CONTRIBUTOR: str = "Health Promotion Board"

# specify content_category. input 'all' if running across all categories
CONTENT_CATEGORY: str = "live-healthy-articles"

# adjust accordingly
THRESHOLD: float = 0.7

In [23]:
INPUT_GROUNDTRUTH_PATH = os.path.join(
    "..",
    "data",
    "01_raw",
    "Synapxe Content Prioritisation - Live Healthy_020724.xlsx",
)

# Create main output folder and subfolders
OUTPUT_FOLDER_PATH = os.path.join(
    "..",
    "data",
    "07_model_output",
    f"{CONTENT_CATEGORY}",
)

if not os.path.exists(OUTPUT_FOLDER_PATH):
    os.makedirs(OUTPUT_FOLDER_PATH)

# Embeddings output
OUTPUT_EMBEDDING_PATH = os.path.join(
    OUTPUT_FOLDER_PATH,
    f"{CONTENT_CATEGORY}_{MODEL_NAME}_embeddings.parquet",
)

OUTPUT_EMBEDDING_NEO4J_PATH = os.path.join(
    OUTPUT_FOLDER_PATH,
    f"{CONTENT_CATEGORY}_{MODEL_NAME}_embeddings_neo4j.pkl",
)

In [5]:
df = catalog.load("merged_data")  # noqa
print(df.shape)
df.head(2)

(2613, 33)


Unnamed: 0,id,content_name,title,article_category_names,cover_image_url,full_url,full_url2,friendly_url,category_description,content_body,...,percentage_total_views,cumulative_percentage_total_views,content_category,to_remove,has_table,has_image,related_sections,extracted_links,extracted_headers,extracted_content_body
0,1435040,Breast Screening Subsidies in Singapore,Breast Screening Subsidies in Singapore,"Conditions and Illnesses,",https://ch-api.healthhub.sg/api/public/content...,https://www.healthhub.sg/a-z/costs-and-financi...,www.healthhub.sg/a-z/costs-and-financing/breas...,breast-cancer-screening-subsidies,Here’s all you need to know about breast cance...,"<div class=""ExternalClass07C58E0D957B4AA7B14FC...",...,0.216244,0.216244,cost-and-financing,False,True,False,[Cancer Facts You Cannot Ignore],"[[Cancer Facts You Cannot Ignore, https://www....","[[Breast Cancer Screening, h2], [Subsidy for M...",Breast cancer is the number one cancer among w...
1,1435071,Marriage and Parenthood Schemes,Marriage and Parenthood Schemes,"Body Care,",https://ch-api.healthhub.sg/api/public/content...,https://www.healthhub.sg/a-z/costs-and-financi...,www.healthhub.sg/a-z/costs-and-financing/marri...,marriage_parenthood_scheme,New parents and couples looking to conceive ca...,"<div class=""ExternalClassE1D82270F17241E495537...",...,0.11118,0.327423,cost-and-financing,False,True,False,"[MediSave, Baby Bonus What You Need to Know, I...","[[How to Submit Claims, https://crms.moh.gov.s...","[[MediSave Maternity Package, h2], [Examples o...",MediSave Maternity Package\nWith the MediSave ...


## Pre-processing

In [6]:
# Filter by contributor, content_category and to_remove
if CONTENT_CATEGORY == "all":
    df_filtered = df[(df["pr_name"] == CONTRIBUTOR) & (~df["to_remove"])]
else:
    df_filtered = df[
        (df["content_category"] == CONTENT_CATEGORY)
        & (df["pr_name"] == CONTRIBUTOR)
        & (~df["to_remove"])
    ]

# Keep only required columns
cols_to_keep = [
    "id",
    "content_name",
    "title",
    "article_category_names",
    "cover_image_url",
    "full_url",
    "category_description",
    "content_body",
    "keywords",
    "feature_title",
    "pr_name",
    "date_modified",
    "page_views",
    "engagement_rate",
    "content_category",
    "has_table",
    "has_image",
    "related_sections",
    "extracted_links",
    "extracted_headers",
    "extracted_content_body",
]

df_filtered = df_filtered.loc[:, cols_to_keep]
print(df_filtered.shape)
df_filtered.head(2)

(623, 21)


Unnamed: 0,id,content_name,title,article_category_names,cover_image_url,full_url,category_description,content_body,keywords,feature_title,...,date_modified,page_views,engagement_rate,content_category,has_table,has_image,related_sections,extracted_links,extracted_headers,extracted_content_body
367,1444475,"Weight, BMI and Health Problems","Weight, BMI and Health Problems","Food and Nutrition,",https://ch-api.healthhub.sg/api/public/content...,https://www.healthhub.sg/live-healthy/weight_p...,What’s your Body Mass Index (BMI)? Learn how t...,"<div class=""ExternalClassE93BEC3784C545A286BB8...","PGM_Obesity Prevention,PGM_HealthAmbassador,AG...",BMI and Your Health,...,2023-05-10T09:39:54.0000000Z,19977,0.690791,live-healthy-articles,False,False,"[BMI Calculator, What is a Healthy Weight?, An...","[[BMI Calculator, https://www.healthhub.sg/pro...","[[What's a Healthy Body Mass Index?, h2], [Why...",What's a Healthy Body Mass Index?\nWe have all...
368,1445137,7-month-baby Diet: An Authoritative Guide by O...,7-month-baby Diet: An Authoritative Guide by O...,"Food and Nutrition,",https://ch-api.healthhub.sg/api/public/content...,https://www.healthhub.sg/live-healthy/meal-ide...,Your little one is now 7 months of age. Should...,"<div class=""ExternalClass46E64333542C4D8CBEA23...",,,...,2022-11-15T08:35:41.0000000Z,18876,0.688392,live-healthy-articles,True,True,"[Nutrition for Your Toddler, No Wholegrain, No...","[[Nutrition for Your Toddler, https://www.heal...","[[Recommended Number of Servings (7 months), h...",By Health Promotion Board in collaboration wit...


## Data understanding

In [7]:
df_filtered.isna().sum()


id                          [1;36m0[0m
content_name                [1;36m0[0m
title                       [1;36m0[0m
article_category_names     [1;36m58[0m
cover_image_url             [1;36m7[0m
full_url                    [1;36m0[0m
category_description        [1;36m1[0m
content_body                [1;36m0[0m
keywords                  [1;36m347[0m
feature_title             [1;36m257[0m
pr_name                     [1;36m0[0m
date_modified               [1;36m0[0m
page_views                  [1;36m0[0m
engagement_rate             [1;36m0[0m
content_category            [1;36m0[0m
has_table                   [1;36m0[0m
has_image                   [1;36m0[0m
related_sections            [1;36m0[0m
extracted_links             [1;36m0[0m
extracted_headers           [1;36m0[0m
extracted_content_body      [1;36m0[0m
dtype: int64

In [8]:
# differences between content_name and title (only 7 are different)
df_explore = df_filtered.copy()
df_explore["contentname_vs_title"] = df_explore["content_name"] == df_explore["title"]
df_explore["contentname_vs_title"].value_counts()


contentname_vs_title
[3;92mTrue[0m     [1;36m616[0m
[3;91mFalse[0m      [1;36m7[0m
Name: count, dtype: int64

To explore similarity embeddings on these columns: <br>
1. title, <br>
2. article_category_names, <br>
3. category_description, <br>
4. extracted_content_body <br>

## Load embedding model

In [9]:
# Load the tokenizer and model
sentence_transformer = SentenceTransformer(f"{OWNER}/{MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(f"{OWNER}/{MODEL_NAME}")

max_length = sentence_transformer.max_seq_length

In [10]:
def split_into_chunks(
    sentences: list[str], max_length: int, tokenizer: BertTokenizerFast
) -> list[str]:
    chunks = []
    current_chunk = []
    current_length = 0

    for sentence in sentences:
        # Tokenize the sentence
        encoded_sentence = tokenizer(sentence, return_tensors="pt")
        num_tokens = encoded_sentence["input_ids"].shape[1]

        # If adding the current sentence would exceed max_length, save the current chunk and start a new one
        if current_length + num_tokens > max_length:
            chunks.append(" ".join(current_chunk))
            current_chunk = []
            current_length = 0

        current_chunk.append(sentence)
        current_length += num_tokens

    # Add the last chunk if any
    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks


def pool_embeddings(embeddings: np.ndarray, strategy: str = "mean") -> np.ndarray:
    if not embeddings:
        raise ValueError("The embeddings are empty.")

    if strategy == "mean":
        return np.mean(embeddings, axis=0)
    elif strategy == "max":
        return np.max(embeddings, axis=0)
    else:
        raise ValueError(
            "Pooling strategy not recognized. The strategy must be either 'average' or 'max'."
        )

In [11]:
# Intialise dict to store embeddings for respective columns
embedding_dict = {
    "title": [],
    "article_category_names": [],
    "category_description": [],
    "extracted_content_body": [],
}

with alive_bar(
    (df_filtered["id"].nunique() * len(embedding_dict)), force_tty=True
) as bar:
    for col_name, embedding_list in embedding_dict.items():
        print(col_name)
        for id in df_filtered["id"].unique():
            text = df_filtered.query("id == @id")[col_name].values[0]

            if not text:
                # Store empty array
                dim = sentence_transformer.get_sentence_embedding_dimension()
                embeddings = np.empty((dim,), dtype=np.float32)
            else:
                # Step 1: Split the article into sentences
                sentences = sent_tokenize(text)

                # Step 2: Tokenize sentences and split into chunks of max 256 tokens
                chunks = split_into_chunks(sentences, max_length, tokenizer)

                # Step 3: Encode each chunk to get their embeddings
                chunk_embeddings = [
                    sentence_transformer.encode(chunk) for chunk in chunks
                ]

                # Step 4: Aggregate chunk embeddings to form a single embedding for the entire article
                embeddings = pool_embeddings(
                    chunk_embeddings, strategy=POOLING_STRATEGY
                )

            indices = df_filtered.query("id == @id").index.values

            for _ in range(len(indices)):
                embedding_list.append(embeddings)

            bar()

on 0: title                                                                     
on 623: article_category_names                                                   ▅▃▁ 26/2492 [1%] in 2s (~3:34, 11.5/s
on 1246: category_description                                                    ▅▇▇ 626/2492 [25%] in 57s (~2:50, 11. ▆█▆ 854/2492 [34%] in 1:14 (~2:22, 11 █▆▄ 873/2492 [35%] in 1:16 (~2:20, 11
on 1869: extracted_content_body                                                  ▃▅▇ 1300/2492 [52%] in 1:52 (~1:43, 1 ▂▄▆ 1362/2492 [55%] in 2:00 (~1:39, 1 ▆█▆ 1645/2492 [66%] in 2:41 (~1:23, 1 ▇▇▅ 1646/2492 [66%] in 2:41 (~1:23, 1 ▃▅▇ 1679/2492 [67%] in 2:46 (~1:20, 1 ▇▇▅ 1691/2492 [68%] in 2:48 (~1:19, 1 ▃▅▇ 1742/2492 [70%] in 2:55 (~1:15, 1 █▆▄ 1800/2492 [72%] in 3:04 (~1:11, 9
|████████████████████████████████████████| 2492/2492 [100%] in 48:22.1 (0.86/s)  ▆█▆ 1880/2492 [75%] in 4:01 (~1:18, 7 ▂▄▆ 1899/2492 [76%] in 12:29 (~3:54,  ▁▃▅ 1975/2492 [79%] in 19:50 (~5:11,  ▄▂▂ 2004/2492 [80%] in 22

In [12]:
# Save embeddings in respective columns
for col_name, embedding_list in embedding_dict.items():
    embedding_col = f"{col_name}_{MODEL_NAME}_embeddings"
    df_filtered[embedding_col] = embedding_list

df_filtered.head(2)

Unnamed: 0,id,content_name,title,article_category_names,cover_image_url,full_url,category_description,content_body,keywords,feature_title,...,has_table,has_image,related_sections,extracted_links,extracted_headers,extracted_content_body,title_multi-qa-mpnet-base-cos-v1_embeddings,article_category_names_multi-qa-mpnet-base-cos-v1_embeddings,category_description_multi-qa-mpnet-base-cos-v1_embeddings,extracted_content_body_multi-qa-mpnet-base-cos-v1_embeddings
367,1444475,"Weight, BMI and Health Problems","Weight, BMI and Health Problems","Food and Nutrition,",https://ch-api.healthhub.sg/api/public/content...,https://www.healthhub.sg/live-healthy/weight_p...,What’s your Body Mass Index (BMI)? Learn how t...,"<div class=""ExternalClassE93BEC3784C545A286BB8...","PGM_Obesity Prevention,PGM_HealthAmbassador,AG...",BMI and Your Health,...,False,False,"[BMI Calculator, What is a Healthy Weight?, An...","[[BMI Calculator, https://www.healthhub.sg/pro...","[[What's a Healthy Body Mass Index?, h2], [Why...",What's a Healthy Body Mass Index?\nWe have all...,"[0.026768425, 0.032185856, 0.011938091, -0.011...","[0.033130124, 0.04989962, -0.023408694, 0.0173...","[0.032101344, 0.038946442, 0.00013390426, -0.0...","[0.028888216, 0.035280235, 0.014185197, 0.0169..."
368,1445137,7-month-baby Diet: An Authoritative Guide by O...,7-month-baby Diet: An Authoritative Guide by O...,"Food and Nutrition,",https://ch-api.healthhub.sg/api/public/content...,https://www.healthhub.sg/live-healthy/meal-ide...,Your little one is now 7 months of age. Should...,"<div class=""ExternalClass46E64333542C4D8CBEA23...",,,...,True,True,"[Nutrition for Your Toddler, No Wholegrain, No...","[[Nutrition for Your Toddler, https://www.heal...","[[Recommended Number of Servings (7 months), h...",By Health Promotion Board in collaboration wit...,"[0.025946843, 0.01707308, -0.004286879, -0.029...","[0.033130124, 0.04989962, -0.023408694, 0.0173...","[0.013637194, -0.027155086, 0.0138272345, -0.0...","[-0.0034918329, 0.0077161994, 0.009509856, -0...."


In [16]:
OUTPUT_EMBEDDING_PATH

[32m'..\\data\\07_model_output\\live-healthy-articles\\multi-qa-mpnet-base-cos-v1\\live-healthy-articles_multi-qa-mpnet-base-cos-v1_embeddings.parquet'[0m

In [27]:
# Save df with embeddings
table = pa.Table.from_pandas(df_filtered)
pq.write_table(table, OUTPUT_EMBEDDING_PATH)

In [28]:
# Save for clustering in neo4j
df_neo4j = df_filtered[
    [
        "id",
        "title",
        "extracted_content_body",
        "category_description",
        f"title_{MODEL_NAME}_embeddings",
        f"article_category_names_{MODEL_NAME}_embeddings",
        f"category_description_{MODEL_NAME}_embeddings",
        f"extracted_content_body_{MODEL_NAME}_embeddings",
    ]
]
df_neo4j = df_neo4j.rename(
    columns={
        "extracted_content_body": "content",
        "category_description": "meta_description",
        f"title_{MODEL_NAME}_embeddings": "vector_title",
        f"article_category_names_{MODEL_NAME}_embeddings": "vector_article_category_names",
        f"category_description_{MODEL_NAME}_embeddings": "vector_category_description",
        f"extracted_content_body_{MODEL_NAME}_embeddings": "vector_extracted_content_body",
    }
)

df_neo4j

df_neo4j.to_pickle(OUTPUT_EMBEDDING_NEO4J_PATH)

## End