In [15]:
# ---------------- Standard libraries ----------------
import os
import re
import time
import shutil
import warnings
from collections import defaultdict
import pickle

# ---------------- Data handling ----------------
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds

# ---------------- Machine learning & NLP ----------------
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics.pairwise import cosine_similarity
import hdbscan
import umap

# ---------------- Deep learning & Transformers ----------------
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer

# ---------------- Visualization ----------------
import plotly.express as px
import plotly.io as pio

Embedding text from Parquet data frames using Gemma300m

In [5]:
# ---------------- Settings ----------------
INPUT_PARQUET = "output/raw_posts_kafka"
BASE_DIR = "output/raw_posts_embeddings_gemma"
MODEL_ID = "google/embeddinggemma-300m"
BATCH_SIZE = 256
MAX_SEQ_LEN = 128
NORMALIZE = False

os.environ["TOKENIZERS_PARALLELISM"] = "true"

# ---------------- NSFW / filtering ----------------
NSFW_KEYWORDS = set([
    "porn","sex","nude","sexy","fuck","cock","cum","blowjob",
    "dick","tits","ass","horny","slut","nsfw","onlyfans"
])

MULTILINGUAL_STOPWORDS = set([
    "the","and","a","of","in","to","is","it","for","on",
    "que","de","le","la","el","en","und","der","die"
])

URL_PATTERN = re.compile(r"(https?://\S+|www\.\S+)")

