In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
from copy import deepcopy
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from datasets import load_dataset
from umap import UMAP
import re
from hdbscan import HDBSCAN
from bertopic.representation import KeyBERTInspired
from sklearn.cluster import KMeans
from sklearn.feature_extraction.text import CountVectorizer
import plotly.io as pio
pio.renderers.default = 'iframe'
import dill

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [2]:
class EmbeddingsClusterTopics:
    def __init__(self, model_name, dataset_path, documents_column_name, embeddings_column_name, clustering_type = 'hdbscan', random_state = None):
        self.model_name = model_name
        self.embeddings_model = SentenceTransformer(self.model_name)
        custom_umap_model = UMAP(n_neighbors=15, n_components=10, random_state=random_state)  # Change 10 to the desired number of dimensions
        if clustering_type == 'hdbscan':
            custom_hdbscan_model = HDBSCAN(metric = 'manhattan')
        elif clustering_type == 'kmeans':
            custom_hdbscan_model = KMeans(n_clusters=25, random_state=random_state)
        vectorizer_model = CountVectorizer(stop_words="english", min_df=2, ngram_range=(1, 2))
        representation_model = KeyBERTInspired()
        self.bertopic_model = BERTopic(# representation_model=representation_model,
                                      embedding_model=self.embeddings_model,
                                      calculate_probabilities = True)
        self.dataset = load_dataset('parquet',data_files=dataset_path)['train']
        self.documents = self._load_documents_from_parquet(documents_column_name)
        self.documents = [self._remove_numeric_words(doc) for doc in self.documents]
        self.embeddings = self._load_embeddings_from_parquet(embeddings_column_name)
        self.create_clusters_topics()
        self.generate_topic_names()

    def _remove_numeric_words(self, text):
        # Remove currency-based numbers like $123.1, currency symbols like €, and rupee symbol ₹
        currency_pattern = r'\$\s*\d+(\.\d+)?|\€\s*\d+(\.\d+)?|₹\s*\d+(\.\d+)?'

        # Match numeric words or currency-based numbers
        numeric_pattern = r'\b\d+(\.\d+)?\b'

        # Combine both patterns using negative lookahead to exclude percentages
        combined_pattern = rf'(?!(?:\d+(\.\d+)?%))({currency_pattern}|{numeric_pattern})'

        cleaned_text = re.sub(combined_pattern, '', text)
        return cleaned_text

    def _load_embeddings_from_parquet(self, embeddings_column_name):
        return np.array(self.dataset[embeddings_column_name])

    def _load_documents_from_parquet(self, documents_column_name):
        return self.dataset[documents_column_name]

    def create_clusters_topics(self):
        topics, _ = self.bertopic_model.fit_transform(documents = self.documents, embeddings = self.embeddings)
        self.hierarchical_topics = self.bertopic_model.hierarchical_topics(self.documents)

    def generate_topic_names(self):
        return self.bertopic_model.generate_topic_labels(nr_words=5, separator=", ")

In [3]:
class TopicHierarchy:
    def __init__(self, df, topic_to_doc_indices):
        self.df = df
        # self.levels = {}
        self.raw_leaf_points_count = {}
        self.raw_leaf_points_list = {}
        self.topic_to_doc_indices = topic_to_doc_indices

    def compute_levels(self, parent_id, level, levels):
        levels[parent_id] = level
        children = self.df[self.df['Parent_ID'] == parent_id]

        for _, child in children.iterrows():
            self.compute_levels(child['Child_Left_ID'], level + 1, levels)
            self.compute_levels(child['Child_Right_ID'], level + 1, levels)

    def compute_raw_leaf_points(self, parent_id):
        if parent_id in self.raw_leaf_points_count:
            return self.raw_leaf_points_count[parent_id], self.raw_leaf_points_list[parent_id]

        children = self.df[self.df['Parent_ID'] == parent_id]

        if children.empty:
            parent_id_int = int(parent_id)
            if parent_id_int in self.topic_to_doc_indices:
                doc_indices = [idx for idx, x in enumerate(self.topic_to_doc_indices) if x == parent_id_int]
                count = len(doc_indices)
            else:
                count = 0
                doc_indices = []
            self.raw_leaf_points_count[parent_id] = count
            self.raw_leaf_points_list[parent_id] = doc_indices
            return count, doc_indices

        total_leaf_points = 0
        all_leaf_points = []

        for _, child in children.iterrows():
            left_count, left_list = self.compute_raw_leaf_points(child['Child_Left_ID'])
            right_count, right_list = self.compute_raw_leaf_points(child['Child_Right_ID'])

            total_leaf_points += left_count + right_count
            all_leaf_points.extend(left_list)
            all_leaf_points.extend(right_list)

        self.raw_leaf_points_count[parent_id] = total_leaf_points
        self.raw_leaf_points_list[parent_id] = all_leaf_points

        return total_leaf_points, all_leaf_points

    def get_levels(self):
        levels = {}
        all_child_ids = set(self.df['Child_Left_ID']).union(set(self.df['Child_Right_ID']))
        roots = self.df[~self.df['Parent_ID'].isin(all_child_ids)]
        for _, root in roots.iterrows():
            self.compute_levels(root['Parent_ID'], 0, levels)
        return levels

    def get_raw_leaf_points(self):
        roots = self.df[~self.df['Parent_ID'].isin(self.df['Child_Left_ID']) & ~self.df['Parent_ID'].isin(self.df['Child_Right_ID'])]
        for _, root in roots.iterrows():
            self.compute_raw_leaf_points(root['Parent_ID'])

        return self.raw_leaf_points_count, self.raw_leaf_points_list

