# TriTopic: Complete Feature Demo

This notebook demonstrates **every feature** of the TriTopic library using the 20 Newsgroups dataset. It serves as a practical reference for how to use TriTopic in your own projects.

**What we will cover:**

1. Loading and preparing a news dataset
2. Fitting a model with dimensionality reduction and iterative refinement
3. Inspecting topics: keywords, representative documents, topic info
4. Soft topic assignments (probabilities)
5. Outlier reduction (embedding and neighbor strategies)
6. Topic merging (`reduce_topics` and `merge_topics`)
7. Evaluation metrics
8. All 5 visualization types
9. Predicting topics for new documents (hard labels + soft probabilities)
10. LLM-powered and rule-based topic labeling
11. Custom configuration
12. Save and load
13. Comparison with ground truth

---
## 1. Setup and Data Loading

We use scikit-learn's **20 Newsgroups** dataset, selecting 6 distinct categories to get clear, well-separated topics. We remove headers, footers, and quotes so the model works on article body text only.

In [None]:
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_20newsgroups

# Select 6 distinct categories
categories = [
    "sci.med",
    "sci.space",
    "rec.sport.baseball",
    "rec.autos",
    "comp.graphics",
    "talk.politics.guns",
]

newsgroups = fetch_20newsgroups(
    subset="all",
    remove=("headers", "footers", "quotes"),
    categories=categories,
)

documents = newsgroups.data
true_labels = newsgroups.target
target_names = newsgroups.target_names

print(f"Loaded {len(documents)} documents from {len(target_names)} categories")
print(f"Categories: {target_names}")
print(f"\nSample document (first 300 chars):\n{documents[0][:300]}...")

Some newsgroup posts are very short or empty after header removal. Let's filter to documents with at least 50 characters.

In [None]:
# Filter out very short documents
mask = np.array([len(doc.strip()) >= 50 for doc in documents])
documents = [doc for doc, keep in zip(documents, mask) if keep]
true_labels = true_labels[mask]

print(f"After filtering: {len(documents)} documents")
print(f"Documents per category:")
for i, name in enumerate(target_names):
    print(f"  {name}: {(true_labels == i).sum()}")

---
## 2. Fitting the Model

We create a TriTopic model with default settings. This will:
1. Encode documents with `all-MiniLM-L6-v2` (384d embeddings)
2. Reduce to 10d with UMAP for better kNN neighbor quality
3. Build a hybrid multi-view graph (semantic + lexical)
4. Run consensus Leiden clustering (10 runs)
5. Iteratively refine embeddings until convergence
6. Extract c-TF-IDF keywords and compute soft probabilities

In [None]:
from tritopic import TriTopic

model = TriTopic(
    n_neighbors=15,
    use_iterative_refinement=True,
    verbose=True,
)

labels = model.fit_transform(documents)

---
## 3. Inspecting Topics

After fitting, we can explore the discovered topics through several interfaces.

### 3a. Topic overview table

`get_topic_info()` returns a DataFrame with topic IDs, sizes, keywords, labels, and coherence scores.

In [None]:
topic_df = model.get_topic_info()
topic_df[["Topic", "Size", "Keywords"]]

### 3b. Single topic detail

`get_topic(topic_id)` returns a `TopicInfo` object with all fields: keywords, scores, representative doc indices, centroid, etc.

In [None]:
# Pick the largest non-outlier topic
largest_topic = [t for t in model.topics_ if t.topic_id != -1][0]
topic = model.get_topic(largest_topic.topic_id)

print(f"Topic {topic.topic_id}")
print(f"  Size: {topic.size} documents")
print(f"  Top 10 keywords: {topic.keywords}")
print(f"  Keyword scores:  {[f'{s:.4f}' for s in topic.keyword_scores]}")
print(f"  Centroid shape:  {topic.centroid.shape}")
print(f"  Representative doc indices: {topic.representative_docs}")

### 3c. Representative documents

`get_representative_docs()` returns the documents closest to the topic centroid.

In [None]:
# Show representative docs for the first 3 non-outlier topics
for topic_info in model.topics_[:5]:
    if topic_info.topic_id == -1:
        continue

    print(f"\n{'='*70}")
    print(f"Topic {topic_info.topic_id} | Keywords: {', '.join(topic_info.keywords[:5])}")
    print(f"{'='*70}")

    rep_docs = model.get_representative_docs(topic_info.topic_id, n_docs=2)
    for idx, text in rep_docs:
        print(f"  [Doc {idx}] {text[:200].strip()}...")
        print()