def preprocess(text):
    text = str(text).lower()
    text = re.sub(r"@\w+", "", text)
    text = re.sub(r"#", " ", text)
    text = re.sub(r"[0-9]+", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    text_words = [w for w in text.split() if w not in MULTILINGUAL_STOPWORDS]
    return " ".join(text_words)

def is_meaningful(text):
    if not re.search(r"\w", text):
        return False
    tokens = text.split()
    if len(tokens) == 0:
        return False
    url_count = sum(bool(URL_PATTERN.match(tok)) for tok in tokens)
    if url_count / len(tokens) > 0.5:
        return False
    return True

def is_nsfw(text):
    text_lower = text.lower()
    return any(kw in text_lower for kw in NSFW_KEYWORDS)

# ---------------- Load posts ----------------
df = pd.read_parquet(INPUT_PARQUET, columns=["did", "rkey", "text"])
print(f"Before filtering: {len(df)} posts.")

df["post_id"] = df["did"].astype(str) + "/" + df["rkey"].astype(str)
df = df[df["text"].notna()]
df["text"] = df["text"].astype(str).str.strip()
df = df[df["text"].str.len() > 0]
df["text_clean"] = df["text"].apply(preprocess)

# ---------------- Preprocessing ----------------
df = df[df["text_clean"].apply(is_meaningful)]
df = df[~df["text_clean"].apply(is_nsfw)]
df = df[df["text_clean"].str.len() > 0]

print(f"After filtering: {len(df)} posts remain.")

# ---------------- Load model ----------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(MODEL_ID, device=device,truncate_dim=256).eval()
try:
    model.max_seq_length = MAX_SEQ_LEN
except Exception:
    pass

if device == "cuda":
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    except Exception:
        pass

# ---------------- Encode embeddings ----------------
texts = df["text_clean"].tolist()
embeddings = model.encode(
    texts,
    prompt_name="Clustering",
    batch_size=BATCH_SIZE,
    show_progress_bar=True,
    normalize_embeddings=False,
    convert_to_numpy=True
).astype(np.float32)

# ---------------- Validate embeddings ----------------
finite_mask = np.isfinite(embeddings).all(axis=1)
if not finite_mask.all():
    n_bad = (~finite_mask).sum()
    print(f"Found {n_bad} rows with NaN/Inf; re-encoding just those rows…")
    bad_idx = np.where(~finite_mask)[0]
    bad_texts = [texts[i] for i in bad_idx]

    fixed = model.encode(
        bad_texts,
        batch_size=max(64, BATCH_SIZE // 2),
        show_progress_bar=True,
        normalize_embeddings=False,
        convert_to_numpy=True
    ).astype(np.float32)

    embeddings[bad_idx] = fixed
    assert np.isfinite(embeddings).all(), "Still found NaN/Inf after repair"

if NORMALIZE:
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    nz = norms.squeeze() > 0
    embeddings[nz] = embeddings[nz] / norms[nz]

# ---------------- Save embeddings ----------------
shutil.rmtree(BASE_DIR, ignore_errors=True)
os.makedirs(BASE_DIR, exist_ok=True)

emb_list_arrays = [pa.array(row, type=pa.float32()) for row in embeddings]
table = pa.table({
    "post_id": pa.array(df["post_id"].tolist()),
    "text": pa.array(df["text"].tolist()),
    "text_clean": pa.array(df["text_clean"].tolist()),
    "embedding": pa.array(emb_list_arrays, type=pa.list_(pa.float32()))
})

fmt = ds.ParquetFileFormat()
opts = fmt.make_write_options(compression=os.environ.get("PARQUET_COMPRESSION", "zstd"))

ds.write_dataset(
    data=table,
    base_dir=BASE_DIR,
    format=fmt,
    file_options=opts,
    existing_data_behavior="overwrite_or_ignore"
)

print(f"✓ Wrote dataset to directory: {BASE_DIR}")

Before filtering: 17771 posts.
After filtering: 16177 posts remain.


Batches:   0%|          | 0/64 [00:00<?, ?it/s]

✓ Wrote dataset to directory: output/raw_posts_embeddings_gemma


Cluster the posts and prep for visualisation

In [None]:
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore",message="n_jobs value .* overridden to 1 by setting random_state")
BASE_DIR = "output/raw_posts_embeddings_gemma"
OUTPUT_DIR = "output"

# ---------------- Load embeddings ----------------
print("Loading Embeddings...")
dataset = ds.dataset(BASE_DIR, format="parquet")
table = dataset.to_table(columns=["post_id", "text", "text_clean", "embedding"])

post_id = table.column("post_id").to_pylist()
texts = table.column("text").to_pylist()
texts_clean = table.column("text_clean").to_pylist()
embeddings = np.vstack(table.column("embedding").to_pylist()).astype(np.float32)

# Normalize embeddings
print("Normalizing Embeddings...")
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
nonzero = norms.squeeze() > 0
embeddings[nonzero] /= norms[nonzero]

#remove comment below if you want further dim reduction before clustering or not
"""
# ---------------- Dimensionality reduction ----------------

print("Dimensionality reduction...")
print("Embedding Dimension ---> 64D...")
start_time = time.time()
reducer_high = umap.UMAP(random_state=2025,n_components=64, metric="cosine")
X_64d = reducer_high.fit_transform(embeddings)
end_time = time.time()
print(f"Dimensionality reduction done in {end_time - start_time:.2f} seconds")
embeddings = X_64d
"""

# ---------------- HDBSCAN clustering ---------------- 
print("HDBSCAN clustering...")
print(f"Number of embeddings: {len(embeddings)}")
start_time = time.time()

clusterer = hdbscan.HDBSCAN(min_cluster_size=10,metric="euclidean",cluster_selection_method="eom")
labels = clusterer.fit_predict(embeddings) 

end_time = time.time()
print(f"HDBSCAN clustering done in {end_time - start_time:.2f} seconds")
# ---------------- Group clusters ----------------
clusters = defaultdict(list)
for i, c in enumerate(labels):
    if c == -1:
        continue
    clusters[c].append(i)

print(f"Found {len(clusters)} clusters")
# ---------------- Compute cluster centroids in embedding space ----------------
cluster_ids = []
cluster_centroids = []

for c, idxs in clusters.items():
    cluster_ids.append(c)
    cluster_centroids.append(embeddings[idxs].mean(axis=0))

cluster_centroids = np.vstack(cluster_centroids)

# ---------------- Reduce centroids to 2D for visualization ----------------
print("Reducing cluster centroids to 2D...")
reducer_2d = umap.UMAP(random_state=2025,n_components=2,metric="cosine",  min_dist=0.1,spread=1.0,)
X_2d_centroids = reducer_2d.fit_transform(cluster_centroids)
# Map back to clusters for plotting
cluster_positions = {c: X_2d_centroids[i] for i, c in enumerate(cluster_ids)}

Loading Embeddings...
Normalizing Embeddings...
Dimensionality reduction...
Embedding Dimension ---> 64D...
HDBSCAN clustering...
Number of embeddings: 16177
HDBSCAN clustering done in 4.21 seconds
Found 118 clusters
Reducing cluster centroids to 2D...


Create labels for clusters using GenAI

In [18]:
# ---------------- Gemma 2 Cluster Labeling ----------------
print("Generating cluster labels with Gemma 2...")
model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
)
model = model.to("cuda")
device = model.device
def clean_post(text):
    text = re.sub(r"http\S+", "", text)  # remove URLs
    text = re.sub(r"[^\w\s]", "", text)  # remove punctuation
    text = re.sub(r"\s+", " ", text)     # normalize whitespace
    return text.strip()

def generate_cluster_label(posts, max_new_tokens=16):
    cleaned_posts = [clean_post(p) for p in posts[:5]]

    prompt = (
        "Summarize the following posts into 1 concise topic of 3–5 words in English. "
        "Do NOT include emojis, quotes, punctuation, hashtags, URLs, or extra commentary.\n\n"
        + "\n".join(cleaned_posts)
    )

    user_prompt = {"role": "user", "content": prompt}
    assistant_prompt = {"role": "assistant", "content": "Provide a concise topic."}

    # Returns a single tensor
    chat_prompt = tokenizer.apply_chat_template(
        [user_prompt, assistant_prompt],
        tokenize=True,
        return_tensors="pt",
        add_generation_prompt=True
    ).to(device)  # single tensor

    outputs = model.generate(chat_prompt, max_new_tokens=max_new_tokens)
    label = tokenizer.decode(outputs[0], skip_special_tokens=True)
    label = re.sub(r'\s+', ' ', label).strip()
    parts = re.split(r'model\s+Provide a concise topic\.?\s*model', label, flags=re.IGNORECASE)
    label = parts[-1].strip()
    return label


# ---------------- Generate cluster summaries ----------------
cluster_summary = []

for i, c in enumerate(cluster_ids):
    doc_idx = clusters[c]
    if not doc_idx:
        continue

    E_cluster = embeddings[doc_idx]
    centroid_hd = E_cluster.mean(axis=0)
    sim = cosine_similarity(E_cluster, centroid_hd.reshape(1, -1)).ravel()
    top_idx = np.argsort(-sim)[:3]
    top_posts = [texts_clean[doc_idx[j]] for j in top_idx]

    
    label_text = generate_cluster_label(top_posts)
    

    centroid_2d = X_2d_centroids[i]
    cluster_summary.append({
        "cluster": int(c),
        "size": len(doc_idx),
        "x": float(centroid_2d[0]),
        "y": float(centroid_2d[1]),
        "summary": label_text
    })

# ---------------- Save cluster summary ----------------
with open(os.path.join(OUTPUT_DIR, "clusters.pkl"), "wb") as f:
    pickle.dump(cluster_summary, f)

print(f"✓ Saved {len(cluster_summary)} clusters with Gemma 2 labels to {os.path.join(OUTPUT_DIR, 'clusters.pkl')}")

Generating cluster labels with Gemma 2...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Saved 118 clusters with Gemma 2 labels to output\clusters.pkl


Visualize

In [19]:
OUTPUT_DIR = "output"

# ---------------- Load cluster summary ----------------
with open(os.path.join(OUTPUT_DIR, "clusters.pkl"), "rb") as f:
    cluster_summary = pickle.load(f)

df_plot = pd.DataFrame(cluster_summary)

# ---------------- Scale positions to [0,1] for visualization ----------------
scaler = MinMaxScaler(feature_range=(0, 1))
df_plot[["x", "y"]] = scaler.fit_transform(df_plot[["x", "y"]])

# ---------------- Scale bubble size for visibility ----------------
df_plot["size_scaled"] = np.sqrt(df_plot["size"])

# ---------------- Assign random color per cluster ----------------
unique_clusters = df_plot["cluster"].unique()
colors = px.colors.sample_colorscale(
    "Rainbow", [i / (len(unique_clusters) - 1) for i in range(len(unique_clusters))]
)
color_map = dict(zip(unique_clusters, colors))
df_plot["color"] = df_plot["cluster"].map(color_map)

# ---------------- Add readable label column ----------------
df_plot["label"] = df_plot["cluster"].apply(lambda c: f"Cluster {c}")

# ---------------- Plot with summaries overlayed ----------------
fig = px.scatter(
    df_plot,
    x="x",
    y="y",
    size="size_scaled",
    color="color",
    text="summary",  # overlay AI-generated summaries
    hover_data={"label": True, "size": True, "cluster": True},
    title="Clusters of Posts",
    width=1000,
    height=700,
)

# ---------------- Improve visualization ----------------
fig.update_traces(textposition="middle center", textfont_size=10)
fig.update_layout(
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    showlegend=False  
)

# ---------------- Use browser renderer to avoid nbformat error ----------------
pio.renderers.default = "browser"
fig.show()



