# Topological Data Analysis - Exploration

Explore building attention graphs and computing persistent homology on encoder attention patterns.

**Methodology:**
1. Extract last layer attention
2. Average across attention heads
3. (Optional) Filter special tokens and renormalize
4. Symmetrize attention matrix
5. Convert to distance: `d_ij = 1 - attention_ij`
6. Compute Vietoris-Rips persistent homology
7. Extract persistence diagrams (β₀, β₁)

## 1. Install and Import Dependencies

In [None]:
# Install TDA libraries if not already installed
# !pip install ripser persim

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings

# Configure matplotlib for Chinese font support
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

# Suppress warnings about infinite death times in persistence diagrams
# (This is expected for H0 diagrams - one component persists forever)
warnings.filterwarnings('ignore', message='.*non-finite death times.*')

# TDA libraries (Scikit-TDA)
from ripser import ripser
from persim import plot_diagrams, wasserstein

print("✓ Libraries imported")

In [None]:
import matplotlib

def plot_diagrams_with_font(diagrams, show=True, **kwargs):
    """
    Wrapper around persim.plot_diagrams that ensures Chinese font is applied.
    Uses matplotlib rc_context to override font settings during plot creation.
    
    Args:
        diagrams: Persistence diagrams from ripser
        show: Whether to call plt.show()
        **kwargs: Additional arguments passed to plot_diagrams
    """
    # Use rc_context to temporarily override font settings
    with matplotlib.rc_context({'font.sans-serif': ['Arial Unicode MS'], 
                                 'axes.unicode_minus': False}):
        plot_diagrams(diagrams, show=show, **kwargs)

print("✓ Helper function defined: plot_diagrams_with_font()")

## 2. Load Attention Data

In [None]:
# Load the extracted attention maps
data_path = Path("../data/attention_maps_zh_en/all_encoder_attention.pkl")

print(f"Loading data from {data_path}...")
with open(data_path, 'rb') as f:
    results = pickle.load(f)

print(f"✓ Loaded {len(results)} sentence pairs")
print(f"\nModel architecture: {results[0]['en_attention'].shape[0]} layers, {results[0]['en_attention'].shape[1]} heads")

## 3. Define Graph Construction Functions

In [None]:
def build_distance_matrix(attention, tokens, layer=-1, filter_special=True):
    """
    Build distance matrix from attention weights.
    
    Args:
        attention: Attention tensor (num_layers, num_heads, seq_len, seq_len)
        tokens: List of token strings
        layer: Which layer to use (-1 for last layer)
        filter_special: Whether to filter out special tokens
    
    Returns:
        distance_matrix: (N, N) array where N = number of tokens (or content tokens if filtered)
        filtered_tokens: List of token strings (content tokens if filtered, all tokens otherwise)
    """
    # 1. Extract last layer and average over heads
    attn = attention[layer].mean(axis=0)  # (seq_len, seq_len)
    
    # 2. Filter special tokens (optional)
    if filter_special:
        special_tokens = {'</s>', '<s>', '<pad>', 'eng_Latn', 'zho_Hans'}
        content_mask = np.array([tok not in special_tokens for tok in tokens])
        
        if sum(content_mask) > 0:  # Only filter if there are content tokens
            # Filter attention matrix
            attn_filtered = attn[content_mask][:, content_mask]
            filtered_tokens = [tok for tok, keep in zip(tokens, content_mask) if keep]
            
            # 3. Renormalize
            row_sums = attn_filtered.sum(axis=1, keepdims=True)
            attn_filtered = attn_filtered / row_sums
        else:
            # No content tokens, use original
            attn_filtered = attn
            filtered_tokens = tokens
    else:
        # No filtering, use all tokens
        attn_filtered = attn
        filtered_tokens = tokens
    
    # 4. Symmetrize (make undirected)
    attn_sym = (attn_filtered + attn_filtered.T) / 2
    
    # 5. Convert to distance: d = 1 - attention
    distance_matrix = 1 - attn_sym
    
    # Ensure diagonal is 0 and symmetric
    np.fill_diagonal(distance_matrix, 0)
    
    return distance_matrix, filtered_tokens


