In [1]:
import torch
import numpy as np
import pickle
from tqdm import tqdm
from scipy.sparse import csr_matrix
from scipy.sparse import vstack

from sparc.feature_extract.extract_open_images import OpenImagesDataset

from sparc.post_analysis import HDF5AnalysisResultsDataset
from tqdm import tqdm
from itertools import combinations

import json
import os
import pandas as pd

In [2]:
np.set_printoptions(suppress=True)

dataset = OpenImagesDataset('/home/ubuntu/Projects/OpenImages/', 'test')
with open('/home/ubuntu/Projects/OpenImages/bbox_labels_600_hierarchy.json', 'r') as f:
    taxonomy_json = json.load(f)

Loading caption data from /home/ubuntu/Projects/OpenImages/captions/test/simplified_open_images_test_localized_narratives.json...
Loading label data...
Total number of classes: 601
Loading annotations from /home/ubuntu/Projects/OpenImages/labels/test-annotations-human-imagelabels-boxable.csv...
Loaded labels for 112194 images


In [3]:
idx_to_class = {value:key for key,value in dataset.class_to_idx.items()}

In [4]:
labels = []
sample_indices = []
for idx in tqdm(range(len(dataset))):
    image_id, caption_idx = dataset.samples[idx]
    if image_id in dataset.image_to_label_tensor:
        labels_tensor = dataset.image_to_label_tensor[image_id]
        labels.append(csr_matrix(labels_tensor))
        sample_indices.append(idx)
sample_indices = np.array(sample_indices)
label_matrix_sparse = vstack(labels).tocsr()


100%|████████████████████████████████| 126020/126020 [00:04<00:00, 27296.08it/s]


In [5]:
analysis_results_global_cross = HDF5AnalysisResultsDataset('../../final_results/msae_open_global_with_cross/analysis_cache_val.h5', 256)
analysis_results_global_no_cross = HDF5AnalysisResultsDataset('../../final_results/msae_open_global_no_cross/analysis_cache_val.h5', 256)
analysis_results_local_cross = HDF5AnalysisResultsDataset('../../final_results/msae_open_local_with_cross/analysis_cache_val.h5', 256)
analysis_results_local_no_cross = HDF5AnalysisResultsDataset('../../final_results/msae_open_local_no_cross/analysis_cache_val.h5', 256)

In [6]:
class_file = os.path.join(dataset.labels_dir, 'oidv7-class-descriptions-boxable.csv')
class_df = pd.read_csv(class_file)
label_to_class = {row['LabelName']: row['DisplayName'] 
                      for _, row in class_df.iterrows()}
class_to_label = {v: k for k, v in label_to_class.items()}

In [7]:
import json
import numpy as np
from itertools import combinations
from tqdm import tqdm


def build_parent_depth_maps(node):
    parent_map = {}
    depth_map  = {}
    def dfs(n, parent=None, depth=0):
        code = n['LabelName']
        # only override if this occurrence is deeper than any seen so far (it happens for Teddy Bear for exmaple)
        if code not in depth_map or depth > depth_map[code]:
            parent_map[code] = parent
            depth_map[code]  = depth
        for child in n.get('Subcategory', []):
            dfs(child, code, depth+1)
    dfs(node, None, 0)
    return parent_map, depth_map


_parent_map, _depth_map = build_parent_depth_maps(taxonomy_json)
_max_depth = max(_depth_map.values())