---
## 4. Dimensionality Reduction

By default, TriTopic reduces embeddings from 384d to 10d using UMAP before building the kNN graph. This combats the curse of dimensionality. Full-dimensional embeddings are preserved for centroid computation and keyword extraction.

In [None]:
print(f"Full embeddings shape:    {model.embeddings_.shape}")
print(f"Reduced embeddings shape: {model.reduced_embeddings_.shape}")
print(f"Topic centroids shape:    {model.topic_embeddings_.shape}  (computed from full embeddings)")
print(f"\nDim reduction method: {model.config.dim_reduction_method}")
print(f"Target dimensions:    {model.config.reduced_dims}")
print(f"UMAP min_dist:        {model.config.umap_min_dist}  (0.0 = optimized for clustering)")

---
## 5. Soft Topic Assignments (Probabilities)

Every document receives a probability distribution over all topics, not just a hard label. This is computed via cosine similarity to topic centroids, followed by softmax.

In [None]:
print(f"Probabilities shape: {model.probabilities_.shape}")
print(f"  (n_documents={model.probabilities_.shape[0]}, n_topics={model.probabilities_.shape[1]})")
print(f"\nRow sums (should be ~1.0):")
print(f"  First 5 rows: {model.probabilities_[:5].sum(axis=1)}")

In [None]:
# Show the probability distribution for a specific document
doc_idx = 0
non_outlier_ids = [t.topic_id for t in model.topics_ if t.topic_id != -1]

print(f"Document {doc_idx}: \"{documents[doc_idx][:100].strip()}...\"")
print(f"Hard label: {model.labels_[doc_idx]}")
print(f"\nSoft probabilities:")

for i, tid in enumerate(non_outlier_ids):
    prob = model.probabilities_[doc_idx, i]
    topic = model.get_topic(tid)
    bar = '#' * int(prob * 50)
    print(f"  Topic {tid:2d} ({', '.join(topic.keywords[:3]):30s}): {prob:.4f} {bar}")

---
## 6. Outlier Reduction

Leiden clustering combined with small-cluster removal often produces outliers. TriTopic provides two strategies to reassign them.

In [None]:
n_outliers_before = int((model.labels_ == -1).sum())
n_total = len(model.labels_)
print(f"Before outlier reduction: {n_outliers_before} outliers ({100*n_outliers_before/n_total:.1f}%)")

### Strategy 1: Embeddings

Each outlier is assigned to the topic whose centroid is most similar (cosine similarity), if the similarity exceeds a threshold.

In [None]:
# Use a low threshold for aggressive reassignment
model.reduce_outliers(strategy="embeddings", threshold=0.05)

n_outliers_after = int((model.labels_ == -1).sum())
print(f"After 'embeddings' strategy: {n_outliers_after} outliers ({100*n_outliers_after/n_total:.1f}%)")
print(f"Reassigned: {n_outliers_before - n_outliers_after} documents")

### Strategy 2: Neighbors

Each remaining outlier is assigned by majority vote of its k nearest non-outlier neighbors. This is threshold-free.

In [None]:
n_before_neighbors = int((model.labels_ == -1).sum())

model.reduce_outliers(strategy="neighbors")

n_after_neighbors = int((model.labels_ == -1).sum())
print(f"After 'neighbors' strategy: {n_after_neighbors} outliers ({100*n_after_neighbors/n_total:.1f}%)")
print(f"Reassigned: {n_before_neighbors - n_after_neighbors} additional documents")
print(f"\nTotal outliers reduced: {n_outliers_before} -> {n_after_neighbors}")

In [None]:
# Verify that downstream state was refreshed
print("Updated topic sizes after outlier reduction:")
for t in model.topics_:
    label = "Outliers" if t.topic_id == -1 else f"Topic {t.topic_id}"
    print(f"  {label}: {t.size} docs")

print(f"\nProbabilities shape: {model.probabilities_.shape}  (recomputed automatically)")

---
## 7. Topic Merging

TriTopic offers two ways to merge topics after fitting: automatic merging to a target count, and manual merging of specific topic IDs.

In [None]:
n_topics_before = len([t for t in model.topics_ if t.topic_id != -1])
print(f"Topics before merging: {n_topics_before}")
print(f"Topic IDs: {[t.topic_id for t in model.topics_ if t.topic_id != -1]}")

### 7a. Automatic: `reduce_topics(n_topics)`

Iteratively merges the two most cosine-similar topic centroids until the target count is reached.