print("✓ Function defined: build_distance_matrix()")

## 4. Test on Sample Example

In [None]:
# Select first example
idx = 0
example = results[idx]

print(f"Example {idx}:")
print(f"English: {example['en_text']}")
print(f"Chinese: {example['zh_text']}")
print()

# Build distance matrices
en_dist, en_tokens = build_distance_matrix(example['en_attention'], example['en_tokens'])
zh_dist, zh_tokens = build_distance_matrix(example['zh_attention'], example['zh_tokens'])

print(f"English distance matrix: {en_dist.shape}")
print(f"  Content tokens: {en_tokens}")
print(f"  Min distance: {en_dist.min():.4f}")
print(f"  Max distance: {en_dist.max():.4f}")
print(f"  Mean distance: {en_dist.mean():.4f}")
print()

print(f"Chinese distance matrix: {zh_dist.shape}")
print(f"  Content tokens: {zh_tokens}")
print(f"  Min distance: {zh_dist.min():.4f}")
print(f"  Max distance: {zh_dist.max():.4f}")
print(f"  Mean distance: {zh_dist.mean():.4f}")

### Visualize Distance Matrices

In [None]:
# Plot distance matrices
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# English
sns.heatmap(en_dist, xticklabels=en_tokens, yticklabels=en_tokens, 
            cmap='RdYlBu_r', ax=ax1, square=True,
            cbar_kws={'label': 'Distance (1 - attention)'})
ax1.set_title(f'English Distance Matrix\n{example["en_text"][:50]}...')
ax1.set_xlabel('Tokens')
ax1.set_ylabel('Tokens')
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right')

# Chinese
sns.heatmap(zh_dist, xticklabels=zh_tokens, yticklabels=zh_tokens, 
            cmap='RdYlBu_r', ax=ax2, square=True,
            cbar_kws={'label': 'Distance (1 - attention)'})
ax2.set_title(f'Chinese Distance Matrix\n{example["zh_text"][:50]}...')
ax2.set_xlabel('Tokens')
ax2.set_ylabel('Tokens')
plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, ha='right')

plt.tight_layout()
plt.show()

## 5. Compute Persistent Homology

In [None]:
# Compute Vietoris-Rips persistent homology using ripser
# maxdim=1 computes β₀ (connected components) and β₁ (loops)

print("✓ Computing Vietoris-Rips persistence using ripser")
print(f"  Computing: β₀ (connected components) and β₁ (1-dimensional holes/loops)")

In [None]:
# Compute persistence for English
# ripser expects distance matrix, returns dict with 'dgms' key
en_result = ripser(en_dist, maxdim=1, distance_matrix=True)
en_diagrams = en_result['dgms']  # List of diagrams: [H0, H1]

print("English persistence diagrams:")
print(f"  H0 (β₀) shape: {en_diagrams[0].shape}")
print(f"  H1 (β₁) shape: {en_diagrams[1].shape}")
print(f"  Format: (birth, death)")
print(f"\nFirst 5 H0 features:")
print(en_diagrams[0][:5])
print(f"\nFirst 5 H1 features:")
print(en_diagrams[1][:5] if len(en_diagrams[1]) > 0 else "None")

In [None]:
# Compute persistence for Chinese
zh_result = ripser(zh_dist, maxdim=1, distance_matrix=True)
zh_diagrams = zh_result['dgms']

print("Chinese persistence diagrams:")
print(f"  H0 (β₀) shape: {zh_diagrams[0].shape}")
print(f"  H1 (β₁) shape: {zh_diagrams[1].shape}")
print(f"\nFirst 5 H0 features:")
print(zh_diagrams[0][:5])
print(f"\nFirst 5 H1 features:")
print(zh_diagrams[1][:5] if len(zh_diagrams[1]) > 0 else "None")

### Analyze Persistence Features

