In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import ToUndirected
from torch_geometric.utils import to_networkx, from_networkx
from transformers import AutoModel, AutoTokenizer
import networkx as nx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict

# scikit-learn imports
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import HDBSCAN
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

#Use Apple Silicon M3 chip
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using MPS (Apple Silicon GPU)


First, we will limit our analysis to a single species for simplicity. In this case, we will choose the Coto species (commonly known as Townsend's big-eared bat), since it is the species with the largest number of examples in our dataset (n = 116696)

In [8]:

df = pd.read_feather("src/data/orig_chirps_2024-06-25T12_55_03.feather")
df_Coto = df[df['species'] == 'Coto'].copy()

# Check is_daytime distribution
print("is_daytime value counts:")
daytime_counts = df_Coto['is_daytime'].value_counts()
print(daytime_counts)
print(f"\nFalse (nighttime): {daytime_counts.get(False, 0)} ({daytime_counts.get(False, 0) / len(df_Coto) * 100:.2f}%)")
print(f"True (daytime): {daytime_counts.get(True, 0)} ({daytime_counts.get(True, 0) / len(df_Coto) * 100:.2f}%)")

# Drop redundant/non-feature columns:
# - species: all same (Coto)
# - TimeInFile: will be used for edge features only
# - PrecedingIntrvl: redundant with TimeInFile differences
# - CallsPerSec: constant within recording
# - file_id, chirp_idx, split, rec_datetime: metadata, not chirp features
# - MinAccpQuality, Max#CallsConsidered, cntxt_sz: processing parameters
# - sin_year, cos_year: too coarse temporal info
df_Coto = df_Coto.drop(['species', 'PrecedingIntrvl', 'CallsPerSec', 'chirp_idx', 'split', 'rec_datetime', 'MinAccpQuality', 'Max#CallsConsidered', 'cntxt_sz', 'sin_year', 'cos_year'], axis=1)
print(f"\nDataframe shape: {df_Coto.shape}")
print(f"Number of feature columns (excluding TimeInFile and file_id): {len(df_Coto.columns) - 2}")


is_daytime value counts:
is_daytime
False    111383
True       5313
Name: count, dtype: int64

False (nighttime): 111383 (95.45%)
True (daytime): 5313 (4.55%)

Dataframe shape: (116696, 111)
Number of feature columns (excluding TimeInFile and file_id): 109


The original dataset is a compilation of a number of recordings taken at different dates and times. We want to analyze only within each recording, since it is unreasonable to assume that there should be any causal connection between bat chirps which are not temporally connected. We also will filter out all recordings which occur during the daytime, since the context for those chirps is different from chirps which occur at night. This brings us down from 5350 files to 5080 files, and from 116696 total chirps to 111383 total chirps. 


After splitting by recording and removing daytime recordings, we create individual graphs for each recording, which are indexed with the variable 'file_id.' This yields a total of 5350 unique graphs, with mean size of around 22 chirps, and a minimum and maximum of 6 and 54 chirps respectively. Clusters will be compared across graphs to determine patterns which occur within the species. 



In [9]:
# Group by file_id and check if all chirps in each file are nighttime
file_id_all_nighttime = df_Coto.groupby('file_id')['is_daytime'].apply(lambda x: (~x).all())
nighttime_file_ids = file_id_all_nighttime[file_id_all_nighttime].index.tolist()

print(f"Total unique file_ids: {df_Coto['file_id'].nunique()}")
print(f"File_ids with all nighttime chirps: {len(nighttime_file_ids)}")

# Filter df_Coto to only include nighttime file_ids
df_Coto_filtered = df_Coto[df_Coto['file_id'].isin(nighttime_file_ids)].copy()
df_Coto_filtered = df_Coto_filtered.drop(['is_daytime'], axis=1)
print(f"\nTotal chirps before filtering: {len(df_Coto)}")
print(f"Total chirps after filtering: {len(df_Coto_filtered)}")

# Split df_Coto_filtered into separate dataframes by file_id
unique_file_ids = df_Coto_filtered['file_id'].unique()
n = len(unique_file_ids)

# Create dictionary to store dataframes
dataframes = {}
lengths = []
file_id_to_index = {}

# Split dataframe by file_id and store each
for i, file_id in enumerate(unique_file_ids):
    df_name = f'df_{i}'
    dataframes[df_name] = df_Coto_filtered[df_Coto_filtered['file_id'] == file_id].copy()
    lengths.append(len(dataframes[df_name]))
    file_id_to_index[file_id] = i

# Print results
print(f"\n Number of nighttime recordings = {n}")
print(f"Min chirps per recording: {min(lengths)}")
print(f"Max chirps per recording: {max(lengths)}")
print(f"Mean chirps per recording: {np.mean(np.array(lengths)):.2f}")



Total unique file_ids: 5350
File_ids with all nighttime chirps: 5080

Total chirps before filtering: 116696
Total chirps after filtering: 111383

 Number of nighttime recordings = 5080
Min chirps per recording: 6
Max chirps per recording: 54
Mean chirps per recording: 21.93


In [10]:
# Create 1D graphs for each recording (file_id)
# Each graph is a chain where nodes are chirps and edges connect consecutive chirps
graphs = []  

for i in range(n):
    df_name = f'df_{i}'
    df_current = dataframes[df_name]
    
    # Get number of nodes (chirps) in this graph
    num_nodes = len(df_current)
    
    # Extract node features (all columns except TimeInFile, file_id, and any non-numeric columns)
    feature_cols = [col for col in df_current.columns 
                    if col not in ['TimeInFile', 'file_id']]
    node_features = torch.tensor(df_current[feature_cols].values, dtype=torch.float32)
    
    # Create sequential edge indices for 1D chain graph
    # Edge from node i to node i+1 for all i in [0, num_nodes-2]
    if num_nodes > 1:
        edge_index = torch.tensor([
            list(range(num_nodes - 1)),  # source nodes
            list(range(1, num_nodes))     # target nodes
        ], dtype=torch.long)
        
        # Calculate edge features: time difference between consecutive chirps
        time_values = df_current['TimeInFile'].values
        time_diffs = time_values[1:] - time_values[:-1]
        edge_attr = torch.tensor(time_diffs, dtype=torch.float32).unsqueeze(1)  # Shape: [num_edges, 1]
    else:
        # Single node graph - no edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 1), dtype=torch.float32)
    
    # Create PyG Data object
    graph = Data(
        x=node_features,           # Node features
        edge_index=edge_index,     # Edge connectivity
        edge_attr=edge_attr,       # Edge features (time differences)
        num_nodes=num_nodes
    )
    
    # Store original file_id as graph attribute
    graph.file_id = df_current['file_id'].iloc[0]
    
    graphs.append(graph)

print(f"Created {len(graphs)} graphs")
print(f"Example graph 0:")
print(f"  - Number of nodes: {graphs[0].num_nodes}")
print(f"  - Number of edges: {graphs[0].edge_index.shape[1]}")
print(f"  - Node feature dimension: {graphs[0].x.shape[1]}")
print(f"  - Edge feature dimension: {graphs[0].edge_attr.shape[1]}")
print(f"  - File ID: {graphs[0].file_id}")
print(f"\nExample graph 10:")
print(f"  - Number of nodes: {graphs[10].num_nodes}")
print(f"  - Number of edges: {graphs[10].edge_index.shape[1]}")
print(f"  - Node feature dimension: {graphs[10].x.shape[1]}")
print(f"  - Edge feature dimension: {graphs[0].edge_attr.shape[1]}")
print(f"  - First 3 edge features (time diffs in ms): {graphs[10].edge_attr[:3].squeeze().tolist()}")


Created 5080 graphs
Example graph 0:
  - Number of nodes: 18
  - Number of edges: 17
  - Node feature dimension: 108
  - Edge feature dimension: 1
  - File ID: 39

Example graph 10:
  - Number of nodes: 19
  - Number of edges: 18
  - Node feature dimension: 108
  - Edge feature dimension: 1
  - First 3 edge features (time diffs in ms): [92.0, 68.0, 57.0]


In [11]:
"""
GraphSAGE Implementation for Chirp Cluster Identification

ARCHITECTURE OVERVIEW:
======================

1. GRAPH CONSTRUCTION (Hybrid Approach):
   - Temporal edges: Connect consecutive chirps in time sequence
   - k-NN edges: Connect each chirp to k=5 nearest neighbors based on feature similarity
   - Edge features: Time differences for temporal edges, feature distances for k-NN edges
   
2. GRAPHSAGE MODEL:
   - Input dimension: 108 (chirp features)
   - Hidden dimension: 64
   - Output dimension: 32 (embedding space)
   - Number of layers: 2
   - Aggregation: Mean aggregation (robust to varying neighborhood sizes)
   - Activation: ReLU
   - Dropout: 0.2 (regularization)
   
3. TRAINING STRATEGY (Self-Supervised):
   - Contrastive learning: Predict whether chirp pairs are temporally adjacent
   - Positive pairs: Consecutive chirps in same recording
   - Negative pairs: Random chirps from same or different recordings
   - Loss: Binary cross-entropy
   - Optimizer: Adam with lr=0.001
   - Epochs: 50-100 (should complete in 5-10 minutes)
   
4. CLUSTERING APPROACH:
   - Algorithm: HDBSCAN (hierarchical density-based clustering)
   - Benefits: Automatically determines number of clusters, handles noise
   - Applied to: Node embeddings from trained GraphSAGE
   
5. CROSS-RECORDING ANALYSIS:
   - Extract cluster distributions per recording (graph-level signatures)
   - Compute cluster co-occurrence matrix across recordings
   - Identify common "sentence" patterns that appear across multiple recordings
"""

print("Building hybrid graphs with temporal + k-NN edges...")

# Step 1: Rebuild graphs with k-NN edges
def add_knn_edges(graph, k=5):
    """Add k-NN edges based on feature similarity to existing temporal edges"""
    x = graph.x.numpy()
    
    # Normalize features for better distance computation
    scaler = StandardScaler()
    x_normalized = scaler.fit_transform(x)
    
    # Find k nearest neighbors
    nbrs = NearestNeighbors(n_neighbors=min(k+1, len(x)), algorithm='ball_tree').fit(x_normalized)
    distances, indices = nbrs.kneighbors(x_normalized)
    
    # Build k-NN edge list (skip first neighbor which is the node itself)
    knn_edges = []
    knn_edge_features = []
    
    for i in range(len(x)):
        for j, neighbor_idx in enumerate(indices[i][1:]):  # Skip self
            knn_edges.append([i, neighbor_idx])
            # Use normalized distance as edge feature
            knn_edge_features.append([distances[i][j+1]])
    
    if len(knn_edges) > 0:
        knn_edge_index = torch.tensor(knn_edges, dtype=torch.long).t()
        knn_edge_attr = torch.tensor(knn_edge_features, dtype=torch.float32)
        
        # Combine temporal and k-NN edges
        combined_edge_index = torch.cat([graph.edge_index, knn_edge_index], dim=1)
        
        # Normalize temporal edge features to similar scale as k-NN distances
        temporal_edge_norm = graph.edge_attr / graph.edge_attr.max() if graph.edge_attr.numel() > 0 else graph.edge_attr
        combined_edge_attr = torch.cat([temporal_edge_norm, knn_edge_attr], dim=0)
        
        graph.edge_index = combined_edge_index
        graph.edge_attr = combined_edge_attr
    
    return graph

# Add k-NN edges to all graphs
hybrid_graphs = [add_knn_edges(graph.clone(), k=5) for graph in graphs]

print(f"Original graph 0: {graphs[0].edge_index.shape[1]} edges")
print(f"Hybrid graph 0: {hybrid_graphs[0].edge_index.shape[1]} edges")
print(f"Edge increase: {hybrid_graphs[0].edge_index.shape[1] - graphs[0].edge_index.shape[1]} k-NN edges added")
print()

# Step 2: Define GraphSAGE Model
class ChirpGraphSAGE(nn.Module):
    """
    GraphSAGE model for learning chirp embeddings
    
    Architecture:
    - Layer 1: SAGEConv(108 -> 64) + ReLU + Dropout(0.2)
    - Layer 2: SAGEConv(64 -> 32)
    
    The mean aggregation combines features from temporal and k-NN neighbors,
    allowing the model to learn both sequential patterns and acoustic similarities.
    """
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.2):
        super(ChirpGraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels, aggr='mean')
        self.conv2 = SAGEConv(hidden_channels, out_channels, aggr='mean')
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        # Layer 1: Aggregate from 1-hop neighbors
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Layer 2: Aggregate from 2-hop neighbors
        x = self.conv2(x, edge_index)
        
        return x
    
    def get_embeddings(self, x, edge_index):
        """Get final node embeddings (without dropout for inference)"""
        self.eval()
        with torch.no_grad():
            return self.forward(x, edge_index)

