Investigating hidden state representations of the semantic meaning in LLMs

In [None]:
# új improved
import numpy as np
import torch
import json
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
import warnings
warnings.filterwarnings('ignore')

from sklearn.cluster import KMeans, DBSCAN, SpectralClustering, AgglomerativeClustering
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from sklearn.preprocessing import StandardScaler
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist, squareform
import umap.umap_ as umap

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

@dataclass
class ClusterResult:
    labels: np.ndarray
    centroids: Optional[np.ndarray]
    metrics: Dict[str, float]
    algorithm: str
    n_clusters: int

class DimensionalityReducer(ABC):
    @abstractmethod
    def fit_transform(self, X: np.ndarray, **kwargs) -> np.ndarray:
        pass

class PCAReducer(DimensionalityReducer):
    def __init__(self, n_components: int = 2):
        self.n_components = n_components
        self.pca = None

    def fit_transform(self, X: np.ndarray, **kwargs) -> np.ndarray:
        self.pca = PCA(n_components=self.n_components)
        return self.pca.fit_transform(X)

class TSNEReducer(DimensionalityReducer):
    def __init__(self, n_components: int = 2, perplexity: float = 30.0):
        self.n_components = n_components
        self.perplexity = perplexity

    def fit_transform(self, X: np.ndarray, **kwargs) -> np.ndarray:
        tsne = TSNE(n_components=self.n_components, perplexity=self.perplexity,
                   random_state=42, n_iter=1000)
        return tsne.fit_transform(X)

class UMAPReducer(DimensionalityReducer):
    def __init__(self, n_components: int = 2, n_neighbors: int = 15, min_dist: float = 0.1):
        self.n_components = n_components
        self.n_neighbors = n_neighbors
        self.min_dist = min_dist

    def fit_transform(self, X: np.ndarray, **kwargs) -> np.ndarray:
        reducer = umap.UMAP(n_components=self.n_components,
                           n_neighbors=self.n_neighbors,
                           min_dist=self.min_dist, random_state=42)
        return reducer.fit_transform(X)

class ClusteringAlgorithm(ABC):
    @abstractmethod
    def cluster(self, X: np.ndarray, **kwargs) -> ClusterResult:
        pass