# ——————————————————————————————————————————————
# 2) Helper to build counts and activation‐counts
# ——————————————————————————————————————————————
def _build_counts_matrix(analysis_results, topk=None):
    """
    Returns:
      counts_arr:        np.array shape (L, S, C)
      activation_counts: np.array shape (L, S)
      streams:           list of stream names

    If topk is None, we include *all* nonzero activations for each latent.
    If topk is an integer, we take only the top `topk` activated samples.
    """
    all_indices    = analysis_results.get_all_original_dataset_indices()
    idx_to_row     = {orig: r for r, orig in enumerate(all_indices)}
    sample_rows    = np.array([idx_to_row[i] for i in sample_indices])

    total_rows, num_classes = label_matrix_sparse.shape
    if total_rows == len(all_indices):
        labels_sub = label_matrix_sparse.tocsr()[sample_rows]
    elif total_rows == len(sample_indices):
        labels_sub = label_matrix_sparse.tocsr()
    else:
        raise ValueError(
            f"label_matrix_sparse has {total_rows} rows; expected "
            f"{len(all_indices)} or {len(sample_indices)}"
        )

    streams = analysis_results.streams
    latent_mats = {
        s: analysis_results
             .get_all_features_for_stream(s, 'latents', return_sparse=True)
             [sample_rows]
             .tocsc()
        for s in streams
    }

    L = latent_mats[streams[0]].shape[1]
    S = len(streams)
    C = num_classes

    counts_arr        = np.zeros((L, S, C), dtype=int)
    activation_counts = np.zeros((L, S),   dtype=int)

    for s_idx, s in enumerate(streams):
        mat = latent_mats[s]
        for l in range(L):
            col = mat.getcol(l)
            nnz = col.nnz
            activation_counts[l, s_idx] = nnz
            if nnz == 0:
                continue

            # choose which rows to include
            if topk is None or topk >= nnz:
                # use all nonzero activations
                rows = col.indices
            else:
                # pick only the top `topk` activations
                k    = int(topk)
                # argpartition gives us the indices of the k largest values
                topk_idx = np.argpartition(-col.data, k-1)[:k]
                rows     = col.indices[topk_idx]

            # sum up class counts over the selected rows
            counts_arr[l, s_idx, :] = labels_sub[rows].sum(axis=0).A1

    return counts_arr, activation_counts, streams

# list of all "/m/..." codes in the same order as the label matrix columns
_label_codes = list(dataset.label_to_class.keys())
root_code = taxonomy_json["LabelName"]   

# ——————————————————————————————————————————————
# 3) Ancestor‐collapsed Jaccard at chosen depth
# ——————————————————————————————————————————————
def compute_jaccard_from_counts(counts_arr, activation_counts, collapse_depth):
    # 1) validate
    if not isinstance(collapse_depth, int):
        raise TypeError(f"collapse_depth must be int, got {type(collapse_depth)}")
    if collapse_depth < 0 or collapse_depth > _max_depth:
        raise ValueError(f"collapse_depth must be in [0, {_max_depth}]")

    L, S, C = counts_arr.shape

    # 2) build bucket_codes
    if collapse_depth == 0:
        # everything merges into the single root bucket
        bucket_codes = [root_code] * C
    else:
        def lift(code):
            d = _depth_map[code]
            while d > collapse_depth:
                code = _parent_map[code]
                d    = _depth_map[code]
            return code
        bucket_codes = [lift(c) for c in _label_codes]

    unique_buckets = sorted(set(bucket_codes))
    b2i = {bc: i for i, bc in enumerate(unique_buckets)}
    B = len(unique_buckets)

    # 3) collapse counts
    collapsed = np.zeros((L, S, B), dtype=int)
    for ci, bc in enumerate(bucket_codes):
        bi = b2i[bc]
        collapsed[:, :, bi] += counts_arr[:, :, ci]

    # 4) compute ancestor‐collapsed Jaccard
    j_scores = np.zeros(L, dtype=float)
    for l in range(L):
        active = [s for s in range(S) if activation_counts[l, s] > 0]
        if len(active) < 2:
            continue
        total = 0.0
        for i, j in combinations(active, 2):
            A = collapsed[l, i]
            Bv = collapsed[l, j]
            num = np.minimum(A, Bv).sum()
            den = np.maximum(A, Bv).sum()
            total += (num / den) if den > 0 else 0.0
        j_scores[l] = total / (len(active)*(len(active)-1)/2)
    return j_scores


In [8]:
from tqdm import tqdm
import numpy as np
from itertools import combinations

# Define your four analysis_results cases
cases = {
    'global_cross':    analysis_results_global_cross,
    'global_no_cross': analysis_results_global_no_cross,
    'local_cross':     analysis_results_local_cross,
    'local_no_cross':  analysis_results_local_no_cross,
}

# Initialize results structure: for each metric, a dict mapping depth→case→scores
results = {
    'jaccard': {depth: {} for depth in range(_max_depth + 1)},
    'dice':    {depth: {} for depth in range(_max_depth + 1)},
}