In [None]:
# Reduce to 6 topics (matching our 6 ground-truth categories)
model.reduce_topics(6)

n_topics_after = len([t for t in model.topics_ if t.topic_id != -1])
print(f"Topics after reduce_topics(6): {n_topics_after}")
print()
model.get_topic_info()[["Topic", "Size", "Keywords"]]

### 7b. Manual: `merge_topics(topic_ids)`

Merge specific topic IDs together. The largest topic's ID is kept.

In [None]:
# Demonstrate manual merge on two topics (pick the two smallest)
non_outlier = [t for t in model.topics_ if t.topic_id != -1]
if len(non_outlier) >= 3:
    # Sort by size to find the two smallest
    smallest = sorted(non_outlier, key=lambda t: t.size)[:2]
    ids_to_merge = [t.topic_id for t in smallest]
    print(f"Merging topics {ids_to_merge} (sizes: {[t.size for t in smallest]})")

    model.merge_topics(ids_to_merge)

    n_after_merge = len([t for t in model.topics_ if t.topic_id != -1])
    print(f"Topics after manual merge: {n_after_merge}")
    print()
    print(model.get_topic_info()[["Topic", "Size", "Keywords"]].to_string())
else:
    print("Not enough topics to demonstrate manual merge (need >= 3).")

---
## 8. Evaluation

`evaluate()` computes coherence, diversity, stability, and outlier ratio.

In [None]:
metrics = model.evaluate()

print("\nMetrics dictionary:")
for key, value in metrics.items():
    if value is not None:
        print(f"  {key:20s}: {value:.4f}" if isinstance(value, float) else f"  {key:20s}: {value}")
    else:
        print(f"  {key:20s}: N/A")

### Additional standalone metrics

In [None]:
from tritopic.utils.metrics import compute_silhouette, compute_downstream_score

# Silhouette score (cluster separation quality)
sil = compute_silhouette(model.embeddings_, model.labels_)
print(f"Silhouette score: {sil:.4f}")

# Downstream classification F1 (using ground truth)
f1 = compute_downstream_score(
    model.embeddings_, model.labels_, true_labels, task="classification"
)
print(f"Downstream F1 (macro): {f1:.4f}")

---
## 9. Visualizations

TriTopic provides 5 interactive visualization types, all powered by Plotly.

### 9a. Document map

2D UMAP projection where each point is a document, colored by topic assignment.

In [None]:
fig = model.visualize(
    method="umap",
    show_outliers=True,
    title="TriTopic Document Map (20 Newsgroups)",
)
fig.show()

### 9b. Topic keywords

Horizontal bar charts showing the top keywords and their importance scores for each topic.

In [None]:
fig = model.visualize_topics(
    n_keywords=8,
    title="Topic Keywords Overview",
)
fig.show()

### 9c. Topic hierarchy (dendrogram)

Shows how topics relate to each other based on cosine distance between centroids.

In [None]:
fig = model.visualize_hierarchy(
    title="Topic Hierarchy (Dendrogram)",
)
fig.show()

### 9d. Topic similarity heatmap

Cosine similarity between all topic centroid pairs.

In [None]:
from tritopic import TopicVisualizer

viz = TopicVisualizer()
fig = viz.plot_topic_similarity(
    model.topic_embeddings_,
    model.topics_,
    title="Topic Similarity Matrix",
)
fig.show()

### 9e. Topics over time

Stacked area chart showing topic prevalence over time. Since 20 Newsgroups doesn't have timestamps, we generate synthetic dates for demonstration purposes.

In [None]:
# Generate synthetic timestamps for demonstration
np.random.seed(42)
fake_dates = pd.date_range("2024-01-01", periods=12, freq="MS")  # 12 months
timestamps = np.random.choice(fake_dates, size=len(documents))

fig = viz.plot_topic_over_time(
    labels=model.labels_,
    timestamps=timestamps,
    topics=model.topics_,
    title="Topic Distribution Over Time (synthetic dates)",
)
fig.show()

---
## 10. Predicting Topics for New Documents

After fitting, you can assign topics to documents the model has never seen.

### 10a. Hard labels via `transform()`

In [None]:
new_documents = [
    "NASA's Perseverance rover discovered organic molecules on Mars, raising hopes for ancient microbial life.",
    "The Yankees clinched the division title with a walk-off home run in the bottom of the ninth inning.",
    "NVIDIA announced its next-generation GPU with real-time ray tracing and AI-accelerated rendering.",
    "A new study in the Lancet shows promising results for an mRNA-based cancer vaccine in phase 3 trials.",
    "The debate over gun control legislation intensified after another mass shooting incident.",
    "Ford revealed its all-electric F-150 Lightning truck with 300 miles of range and fast charging.",
]