In [4]:
def get_balanced_clusters(df, parent_id, max_points=4000):
    balanced_clusters = []
    cluster_row = df[df['Parent_ID'] == parent_id].iloc[0]
    num_points = cluster_row['num_points']

    if num_points <= max_points:
        balanced_clusters.append(cluster_row['Parent_ID'])
        return balanced_clusters

    children = df[df['Parent_ID'] == parent_id]

    for _, child in children.iterrows():
        balanced_clusters += get_balanced_clusters(df, child['Child_Left_ID'], max_points)
        balanced_clusters += get_balanced_clusters(df, child['Child_Right_ID'], max_points)

    return balanced_clusters

In [5]:
def assign_outliers_to_balanced_clusters(balanced_clusters_df, document_topic, probabilities):
    # Prepare a topic-cluster mapping for fast lookups
    topic_cluster_map = {}
    for _, row in balanced_clusters_df.iterrows():
        for topic in row['Topics']:
            topic_cluster_map[topic] = row['Parent_ID']

    # Initialize lists to store results
    original_indices = []
    new_clusters = []
    total_probs = []

    # Find the indices of documents that are outliers (-1)
    # outlier_indices = np.where(np.array(document_topic) == -1)[0]
    outlier_indices = np.where(np.array(document_topic) > -2)[0]

    # For each outlier, find the most probable cluster
    for idx in outlier_indices:
        topic_probs = np.array(probabilities[idx])
        cluster_indices = [topic_cluster_map.get(t, -1) for t in range(len(topic_probs))]

        # Create a DataFrame for aggregation
        df = pd.DataFrame({
            'Cluster': cluster_indices,
            'Probability': topic_probs
        })

        # Sum probabilities by cluster
        df_grouped = df.groupby('Cluster').sum()

        # Find the cluster with the maximum total probability
        best_cluster = df_grouped['Probability'].idxmax()

        # Append to lists
        original_indices.append(idx)
        new_clusters.append(best_cluster)
        total_probs.append(df_grouped.loc[best_cluster, 'Probability'])

    # Create a DataFrame for the results
    return pd.DataFrame({
        'original_index': original_indices,
        'new_cluster': new_clusters,
        'total_probability': total_probs
    })

In [6]:
folder_path = '/Users/ravi.tej/Desktop/ML/Recommendations/Embedding Model Selection/Embeddings/'
file = 'formatted_articles_data_2023_embeddings_bge_small_en.parquet'
ect_bge_small = EmbeddingsClusterTopics(model_name = 'BAAI/bge-small-en', 
                                        dataset_path = folder_path + file, 
                                        documents_column_name = 'title_summary',
                                        embeddings_column_name = 'embeddings', 
                                        clustering_type='hdbscan',
                                        random_state=86)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
100%|██████████| 962/962 [00:05<00:00, 192.06it/s]


In [7]:
with open("ect_bge_small_search.dill", "wb") as f:
    dill.dump(ect_bge_small,f)

In [367]:
word = 'agriculture'

In [368]:
word_encoding = ect_bge_small.embeddings_model.encode(word)

sim_cos = cosine_similarity(word_encoding.reshape(1, -1), ect_bge_small.bertopic_model.topic_embeddings_).flatten()

sim_man = manhattan_distances(word_encoding.reshape(1, -1), ect_bge_small.bertopic_model.topic_embeddings_).flatten()

ids = np.argsort(sim_man)[0:900]
similarity_man = [sim_man[i] for i in ids]
similar_topics = [list(ect_bge_small.bertopic_model.topic_labels_.keys())[index] for index in ids]
similar_topics_cos, similarity_cos = ect_bge_small.bertopic_model.find_topics(word,top_n=900)