print("Initializing GraphSAGE model...")
model = ChirpGraphSAGE(in_channels=108, hidden_channels=64, out_channels=32, dropout=0.2)
model = model.to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
print()

# Step 3: Prepare training data for contrastive learning
print("Preparing contrastive learning dataset...")

def create_training_pairs(graphs, num_negatives_per_positive=1):
    """
    Create positive and negative pairs for contrastive learning
    Positive: temporally adjacent chirps
    Negative: random non-adjacent chirps
    """
    pairs = []
    labels = []
    graph_indices = []
    
    for graph_idx, graph in enumerate(graphs):
        num_nodes = graph.num_nodes
        
        if num_nodes < 2:
            continue
        
        # Positive pairs: consecutive chirps (from temporal edges only)
        temporal_edges = graphs[graph_idx].edge_index.shape[1]  # Original temporal edges
        for i in range(min(temporal_edges, num_nodes - 1)):
            pairs.append([i, i + 1])
            labels.append(1)
            graph_indices.append(graph_idx)
            
            # Add negatives: random non-adjacent pairs
            for _ in range(num_negatives_per_positive):
                node1 = np.random.randint(0, num_nodes)
                # Sample node2 that's not adjacent to node1
                possible_nodes = [j for j in range(num_nodes) if abs(j - node1) > 1]
                if len(possible_nodes) > 0:
                    node2 = np.random.choice(possible_nodes)
                    pairs.append([node1, node2])
                    labels.append(0)
                    graph_indices.append(graph_idx)
    
    return pairs, labels, graph_indices