new_labels = model.transform(new_documents)

print("Predicted topics for new documents:\n")
for doc, label in zip(new_documents, new_labels):
    if label == -1:
        topic_keywords = "(outlier)"
    else:
        topic = model.get_topic(label)
        topic_keywords = ", ".join(topic.keywords[:4])
    print(f"  Topic {label:2d} [{topic_keywords}]")
    print(f"    -> {doc[:90]}...")
    print()

### 10b. Soft probabilities via `transform_proba()`

In [None]:
new_proba = model.transform_proba(new_documents)

print(f"Probabilities shape: {new_proba.shape}")
print(f"Row sums: {new_proba.sum(axis=1)}\n")

# Show probabilities as a DataFrame
non_outlier_ids = [t.topic_id for t in model.topics_ if t.topic_id != -1]
proba_df = pd.DataFrame(
    new_proba,
    columns=[f"Topic {tid}" for tid in non_outlier_ids],
    index=[f"Doc {i}: {doc[:50]}..." for i, doc in enumerate(new_documents)],
)
proba_df.style.background_gradient(cmap="YlOrRd", axis=1).format("{:.3f}")

---
## 11. Topic Labeling

TriTopic supports LLM-powered labeling (Claude, GPT-4) and a simple rule-based fallback.

### 11a. SimpleLabeler (no API key needed)

Creates labels from the top keywords. Good for quick exploration.

In [None]:
from tritopic import SimpleLabeler

simple_labeler = SimpleLabeler(n_words=3)
model.generate_labels(simple_labeler)

print("Topics with simple labels:\n")
for topic in model.topics_:
    if topic.topic_id == -1:
        continue
    print(f"  Topic {topic.topic_id}: {topic.label}")
    print(f"    {topic.description}")
    print()

### 11b. LLM Labeler (requires API key)

Uncomment and fill in your API key to generate high-quality labels with Claude or GPT-4.

In [None]:
# ---- Uncomment one of the options below ----

# Option A: Claude (Anthropic)
from tritopic import LLMLabeler
labeler = LLMLabeler(
    provider="anthropic",
    api_key="sk-ant...",
    model="claude-3-haiku-20240307",
    language="english",
    domain_hint="news articles",
)
model.generate_labels(labeler)

# Option B: GPT-4 (OpenAI)
# from tritopic import LLMLabeler
# labeler = LLMLabeler(
#     provider="openai",
#     api_key="sk-...",
#     model="gpt-4o-mini",
#     language="english",
# )
# model.generate_labels(labeler)

# # View results
model.get_topic_info()[["Topic", "Label", "Description"]]

---
## 12. Custom Configuration

TriTopic is highly configurable. Here we show how to create a model with a custom `TriTopicConfig`.

In [None]:
from tritopic import TriTopicConfig

custom_config = TriTopicConfig(
    # Use a stronger embedding model
    embedding_model="all-MiniLM-L6-v2",  # swap to "all-mpnet-base-v2" for higher quality
    embedding_batch_size=64,

    # Dimensionality reduction
    use_dim_reduction=True,
    reduced_dims=8,
    dim_reduction_method="umap",
    umap_n_neighbors=20,
    umap_min_dist=0.0,

    # Graph construction
    n_neighbors=20,
    graph_type="hybrid",
    snn_weight=0.4,

    # Multi-view fusion
    use_lexical_view=True,
    semantic_weight=0.6,
    lexical_weight=0.4,

    # Clustering
    resolution=1.0,
    n_consensus_runs=15,
    min_cluster_size=10,

    # Iterative refinement
    use_iterative_refinement=True,
    max_iterations=7,
    convergence_threshold=0.97,

    # Keywords
    keyword_method="ctfidf",
    n_keywords=15,
    n_representative_docs=5,

    # Misc
    outlier_threshold=0.1,
    random_state=42,
    verbose=True,
)

print("Custom config created. Key settings:")
print(f"  Dim reduction:  {custom_config.reduced_dims}d ({custom_config.dim_reduction_method})")
print(f"  Graph:          {custom_config.graph_type} (k={custom_config.n_neighbors})")
print(f"  Consensus:      {custom_config.n_consensus_runs} runs")
print(f"  Refinement:     max {custom_config.max_iterations} iterations")
print(f"  Keywords:       {custom_config.keyword_method} (n={custom_config.n_keywords})")