In [369]:
ect_bge_small.bertopic_model.representative_docs_[similar_topics[1]]

['EU climate chief says greater responsibility on rich nations EU’s climate policy chief cautioned against protectionist moves while looking for ways to address climate change.',
 'Resources for resilience Funding at the scale needed for climate-change adaptation in developing countries is not immediately forthcoming from the international financial system. But, these countries can’t just sit back and wait for a global agreement on climate finance\n',
 "The once-bold UK has begun to look like a climate laggard The UK, once a climate champion, is now jeopardizing its global standing by failing to meet climate targets and lacking leadership in tackling the climate crisis, according to the UK's Climate Change Committee. Delays in policy development and implementation, as well as a narrow approach to solutions, are cited as reasons for the country's failure. The committee warns that urgent action is needed to quadruple emissions reduction outside the electricity supply sector. The UK risks

In [371]:
similar_topics_cos[0]

351

In [370]:
ect_bge_small.bertopic_model.representative_docs_[similar_topics_cos[0]]

['Govt lines up millet-centric activities as international year of millets kicks in New Delhi Millets are also an integral part of the G- meetings and delegates will be given a true millet experience through tasting, meeting farmers and interactive sessions with start-ups and FPOs',
 "‘Very creative’: PM Modi praises ‘Abundance in Millets’ song A new song titled 'Abundance in Millets' has been launched to create awareness about the importance of millets during the International Year of Millets. The song caught the attention of Indian Prime Minister Narendra Modi. ",
 "Sowing millet-led growth globally The initiative is part of India's efforts to get the world to observe  as the International Year of Millets, as declared by the UN. Sarthak Ray takes a look at the millets push by India"]

In [72]:
from sklearn.metrics.pairwise import cosine_similarity, manhattan_distances

In [345]:
similar_topics, similarity = ect_bge_small.bertopic_model.find_topics('investments',top_n=900)

In [346]:
[max(old_score - 0.7,0)/(1.0 - 0.7) for old_score in similarity]

[0.6763882682951663,
 0.6462609171875484,
 0.614483880219628,
 0.6064809648348508,
 0.6019449174501988,
 0.5960523120841517,
 0.5915867456063532,
 0.5843446051615501,
 0.5800919036722062,
 0.5796914755420708,
 0.5788138113374637,
 0.5752046037668094,
 0.572004854238478,
 0.5700823804716618,
 0.5638436407301912,
 0.5599670111430491,
 0.5504976470920518,
 0.5476722512438136,
 0.5442087665417549,
 0.5429378357955331,
 0.5415443234233648,
 0.5414928435950316,
 0.539320208576357,
 0.5366212823575044,
 0.5365849095723972,
 0.5358683926004951,
 0.5339270186773133,
 0.5290929105544709,
 0.5279587874391729,
 0.5270943227207557,
 0.5252875993038856,
 0.5248814487445644,
 0.5236435077355593,
 0.5222942406282359,
 0.5209302978913636,
 0.5204293334807101,
 0.5187363357231252,
 0.5184917603067961,
 0.5184166482537654,
 0.5183615876521089,
 0.5172381993540094,
 0.516404516595836,
 0.5156618868079902,
 0.5155741177666565,
 0.5153463689501931,
 0.5145437331310042,
 0.5141702857833859,
 0.51352047341173

In [344]:
similarity

[0.8440139836450539,
 0.8361345390457295,
 0.8323181205119495,
 0.8322973824201345,
 0.8301979401504347,
 0.8294949339350829,
 0.8290360863703325,
 0.8263221905977494,
 0.8260972608785327,
 0.8259761723451816,
 0.8244963343018781,
 0.8239937929515138,
 0.8238578166147089,
 0.823239386699018,
 0.8231422533504211,
 0.8229360408395252,
 0.8225244695246363,
 0.8224246899855906,
 0.8222789259239556,
 0.8218831853561271,
 0.8213861380748975,
 0.8211492715143665,
 0.8208367456895088,
 0.8206848689068706,
 0.8206734157723039,
 0.8201595765264851,
 0.8197474342892739,
 0.8196901190338448,
 0.8195755590367815,
 0.8195117993373938,
 0.8194590016318857,
 0.8193470659125385,
 0.8192927496151257,
 0.819262594186748,
 0.819196288368943,
 0.8191704527114396,
 0.818981249763037,
 0.818800015798456,
 0.8187707637480892,
 0.81820126542078,
 0.8181731388628415,
 0.8178073795933355,
 0.8176094719635116,
 0.8175857456137362,
 0.8175758087763337,
 0.817558522784699,
 0.8175352866362416,
 0.8173888179194667,


In [264]:
similar_topics

[460,
 533,
 -1,
 23,
 228,
 27,
 153,
 301,
 101,
 295,
 901,
 0,
 498,
 561,
 36,
 354,
 1,
 618,
 80,
 733,
 867,
 39,
 655,
 65,
 565,
 562,
 435,
 189,
 351,
 690,
 850,
 9,
 237,
 196,
 246,
 902,
 599,
 919,
 694,
 40,
 930,
 16,
 540,
 4,
 719,
 931,
 14,
 911,
 248,
 269,
 516,
 366,
 365,
 408,
 129,
 341,
 302,
 579,
 558,
 316,
 58,
 305,
 430,
 46,
 219,
 17,
 830,
 402,
 312,
 149,
 62,
 921,
 688,
 174,
 560,
 37,
 235,
 166,
 230,
 18,
 73,
 419,
 279,
 114,
 505,
 285,
 401,
 619,
 157,
 391,
 286,
 26,
 718,
 240,
 903,
 106,
 33,
 661,
 587,
 954,
 147,
 331,
 24,
 183,
 123,
 2,
 380,
 553,
 532,
 439,
 165,
 122,
 857,
 675,
 192,
 108,
 118,
 213,
 281,
 695,
 630,
 556,
 336,
 143,
 212,
 64,
 525,
 625,
 241,
 164,
 288,
 421,
 549,
 631,
 620,
 643,
 168,
 847,
 61,
 715,
 52,
 757,
 175,
 322,
 648,
 792,
 904,
 276,
 138,
 54,
 5,
 538,
 432,
 734,
 450,
 233,
 254,
 313,
 383,
 76,
 82,
 184,
 394,
 225,
 509,
 100,
 144,
 386,
 874,
 349,
 257,
 195,
 835,


In [27]:
from copy import deepcopy

In [29]:
df = deepcopy(ect_bge_small.hierarchical_topics)

In [372]:
jagannath_yatra = [idx for idx, x in enumerate(ect_bge_small.bertopic_model.topics_) if x == 351]

In [373]:
[ect_bge_small.documents[x] for x in jagannath_yatra]

['G20 meeting of agri scientists focuses on technological intervention for agri-food transformation MACS  unanimously agreed to launch a Millet Initiative - MAHARISHI which was proposed by India for research in the field of millets',
 'Spirit of a farmer is like that of an Indian soldier: Minister Narendra Singh Tomar said that the way soldiers protect the nation by standing bravely on the borders, farmers generate agricultural produce, making an exemplary contribution to food security in the same way',
 'PM Modi adds a new dimension to millets by terming it Shree Anna: Tomar Millets are an alternative food system in times of increasing demand for vegetarian foods as it contributes to a balanced diet as well as a safe environment',
 "How an orphan crop became a rich man's favourite snack Millets were ignored and nearly forgotten. Now, they have gained popularity with chefs and the urban rich.After India’s green revolution to ensure food security and reduce dependence on food-aid in the

In [265]:
ect_bge_small.bertopic_model.representative_docs_[460]

['RIL shares jumped 1.3% today | The Financial Express Reliance Industries Ltd share price jumped over 1% on Monday after Mukesh Ambani-led company’s net profit surged 19.1% to Rs , crore on-year in Q4FY23.',
 "RIL Q4 Results: Earnings in line with estimates, cons PAT at record  , cr Reliance Industries Q4 results: Billionaire Mukesh Ambani-backed Reliance Industries has announced its Q4 earnings for FY23 on Thursday. RIL's share price closed in the green.\xa0",
 "Reliance Industries share price jumps 1% today after strong Q4 earnings; Should you buy, hold or sell RIL stock? RIL shares rose 1% today. Reliance Industries Ltd's net profit surged 19.1% to Rs , crore on-year in Q4FY23, beating analysts' expectations."]

### Observations

- The similarity values are largely ranging betweeen 0.7 and 1 (reported by others as well - https://huggingface.co/intfloat/multilingual-e5-large/discussions/10#:~:text=An%20embedding%20model%20is%20good,0.01%20for%20InfoNCE%20contrastive%20loss.)
- The top topic is usually relevant if there is content with the keyword or similar meaning
- If there is not, there are many instances when there is a keyword similarity without semantic similarity (jai bajrang bali gets jaishankar topic as top result, while char dham yatra related results are 3rd or 4th result)
- Manhattam distance and cosine similarity are giving very similar results, with some differences in the rankings
- To handle the issue with non uniformity, we'll use a scaling factor similar to one mentioned in the HF discussion of [max(score - 0.7,0)/(1.0 - 0.7)]

### Plotting Beta

In [620]:
from scipy.stats import beta

import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1)

a, b = 1,2
x = np.linspace(beta.ppf(0.01, a, b), beta.ppf(0.99, a, b), 100)
plt.plot(x, beta.pdf(x, a, b), 'r-', lw=5, alpha=0.6, label='beta pdf')

In [335]:
len([i for i in x if i < 0.5])/len(x)

0.17

In [336]:
len([i for i in x if i < 0.25])/len(x)

0.0

### Observations
- At 4,1 there is only a 27% chance of a sample being less than 50%. For 5,1 that is 0.17. 
- Hence we'll cap the max prior we'll assign a cluster to be 5,1. 5 if the similarity is 1


#### Open Question
- Next question is what is the logic which should be used for calculating the priors
    - taking top 5/10 topics and assign the clusters with these topics the corresponding priors
    - calculating the average of similarity of all the topics for each cluster and calculate prior
- What to do with the outlier documents? They don't figure as part of this

### Exploring using the average of similarity approach

In [374]:
def get_balanced_clusters(df, parent_id, max_points=4000):
    balanced_clusters = []
    cluster_row = df[df['Parent_ID'] == parent_id].iloc[0]
    num_points = cluster_row['num_points']

    if num_points <= max_points:
        balanced_clusters.append(cluster_row['Parent_ID'])
        return balanced_clusters

    children = df[df['Parent_ID'] == parent_id]

    for _, child in children.iterrows():
        balanced_clusters += get_balanced_clusters(df, child['Child_Left_ID'], max_points)
        balanced_clusters += get_balanced_clusters(df, child['Child_Right_ID'], max_points)

    return balanced_clusters

In [375]:
def assign_outliers_to_balanced_clusters(balanced_clusters_df, document_topic, probabilities):
    # Prepare a topic-cluster mapping for fast lookups
    topic_cluster_map = {}
    for _, row in balanced_clusters_df.iterrows():
        for topic in row['Topics']:
            topic_cluster_map[topic] = row['Parent_ID']

    # Initialize lists to store results
    original_indices = []
    new_clusters = []
    total_probs = []

    # Find the indices of documents that are outliers (-1)
    # outlier_indices = np.where(np.array(document_topic) == -1)[0]
    outlier_indices = np.where(np.array(document_topic) > -2)[0]

    # For each outlier, find the most probable cluster
    for idx in outlier_indices:
        topic_probs = np.array(probabilities[idx])
        cluster_indices = [topic_cluster_map.get(t, -1) for t in range(len(topic_probs))]

        # Create a DataFrame for aggregation
        df = pd.DataFrame({
            'Cluster': cluster_indices,
            'Probability': topic_probs
        })

        # Sum probabilities by cluster
        df_grouped = df.groupby('Cluster').sum()

        # Find the cluster with the maximum total probability
        best_cluster = df_grouped['Probability'].idxmax()

        # Append to lists
        original_indices.append(idx)
        new_clusters.append(best_cluster)
        total_probs.append(df_grouped.loc[best_cluster, 'Probability'])

    # Create a DataFrame for the results
    return pd.DataFrame({
        'original_index': original_indices,
        'new_cluster': new_clusters,
        'total_probability': total_probs
    })

In [376]:
bge_hierarchy = deepcopy(ect_bge_small.hierarchical_topics)

In [377]:
bge_topic_hierarchy = TopicHierarchy(bge_hierarchy, ect_bge_small.bertopic_model.topics_)
bge_levels = bge_topic_hierarchy.get_levels()
num_points, points = bge_topic_hierarchy.get_raw_leaf_points()

# Updating DataFrame as before
bge_hierarchy['Level'] = bge_hierarchy['Parent_ID'].map(bge_levels)
bge_hierarchy['num_points'] = bge_hierarchy['Parent_ID'].map(num_points)
bge_hierarchy['points'] = bge_hierarchy['Parent_ID'].map(points)

In [383]:
balanced_clusters = get_balanced_clusters(bge_hierarchy,parent_id = '1934', max_points = 5000)
len(balanced_clusters)

34

In [385]:
balanced_cluster_df = bge_hierarchy[bge_hierarchy.Parent_ID.isin(balanced_clusters)]

In [386]:
balanced_cluster_df

Unnamed: 0,Parent_ID,Parent_Name,Topics,Child_Left_ID,Child_Left_Name,Child_Right_ID,Child_Right_Name,Distance,Level,num_points,points
959,1927,closed_yesterday_stock_reacts_monitor,"[120, 126, 198, 200, 204, 216, 235, 243, 273, ...",1907,tcs_infosys_lrs_remittance_remittances,1925,closed_yesterday_stock_reacts_monitor,2.158897,1,2856,"[102, 290, 491, 524, 619, 631, 815, 1339, 1500..."
952,1920,tax_sebi_insurance_income_mutual,"[16, 18, 22, 28, 34, 39, 44, 59, 71, 77, 95, 1...",1908,tax_sebi_insurance_income_itr,1829,mutual_sip_funds_fund_returns,1.940746,8,4645,"[2555, 11011, 11541, 19844, 29184, 29602, 2975..."
937,1905,maruti_suzuki_hyundai_electric_hero,"[9, 62, 169, 183, 185, 194, 199, 239, 251, 253...",1816,maruti_suzuki_jimny_fronx_toyota,1888,electric_hero_ev_hyundai_kia,1.758566,6,2026,"[8979, 12218, 12316, 14310, 14390, 14399, 1471..."
936,1904,gold_fed_crude_inflation_rate,"[11, 20, 29, 66, 85, 94, 113, 133, 167, 191, 1...",1896,gold_fed_inflation_rate_rates,1709,crude_oil_barrel_cents_brent,1.756502,9,1672,"[93, 1406, 1782, 4282, 4555, 5506, 7839, 10431..."
934,1902,apple_5g_oneplus_iphone_galaxy,"[80, 100, 125, 136, 164, 168, 187, 196, 229, 2...",1725,apple_iphone_wwdc_headset_ios,1835,5g_oneplus_galaxy_samsung_nord,1.75378,9,1329,"[6203, 6738, 12692, 12704, 12834, 12857, 13287..."
931,1899,vande_pawar_metro_train_imd,"[10, 24, 26, 30, 33, 47, 52, 69, 73, 76, 84, 8...",1777,airline_air_aircraft_aviation_go,1892,vande_pawar_metro_train_imd,1.742215,12,4504,"[411, 3801, 3837, 4252, 8453, 9170, 9346, 9784..."
923,1891,vaishali_buy_parekh_sell_axis,"[121, 141, 154, 339, 529, 642, 662, 687, 829]",1666,vaishali_parekh_buy_sell_recommended,1289,axis_citibank_citi_bank_picks,1.640322,10,337,"[23242, 26099, 26555, 27164, 27670, 28180, 284..."
919,1887,trump_pakistan_rupee_biden_imran,"[21, 31, 43, 45, 50, 51, 60, 67, 68, 75, 78, 8...",1869,rupee_pakistan_imran_dollar_sitharaman,1853,trump_biden_donald_president_visit,1.627469,15,2904,"[380, 622, 658, 898, 4669, 7338, 7525, 15757, ..."
914,1882,chatgpt_google_ai_twitter_musk,"[6, 38, 102, 112, 127, 137, 188, 255, 261, 350...",1866,chatgpt_google_ai_openai_chatbot,1783,twitter_musk_elon_tesla_blue,1.601761,11,1320,"[4124, 8283, 10553, 19571, 32285, 39577, 40694..."
911,1879,coal_hydrogen_solar_power_mw,"[36, 81, 123, 438, 463, 553, 570, 619, 779, 79...",1251,coal_mines_cil_mt_tonne,1828,hydrogen_solar_power_mw_suzlon,1.5751,10,557,"[597, 1380, 6978, 7993, 8777, 36270, 46055, 52..."


In [391]:
def get_topic_wise_similarity(word,embedding_model, topic_embeddings):
    word_encoding = embedding_model.encode(word)
    sims = cosine_similarity(word_encoding.reshape(1, -1), topic_embeddings).flatten()
    return [max(original_score - 0.7,0)/(1.0 - 0.7) for original_score in sims]

In [392]:
similarity = get_topic_wise_similarity(word = 'mutual funds', embedding_model=ect_bge_small.embeddings_model, topic_embeddings=ect_bge_small.bertopic_model.topic_embeddings_)

In [399]:
{row['Parent_ID']: np.mean([similarity[x] for x in row['Topics']]) for index, row in balanced_cluster_df.iterrows()}

{'1927': 0.35376407275562743,
 '1920': 0.37039839818192855,
 '1905': 0.37106296670345923,
 '1904': 0.3951376409233743,
 '1902': 0.3858706068592208,
 '1899': 0.37532969916025155,
 '1891': 0.34900186811061723,
 '1887': 0.38407867566120435,
 '1882': 0.36014478772471376,
 '1879': 0.3883794070635088,
 '1873': 0.37977214677414156,
 '1870': 0.37641507895804394,
 '1862': 0.38508966643578113,
 '1844': 0.3664370165440857,
 '1826': 0.4610711902742933,
 '1823': 0.4039322068451179,
 '1822': 0.3858014314236751,
 '1815': 0.4224159448093688,
 '1800': 0.3538193344366628,
 '1790': 0.34587496759325487,
 '1785': 0.3611043114464614,
 '1784': 0.34448557189012163,
 '1781': 0.30750701638201783,
 '1774': 0.34593392561955977,
 '1732': 0.36142419099536793,
 '1652': 0.3658714209444416,
 '1637': 0.38832048187360446,
 '1568': 0.40834781967942124,
 '1555': 0.3888749283530757,
 '1547': 0.42122536249553216,
 '1414': 0.34161297505723737,
 '1304': 0.32437355783150335,
 '1290': 0.36562993007021216,
 '1236': 0.36710982822

### Trying out the max method

In [408]:
topic_cluster_map = {}
for _, row in bge_hierarchy[bge_hierarchy.Parent_ID.isin(balanced_clusters)].iterrows():
    for topic in row['Topics']:
        topic_cluster_map[topic] = row['Parent_ID']

In [581]:
similar_topics, similarity = ect_bge_small.bertopic_model.find_topics('stocks',top_n=10)
modified_similarity = [((original_score-0.6)/(1-0.6))*min(1/(np.log10(idx + 2)/np.log10(5)),1) for idx, original_score in enumerate(similarity)]
similar_clusters = [topic_cluster_map[topic] if topic != -1 else -1 for topic in similar_topics]

In [582]:
similarity

[0.895295390779443,
 0.8952000444822616,
 0.890217544302979,
 0.8897570078252569,
 0.8873383257354643,
 0.885461127939464,
 0.8829983025173158,
 0.8735331987988386,
 0.873039218262138,
 0.871755670448549]

In [587]:
similar_topics

[322, 335, 228, 58, 219, 154, 49, 844, 121, 829]

In [584]:
modified_similarity

[0.7382384769486074,
 0.7380001112056541,
 0.7255438607574474,
 0.7243925195631423,
 0.6452501062171508,
 0.5902533090428357,
 0.5475847578503239,
 0.5008986165969642,
 0.47711555893147456,
 0.4559976867009102]

In [586]:
ect_bge_small.bertopic_model.representative_docs_[322]

['Stocks to Watch: SBI, RIL, Bajaj Finserv, L&T, Glenmark, Vedanta Delta Corp will be among the stocks in focus as the company will be declaring its quarterly earnings report on Tuesday.',
 'Stocks to Watch: Kotak Bank, Bajaj Auto, Adani Ports, Tata Consumer, Maruti Maruti Suzuki, Bajaj Finance, HDFC Life Insurance, SBI Life Insurance, L&T Technology Services, and Voltas will be among the stocks in focus as they will be declaring their March quarter earnings today.',
 'Stocks to Watch: TCS, RIL, HDFC Ltd, NTPC, Adani Transmission, Britannia Infosys will be among the stocks in focus as it will be declaring its March quarter earnings today.']

In [490]:
from collections import defaultdict

In [588]:
result = defaultdict(float)  # Initialize a defaultdict with default value as float (0.0)

for cluster, sim in zip(similar_clusters, modified_similarity):
    result[cluster] += sim

In [589]:
result

defaultdict(float,
            {'1870': 3.394617312979184,
             '1862': 0.7243925195631423,
             '1891': 1.5233665546752204,
             '1904': 0.5008986165969642})

In [590]:
balanced_cluster_df[balanced_cluster_df.Parent_ID == '1891']

Unnamed: 0,Parent_ID,Parent_Name,Topics,Child_Left_ID,Child_Left_Name,Child_Right_ID,Child_Right_Name,Distance,Level,num_points,points
923,1891,vaishali_buy_parekh_sell_axis,"[121, 141, 154, 339, 529, 642, 662, 687, 829]",1666,vaishali_parekh_buy_sell_recommended,1289,axis_citibank_citi_bank_picks,1.640322,10,337,"[23242, 26099, 26555, 27164, 27670, 28180, 284..."


### Function timing

In [592]:
def get_cluster_scores_for_word(word):
    similar_topics, similarity = ect_bge_small.bertopic_model.find_topics(word,top_n=10)
    modified_similarity = [((original_score-0.6)/(1-0.6))*min(1/(np.log10(idx + 2)/np.log10(5)),1) for idx, original_score in enumerate(similarity)]
    similar_clusters = [topic_cluster_map[topic] if topic != -1 else -1 for topic in similar_topics]
    result = defaultdict(float)  # Initialize a defaultdict with default value as float (0.0)
    for cluster, sim in zip(similar_clusters, modified_similarity):
        result[cluster] += sim
    return result

In [614]:
def update_cluster_priors_for_keyword(word):
    cluster_priors = {cluster: {'a': 1, 'b': 1} for cluster in balanced_clusters}
    cluster_scores = get_cluster_scores_for_word(word)
    for cluster in cluster_scores.keys():
        cluster_priors[cluster]['a'] = np.round(min(5, cluster_scores[cluster] + cluster_priors[cluster]['a']),2)
    return cluster_priors

* high b - higher prior assumption that user won't like it
* high a - higher prior that user will like it 

In [615]:
update_cluster_priors_for_keyword('elections')

{'1927': {'a': 1, 'b': 1},
 '1862': {'a': 1, 'b': 1},
 '1781': {'a': 1, 'b': 1},
 '1891': {'a': 1, 'b': 1},
 '1826': {'a': 1, 'b': 1},
 '1873': {'a': 1, 'b': 1},
 '1870': {'a': 1, 'b': 1},
 '1732': {'a': 1, 'b': 1},
 '1904': {'a': 1, 'b': 1},
 '1414': {'a': 1, 'b': 1},
 '1815': {'a': 1, 'b': 1},
 '1905': {'a': 1, 'b': 1},
 '1902': {'a': 1, 'b': 1},
 '1882': {'a': 1, 'b': 1},
 '1899': {'a': 5, 'b': 1},
 '1555': {'a': 1, 'b': 1},
 '1304': {'a': 1, 'b': 1},
 '1637': {'a': 1, 'b': 1},
 '1844': {'a': 1, 'b': 1},
 '1290': {'a': 1, 'b': 1},
 '1800': {'a': 1.45, 'b': 1},
 '1785': {'a': 1, 'b': 1},
 '1568': {'a': 1, 'b': 1},
 '1547': {'a': 1, 'b': 1},
 '1822': {'a': 1, 'b': 1},
 '1790': {'a': 1, 'b': 1},
 '1887': {'a': 1.65, 'b': 1},
 '1823': {'a': 1, 'b': 1},
 '1879': {'a': 1, 'b': 1},
 '1774': {'a': 1, 'b': 1},
 '1652': {'a': 1, 'b': 1},
 '1920': {'a': 1, 'b': 1},
 '1236': {'a': 1, 'b': 1},
 '1784': {'a': 1, 'b': 1}}

In [595]:
%timeit get_cluster_scores_for_word('elections')

32.1 ms ± 5.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [611]:
get_cluster_scores_for_word('elections')

defaultdict(float,
            {'1899': 4.358135728919208,
             '1887': 0.6489026354269658,
             '1800': 0.4456312300603669})

In [601]:
for i in get_cluster_scores_for_word('elections').keys():
    print(i)

1899
1887
1800


In [616]:
a = """UK has no plans to change migration policy for free trade deal with India
Sanchari Ghosh
The UK has no intentions of altering its strategy for reducing net migration in pursuit of a free trade agreement with India, a spokesperson for British Prime Minister Rishi Sunak confirmed in a recent announcement. This comes only a day ahead Sunak is headed for India to attend the G20 summit.
"The prime minister believes that the current levels of migration are too high ... To be crystal clear, there are no plans to change our immigration policy to achieve this free trade agreement and that includes student visas," Sunak's spokesperson told journalists.
Last year, Interior Minister Suella Braverman stirred controversy by expressing concerns about the potential impact of Indian migrants on trade talks. She cited worries about an "open borders migration policy with India" and individuals overstaying visas.
However, as the talks between the two countries started, trade minister Kemi Badenoch, earlier this year, had asserted Britain would discuss temporary business visas as part of trade talks but not broader immigration commitments or access to Britain's labour market for Indian workers.
‘Visas were never part of our ask’
However India's High Commissioner to Britain Vikram Doriaswamy said that the notion India wanted more visas had been in the British press but not in Indian media.
"We never said that the visas are part of our ask," he told Times radio, adding that India instead sought simpler ways for companies to move UK and Indian nationals between the countries.
"We are not asking for migrants to be able to come here."
Modern, forward-looking free trade agreement: UK PM Rishi Sunak
Both India and UK are optimistic about finalising a trade deal this year, although several challenging topics still need to be addressed. 
About the deal, Sunak said, a modern, forward-looking free trade agreement can put us firmly on the path to our shared ambition of doubling UK-India trade by 2030,"
"It's very exciting to have this opportunity to expand our trade relationship, and to be the first European country that India has negotiated a free trade deal with," he said.
The prime minister’s responses to PTI’s questions were sent by email."""

In [618]:
len(a.split(' '))/150

2.36