In [None]:
# Count features by dimension
def analyze_diagram(diagrams, language=""):
    """
    Analyze persistence diagrams from ripser.
    diagrams is a list: [H0, H1] where each is (n_features, 2) array
    """
    dim0 = diagrams[0]  # H0 features (birth, death)
    dim1 = diagrams[1]  # H1 features
    
    # Remove infinite persistence (last feature in H0 is typically infinite)
    dim0_finite = dim0[np.isfinite(dim0).all(axis=1)]
    dim1_finite = dim1[np.isfinite(dim1).all(axis=1)]
    
    # Persistence = death - birth
    persistence0 = dim0_finite[:, 1] - dim0_finite[:, 0]
    persistence1 = dim1_finite[:, 1] - dim1_finite[:, 0]
    
    print(f"{language} Persistence Summary:")
    print("="*50)
    print(f"β₀ features (connected components): {len(dim0_finite)}")
    if len(persistence0) > 0:
        print(f"  Max persistence: {persistence0.max():.4f}")
        print(f"  Mean persistence: {persistence0.mean():.4f}")
    
    print(f"\nβ₁ features (loops/holes): {len(dim1_finite)}")
    if len(persistence1) > 0:
        print(f"  Max persistence: {persistence1.max():.4f}")
        print(f"  Mean persistence: {persistence1.mean():.4f}")
    else:
        print(f"  No H1 features detected")
    print()

analyze_diagram(en_diagrams, "English")
analyze_diagram(zh_diagrams, "Chinese")

## 6. Visualize Persistence Diagrams

In [None]:
# Plot English persistence diagram
plot_diagrams_with_font(en_diagrams, show=False)
plt.title(f'English Persistence Diagram\n{example["en_text"][:60]}...')
plt.tight_layout()
plt.show()

In [None]:
# Plot Chinese persistence diagram
plot_diagrams_with_font(zh_diagrams, show=False)
plt.title(f'Chinese Persistence Diagram\n{example["zh_text"][:60]}...')
plt.tight_layout()
plt.show()

## 7. Compare Persistence Diagrams

Use Wasserstein distance to measure similarity between English and Chinese topologies.

In [None]:
# Compute Wasserstein distance between diagrams using persim
# We compute for both H0 and H1 separately, then can combine

# Wasserstein distance for H0 (connected components)
w_dist_h0 = wasserstein(en_diagrams[0], zh_diagrams[0])

# Wasserstein distance for H1 (loops) - handle case where one diagram is empty
if len(en_diagrams[1]) > 0 and len(zh_diagrams[1]) > 0:
    w_dist_h1 = wasserstein(en_diagrams[1], zh_diagrams[1])
else:
    # If one has no H1 features, compute distance to empty diagram
    w_dist_h1 = wasserstein(en_diagrams[1], zh_diagrams[1])

# Total distance (can weight differently, but simple sum for now)
total_w_dist = w_dist_h0 + w_dist_h1

print("Wasserstein Distances:")
print("="*50)
print(f"H0 (connected components): {w_dist_h0:.6f}")
print(f"H1 (loops/holes):          {w_dist_h1:.6f}")
print(f"Total:                     {total_w_dist:.6f}")

## 8. Process Multiple Examples

In [None]:
# Process first 10 examples
num_examples = 10
wasserstein_distances = []

print(f"Computing persistent homology for {num_examples} sentence pairs...\n")