In [None]:
# Fit a model with the custom config
# (uses the same embeddings we already computed to save time)
model_custom = TriTopic(config=custom_config)
labels_custom = model_custom.fit_transform(documents, embeddings=model.embeddings_)

In [None]:
model_custom.get_topic_info()[["Topic", "Size", "Keywords"]]

---
## 13. Save and Load

All model state is preserved: embeddings, reduced embeddings, probabilities, the fitted UMAP reducer, topics, labels, and keyword extractor state.

In [None]:
import os

save_path = "tritopic_demo_model.pkl"

# Save
model.save(save_path)
file_size_mb = os.path.getsize(save_path) / (1024 * 1024)
print(f"Model saved to {save_path} ({file_size_mb:.1f} MB)")

In [None]:
# Load
loaded_model = TriTopic.load(save_path)

print(f"Loaded model: {loaded_model}")
print(f"  Labels shape:              {loaded_model.labels_.shape}")
print(f"  Embeddings shape:          {loaded_model.embeddings_.shape}")
print(f"  Reduced embeddings shape:  {loaded_model.reduced_embeddings_.shape}")
print(f"  Probabilities shape:       {loaded_model.probabilities_.shape}")
print(f"  Topic centroids shape:     {loaded_model.topic_embeddings_.shape}")
print(f"  Dim reducer present:       {loaded_model._dim_reducer is not None}")
print(f"  Number of topics:          {len([t for t in loaded_model.topics_ if t.topic_id != -1])}")

In [None]:
# Verify loaded model produces the same predictions
test_docs = ["Hubble telescope captured images of a distant galaxy"]
original_pred = model.transform(test_docs)
loaded_pred = loaded_model.transform(test_docs)

print(f"Original model prediction:  Topic {original_pred[0]}")
print(f"Loaded model prediction:    Topic {loaded_pred[0]}")
print(f"Match: {np.array_equal(original_pred, loaded_pred)}")

# Clean up
os.remove(save_path)
print(f"\nCleaned up {save_path}")

---
## 14. Comparison with Ground Truth

Since we know the true newsgroup categories, we can measure how well TriTopic's discovered topics align with the ground truth.

In [None]:
from sklearn.metrics import (
    adjusted_rand_score,
    normalized_mutual_info_score,
    confusion_matrix,
)

# Filter out outliers for a fair comparison
non_outlier_mask = model.labels_ != -1
pred = model.labels_[non_outlier_mask]
gt = true_labels[non_outlier_mask]

ari = adjusted_rand_score(gt, pred)
nmi = normalized_mutual_info_score(gt, pred)

print(f"Comparison with ground truth ({non_outlier_mask.sum()}/{len(model.labels_)} non-outlier docs):")
print(f"  Adjusted Rand Index (ARI):       {ari:.4f}")
print(f"  Normalized Mutual Info (NMI):     {nmi:.4f}")

In [None]:
# Cross-tabulation: how do discovered topics map to true categories?
ct = pd.crosstab(
    pd.Series(gt, name="True Category").map(dict(enumerate(target_names))),
    pd.Series(pred, name="Predicted Topic"),
)
ct.style.background_gradient(cmap="Blues")

---
## 15. Summary

This notebook demonstrated all major features of TriTopic:

| Feature | Method / Attribute | Section |
|---|---|---|
| Fit model | `fit()`, `fit_transform()` | 2 |
| Topic inspection | `get_topic_info()`, `get_topic()`, `get_representative_docs()` | 3 |
| Dimensionality reduction | `reduced_embeddings_`, config `use_dim_reduction` | 4 |
| Soft assignments | `probabilities_`, `transform_proba()` | 5 |
| Outlier reduction | `reduce_outliers(strategy=...)` | 6 |
| Topic merging | `reduce_topics()`, `merge_topics()` | 7 |
| Evaluation | `evaluate()` | 8 |
| Visualizations | `visualize()`, `visualize_topics()`, `visualize_hierarchy()`, `plot_topic_similarity()`, `plot_topic_over_time()` | 9 |
| Prediction | `transform()`, `transform_proba()` | 10 |
| Topic labeling | `generate_labels()` with `SimpleLabeler` / `LLMLabeler` | 11 |
| Custom config | `TriTopicConfig(...)` | 12 |
| Save / Load | `save()`, `TriTopic.load()` | 13 |
| Ground truth comparison | sklearn metrics | 14 |

For more details, see the [README](../README.md) and [API documentation](https://tritopic.readthedocs.io).