pairs, labels, graph_indices = create_training_pairs(hybrid_graphs, num_negatives_per_positive=1)
print(f"Created {len(pairs)} training pairs ({sum(labels)} positive, {len(labels) - sum(labels)} negative)")
print()

# Step 4: Training loop (OPTIMIZED - pre-compute embeddings per epoch)
print("Training GraphSAGE model...")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.BCEWithLogitsLoss()

num_epochs = 100
batch_size = 1024  # Increased batch size for efficiency

# Convert pairs and labels to tensors for faster processing
pairs_tensor = torch.tensor(pairs, dtype=torch.long)
labels_tensor = torch.tensor(labels, dtype=torch.float32)
graph_indices_tensor = torch.tensor(graph_indices, dtype=torch.long)

model.train()
losses = []

import time
start_time = time.time()

for epoch in range(num_epochs):
    epoch_loss = 0
    num_batches = 0
    
    # OPTIMIZATION: Pre-compute embeddings for all graphs once per epoch
    all_graph_embeddings = []
    for graph in hybrid_graphs:
        graph = graph.to(device)
        embeddings = model(graph.x, graph.edge_index)
        all_graph_embeddings.append(embeddings)
    
    # Shuffle training data
    indices = torch.randperm(len(pairs))
    
    for batch_start in range(0, len(pairs), batch_size):
        optimizer.zero_grad()
        
        batch_indices = indices[batch_start:batch_start + batch_size]
        batch_pairs = pairs_tensor[batch_indices]
        batch_labels = labels_tensor[batch_indices].to(device)
        batch_graph_indices = graph_indices_tensor[batch_indices]
        
        # Compute scores for all pairs in batch
        scores = []
        for i, idx in enumerate(batch_indices):
            graph_idx = batch_graph_indices[i].item()
            node1, node2 = batch_pairs[i]
            
            # Look up pre-computed embeddings
            emb1 = all_graph_embeddings[graph_idx][node1]
            emb2 = all_graph_embeddings[graph_idx][node2]
            score = torch.sum(emb1 * emb2)
            scores.append(score)
        
        scores = torch.stack(scores)
        
        # Compute loss for entire batch
        loss = criterion(scores, batch_labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        num_batches += 1
    
    avg_loss = epoch_loss / num_batches
    losses.append(avg_loss)
    
    if (epoch + 1) % 10 == 0:
        elapsed = time.time() - start_time
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Time: {elapsed:.1f}s")

print("\nTraining complete!")

# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GraphSAGE Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

print("\nGenerating embeddings for all chirps...")



Building hybrid graphs with temporal + k-NN edges...
Original graph 0: 17 edges
Hybrid graph 0: 107 edges
Edge increase: 90 k-NN edges added

Initializing GraphSAGE model...
Model parameters: 18016

Preparing contrastive learning dataset...
Created 212606 training pairs (106303 positive, 106303 negative)

Training GraphSAGE model...


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [MPSFloatType [32, 64]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [None]:
# Step 5: Generate embeddings for all chirps
model.eval()
all_embeddings = []
all_graph_ids = []

with torch.no_grad():
    for graph_idx, graph in enumerate(hybrid_graphs):
        graph = graph.to(device)
        embeddings = model(graph.x, graph.edge_index)
        all_embeddings.append(embeddings.cpu().numpy())
        all_graph_ids.extend([graph_idx] * graph.num_nodes)

# Concatenate all embeddings
all_embeddings_concat = np.vstack(all_embeddings)
all_graph_ids = np.array(all_graph_ids)

print(f"Generated embeddings for {len(all_embeddings_concat)} chirps")
print(f"Embedding shape: {all_embeddings_concat.shape}")
print()

# Step 6: Apply HDBSCAN clustering
print("Applying HDBSCAN clustering...")
print("HDBSCAN Parameters:")
print("  - min_cluster_size: 50 (minimum chirps to form a cluster)")
print("  - min_samples: 10 (neighborhood size for density estimation)")
print("  - metric: euclidean")
print()

clusterer = HDBSCAN(min_cluster_size=50, min_samples=10, metric='euclidean')
cluster_labels = clusterer.fit_predict(all_embeddings_concat)

# Analyze clustering results
unique_clusters = set(cluster_labels)
num_clusters = len(unique_clusters - {-1})  # Exclude noise (-1)
num_noise = sum(cluster_labels == -1)

print(f"Clustering Results:")
print(f"  - Number of clusters found: {num_clusters}")
print(f"  - Noise points (unassigned): {num_noise} ({num_noise/len(cluster_labels)*100:.2f}%)")
print(f"  - Clustered points: {len(cluster_labels) - num_noise} ({(len(cluster_labels)-num_noise)/len(cluster_labels)*100:.2f}%)")
print()

# Cluster size distribution
cluster_sizes = Counter(cluster_labels)
del cluster_sizes[-1]  # Remove noise
cluster_sizes_sorted = sorted(cluster_sizes.items(), key=lambda x: x[1], reverse=True)

print("Top 10 largest clusters:")
for cluster_id, size in cluster_sizes_sorted[:10]:
    print(f"  Cluster {cluster_id}: {size} chirps ({size/len(cluster_labels)*100:.2f}%)")
print()

# Visualize cluster size distribution
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
sizes = [size for _, size in cluster_sizes_sorted]
plt.hist(sizes, bins=30, edgecolor='black', alpha=0.7)
plt.xlabel('Cluster Size (number of chirps)')
plt.ylabel('Frequency')
plt.title('Cluster Size Distribution')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.bar(range(min(20, len(cluster_sizes_sorted))), [s for _, s in cluster_sizes_sorted[:20]])
plt.xlabel('Cluster ID (top 20)')
plt.ylabel('Number of Chirps')
plt.title('Top 20 Clusters by Size')
plt.xticks(range(min(20, len(cluster_sizes_sorted))), [c for c, _ in cluster_sizes_sorted[:20]], rotation=45)
plt.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nComputing graph-level signatures (cluster distributions per recording)...")



In [None]:
# Step 7: Graph-level signature analysis
# For each recording, compute which clusters appear and how often

recording_cluster_signatures = defaultdict(lambda: defaultdict(int))
cluster_to_recordings = defaultdict(set)

# Build cluster signatures for each recording
idx = 0
for graph_idx, graph in enumerate(hybrid_graphs):
    num_nodes = graph.num_nodes
    graph_cluster_labels = cluster_labels[idx:idx + num_nodes]
    
    for cluster_id in graph_cluster_labels:
        if cluster_id != -1:  # Ignore noise
            recording_cluster_signatures[graph_idx][cluster_id] += 1
            cluster_to_recordings[cluster_id].add(graph_idx)
    
    idx += num_nodes

print(f"Built signatures for {len(recording_cluster_signatures)} recordings")
print()

# Analyze cluster prevalence across recordings
print("Cluster Prevalence Analysis:")
print("How many recordings contain each cluster (i.e., how common is each 'sentence')?")
print()

cluster_prevalence = [(cluster_id, len(recordings)) 
                      for cluster_id, recordings in cluster_to_recordings.items()]
cluster_prevalence_sorted = sorted(cluster_prevalence, key=lambda x: x[1], reverse=True)

print("Top 15 most common clusters (appear in most recordings):")
for cluster_id, num_recordings in cluster_prevalence_sorted[:15]:
    cluster_size = cluster_sizes[cluster_id]
    print(f"  Cluster {cluster_id}: appears in {num_recordings} recordings ({num_recordings/len(recording_cluster_signatures)*100:.2f}%), "
          f"total {cluster_size} chirps")
print()

# Find recordings with similar cluster profiles
print("Finding recordings with similar cluster profiles...")

# Convert signatures to vectors for similarity computation
all_cluster_ids = sorted(set(cluster_labels) - {-1})
cluster_id_to_idx = {cid: i for i, cid in enumerate(all_cluster_ids)}

signature_vectors = []
signature_graph_ids = []

for graph_idx in recording_cluster_signatures:
    vec = np.zeros(len(all_cluster_ids))
    for cluster_id, count in recording_cluster_signatures[graph_idx].items():
        vec[cluster_id_to_idx[cluster_id]] = count
    # Normalize by total chirps in recording to get proportions
    vec = vec / vec.sum() if vec.sum() > 0 else vec
    signature_vectors.append(vec)
    signature_graph_ids.append(graph_idx)

signature_vectors = np.array(signature_vectors)
print(f"Created signature vectors of dimension {signature_vectors.shape[1]} for {len(signature_vectors)} recordings")
print()

# Compute similarity between recordings (cosine similarity)
similarity_matrix = cosine_similarity(signature_vectors)

# Find pairs of recordings with high similarity (excluding self-similarity)
print("Most similar recording pairs (share similar cluster patterns):")
similarity_pairs = []
for i in range(len(similarity_matrix)):
    for j in range(i + 1, len(similarity_matrix)):
        similarity_pairs.append((i, j, similarity_matrix[i, j]))

similarity_pairs_sorted = sorted(similarity_pairs, key=lambda x: x[2], reverse=True)

print("\nTop 10 most similar recording pairs:")
for i, j, sim in similarity_pairs_sorted[:10]:
    graph_i = signature_graph_ids[i]
    graph_j = signature_graph_ids[j]
    file_id_i = hybrid_graphs[graph_i].file_id
    file_id_j = hybrid_graphs[graph_j].file_id
    
    # Find shared clusters
    clusters_i = set(recording_cluster_signatures[graph_i].keys())
    clusters_j = set(recording_cluster_signatures[graph_j].keys())
    shared_clusters = clusters_i & clusters_j
    
    print(f"  Recording {graph_i} (file {file_id_i}) <-> Recording {graph_j} (file {file_id_j})")
    print(f"    Similarity: {sim:.3f}, Shared clusters: {len(shared_clusters)}/{len(clusters_i | clusters_j)}")
print()

# Visualize similarity matrix (sample)
print("Visualizing recording similarity matrix (first 100 recordings)...")
plt.figure(figsize=(12, 10))
sample_size = min(100, len(similarity_matrix))
plt.imshow(similarity_matrix[:sample_size, :sample_size], cmap='viridis', aspect='auto')
plt.colorbar(label='Cosine Similarity')
plt.xlabel('Recording Index')
plt.ylabel('Recording Index')
plt.title(f'Recording Similarity Matrix (first {sample_size} recordings)')
plt.tight_layout()
plt.show()

# Analyze cluster co-occurrence
print("\nCluster Co-occurrence Analysis:")
print("Which clusters tend to appear together in the same recordings?")
print()

# Build co-occurrence matrix
cluster_cooccurrence = np.zeros((len(all_cluster_ids), len(all_cluster_ids)))

for graph_idx in recording_cluster_signatures:
    clusters_in_recording = list(recording_cluster_signatures[graph_idx].keys())
    for i, c1 in enumerate(clusters_in_recording):
        for c2 in clusters_in_recording[i:]:
            idx1 = cluster_id_to_idx[c1]
            idx2 = cluster_id_to_idx[c2]
            cluster_cooccurrence[idx1, idx2] += 1
            if idx1 != idx2:
                cluster_cooccurrence[idx2, idx1] += 1

# Normalize by cluster prevalence to get co-occurrence strength
for i in range(len(all_cluster_ids)):
    for j in range(len(all_cluster_ids)):
        cluster_i = all_cluster_ids[i]
        cluster_j = all_cluster_ids[j]
        max_possible = min(len(cluster_to_recordings[cluster_i]), 
                          len(cluster_to_recordings[cluster_j]))
        if max_possible > 0:
            cluster_cooccurrence[i, j] /= max_possible

# Find strongest co-occurrences
print("Strongest cluster co-occurrences (top 15):")
cooccurrence_pairs = []
for i in range(len(all_cluster_ids)):
    for j in range(i + 1, len(all_cluster_ids)):
        cooccurrence_pairs.append((all_cluster_ids[i], all_cluster_ids[j], 
                                  cluster_cooccurrence[i, j]))

cooccurrence_pairs_sorted = sorted(cooccurrence_pairs, key=lambda x: x[2], reverse=True)

for c1, c2, strength in cooccurrence_pairs_sorted[:15]:
    recordings_with_both = len(cluster_to_recordings[c1] & cluster_to_recordings[c2])
    print(f"  Clusters {c1} & {c2}: co-occur in {recordings_with_both} recordings, "
          f"strength = {strength:.3f}")
print()

# Summary statistics
print("="*70)
print("SUMMARY OF FINDINGS:")
print("="*70)
print(f"Total recordings analyzed: {len(recording_cluster_signatures)}")
print(f"Total clusters discovered: {num_clusters}")
print(f"Average clusters per recording: {np.mean([len(sig) for sig in recording_cluster_signatures.values()]):.2f}")
print(f"Max clusters in a single recording: {max([len(sig) for sig in recording_cluster_signatures.values()])}")
print(f"Min clusters in a single recording: {min([len(sig) for sig in recording_cluster_signatures.values()])}")
print()
print(f"Most prevalent cluster (Cluster {cluster_prevalence_sorted[0][0]}): "
      f"appears in {cluster_prevalence_sorted[0][1]} recordings "
      f"({cluster_prevalence_sorted[0][1]/len(recording_cluster_signatures)*100:.1f}%)")
print()
print("Interpretation:")
print("  - Each cluster represents a distinct 'sentence' pattern in bat communication")
print("  - Recordings with high similarity share similar communication patterns")
print("  - Prevalent clusters appearing across many recordings suggest")
print("    common communication behaviors across the species")
print("  - Co-occurring clusters may represent sequential patterns or context-dependent communication")
print("="*70)



In [None]:
# Utility functions for exploring clusters

def visualize_cluster_embeddings(cluster_id, method='tsne'):
    """
    Visualize embeddings of chirps in a specific cluster
    
    Args:
        cluster_id: The cluster to visualize
        method: 'tsne' or 'pca' for dimensionality reduction
    """
    cluster_mask = cluster_labels == cluster_id
    cluster_embeddings = all_embeddings_concat[cluster_mask]
    
    if len(cluster_embeddings) < 2:
        print(f"Cluster {cluster_id} has too few points to visualize")
        return
    
    if method == 'tsne':
        reducer = TSNE(n_components=2, random_state=42)
        reduced = reducer.fit_transform(cluster_embeddings)
        title = f'Cluster {cluster_id} - t-SNE Visualization'
    else:  # pca
        reducer = PCA(n_components=2, random_state=42)
        reduced = reducer.fit_transform(cluster_embeddings)
        title = f'Cluster {cluster_id} - PCA Visualization'
    
    plt.figure(figsize=(10, 8))
    plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.5, s=20)
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"Cluster {cluster_id} contains {len(cluster_embeddings)} chirps")
    print(f"Appears in {len(cluster_to_recordings[cluster_id])} recordings")