for idx in range(num_examples):
    example = results[idx]
    
    # Build distance matrices
    en_dist, _ = build_distance_matrix(example['en_attention'], example['en_tokens'])
    zh_dist, _ = build_distance_matrix(example['zh_attention'], example['zh_tokens'])
    
    # Compute persistence
    en_result = ripser(en_dist, maxdim=1, distance_matrix=True)
    zh_result = ripser(zh_dist, maxdim=1, distance_matrix=True)
    
    en_diagrams = en_result['dgms']
    zh_diagrams = zh_result['dgms']
    
    # Compute Wasserstein distances
    w_dist_h0 = wasserstein(en_diagrams[0], zh_diagrams[0])
    w_dist_h1 = wasserstein(en_diagrams[1], zh_diagrams[1])
    total_w_dist = w_dist_h0 + w_dist_h1
    
    wasserstein_distances.append({
        'idx': idx,
        'en_text': example['en_text'],
        'zh_text': example['zh_text'],
        'wasserstein_distance': total_w_dist,
        'wasserstein_h0': w_dist_h0,
        'wasserstein_h1': w_dist_h1,
        'en_diagrams': en_diagrams,
        'zh_diagrams': zh_diagrams
    })
    
    print(f"[{idx}] W-dist: {total_w_dist:.6f} (H0: {w_dist_h0:.4f}, H1: {w_dist_h1:.4f}) | EN: {example['en_text'][:40]}...")

print(f"\n✓ Processed {num_examples} examples")

In [None]:
# Summary statistics
w_dists = [d['wasserstein_distance'] for d in wasserstein_distances]

print("Wasserstein Distance Statistics (first 10 pairs):")
print("="*50)
print(f"Min:    {np.min(w_dists):.6f}")
print(f"Max:    {np.max(w_dists):.6f}")
print(f"Mean:   {np.mean(w_dists):.6f}")
print(f"Median: {np.median(w_dists):.6f}")
print(f"Std:    {np.std(w_dists):.6f}")

In [None]:
# Plot distribution
plt.figure(figsize=(10, 5))
plt.hist(w_dists, bins=10, alpha=0.7, color='blue', edgecolor='black')
plt.axvline(np.mean(w_dists), color='red', linestyle='--', 
            label=f'Mean: {np.mean(w_dists):.6f}')
plt.xlabel('Wasserstein Distance')
plt.ylabel('Frequency')
plt.title('Distribution of Topological Similarity (English vs Chinese)\nLower = More Similar')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 9. Compare High vs Low Similarity Examples

In [ ]:
# Find most and least topologically similar pairs
sorted_by_similarity = sorted(wasserstein_distances, key=lambda x: x['wasserstein_distance'])

print("Most topologically similar (lowest Wasserstein distance):")
print("="*70)
most_similar = sorted_by_similarity[0]
print(f"Pair {most_similar['idx']}: W-dist = {most_similar['wasserstein_distance']:.6f}")
print(f"  EN: {most_similar['en_text']}")
print(f"  ZH: {most_similar['zh_text']}")
print()

print("Least topologically similar (highest Wasserstein distance):")
print("="*70)
least_similar = sorted_by_similarity[-1]
print(f"Pair {least_similar['idx']}: W-dist = {least_similar['wasserstein_distance']:.6f}")
print(f"  EN: {least_similar['en_text']}")
print(f"  ZH: {least_similar['zh_text']}")

In [None]:
# Plot persistence diagrams for most similar pair
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# English
plt.sca(ax1)
plot_diagrams_with_font(most_similar['en_diagrams'], show=False)
ax1.set_title('English')

# Chinese
plt.sca(ax2)
plot_diagrams_with_font(most_similar['zh_diagrams'], show=False)
ax2.set_title('Chinese')

fig.suptitle(f"Most Topologically Similar Pair (W-dist: {most_similar['wasserstein_distance']:.6f})")
plt.tight_layout()
plt.show()

## Summary

✅ **Successfully computed topological features!**

**What we learned:**
- Built distance matrices from attention weights
- Computed Vietoris-Rips persistent homology (β₀, β₁)
- Visualized persistence diagrams
- Measured topological similarity using Wasserstein distance
- Analyzed first 10 sentence pairs

**Next steps:**
1. Scale to all 2000 sentence pairs
2. Compute BLEU scores for translation quality
3. Correlate Wasserstein distance with BLEU scores
4. Statistical analysis and visualization

**Key findings so far:**
- Wasserstein distances vary across sentence pairs
- Both languages show β₀ and β₁ features
- Topological structure varies across sentence pairs