In [None]:
# Install required packages if needed
# !pip install -e .[test]
# !pip install transformers datasets torch scikit-learn tqdm matplotlib seaborn


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import Dataset
from transformers import AutoTokenizer, AutoModel
from sklearn.decomposition import PCA

# dPrune imports
from dprune.scorers.unsupervised import KMeansCentroidDistanceScorer, get_cls_embeddings
from dprune.pruners.selection import TopKPruner, BottomKPruner, StratifiedPruner
from dprune.pipeline import PruningPipeline

print("All imports successful!")


In [None]:
# Load the TREC dataset from Hugging Face
from datasets import load_dataset
import random

# Load the dataset
trec_dataset = load_dataset("CogComp/trec", split="train")

# Get the label names for coarse categories
coarse_label_names = trec_dataset.features['coarse_label'].names
print(f"Coarse label categories: {coarse_label_names}")

# Add readable category names to the dataset
def add_category_name(example):
    example['category'] = coarse_label_names[example['coarse_label']]
    return example

raw_dataset = trec_dataset.map(add_category_name)

# Sample a subset for demonstration (use first 200 examples for faster processing)
# In practice, you can use the full dataset
sample_size = 200
indices = list(range(len(raw_dataset)))
random.seed(42)
random.shuffle(indices)
sample_indices = indices[:sample_size]

# Create a smaller dataset for this example
raw_dataset = raw_dataset.select(sample_indices)

print(f"Dataset loaded with {len(raw_dataset)} examples (sampled from {len(trec_dataset)} total)")
print(f"Categories: {set(raw_dataset['category'])}")

# Count examples per category
category_counts = {}
for cat in raw_dataset['category']:
    category_counts[cat] = category_counts.get(cat, 0) + 1

print(f"\nExamples per category:")
for category, count in category_counts.items():
    print(f"  {category}: {count}")

print("\nSample texts from each category:")
seen_categories = set()
for i, example in enumerate(raw_dataset):
    if example['category'] not in seen_categories:
        print(f"{example['category']}: '{example['text']}'")
        seen_categories.add(example['category'])
        if len(seen_categories) >= 6:  # Show all 6 categories
            break


In [None]:
# Load a pre-trained model for embeddings
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

print(f"Loaded model: {model_name}")

# Extract embeddings using the helper function
embeddings = get_cls_embeddings(raw_dataset, model, tokenizer, text_column='text')

print(f"Extracted embeddings shape: {embeddings.shape}")
print(f"Embedding dimension: {embeddings.shape[1]}")

# Add embeddings to dataset
dataset_with_embeddings = raw_dataset.add_column('embedding', embeddings.tolist())
print(f"Dataset now has columns: {dataset_with_embeddings.column_names}")


In [None]:
# Reduce dimensionality for visualization
pca = PCA(n_components=2, random_state=42)
embeddings_2d = pca.fit_transform(embeddings)

# Create a color map for categories
category_colors = {'Technology': 'red', 'Sports': 'blue', 'Cooking': 'green', 'Science': 'orange'}
colors = [category_colors[cat] for cat in categories]

# Plot the embeddings
plt.figure(figsize=(12, 8))
scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=colors, alpha=0.7, s=100)

# Add labels for each point
for i, (x, y) in enumerate(embeddings_2d):
    plt.annotate(f"{i}", (x, y), xytext=(5, 5), textcoords='offset points', fontsize=8)

# Create legend
for category, color in category_colors.items():
    plt.scatter([], [], c=color, label=category, s=100)

plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)')
plt.title('Text Embeddings Visualization (PCA)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"PCA explains {pca.explained_variance_ratio_.sum():.1%} of the total variance")


In [None]:
# Create the K-means scorer
# We'll use 4 clusters since we have 4 content categories
kmeans_scorer = KMeansCentroidDistanceScorer(num_clusters=4, random_state=42)

# Score the dataset
scored_dataset = kmeans_scorer.score(dataset_with_embeddings)

print("Dataset scored with K-means centroid distances!")
print(f"Scored dataset columns: {scored_dataset.column_names}")

# Examine the scores
scores = scored_dataset['score']
print(f"\nScore statistics:")
print(f"  Min score: {min(scores):.3f}")
print(f"  Max score: {max(scores):.3f}")
print(f"  Mean score: {np.mean(scores):.3f}")
print(f"  Std score: {np.std(scores):.3f}")

print("\nFirst few examples with scores:")
for i in range(5):
    print(f"  Score: {scores[i]:.3f}, Category: {scored_dataset['category'][i]}, Text: '{scored_dataset['text'][i][:60]}...'")


In [None]:
# Get cluster assignments from the scorer
cluster_labels = kmeans_scorer.kmeans.labels_
centroids = kmeans_scorer.kmeans.cluster_centers_

# Project centroids to 2D for visualization
centroids_2d = pca.transform(centroids)

# Create subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# Plot 1: Clusters found by K-means
scatter1 = ax1.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=cluster_labels, cmap='viridis', alpha=0.7, s=100)
ax1.scatter(centroids_2d[:, 0], centroids_2d[:, 1], c='red', marker='x', s=200, linewidths=3, label='Centroids')
ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)')
ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)')
ax1.set_title('K-Means Clusters')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Add cluster labels
for i, (x, y) in enumerate(embeddings_2d):
    ax1.annotate(f"{i}({cluster_labels[i]})", (x, y), xytext=(5, 5), textcoords='offset points', fontsize=8)