def get_cluster_statistics(cluster_id):
    """Get detailed statistics about a specific cluster"""
    cluster_mask = cluster_labels == cluster_id
    cluster_indices = np.where(cluster_mask)[0]
    
    print(f"\n{'='*60}")
    print(f"CLUSTER {cluster_id} STATISTICS")
    print(f"{'='*60}")
    print(f"Total chirps in cluster: {len(cluster_indices)}")
    print(f"Appears in {len(cluster_to_recordings[cluster_id])} recordings")
    print(f"Prevalence: {len(cluster_to_recordings[cluster_id])/len(recording_cluster_signatures)*100:.2f}% of recordings")
    
    # Find which recordings contain this cluster
    recordings_with_cluster = sorted(list(cluster_to_recordings[cluster_id]))[:10]
    print(f"\nExample recordings containing this cluster (first 10):")
    for rec_idx in recordings_with_cluster:
        file_id = hybrid_graphs[rec_idx].file_id
        num_chirps_in_cluster = recording_cluster_signatures[rec_idx][cluster_id]
        total_chirps = hybrid_graphs[rec_idx].num_nodes
        print(f"  Recording {rec_idx} (file {file_id}): {num_chirps_in_cluster}/{total_chirps} chirps "
              f"({num_chirps_in_cluster/total_chirps*100:.1f}%)")
    
    # Analyze temporal distribution within recordings
    print(f"\nTemporal distribution:")
    temporal_positions = []
    for rec_idx in cluster_to_recordings[cluster_id]:
        # Find positions of cluster chirps within this recording
        start_idx = sum(hybrid_graphs[i].num_nodes for i in range(rec_idx))
        end_idx = start_idx + hybrid_graphs[rec_idx].num_nodes
        rec_cluster_mask = cluster_labels[start_idx:end_idx] == cluster_id
        positions = np.where(rec_cluster_mask)[0]
        # Normalize positions to [0, 1]
        normalized_positions = positions / len(rec_cluster_mask) if len(rec_cluster_mask) > 0 else []
        temporal_positions.extend(normalized_positions)
    
    if len(temporal_positions) > 0:
        print(f"  Mean position in recording: {np.mean(temporal_positions):.2f} (0=start, 1=end)")
        print(f"  Std position: {np.std(temporal_positions):.2f}")
        
        plt.figure(figsize=(10, 4))
        plt.hist(temporal_positions, bins=20, edgecolor='black', alpha=0.7)
        plt.xlabel('Normalized Position in Recording')
        plt.ylabel('Frequency')
        plt.title(f'Cluster {cluster_id} - Temporal Distribution within Recordings')
        plt.grid(True, alpha=0.3)
        plt.show()