class KMeansClusterer(ClusteringAlgorithm):
    def __init__(self, n_clusters: int = 8):
        self.n_clusters = n_clusters

    def cluster(self, X: np.ndarray, **kwargs) -> ClusterResult:
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
        labels = kmeans.fit_predict(X)

        metrics = self._compute_metrics(X, labels)

        return ClusterResult(
            labels=labels,
            centroids=kmeans.cluster_centers_,
            metrics=metrics,
            algorithm="KMeans",
            n_clusters=self.n_clusters
        )

    def _compute_metrics(self, X: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
        if len(np.unique(labels)) < 2:
            return {"silhouette": 0.0, "calinski_harabasz": 0.0, "davies_bouldin": float('inf')}

        return {
            "silhouette": silhouette_score(X, labels),
            "calinski_harabasz": calinski_harabasz_score(X, labels),
            "davies_bouldin": davies_bouldin_score(X, labels)
        }

class DBSCANClusterer(ClusteringAlgorithm):
    def __init__(self, eps: float = 0.5, min_samples: int = 5):
        self.eps = eps
        self.min_samples = min_samples

    def cluster(self, X: np.ndarray, **kwargs) -> ClusterResult:
        dbscan = DBSCAN(eps=self.eps, min_samples=self.min_samples)
        labels = dbscan.fit_predict(X)

        n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
        metrics = self._compute_metrics(X, labels)

        return ClusterResult(
            labels=labels,
            centroids=None,
            metrics=metrics,
            algorithm="DBSCAN",
            n_clusters=n_clusters
        )

    def _compute_metrics(self, X: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
        unique_labels = set(labels)
        if len(unique_labels) < 2 or -1 in unique_labels and len(unique_labels) == 2:
            return {"silhouette": 0.0, "calinski_harabasz": 0.0, "davies_bouldin": float('inf')}

        mask = labels != -1
        if np.sum(mask) < 2:
            return {"silhouette": 0.0, "calinski_harabasz": 0.0, "davies_bouldin": float('inf')}

        return {
            "silhouette": silhouette_score(X[mask], labels[mask]),
            "calinski_harabasz": calinski_harabasz_score(X[mask], labels[mask]),
            "davies_bouldin": davies_bouldin_score(X[mask], labels[mask])
        }

class SpectralClusterer(ClusteringAlgorithm):
    def __init__(self, n_clusters: int = 8):
        self.n_clusters = n_clusters

    def cluster(self, X: np.ndarray, **kwargs) -> ClusterResult:
        spectral = SpectralClustering(n_clusters=self.n_clusters, random_state=42)
        labels = spectral.fit_predict(X)

        metrics = self._compute_metrics(X, labels)

        return ClusterResult(
            labels=labels,
            centroids=None,
            metrics=metrics,
            algorithm="Spectral",
            n_clusters=self.n_clusters
        )

    def _compute_metrics(self, X: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
        if len(np.unique(labels)) < 2:
            return {"silhouette": 0.0, "calinski_harabasz": 0.0, "davies_bouldin": float('inf')}

        return {
            "silhouette": silhouette_score(X, labels),
            "calinski_harabasz": calinski_harabasz_score(X, labels),
            "davies_bouldin": davies_bouldin_score(X, labels)
        }

class HierarchicalClusterer(ClusteringAlgorithm):
    def __init__(self, n_clusters: int = 8, linkage_method: str = 'ward'):
        self.n_clusters = n_clusters
        self.linkage_method = linkage_method

    def cluster(self, X: np.ndarray, **kwargs) -> ClusterResult:
        hierarchical = AgglomerativeClustering(n_clusters=self.n_clusters,
                                             linkage=self.linkage_method)
        labels = hierarchical.fit_predict(X)

        metrics = self._compute_metrics(X, labels)

        return ClusterResult(
            labels=labels,
            centroids=None,
            metrics=metrics,
            algorithm="Hierarchical",
            n_clusters=self.n_clusters
        )

    def _compute_metrics(self, X: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
        if len(np.unique(labels)) < 2:
            return {"silhouette": 0.0, "calinski_harabasz": 0.0, "davies_bouldin": float('inf')}

        return {
            "silhouette": silhouette_score(X, labels),
            "calinski_harabasz": calinski_harabasz_score(X, labels),
            "davies_bouldin": davies_bouldin_score(X, labels)
        }

class StructuralFilter:
    def __init__(self, variance_threshold: float = 0.1, isolation_threshold: float = 0.8):
        self.variance_threshold = variance_threshold
        self.isolation_threshold = isolation_threshold

    def identify_structural_clusters(self, embeddings: np.ndarray, tokens: List[str],
                                   cluster_labels: np.ndarray) -> List[int]:
        structural_clusters = []
        unique_labels = np.unique(cluster_labels)

        for label in unique_labels:
            if label == -1:  # noise cluster
                continue

            cluster_mask = cluster_labels == label
            cluster_embeddings = embeddings[cluster_mask]
            cluster_tokens = [tokens[i] for i in np.where(cluster_mask)[0]]

            if self._is_structural_cluster(cluster_embeddings, cluster_tokens):
                structural_clusters.append(label)

        return structural_clusters


    def _is_structural_cluster(self, embeddings: np.ndarray, tokens: List[str]) -> bool:
        if len(embeddings) < 3:
            return True

        # Check for pure structural token clusters first
        if self._is_pure_structural_cluster(tokens):
            return True

        # For mixed content, apply existing logic
        embedding_variance = np.mean(np.var(embeddings, axis=0))
        structural_ratio = self._detect_structural_patterns(tokens)

        if structural_ratio > 0.8:
            return True

        if embedding_variance < self.variance_threshold and structural_ratio > 0.6:
            return True

        if self._has_semantic_diversity(tokens):
            return False

        unique_tokens = set(tokens)
        if len(unique_tokens) == 1:
            return True

        repetition_ratio = 1 - (len(unique_tokens) / len(tokens))
        return repetition_ratio > 0.9 and structural_ratio > 0.5

    def _is_pure_structural_cluster(self, tokens: List[str]) -> bool:
        """Check if cluster contains exclusively structural tokens"""
        structural_chars = set('*.,;:-\n "\'()[]{}')

        for token in tokens:
            # Check if token contains any non-structural characters
            if any(c not in structural_chars for c in token):
                return False

        return True

    def _has_semantic_diversity(self, tokens: List[str]) -> bool:
        """Check if tokens represent semantic concepts rather than pure structure"""
        content_tokens = [t.strip() for t in tokens if t.strip()]
        if not content_tokens:
            return False

        semantic_indicators = 0
        for token in content_tokens:
            clean_token = token.lower().strip()

            # Skip pure formatting/punctuation tokens
            if all(c in '*.,-:;!?()[]{}"\'\n ** * *\n **\n **\n\n' for c in token):
                continue

            if len(clean_token) < 2:
                continue

            if clean_token.isalpha() and len(clean_token) > 2:
                semantic_indicators += 1


        return semantic_indicators > len(content_tokens) * 0.4

    def _detect_structural_patterns(self, tokens: List[str]) -> float:
        structural_count = 0

        for token in tokens:
            if self._is_structural_token(token):
                structural_count += 1

        return structural_count / len(tokens) if tokens else 0.0

    def _is_structural_token(self, token: str) -> bool: # GGG
        # Punctuation and formatting
        if len(token) == 1 and not token.isalnum():
            return True

        structural_words = {'.', ',', '-', '--', '\n', ':', '*', '**', '\"', '\n\n', ' ', ':', '.\n', '**\n\n', '*\n\n', '**\n' '<\uff5cbegin\u2581of\u2581sentence\uff5c>'}
        if token.lower() in structural_words:
            return True

        # Numeric patterns
        if token.isdigit() or token.replace('.', '').replace(',', '').isdigit():
            return True

        return False

class SemanticAnalyzer:
    def __init__(self):
        self.embeddings = None
        self.tokens = None
        self.cosine_sim_matrix = None
        self.scaler = StandardScaler()
        self.structural_filter = StructuralFilter()

    def load_data(self, tokens_path: str, embeddings_path: str, cosine_sim_path: str):
        with open(tokens_path, 'r') as f:
            self.tokens = json.load(f)

        embeddings_tensor = torch.load(embeddings_path, map_location='cpu')

        # Handle different tensor shapes
        if embeddings_tensor.dim() == 3:
            # Shape: (batch, seq_len, hidden_dim) -> flatten to (batch*seq_len, hidden_dim)
            self.embeddings = embeddings_tensor.view(-1, embeddings_tensor.size(-1)).numpy()
        elif embeddings_tensor.dim() == 2:
            # Shape: (hidden_dim, num_tokens) -> transpose to (num_tokens, hidden_dim)
            self.embeddings = embeddings_tensor.T.numpy()
        else:
            raise ValueError(f"Unexpected embedding tensor shape: {embeddings_tensor.shape}")

        # Ensure token count matches embeddings
        if len(self.tokens) != self.embeddings.shape[0]:
            min_len = min(len(self.tokens), self.embeddings.shape[0])
            self.tokens = self.tokens[:min_len]
            self.embeddings = self.embeddings[:min_len]

        self.cosine_sim_matrix = torch.load(cosine_sim_path, map_location='cpu').numpy()

        # Normalize embeddings
        self.embeddings = self.scaler.fit_transform(self.embeddings)

    def find_optimal_clusters(self, clustering_algorithms: List[ClusteringAlgorithm],
                            k_range: range = range(2, 15)) -> Dict[str, ClusterResult]:
        results = {}

        for algorithm in clustering_algorithms:
            if isinstance(algorithm, (KMeansClusterer, SpectralClusterer, HierarchicalClusterer)):
                best_result = None
                best_score = -1

                for k in k_range:
                    algorithm.n_clusters = k
                    result = algorithm.cluster(self.embeddings)

                    if result.metrics['silhouette'] > best_score:
                        best_score = result.metrics['silhouette']
                        best_result = result

                results[algorithm.__class__.__name__] = best_result
            else:
                results[algorithm.__class__.__name__] = algorithm.cluster(self.embeddings)

        return results

    def filter_semantic_clusters(self, cluster_result: ClusterResult) -> Tuple[np.ndarray, List[int]]:
        structural_clusters = self.structural_filter.identify_structural_clusters(
            self.embeddings, self.tokens, cluster_result.labels
        )

        # Create filtered labels, setting structural clusters to -1
        filtered_labels = cluster_result.labels.copy()
        for struct_cluster in structural_clusters:
            filtered_labels[filtered_labels == struct_cluster] = -1

        return filtered_labels, structural_clusters

    def refine_semantic_clusters(self, filtered_labels: np.ndarray,
                               algorithm: ClusteringAlgorithm) -> ClusterResult:
        # Extract semantic tokens and embeddings
        semantic_mask = filtered_labels != -1
        if np.sum(semantic_mask) < 10:  # Not enough semantic tokens
            return ClusterResult(filtered_labels, None, {}, "Refined", 0)

        semantic_embeddings = self.embeddings[semantic_mask]

        # Re-cluster semantic embeddings
        refined_result = algorithm.cluster(semantic_embeddings)

        # Map back to original indices
        final_labels = np.full(len(self.tokens), -1)
        semantic_indices = np.where(semantic_mask)[0]
        final_labels[semantic_indices] = refined_result.labels

        return ClusterResult(
            labels=final_labels,
            centroids=refined_result.centroids,
            metrics=refined_result.metrics,
            algorithm=f"Refined_{refined_result.algorithm}",
            n_clusters=refined_result.n_clusters
        )

    def analyze_clusters(self, cluster_result: ClusterResult) -> Dict[str, Any]:
        analysis = {
            'cluster_sizes': {},
            'cluster_tokens': {},
            'cluster_centroids': cluster_result.centroids,
            'metrics': cluster_result.metrics,
            'total_tokens': len(self.tokens),
            'n_clusters': cluster_result.n_clusters
        }

        unique_labels = np.unique(cluster_result.labels)

        for label in unique_labels:
            mask = cluster_result.labels == label
            cluster_tokens = [self.tokens[i] for i in np.where(mask)[0]]

            analysis['cluster_sizes'][int(label)] = len(cluster_tokens)
            analysis['cluster_tokens'][int(label)] = cluster_tokens[:20]  # First 20 tokens

        return analysis

class Visualizer:
    def __init__(self, figsize: Tuple[int, int] = (12, 8)):
        self.figsize = figsize
        plt.style.use('seaborn-v0_8')

    def plot_2d_clusters(self, embeddings: np.ndarray, labels: np.ndarray,
                        tokens: List[str], reducer: DimensionalityReducer,
                        title: str = "Cluster Visualization") -> go.Figure:

        # Reduce dimensionality
        coords_2d = reducer.fit_transform(embeddings)

        # Create color palette
        unique_labels = np.unique(labels)
        colors = px.colors.qualitative.Set3[:len(unique_labels)]

        fig = go.Figure()

        for i, label in enumerate(unique_labels):
            mask = labels == label
            cluster_coords = coords_2d[mask]
            cluster_tokens = [tokens[j] for j in np.where(mask)[0]]

            color = colors[i % len(colors)] if label != -1 else 'gray'
            name = f'Cluster {label}' if label != -1 else 'Noise'

            fig.add_trace(go.Scatter(
                x=cluster_coords[:, 0],
                y=cluster_coords[:, 1],
                mode='markers',
                marker=dict(color=color, size=6, opacity=0.7),
                text=cluster_tokens,
                name=name,
                hovertemplate='<b>%{text}</b><br>Cluster: ' + name + '<extra></extra>'
            ))

        fig.update_layout(
            title=title,
            xaxis_title=f'{reducer.__class__.__name__} 1',
            yaxis_title=f'{reducer.__class__.__name__} 2',
            showlegend=True,
            width=800,
            height=600
        )

        return fig

    def plot_cluster_metrics(self, results: Dict[str, ClusterResult]) -> go.Figure:
        algorithms = list(results.keys())
        silhouette_scores = [results[alg].metrics.get('silhouette', 0) for alg in algorithms]
        calinski_scores = [results[alg].metrics.get('calinski_harabasz', 0) for alg in algorithms]
        davies_scores = [results[alg].metrics.get('davies_bouldin', 0) for alg in algorithms]

        fig = make_subplots(
            rows=1, cols=3,
            subplot_titles=('Silhouette Score', 'Calinski-Harabasz', 'Davies-Bouldin'),
            specs=[[{"secondary_y": False}, {"secondary_y": False}, {"secondary_y": False}]]
        )

        fig.add_trace(go.Bar(x=algorithms, y=silhouette_scores, name='Silhouette'), row=1, col=1)
        fig.add_trace(go.Bar(x=algorithms, y=calinski_scores, name='Calinski-Harabasz'), row=1, col=2)
        fig.add_trace(go.Bar(x=algorithms, y=davies_scores, name='Davies-Bouldin'), row=1, col=3)

        fig.update_layout(title="Clustering Algorithm Comparison", showlegend=False)
        return fig

    def plot_cluster_distribution(self, cluster_result: ClusterResult) -> go.Figure:
        unique_labels, counts = np.unique(cluster_result.labels, return_counts=True)

        # Sort by count
        sorted_indices = np.argsort(counts)[::-1]
        sorted_labels = unique_labels[sorted_indices]
        sorted_counts = counts[sorted_indices]

        # Create labels for display
        display_labels = [f'Cluster {label}' if label != -1 else 'Noise'
                         for label in sorted_labels]

        fig = go.Figure(data=[go.Bar(x=display_labels, y=sorted_counts)])
        fig.update_layout(
            title="Cluster Size Distribution",
            xaxis_title="Clusters",
            yaxis_title="Number of Tokens"
        )

        return fig

class SemanticPipeline:
    def __init__(self):
        self.analyzer = SemanticAnalyzer()
        self.visualizer = Visualizer()

        # Available algorithms
        self.clustering_algorithms = {
            'kmeans': KMeansClusterer,
            'dbscan': DBSCANClusterer,
            'spectral': SpectralClusterer,
            'hierarchical': HierarchicalClusterer
        }

        self.dimensionality_reducers = {
            'pca': PCAReducer,
            'tsne': TSNEReducer,
            'umap': UMAPReducer
        }

    def run_analysis(self, tokens_path: str, embeddings_path: str, cosine_sim_path: str,
                    clustering_methods: List[str] = ['kmeans', 'spectral'],
                    visualization_method: str = 'umap') -> Dict[str, Any]:

        # Load data
        # print("Loading data...")
        self.analyzer.load_data(tokens_path, embeddings_path, cosine_sim_path)

        # Initialize clustering algorithms
        algorithms = [self.clustering_algorithms[method]() for method in clustering_methods]

        # Find optimal clusters
        print("Finding optimal clusters...")
        clustering_results = self.analyzer.find_optimal_clusters(algorithms)

        # Select best algorithm based on silhouette score
        best_algorithm = max(clustering_results.keys(),
                           key=lambda x: clustering_results[x].metrics.get('silhouette', 0))
        best_result = clustering_results[best_algorithm]

        print(f"Best clustering algorithm: {best_algorithm}")

        # Filter structural clusters
        print("Filtering structural clusters...")
        filtered_labels, structural_clusters = self.analyzer.filter_semantic_clusters(best_result)

        # Refine semantic clusters
        print("Refining semantic clusters...")
        refined_algorithm = self.clustering_algorithms[clustering_methods[0]]()
        refined_result = self.analyzer.refine_semantic_clusters(filtered_labels, refined_algorithm)

        # Analyze results
        # print("Analyzing clusters...")
        initial_analysis = self.analyzer.analyze_clusters(best_result)
        refined_analysis = self.analyzer.analyze_clusters(refined_result)

        # Create visualizations
        # print("Creating visualizations...")
        reducer = self.dimensionality_reducers[visualization_method]()

        initial_viz = self.visualizer.plot_2d_clusters(
            self.analyzer.embeddings, best_result.labels, self.analyzer.tokens,
            reducer, f"Initial Clustering ({best_algorithm})"
        )

        refined_viz = self.visualizer.plot_2d_clusters(
            self.analyzer.embeddings, refined_result.labels, self.analyzer.tokens,
            reducer, "Refined Semantic Clustering"
        )

        metrics_viz = self.visualizer.plot_cluster_metrics(clustering_results)
        distribution_viz = self.visualizer.plot_cluster_distribution(refined_result)

        return {
            'clustering_results': clustering_results,
            'best_algorithm': best_algorithm,
            'structural_clusters': structural_clusters,
            'initial_analysis': initial_analysis,
            'refined_analysis': refined_analysis,
            'visualizations': {
                'initial_clustering': initial_viz,
                'refined_clustering': refined_viz,
                'metrics_comparison': metrics_viz,
                'cluster_distribution': distribution_viz
            },
            'filtered_labels': filtered_labels,
            'refined_labels': refined_result.labels
        }

this code clusters and visualizes the hidden_states. It tries to find the "structurak clusters" which correspond to tokens like

{'**', '**\n', '**\n\n', ' **', ':**'} (there is some hard-coding used here but for the most part we used ml algorithms)

which dont carry semantic meaning and are uninteresting to us. then it excludes them from the analysis.

finally it is able to compare between different clustering methods to determine which is best based on some standard "clustering quality measures" like silhouette, calinski_harabasz, davies_bouldin

we also look at things like the cluster size distribution: if there is a cluster with 90% of the tokens and then a lot of clusters with a single token, then it is a bad clustering technique for our use case.

like this we investigate clustering and find some nice methods which are semantically meaningful.

for visualization we used pca, tsne and umap. we used libraries that implemented them and it was straightforward to try out

In [None]:
# show clusters with different visualization methods
def visualize(clustering_methods, visualization_method, tokens_path, embeddings_path, cosine_sim_path):
    pipeline = SemanticPipeline()

    results = pipeline.run_analysis(
        tokens_path=tokens_path,
        embeddings_path=embeddings_path,
        cosine_sim_path=cosine_sim_path,
        clustering_methods=clustering_methods,
        visualization_method=visualization_method
    )

    # Display results
    print(f"Structural clusters identified: {results['structural_clusters']}")
    print(f"Initial clusters: {results['initial_analysis']['n_clusters']}")
    print(f"Refined semantic clusters: {results['refined_analysis']['n_clusters']}")

    print(f"Visualization Method: {visualization_method}")

    # Show visualizations
    results['visualizations']['initial_clustering'].show()
    results['visualizations']['refined_clustering'].show()
    print("\n\n")


In [None]:
# pipeline for evaluating and displaying results
def usage(clustering_methods, visualization_method, tokens_path, embeddings_path, cosine_sim_path):
    pipeline = SemanticPipeline()

    results = pipeline.run_analysis(
        tokens_path=tokens_path,
        embeddings_path=embeddings_path,
        cosine_sim_path=cosine_sim_path,
        clustering_methods=clustering_methods,
        visualization_method=visualization_method
    )

    # Display results
    print(f"Structural clusters identified: {results['structural_clusters']}")
    print(f"Initial clusters: {results['initial_analysis']['n_clusters']}")
    print(f"Refined semantic clusters: {results['refined_analysis']['n_clusters']}")

    # Show visualizations
    results['visualizations']['initial_clustering'].show()
    results['visualizations']['refined_clustering'].show()
    results['visualizations']['cluster_distribution'].show()

    # Print cluster analysis
    print("\nRefined Semantic Clusters:")
    for cluster_id, tokens in results['refined_analysis']['cluster_tokens'].items():
        if cluster_id != -1:  # Skip noise cluster
            print(f"Cluster {cluster_id} ({results['refined_analysis']['cluster_sizes'][cluster_id]} tokens):")
            print(f"  Sample tokens: {set(tokens[:15])}")
            print()
            # Print eliminated structural clusters
    if results['structural_clusters']:
        print("\nEliminated Structural Clusters:")
        for cluster_id in results['structural_clusters']:
            if cluster_id in results['initial_analysis']['cluster_tokens']:
                tokens = results['initial_analysis']['cluster_tokens'][cluster_id]
                size = results['initial_analysis']['cluster_sizes'][cluster_id]
                print(f"Structural Cluster {cluster_id} ({size} tokens):")
                print(f"  Sample tokens: {set(tokens[:15])}")
                print()
    else:
        print("\nNo structural clusters were eliminated.")

What visualization method is best for the cluster? we use KMEANS and look at the results:
(we already know that kmeans is good for clustering in this)

we concluded that tsne is the best dimension reducing method for visualizing. tsne is a nonlinear method for dimensionality reduction in such a way that similar objects are modeled by nearby points and dissimilar objects are modeled by distant points with high probability

it was reasonably fast.

What is the best clustering method?

We look at the clusters we obtain visually and also look at representatives of the group in order to decide.

FIRST LAYER

In [None]:
for cluster_meth in ['kmeans', 'dbscan', 'spectral', 'hierarchical']:
  print(f"Cluster Method: {cluster_meth}")
  usage([cluster_meth], 'tsne', "tokens.json", "hidden_states_first.pt", "cosine_sim_first_first.pt")

Cluster Method: kmeans
Finding optimal clusters...
Best clustering algorithm: KMeansClusterer
Filtering structural clusters...
Refining semantic clusters...
Structural clusters identified: [np.int32(0), np.int32(1), np.int32(2), np.int32(11)]
Initial clusters: 13
Refined semantic clusters: 8



Refined Semantic Clusters:
Cluster 0 (27 tokens):
  Sample tokens: {' Positive', ' negative', 'Positive', ' positive'}

Cluster 1 (27 tokens):
  Sample tokens: {' self'}

Cluster 2 (195 tokens):
  Sample tokens: {' like', ' in', ' of', ' on', ' without', ' and'}

Cluster 3 (94 tokens):
  Sample tokens: {'-esteem', ' personality', ' skills', ' traits', ' behaviors', ' styles'}

Cluster 4 (39 tokens):
  Sample tokens: {' how', ' who'}

Cluster 5 (869 tokens):
  Sample tokens: {' including', '<｜begin▁of▁sentence｜>', ' education', 'amine', ' experiences', ' shape', 'Ex', ' influences', ' various', ' friendships', ' development', ' Discuss', ' childhood', ' family', ' environment'}

Cluster 6 (118 tokens):
  Sample tokens: {' a', ' each', ' these', ' the', ' it'}

Cluster 7 (80 tokens):
  Sample tokens: {' can', ' has', "'ll", ' need', ' should', ' have'}


Eliminated Structural Clusters:
Structural Cluster 0 (34 tokens):
  Sample tokens: {' to'}

Structural Cluster 1 (50 tokens):
  Sample


Refined Semantic Clusters:
Cluster 0 (21 tokens):
  Sample tokens: {' how'}

Cluster 1 (5 tokens):
  Sample tokens: {' childhood'}

Cluster 2 (19 tokens):
  Sample tokens: {' experiences'}

Cluster 3 (9 tokens):
  Sample tokens: {' shape'}

Cluster 4 (26 tokens):
  Sample tokens: {' personality'}

Cluster 5 (9 tokens):
  Sample tokens: {' development'}

Cluster 6 (9 tokens):
  Sample tokens: {' family'}

Cluster 7 (15 tokens):
  Sample tokens: {' environment'}

Cluster 8 (8 tokens):
  Sample tokens: {' education'}

Cluster 9 (7 tokens):
  Sample tokens: {' friendships'}

Cluster 10 (55 tokens):
  Sample tokens: {' and'}

Cluster 11 (6 tokens):
  Sample tokens: {' life'}

Cluster 12 (6 tokens):
  Sample tokens: {' events'}

Cluster 13 (5 tokens):
  Sample tokens: {' concepts'}

Cluster 14 (8 tokens):
  Sample tokens: {' like'}

Cluster 15 (19 tokens):
  Sample tokens: {' attachment'}

Cluster 16 (6 tokens):
  Sample tokens: {' theory'}

Cluster 17 (6 tokens):
  Sample tokens: {' nature


Refined Semantic Clusters:
Cluster 0 (1 tokens):
  Sample tokens: {'.'}

Cluster 1 (1 tokens):
  Sample tokens: {' ensuring'}

Cluster 2 (1 tokens):
  Sample tokens: {'Moving'}

Cluster 3 (1 tokens):
  Sample tokens: {'Ex'}

Cluster 4 (1 tokens):
  Sample tokens: {' Relationship'}

Cluster 5 (2 tokens):
  Sample tokens: {' Long', 'Long'}

Cluster 6 (1789 tokens):
  Sample tokens: {' including', '<｜begin▁of▁sentence｜>', 'amine', ' experiences', ' shape', '.', ' influences', ' various', ' how', ' personality', ' development', ' Discuss', ' childhood', ' family', ' environment'}

Cluster 7 (1 tokens):
  Sample tokens: {' Shape'}


Eliminated Structural Clusters:
Structural Cluster 1 (1 tokens):
  Sample tokens: {' Impact'}

Cluster Method: hierarchical
Finding optimal clusters...
Best clustering algorithm: HierarchicalClusterer
Filtering structural clusters...
Refining semantic clusters...
Structural clusters identified: [np.int64(3), np.int64(10), np.int64(11), np.int64(12)]
Initial clu


Refined Semantic Clusters:
Cluster 0 (396 tokens):
  Sample tokens: {' like', ':', ' can', ' to', ' in', ' how', ' it', ' of', ' on'}

Cluster 1 (834 tokens):
  Sample tokens: {' including', '<｜begin▁of▁sentence｜>', ' education', 'amine', ' experiences', ' shape', 'Ex', ' influences', ' various', ' friendships', ' development', ' Discuss', ' childhood', ' family', ' environment'}

Cluster 2 (73 tokens):
  Sample tokens: {' a', ' the'}

Cluster 3 (47 tokens):
  Sample tokens: {'-esteem', ' attachment'}

Cluster 4 (55 tokens):
  Sample tokens: {' and'}

Cluster 5 (33 tokens):
  Sample tokens: {' personality', ' Personality'}

Cluster 6 (32 tokens):
  Sample tokens: {' terms', '-term', ' term'}

Cluster 7 (27 tokens):
  Sample tokens: {' self'}


Eliminated Structural Clusters:
Structural Cluster 3 (96 tokens):
  Sample tokens: {'.\n', '.\n\n', '.'}

Structural Cluster 10 (50 tokens):
  Sample tokens: {'**', '**\n', '**\n\n', ' **', ':**'}

Structural Cluster 11 (133 tokens):
  Sample to

kmeans: nice

dbscan: didnt make any semantically meaningful cluster

spectral: made a single huge cluster and then tiny (single token) clusters

hierarchical: nice

The fact that spectral and dbscan were not so good may be our fault for not using the right hyperparameters

note:
Structural Cluster 3 (96 tokens):
  Sample tokens: {'.\n', '.\n\n', '.'}

It says 96 tokens because there are many repeated ones (which in general correspond to slightly different hidden_state vectors which get decoded to the same token)

Which is better? kmeans or hierarchical. we use an authomatic comparator to decide

In [None]:
usage(['kmeans', 'hierarchical'], 'tsne', "tokens.json", "hidden_states_first.pt", "cosine_sim_first_first.pt")

Finding optimal clusters...
Best clustering algorithm: HierarchicalClusterer
Filtering structural clusters...
Refining semantic clusters...
Structural clusters identified: [np.int64(3), np.int64(10), np.int64(11), np.int64(12)]
Initial clusters: 14
Refined semantic clusters: 8



Refined Semantic Clusters:
Cluster 0 (9 tokens):
  Sample tokens: {' supportive', ' support'}

Cluster 1 (139 tokens):
  Sample tokens: {' a', ' each', ' the', ' how', ' it'}

Cluster 2 (161 tokens):
  Sample tokens: {' to', ' in', ' of', ' on', ' and'}

Cluster 3 (328 tokens):
  Sample tokens: {' nurture', '<｜begin▁of▁sentence｜>', ' education', 'amine', ' psychological', ' attachment', ' influences', ' friendships', ' personality', ' adult', ' traits', ' Discuss', ' resilience', ' childhood'}

Cluster 4 (689 tokens):
  Sample tokens: {' significant', ' including', ' like', ' life', ' shape', ' experiences', ' Explain', 'Ex', ' various', ' theory', ' development', ' events', ' concepts', ' family', ' environment'}

Cluster 5 (32 tokens):
  Sample tokens: {' terms', '-term', ' term'}

Cluster 6 (27 tokens):
  Sample tokens: {' Positive', ' negative', 'Positive', ' positive'}

Cluster 7 (112 tokens):
  Sample tokens: {"'s", ':', ' can', ' has', ' "', ' should', '"', ' have', ' vs'}


El

we read manually the clusters and both kmeans and hierarchical seem oif similar quality

we visualize how hierarchical looks like now:

In [None]:
import numpy as np
import torch
import json
from typing import List, Optional
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

class DendrogramVisualizer:
    def __init__(self, max_display_tokens: int = 50):
        self.max_display_tokens = max_display_tokens

    def load_data(self, tokens_path: str, embeddings_path: str) -> tuple[List[str], np.ndarray]:
        with open(tokens_path, 'r') as f:
            tokens = json.load(f)

        embeddings_tensor = torch.load(embeddings_path, map_location='cpu')

        if embeddings_tensor.dim() == 3:
            embeddings = embeddings_tensor.view(-1, embeddings_tensor.size(-1)).numpy()
        elif embeddings_tensor.dim() == 2:
            embeddings = embeddings_tensor.T.numpy()
        else:
            raise ValueError(f"Unexpected tensor shape: {embeddings_tensor.shape}")

        min_len = min(len(tokens), embeddings.shape[0])
        return tokens[:min_len], embeddings[:min_len]

    def create_dendrogram(self, embeddings: np.ndarray, tokens: List[str],
                         linkage_method: str = 'ward',
                         metric: str = 'euclidean',
                         truncate_mode: Optional[str] = 'lastp') -> go.Figure:

        n_samples = min(len(tokens), self.max_display_tokens)
        sample_indices = np.random.choice(len(tokens), n_samples, replace=False)
        sample_embeddings = embeddings[sample_indices]
        sample_tokens = [tokens[i] for i in sample_indices]

        distances = pdist(sample_embeddings, metric=metric)
        linkage_matrix = linkage(distances, method=linkage_method)

        # Create dendrogram structure
        dend_data = dendrogram(
            linkage_matrix,
            labels=sample_tokens,
            no_plot=True,
            truncate_mode=truncate_mode,
            p=min(30, n_samples) if truncate_mode else None
        )

        return self._build_plotly_dendrogram(dend_data, linkage_method, metric)

    def _build_plotly_dendrogram(self, dend_data: dict,
                                linkage_method: str, metric: str) -> go.Figure:
        fig = go.Figure()

        # Add dendrogram branches
        for xs, ys in zip(dend_data['icoord'], dend_data['dcoord']):
            fig.add_trace(go.Scatter(
                x=xs, y=ys,
                mode='lines',
                line=dict(color='#2E86AB', width=1.5),
                hoverinfo='skip',
                showlegend=False
            ))

        # Add leaf labels
        leaf_positions = [(dend_data['icoord'][i][1] + dend_data['icoord'][i][2]) / 2
                         for i in range(len(dend_data['icoord']))][:len(dend_data['ivl'])]

        fig.add_trace(go.Scatter(
            x=leaf_positions,
            y=[0] * len(dend_data['ivl']),
            mode='text',
            text=dend_data['ivl'],
            textposition='bottom center',
            textfont=dict(size=9),
            hovertemplate='<b>%{text}</b><extra></extra>',
            showlegend=False
        ))

        fig.update_layout(
            title=f'Hierarchical Clustering Dendrogram<br><sub>{linkage_method.title()} linkage, {metric} distance</sub>',
            xaxis=dict(
                showticklabels=False,
                showgrid=False,
                zeroline=False
            ),
            yaxis=dict(
                title='Distance',
                showgrid=True,
                gridcolor='lightgray',
                gridwidth=0.5
            ),
            plot_bgcolor='white',
            height=700,
            margin=dict(b=120, t=80),
            font=dict(size=11)
        )

        return fig

    def create_comparison_plot(self, embeddings: np.ndarray, tokens: List[str],
                              methods: List[str] = ['ward', 'complete', 'average']) -> go.Figure:
        n_methods = len(methods)
        fig = make_subplots(
            rows=n_methods, cols=1,
            subplot_titles=[f'{method.title()} Linkage' for method in methods],
            vertical_spacing=0.08
        )

        n_samples = min(len(tokens), 30)
        sample_indices = np.random.choice(len(tokens), n_samples, replace=False)
        sample_embeddings = embeddings[sample_indices]
        sample_tokens = [tokens[i] for i in sample_indices]

        for i, method in enumerate(methods):
            metric = 'euclidean' if method == 'ward' else 'cosine'
            distances = pdist(sample_embeddings, metric=metric)
            linkage_matrix = linkage(distances, method=method)

            dend_data = dendrogram(linkage_matrix, labels=sample_tokens, no_plot=True)

            # Add branches for this subplot
            for xs, ys in zip(dend_data['icoord'], dend_data['dcoord']):
                fig.add_trace(go.Scatter(
                    x=xs, y=ys,
                    mode='lines',
                    line=dict(color=px.colors.qualitative.Set2[i], width=1.2),
                    hoverinfo='skip',
                    showlegend=False
                ), row=i+1, col=1)

        fig.update_layout(
            title='Hierarchical Clustering Method Comparison',
            height=300 * n_methods,
            showlegend=False
        )

        for i in range(n_methods):
            fig.update_xaxes(showticklabels=False, row=i+1, col=1)
            fig.update_yaxes(title_text='Distance', row=i+1, col=1)

        return fig

def visualize_hierarchical_clustering(tokens_path: str, embeddings_path: str,
                                    linkage_method: str = 'ward',
                                    show_comparison: bool = False):
    viz = DendrogramVisualizer(max_display_tokens=60)
    tokens, embeddings = viz.load_data(tokens_path, embeddings_path)

    print(f"Loaded {len(tokens)} tokens with {embeddings.shape[1]}D embeddings")

    fig = viz.create_dendrogram(embeddings, tokens, linkage_method=linkage_method)
    fig.show()

    if show_comparison:
        comparison_fig = viz.create_comparison_plot(embeddings, tokens)
        comparison_fig.show()

if __name__ == "__main__":
    visualize_hierarchical_clustering(
        "tokens.json",
        "hidden_states_first.pt",
        linkage_method='ward',
        show_comparison=True
    )

Loaded 1798 tokens with 1536D embeddings


**LAST** LAYER hidden_states: more challenging than expected :/

here we got worse clusters. and we got no clustered labeled as "structural"

we believe that the reason is that the hidden_state vectors at the last layer have incorporated a lot of context and new meaning and they are not as straughtforward as they are at the first layer.

In [None]:
for cluster_meth in ['kmeans', 'hierarchical']:
  print(f"Cluster Method: {cluster_meth}")
  usage([cluster_meth], 'tsne', "tokens.json", "hidden_states_last.pt", "cosine_sim_last_last.pt")

Cluster Method: kmeans
Finding optimal clusters...
Best clustering algorithm: KMeansClusterer
Filtering structural clusters...
Refining semantic clusters...
Structural clusters identified: []
Initial clusters: 13
Refined semantic clusters: 8



Refined Semantic Clusters:
Cluster 0 (317 tokens):
  Sample tokens: {' terms', ' Provide', '<｜begin▁of▁sentence｜>', ' Explain', ' in', '.', ',', ' Discuss', ' accessible', ' examples', ' concepts', ' Also'}

Cluster 1 (191 tokens):
  Sample tokens: {' like', '-term', ' to', ' shape', ' affect', ',', ' adult', ' on', ' and'}

Cluster 2 (190 tokens):
  Sample tokens: {' "', '.\n\n', '.'}

Cluster 3 (27 tokens):
  Sample tokens: {' self'}

Cluster 4 (290 tokens):
  Sample tokens: {' significant', ' including', ' life', ' shape', '.', ' various', ' how', ',', ' of', ' childhood', ' family', ' and', ' vs'}

Cluster 5 (195 tokens):
  Sample tokens: {' like', ' parent', "'s", ' a', ' versus', ' one', ' who', ' very', ' harsh', ' vs'}

Cluster 6 (454 tokens):
  Sample tokens: {' nurture', ' education', 'amine', ' psychological', 'Ex', ' influences', ' attachment', ' friendships', ' how', ' theory', ' personality', ' nature', ' development', ' events', ' environment'}

Cluster 7 (134 tokens):



Refined Semantic Clusters:
Cluster 0 (350 tokens):
  Sample tokens: {' like', ' parent', ' mention', ' about', ' a', ' the', ' who', ' how', ' of', ' despite', ' and'}

Cluster 1 (496 tokens):
  Sample tokens: {' nurture', ' education', ' psychological', 'Ex', ' influences', ' attachment', ' friendships', ' theory', ' personality', ' development', ' events', ' concepts', ' family', ' environment'}

Cluster 2 (310 tokens):
  Sample tokens: {' terms', '<｜begin▁of▁sentence｜>', 'amine', ' shape', ' Explain', ' in', '.', ' various', ' how', ' Discuss', ' accessible', ' examples', ' Provide'}

Cluster 3 (247 tokens):
  Sample tokens: {' significant', '-term', ' life', ' to', ' positive', ' negative', ' affect', ' personality', ' adult', ' childhood', ' lead'}

Cluster 4 (27 tokens):
  Sample tokens: {' self'}

Cluster 5 (84 tokens):
  Sample tokens: {' "', '**', '\n\n', '.\n\n'}

Cluster 6 (109 tokens):
  Sample tokens: {' school', ' can', ' schooling', ' experiences', ' these', ' dynamics'

Summary: we explored clustering and visualization techniques in high dimensionbal data. We found suitable methods for doing this and presenting the results.  We also classified clusters based on how semantically interesting they were.

Our method was most effective for the hidden_states of the first layer of the LLM, which we believe is due to them being more representative of the individual menaing of that token. The hidden_states at the last layer carry context and more meaning than just that of the individual token which corresponds to it.