# Outer loop over the four cases; _build_counts_matrix runs once per case
for case_name, ar in tqdm(cases.items(), desc="Case", total=len(cases)):
    counts_arr, activation_counts, streams = _build_counts_matrix(ar)

    # Inner loop over all depths; compute both metrics at each depth
    for depth in tqdm(range(_max_depth + 1),
                      desc=f"Depth [{case_name}]",
                      leave=False):
        # Compute Jaccard from the precomputed counts
        j_scores = compute_jaccard_from_counts( counts_arr,
                                                activation_counts,
                                                depth)
        results['jaccard'][depth][case_name] = j_scores

Case:   0%|                                               | 0/4 [00:00<?, ?it/s]
Depth [global_cross]:   0%|                               | 0/6 [00:00<?, ?it/s][A
Depth [global_cross]:  33%|███████▋               | 2/6 [00:00<00:00, 10.34it/s][A
Depth [global_cross]:  67%|███████████████▎       | 4/6 [00:00<00:00,  7.13it/s][A
Depth [global_cross]:  83%|███████████████████▏   | 5/6 [00:00<00:00,  6.26it/s][A
Depth [global_cross]: 100%|███████████████████████| 6/6 [00:00<00:00,  5.75it/s][A
Case:  25%|█████████▊                             | 1/4 [00:08<00:25,  8.47s/it][A
Depth [global_no_cross]:   0%|                            | 0/6 [00:00<?, ?it/s][A
Depth [global_no_cross]:  33%|██████▋             | 2/6 [00:00<00:00,  9.72it/s][A
Depth [global_no_cross]:  50%|██████████          | 3/6 [00:00<00:00,  7.36it/s][A
Depth [global_no_cross]:  67%|█████████████▎      | 4/6 [00:00<00:00,  6.06it/s][A
Depth [global_no_cross]:  83%|████████████████▋   | 5/6 [00:00<00:00,  5.44it/s

In [9]:
metric = 'jaccard'
print(f"\n=== {metric.upper()} SUMMARY ===")
for depth, case_dict in results[metric].items():
    stats = []
    for case_name in cases:
        scores = case_dict.get(case_name)
        if scores is not None:
            stats.append(f"{case_name}: μ={scores.mean():.4f}")
    print(f" depth={depth:2d} | " + " | ".join(stats))


=== JACCARD SUMMARY ===
 depth= 0 | global_cross: μ=0.8118 | global_no_cross: μ=0.8054 | local_cross: μ=0.3238 | local_no_cross: μ=0.3458
 depth= 1 | global_cross: μ=0.8056 | global_no_cross: μ=0.7633 | local_cross: μ=0.2866 | local_no_cross: μ=0.2031
 depth= 2 | global_cross: μ=0.8032 | global_no_cross: μ=0.7451 | local_cross: μ=0.2706 | local_no_cross: μ=0.1783
 depth= 3 | global_cross: μ=0.8020 | global_no_cross: μ=0.7359 | local_cross: μ=0.2615 | local_no_cross: μ=0.1666
 depth= 4 | global_cross: μ=0.8018 | global_no_cross: μ=0.7344 | local_cross: μ=0.2600 | local_no_cross: μ=0.1652
 depth= 5 | global_cross: μ=0.8018 | global_no_cross: μ=0.7344 | local_cross: μ=0.2599 | local_no_cross: μ=0.1651


# Dead latents

In [10]:
# Cell 1: Precompute all the expensive matrix operations
from tqdm import tqdm

print("Precomputing matrices...")
precomputed_data = {}
for case_name, ar in tqdm(cases.items(), desc="Precomputing matrices", total=len(cases)):
    counts_all, activation_counts_all, streams = _build_counts_matrix(ar, topk=None)
    
    precomputed_data[case_name] = {
        'counts_all': counts_all,
        'activation_counts_all': activation_counts_all,
        'streams': streams,
    }

Precomputing matrices...


Precomputing matrices: 100%|██████████████████████| 4/4 [00:29<00:00,  7.32s/it]


In [11]:
# Check patterns across all cases
case_patterns = {}
for case_name in precomputed_data.keys():
    counts_all = precomputed_data[case_name]['counts_all']
    
    all_alive_list = []
    all_dead_list = []
    mixed_list = []
    
    for i in range(len(counts_all)):
        stream_alive = (counts_all[i].sum(axis=1) > 0)  # which streams are alive
        num_alive = stream_alive.sum()
        
        if num_alive == 3:
            all_alive_list.append(True)
            all_dead_list.append(False)
            mixed_list.append(False)
        elif num_alive == 0:
            all_alive_list.append(False)
            all_dead_list.append(True)
            mixed_list.append(False)
        else:  # 1 or 2 streams alive
            all_alive_list.append(False)
            all_dead_list.append(False)
            mixed_list.append(True)
    
    case_patterns[case_name] = {
        'num_all_alive': sum(all_alive_list),
        'num_all_dead': sum(all_dead_list),
        'num_mixed': sum(mixed_list),
        'total_latents': len(all_alive_list)
    }

# Summary
for case_name, stats in case_patterns.items():
    print(f"{case_name}: {stats['num_all_alive']} all-alive, {stats['num_all_dead']} all-dead, {stats['num_mixed']} mixed, {stats['total_latents']} total")

global_cross: 6912 all-alive, 1280 all-dead, 0 mixed, 8192 total
global_no_cross: 8044 all-alive, 131 all-dead, 17 mixed, 8192 total
local_cross: 3575 all-alive, 0 all-dead, 4617 mixed, 8192 total
local_no_cross: 7019 all-alive, 0 all-dead, 1173 mixed, 8192 total


In [12]:
# Comprehensive Alignment and Dead Neuron Analysis
print("=" * 70)
print("COMPREHENSIVE NEURON ACTIVATION PATTERN ANALYSIS")
print("Goal: Complete breakdown of neuron activation patterns across all streams")
print("=" * 70)

for case_name, data in precomputed_data.items():
    counts_all = data['counts_all']
    stream_names = ['CLIP-img', 'CLIP-txt', 'DINO']
    
    # Count all patterns
    exactly_0_alive = 0
    exactly_1_alive = 0
    exactly_2_alive = 0
    exactly_3_alive = 0
    
    # Stream-specific dead counts
    stream_dead_counts = []
    for stream_idx in range(3):
        dead_in_stream = (counts_all[:, stream_idx, :].sum(axis=1) == 0).sum()
        stream_dead_counts.append(dead_in_stream)
    
    # 1-out-of-3 patterns
    pattern_ci_only = 0   # Only CLIP-img alive
    pattern_ct_only = 0   # Only CLIP-txt alive  
    pattern_d_only = 0    # Only DINO alive
    
    # 2-out-of-3 patterns
    pattern_ci_ct = 0     # CLIP-img + CLIP-txt alive, DINO dead
    pattern_ci_d = 0      # CLIP-img + DINO alive, CLIP-txt dead  
    pattern_ct_d = 0      # CLIP-txt + DINO alive, CLIP-img dead

    for i in range(len(counts_all)):
        stream_alive = (counts_all[i].sum(axis=1) > 0)
        num_alive = stream_alive.sum()
        
        if num_alive == 0:
            exactly_0_alive += 1
        elif num_alive == 1:
            exactly_1_alive += 1
            if stream_alive[0] and not stream_alive[1] and not stream_alive[2]:
                pattern_ci_only += 1
            elif not stream_alive[0] and stream_alive[1] and not stream_alive[2]:
                pattern_ct_only += 1
            elif not stream_alive[0] and not stream_alive[1] and stream_alive[2]:
                pattern_d_only += 1
        elif num_alive == 2:
            exactly_2_alive += 1
            if stream_alive[0] and stream_alive[1] and not stream_alive[2]:
                pattern_ci_ct += 1
            elif stream_alive[0] and not stream_alive[1] and stream_alive[2]:
                pattern_ci_d += 1
            elif not stream_alive[0] and stream_alive[1] and stream_alive[2]:
                pattern_ct_d += 1
        elif num_alive == 3:
            exactly_3_alive += 1
    
    total_neurons = len(counts_all)
    
    print(f"\n{case_name.upper()} ({total_neurons} total neurons)")
    print("─" * 50)
    
    # Overall Pattern Summary
    print("OVERALL ACTIVATION PATTERNS:")
    print(f"  All dead (0/3):          {exactly_0_alive:4d} ({exactly_0_alive/total_neurons*100:5.1f}%)")
    print(f"  All alive (3/3):         {exactly_3_alive:4d} ({exactly_3_alive/total_neurons*100:5.1f}%)")
    print(f"  Mixed (1/3):             {exactly_1_alive:4d} ({exactly_1_alive/total_neurons*100:5.1f}%)")
    print(f"  Mixed (2/3):             {exactly_2_alive:4d} ({exactly_2_alive/total_neurons*100:5.1f}%)")
    
    # Stream-Specific Dead Counts
    print(f"\nSTREAM-SPECIFIC DEAD COUNTS:")
    for i, (name, count) in enumerate(zip(stream_names, stream_dead_counts)):
        pct = (count / total_neurons) * 100
        print(f"  {name:12}: {count:4d} dead ({pct:5.1f}%)")
    
    # Detailed Pattern Breakdowns
    if exactly_1_alive > 0:
        print(f"\n1/3 PATTERN BREAKDOWN:")
        print(f"  CLIP-img only:           {pattern_ci_only:3d} ({pattern_ci_only/exactly_1_alive*100:5.1f}%)")
        print(f"  CLIP-txt only:           {pattern_ct_only:3d} ({pattern_ct_only/exactly_1_alive*100:5.1f}%)")
        print(f"  DINO only:               {pattern_d_only:3d} ({pattern_d_only/exactly_1_alive*100:5.1f}%)")
    
    if exactly_2_alive > 0:
        print(f"\n2/3 PATTERN BREAKDOWN:")
        print(f"  CLIP-img + CLIP-txt:     {pattern_ci_ct:3d} ({pattern_ci_ct/exactly_2_alive*100:5.1f}%)")
        print(f"  CLIP-img + DINO:         {pattern_ci_d:3d} ({pattern_ci_d/exactly_2_alive*100:5.1f}%)")
        print(f"  CLIP-txt + DINO:         {pattern_ct_d:3d} ({pattern_ct_d/exactly_2_alive*100:5.1f}%)")

print("\n" + "=" * 70)

COMPREHENSIVE NEURON ACTIVATION PATTERN ANALYSIS
Goal: Complete breakdown of neuron activation patterns across all streams

GLOBAL_CROSS (8192 total neurons)
──────────────────────────────────────────────────
OVERALL ACTIVATION PATTERNS:
  All dead (0/3):          1280 ( 15.6%)
  All alive (3/3):         6912 ( 84.4%)
  Mixed (1/3):                0 (  0.0%)
  Mixed (2/3):                0 (  0.0%)

STREAM-SPECIFIC DEAD COUNTS:
  CLIP-img    : 1280 dead ( 15.6%)
  CLIP-txt    : 1280 dead ( 15.6%)
  DINO        : 1280 dead ( 15.6%)

GLOBAL_NO_CROSS (8192 total neurons)
──────────────────────────────────────────────────
OVERALL ACTIVATION PATTERNS:
  All dead (0/3):           131 (  1.6%)
  All alive (3/3):         8044 ( 98.2%)
  Mixed (1/3):                2 (  0.0%)
  Mixed (2/3):               15 (  0.2%)

STREAM-SPECIFIC DEAD COUNTS:
  CLIP-img    :  135 dead (  1.6%)
  CLIP-txt    :  144 dead (  1.8%)
  DINO        :  133 dead (  1.6%)

1/3 PATTERN BREAKDOWN:
  CLIP-img only:      

In [13]:
# Partial Alignment Patterns  
print("ALIGNMENT PATTERNS ANALYSIS")
print("Goal: Understand which specific stream combinations are active (1/3, 2/3, 3/3)")
print("=" * 60)

for case_name, data in precomputed_data.items():
    counts_all = data['counts_all']
    
    # Count all patterns
    exactly_1_alive = 0
    exactly_2_alive = 0
    exactly_3_alive = 0
    exactly_0_alive = 0
    
    # 1-out-of-3 patterns
    pattern_ci_only = 0   # Only CLIP-img alive
    pattern_ct_only = 0   # Only CLIP-txt alive  
    pattern_d_only = 0    # Only DINO alive
    
    # 2-out-of-3 patterns
    pattern_ci_ct = 0     # CLIP-img + CLIP-txt alive, DINO dead
    pattern_ci_d = 0      # CLIP-img + DINO alive, CLIP-txt dead  
    pattern_ct_d = 0      # CLIP-txt + DINO alive, CLIP-img dead

    for i in range(len(counts_all)):
        stream_alive = (counts_all[i].sum(axis=1) > 0)
        num_alive = stream_alive.sum()
        
        if num_alive == 0:
            exactly_0_alive += 1
        elif num_alive == 1:
            exactly_1_alive += 1
            if stream_alive[0] and not stream_alive[1] and not stream_alive[2]:
                pattern_ci_only += 1
            elif not stream_alive[0] and stream_alive[1] and not stream_alive[2]:
                pattern_ct_only += 1
            elif not stream_alive[0] and not stream_alive[1] and stream_alive[2]:
                pattern_d_only += 1
        elif num_alive == 2:
            exactly_2_alive += 1
            if stream_alive[0] and stream_alive[1] and not stream_alive[2]:
                pattern_ci_ct += 1
            elif stream_alive[0] and not stream_alive[1] and stream_alive[2]:
                pattern_ci_d += 1
            elif not stream_alive[0] and stream_alive[1] and stream_alive[2]:
                pattern_ct_d += 1
        elif num_alive == 3:
            exactly_3_alive += 1
    
    total_neurons = len(counts_all)
    print(f"\n{case_name} ({total_neurons} total neurons):")
    print(f"  All dead (0/3):          {exactly_0_alive:4d} ({exactly_0_alive/total_neurons*100:5.1f}%)")
    print(f"  All alive (3/3):         {exactly_3_alive:4d} ({exactly_3_alive/total_neurons*100:5.1f}%)")
    print(f"  Mixed (1/3):             {exactly_1_alive:4d} ({exactly_1_alive/total_neurons*100:5.1f}%)")
    print(f"  Mixed (2/3):             {exactly_2_alive:4d} ({exactly_2_alive/total_neurons*100:5.1f}%)")
    
    print(f"\n  1/3 Breakdown:")
    if exactly_1_alive > 0:
        print(f"    CLIP-img only:         {pattern_ci_only:3d} ({pattern_ci_only/exactly_1_alive*100:5.1f}%)")
        print(f"    CLIP-txt only:         {pattern_ct_only:3d} ({pattern_ct_only/exactly_1_alive*100:5.1f}%)")
        print(f"    DINO only:             {pattern_d_only:3d} ({pattern_d_only/exactly_1_alive*100:5.1f}%)")
    
    print(f"\n  2/3 Breakdown:")
    if exactly_2_alive > 0:
        print(f"    CLIP-img + CLIP-txt:   {pattern_ci_ct:3d} ({pattern_ci_ct/exactly_2_alive*100:5.1f}%)")
        print(f"    CLIP-img + DINO:       {pattern_ci_d:3d} ({pattern_ci_d/exactly_2_alive*100:5.1f}%)")
        print(f"    CLIP-txt + DINO:       {pattern_ct_d:3d} ({pattern_ct_d/exactly_2_alive*100:5.1f}%)")

print("\n" + "=" * 60)

ALIGNMENT PATTERNS ANALYSIS
Goal: Understand which specific stream combinations are active (1/3, 2/3, 3/3)

global_cross (8192 total neurons):
  All dead (0/3):          1280 ( 15.6%)
  All alive (3/3):         6912 ( 84.4%)
  Mixed (1/3):                0 (  0.0%)
  Mixed (2/3):                0 (  0.0%)

  1/3 Breakdown:

  2/3 Breakdown:

global_no_cross (8192 total neurons):
  All dead (0/3):           131 (  1.6%)
  All alive (3/3):         8044 ( 98.2%)
  Mixed (1/3):                2 (  0.0%)
  Mixed (2/3):               15 (  0.2%)

  1/3 Breakdown:
    CLIP-img only:           0 (  0.0%)
    CLIP-txt only:           0 (  0.0%)
    DINO only:               2 (100.0%)

  2/3 Breakdown:
    CLIP-img + CLIP-txt:     2 ( 13.3%)
    CLIP-img + DINO:        11 ( 73.3%)
    CLIP-txt + DINO:         2 ( 13.3%)

local_cross (8192 total neurons):
  All dead (0/3):             0 (  0.0%)
  All alive (3/3):         3575 ( 43.6%)
  Mixed (1/3):              622 (  7.6%)
  Mixed (2/3):      