def compare_clusters(cluster_id1, cluster_id2):
    """Compare two clusters to understand their differences"""
    print(f"\n{'='*60}")
    print(f"COMPARING CLUSTER {cluster_id1} vs CLUSTER {cluster_id2}")
    print(f"{'='*60}")
    
    # Size comparison
    size1 = cluster_sizes[cluster_id1]
    size2 = cluster_sizes[cluster_id2]
    print(f"Cluster {cluster_id1}: {size1} chirps")
    print(f"Cluster {cluster_id2}: {size2} chirps")
    
    # Prevalence comparison
    prev1 = len(cluster_to_recordings[cluster_id1])
    prev2 = len(cluster_to_recordings[cluster_id2])
    print(f"\nCluster {cluster_id1}: appears in {prev1} recordings")
    print(f"Cluster {cluster_id2}: appears in {prev2} recordings")
    
    # Overlap
    recordings1 = cluster_to_recordings[cluster_id1]
    recordings2 = cluster_to_recordings[cluster_id2]
    overlap = len(recordings1 & recordings2)
    print(f"\nRecordings with both clusters: {overlap}")
    print(f"Overlap coefficient: {overlap / min(prev1, prev2):.3f}")
    
    # Embedding distance
    mask1 = cluster_labels == cluster_id1
    mask2 = cluster_labels == cluster_id2
    centroid1 = all_embeddings_concat[mask1].mean(axis=0)
    centroid2 = all_embeddings_concat[mask2].mean(axis=0)
    distance = np.linalg.norm(centroid1 - centroid2)
    print(f"\nEmbedding space distance between centroids: {distance:.3f}")

# Example usage:
print("\n" + "="*70)
print("CLUSTER EXPLORATION UTILITIES")
print("="*70)
print("\nAvailable functions:")
print("1. visualize_cluster_embeddings(cluster_id, method='tsne')")
print("   - Visualize chirps in a cluster using dimensionality reduction")
print()
print("2. get_cluster_statistics(cluster_id)")
print("   - Get detailed statistics about a specific cluster")
print()
print("3. compare_clusters(cluster_id1, cluster_id2)")
print("   - Compare two clusters to understand their differences")
print()
print("Example usage:")
print("  get_cluster_statistics(0)  # Explore cluster 0")
print("  visualize_cluster_embeddings(0)  # Visualize cluster 0")
print("  compare_clusters(0, 1)  # Compare clusters 0 and 1")
print()

# Automatically explore the top 3 clusters
if len(cluster_sizes_sorted) >= 3:
    print("\nAutomatically exploring top 3 largest clusters...")
    for i in range(3):
        cluster_id = cluster_sizes_sorted[i][0]
        get_cluster_statistics(cluster_id)
        print()

