Additions

In [None]:
import matplotlib.pyplot as plt

# Plot the distribution of topic probabilities to understand the spread
data['probability'].hist(bins=30, edgecolor='black', alpha=0.7)
plt.title("Distribution of Topic Probabilities")
plt.xlabel("Probability")
plt.ylabel("Frequency")
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Assuming 'sentiment_score' is the column for sentiment score
data['sentiment_score'].hist(bins=30, edgecolor='black', alpha=0.7)
plt.title("Distribution of Sentiment Scores")
plt.xlabel("Sentiment Score")
plt.ylabel("Frequency")
plt.show()

In [None]:
from bertopic import BERTopic
import plotly.express as px

# Load the pre-trained BERTopic model
topic_model = BERTopic.load("bertopic_model_update_3")

# Generate the interactive visualization
fig = topic_model.visualize_topics()

# Show the plot in an interactive view
fig.show()

# Optionally, save the figure as an HTML file
fig.write_html("topic_model_visualization.html")

In [None]:
import pandas as pd
from sqlalchemy import create_engine
from dotenv import dotenv_values
import plotly.graph_objects as go

# --- Load DB config ---
config = dotenv_values()

pg_user = config['POSTGRES_USER']
pg_host = config['POSTGRES_HOST']
pg_port = config['POSTGRES_PORT']
pg_db = config['POSTGRES_DB']
pg_schema = config['POSTGRES_SCHEMA']
pg_pass = config['POSTGRES_PASS']

# --- Create DB connection ---
url = f'postgresql://{pg_user}:{pg_pass}@{pg_host}:{pg_port}/{pg_db}'
engine = create_engine(url)

# --- Read tables ---
df_3 = pd.read_sql(f'SELECT * FROM {pg_schema}."table_test_3"', con=engine)
df_4 = pd.read_sql(f'SELECT * FROM {pg_schema}."table_test_4"', con=engine)

# --- Clean and merge ---
df_3 = df_3[['topic', 'topic_name', 'metro', 'percentage']].dropna()
df_4 = df_4[['topic', 'topic_name', 'key_category', 'percentage']].dropna()
df = pd.merge(df_3, df_4, on=['topic', 'topic_name'], suffixes=('_metro', '_category'))

# --- Rename topics ---
topic_name_map = {
    '7_thai_pho_chinese_soup': 'Asian Cuisine',
    '5_pizza_crust_pizzas_best pizza': 'Pizza',
    '4_hair_massage_dress_salon': 'Hair Salon',
    '2_dr_dentist_office_dental': 'Dentist / Doctors',
    '11_nails_nail_salon_gel': 'Nail Salon Positive',
    '9_room_hotel_stay_desk': 'Hotels',
    '58_used_changed_gone_quality': 'Change for the Worse',
    '31_dog_vet_dogs_cat': 'Pets',
    '14_nails_nail_gel_polish': 'Nail Salon Negative',
    '111_pharmacy_prescription_walgreens_cvs': 'Pharmacy'
}
df['topic_name'] = df['topic_name'].replace(topic_name_map)

# --- Aggregate top combinations ---
grouped = df.groupby(['key_category', 'metro'])['percentage_metro'].sum().reset_index()
top_combinations = grouped.sort_values('percentage_metro', ascending=False).head(30)

# --- Encode labels as integers for axis ---
metro_list = top_combinations['metro'].unique()
category_list = top_combinations['key_category'].unique()

top_combinations['x'] = top_combinations['metro'].apply(lambda m: list(metro_list).index(m))
top_combinations['y'] = top_combinations['key_category'].apply(lambda c: list(category_list).index(c))
top_combinations['z'] = top_combinations['percentage_metro']

# --- Create Plotly 3D Bubble Chart ---
fig = go.Figure(data=[
    go.Scatter3d(
        x=top_combinations['x'],
        y=top_combinations['y'],
        z=top_combinations['z'],
        mode='markers',
        marker=dict(
            size=top_combinations['z'],
            sizemode='area',
            sizeref=2.*max(top_combinations['z'])/(40.**2),
            color=top_combinations['z'],
            colorscale='Viridis',
            opacity=0.8
        ),
        text=[f"Metro: {m}<br>Category: {c}<br>Popularity: {p:.2f}%" 
              for m, c, p in zip(top_combinations['metro'], top_combinations['key_category'], top_combinations['z'])]
    )
])

fig.update_layout(
    title="Top Metro × Key Category Popularity (3D Bubble Chart)",
    scene=dict(
        xaxis=dict(title='Metro', tickvals=list(range(len(metro_list))), ticktext=metro_list),
        yaxis=dict(title='Key Category', tickvals=list(range(len(category_list))), ticktext=category_list),
        zaxis=dict(title='Popularity %'),
    ),
    margin=dict(l=0, r=0, b=0, t=40)
)

fig.show()