# Plot 2: Scores as colors
scatter2 = ax2.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=scores, cmap='coolwarm', alpha=0.7, s=100)
ax2.scatter(centroids_2d[:, 0], centroids_2d[:, 1], c='black', marker='x', s=200, linewidths=3, label='Centroids')
ax2.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)')
ax2.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)')
ax2.set_title('Distance Scores (Blue=Low, Red=High)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Add score labels
for i, (x, y) in enumerate(embeddings_2d):
    ax2.annotate(f"{i}({scores[i]:.2f})", (x, y), xytext=(5, 5), textcoords='offset points', fontsize=8)

# Add colorbars
plt.colorbar(scatter1, ax=ax1, label='Cluster ID')
plt.colorbar(scatter2, ax=ax2, label='Distance Score')

plt.tight_layout()
plt.show()

# Analyze cluster composition
print("Cluster composition by category:")
for cluster_id in range(4):
    cluster_mask = cluster_labels == cluster_id
    cluster_categories = [categories[i] for i in range(len(categories)) if cluster_mask[i]]
    category_counts = {cat: cluster_categories.count(cat) for cat in set(cluster_categories)}
    print(f"  Cluster {cluster_id}: {category_counts}")


In [None]:
# Strategy 1: Keep examples closest to centroids (most representative)
bottom_pruner = BottomKPruner(k=0.5)  # Keep bottom 50% (lowest distances)
pipeline_representative = PruningPipeline(scorer=kmeans_scorer, pruner=bottom_pruner)
representative_examples = pipeline_representative.run(dataset_with_embeddings)

# Strategy 2: Keep examples farthest from centroids (most diverse/outliers)
top_pruner = TopKPruner(k=0.5)  # Keep top 50% (highest distances)
pipeline_diverse = PruningPipeline(scorer=kmeans_scorer, pruner=top_pruner)
diverse_examples = pipeline_diverse.run(dataset_with_embeddings)

# Strategy 3: Stratified sampling across score ranges
stratified_pruner = StratifiedPruner(k=0.5, num_strata=4)
pipeline_stratified = PruningPipeline(scorer=kmeans_scorer, pruner=stratified_pruner)
stratified_examples = pipeline_stratified.run(dataset_with_embeddings)

print("Pruning Results:")
print(f"Original dataset: {len(scored_dataset)} examples")
print(f"Representative examples (closest to centroids): {len(representative_examples)} examples")
print(f"Diverse examples (farthest from centroids): {len(diverse_examples)} examples")
print(f"Stratified examples (balanced across score ranges): {len(stratified_examples)} examples")

def analyze_category_distribution(dataset, name):
    category_counts = {}
    for cat in dataset['category']:
        category_counts[cat] = category_counts.get(cat, 0) + 1
    print(f"\n{name} category distribution:")
    for cat, count in category_counts.items():
        percentage = (count / len(dataset)) * 100
        print(f"  {cat}: {count} ({percentage:.1f}%)")

analyze_category_distribution(representative_examples, "Representative examples")
analyze_category_distribution(diverse_examples, "Diverse examples")
analyze_category_distribution(stratified_examples, "Stratified examples")


In [None]:
print("=== REPRESENTATIVE EXAMPLES (Closest to Centroids) ===")
print("These are the most 'typical' examples of each cluster:")
rep_scores = representative_examples['score']
rep_indices = [scored_dataset['text'].index(text) for text in representative_examples['text']]

for i, (idx, score) in enumerate(zip(rep_indices, rep_scores)):
    print(f"\nExample {i+1} (Original index: {idx}, Score: {score:.3f}):")
    print(f"  Category: {representative_examples['category'][i]}")
    print(f"  Cluster: {cluster_labels[idx]}")
    print(f"  Text: '{representative_examples['text'][i]}'")

print("\n" + "="*60)
print("=== DIVERSE EXAMPLES (Farthest from Centroids) ===")
print("These are the most 'unusual' or outlier examples:")
div_scores = diverse_examples['score']
div_indices = [scored_dataset['text'].index(text) for text in diverse_examples['text']]

for i, (idx, score) in enumerate(zip(div_indices, div_scores)):
    print(f"\nExample {i+1} (Original index: {idx}, Score: {score:.3f}):")
    print(f"  Category: {diverse_examples['category'][i]}")
    print(f"  Cluster: {cluster_labels[idx]}")
    print(f"  Text: '{diverse_examples['text'][i]}'")

print("\n" + "="*60)
print("=== SCORE DISTRIBUTION ANALYSIS ===")

# Plot score distributions for different strategies
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Original scores
axes[0,0].hist(scored_dataset['score'], bins=10, alpha=0.7, color='gray', edgecolor='black')
axes[0,0].set_title('Original Dataset Scores')
axes[0,0].set_xlabel('Distance Score')
axes[0,0].set_ylabel('Frequency')

# Representative examples scores
axes[0,1].hist(representative_examples['score'], bins=10, alpha=0.7, color='blue', edgecolor='black')
axes[0,1].set_title('Representative Examples Scores')
axes[0,1].set_xlabel('Distance Score')
axes[0,1].set_ylabel('Frequency')

# Diverse examples scores
axes[1,0].hist(diverse_examples['score'], bins=10, alpha=0.7, color='red', edgecolor='black')
axes[1,0].set_title('Diverse Examples Scores')
axes[1,0].set_xlabel('Distance Score')
axes[1,0].set_ylabel('Frequency')

# Stratified examples scores
axes[1,1].hist(stratified_examples['score'], bins=10, alpha=0.7, color='green', edgecolor='black')
axes[1,1].set_title('Stratified Examples Scores')
axes[1,1].set_xlabel('Distance Score')
axes[1,1].set_ylabel('Frequency')

plt.tight_layout()
plt.show()


In [None]:
# Test different numbers of clusters
cluster_numbers = [2, 3, 4, 6, 8]
fig, axes = plt.subplots(1, len(cluster_numbers), figsize=(20, 4))

for i, n_clusters in enumerate(cluster_numbers):
    # Create scorer with different number of clusters
    scorer = KMeansCentroidDistanceScorer(num_clusters=n_clusters, random_state=42)
    temp_scored = scorer.score(dataset_with_embeddings)
    temp_labels = scorer.kmeans.labels_
    
    # Plot the clustering
    scatter = axes[i].scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                             c=temp_labels, cmap='tab10', alpha=0.7, s=50)
    
    # Plot centroids
    temp_centroids_2d = pca.transform(scorer.kmeans.cluster_centers_)
    axes[i].scatter(temp_centroids_2d[:, 0], temp_centroids_2d[:, 1], 
                   c='red', marker='x', s=100, linewidths=2)
    
    axes[i].set_title(f'{n_clusters} Clusters')
    axes[i].set_xlabel('PC1')
    axes[i].set_ylabel('PC2')
    axes[i].grid(True, alpha=0.3)
    
    # Calculate and print inertia (within-cluster sum of squares)
    inertia = scorer.kmeans.inertia_
    axes[i].text(0.02, 0.98, f'Inertia: {inertia:.1f}', 
                transform=axes[i].transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

# Analyze how scores change with different cluster numbers
print("Score statistics for different numbers of clusters:")
print("Clusters | Mean Score | Std Score | Min Score | Max Score")
print("-" * 55)

for n_clusters in cluster_numbers:
    scorer = KMeansCentroidDistanceScorer(num_clusters=n_clusters, random_state=42)
    temp_scored = scorer.score(dataset_with_embeddings)
    scores = temp_scored['score']
    
    print(f"{n_clusters:8d} | {np.mean(scores):10.3f} | {np.std(scores):9.3f} | {min(scores):9.3f} | {max(scores):9.3f}")
