In [1]:
import os
import json
import pickle
import warnings
from itertools import combinations

import numpy as np
import torch
from scipy import sparse
from scipy.sparse import csr_matrix, vstack
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm.auto import tqdm

from sparc.feature_extract.extract_open_images import OpenImagesDataset
from sparc.post_analysis import HDF5AnalysisResultsDataset

# ────────────────────────────────────────────────────────────────────────
# Suppress convergence & line‐search warnings for clean output
warnings.filterwarnings("ignore", category=UserWarning,
                        message="The max_iter was reached")

In [2]:
np.set_printoptions(suppress=True)

dataset = OpenImagesDataset('/home/ubuntu/Projects/OpenImages/', 'test')

Loading caption data from /home/ubuntu/Projects/OpenImages/captions/test/simplified_open_images_test_localized_narratives.json...
Loading label data...
Total number of classes: 601
Loading annotations from /home/ubuntu/Projects/OpenImages/labels/test-annotations-human-imagelabels-boxable.csv...
Loaded labels for 112194 images


In [3]:
labels = []
sample_indices = []
for idx in tqdm(range(len(dataset))):
    image_id, caption_idx = dataset.samples[idx]
    if image_id in dataset.image_to_label_tensor:
        labels_tensor = dataset.image_to_label_tensor[image_id]
        labels.append(csr_matrix(labels_tensor))
        sample_indices.append(idx)
sample_indices = np.array(sample_indices)
label_matrix_sparse = vstack(labels).tocsr()


  0%|          | 0/126020 [00:00<?, ?it/s]

In [4]:
analysis_results_global_cross = HDF5AnalysisResultsDataset('../../final_results/msae_open_global_with_cross/analysis_cache_val.h5', 256)
analysis_results_global_no_cross = HDF5AnalysisResultsDataset('../../final_results/msae_open_global_no_cross/analysis_cache_val.h5', 256)
analysis_results_local_cross = HDF5AnalysisResultsDataset('../../final_results/msae_open_local_with_cross/analysis_cache_val.h5', 256)
analysis_results_local_no_cross = HDF5AnalysisResultsDataset('../../final_results/msae_open_local_no_cross/analysis_cache_val.h5', 256)

# Probe

In [5]:
def get_XY(analysis_results, stream):
    csr_matrix = analysis_results.get_all_features_for_stream(stream, 'latents', return_sparse=True)
    all_original_indices = analysis_results.get_all_original_dataset_indices()
    index_to_position_map = {original_idx: pos for pos, original_idx in enumerate(all_original_indices)}
    Y = label_matrix_sparse
    N = analysis_results.get_all_original_dataset_indices().shape[0]
    positions_for_labeled_samples = np.array([index_to_position_map[idx] for idx in sample_indices])
    stream_X_full = analysis_results.get_all_features_for_stream(stream, 'latents', return_sparse=True)
    X = stream_X_full[positions_for_labeled_samples]
    
    return X, Y

In [6]:
def compute_probe_loss(X, Y):
    # X : csr_matrix of shape (112_699, 8_192) – latent activations
    # Y : csr_matrix of shape (112_699,   601) – binary labels
    N, D = X.shape
    rng_global = np.random.RandomState(0)

    X_binary   = (X > 0).astype(int)                 
    all_counts = (X_binary.T @ Y).toarray()          
    min_count  = 50

    tasks = np.where(Y.sum(0)>min_count)[1]
    probe_losses = []

    for t in tqdm(tasks, desc="Probing tasks", unit="task"):
        rng_task = np.random.RandomState(1000 + t)

        y_all = Y[:, t].toarray().ravel()
        pos_idx = np.where(y_all == 1)[0]
        neg_idx = np.where(y_all == 0)[0]

        neg_sample   = rng_task.choice(neg_idx, size=pos_idx.size, replace=False)
        balanced_idx = np.concatenate([pos_idx, neg_sample])
        balanced_y   = y_all[balanced_idx]

        # stratified 70/15/15 split on balanced_idx …
        sss1 = StratifiedShuffleSplit(n_splits=1, train_size=0.70, test_size=0.30,
                                      random_state=2000 + t)
        train_idx_bal, temp_idx_bal = next(sss1.split(np.zeros(balanced_idx.size), balanced_y))
        sss2 = StratifiedShuffleSplit(n_splits=1, train_size=0.50, test_size=0.50,
                                      random_state=3000 + t)
        val_idx_rel, test_idx_rel = next(
            sss2.split(np.zeros(temp_idx_bal.size), balanced_y[temp_idx_bal])
        )

        train_idx = balanced_idx[train_idx_bal]
        val_idx   = balanced_idx[temp_idx_bal[val_idx_rel]]
        test_idx  = balanced_idx[temp_idx_bal[test_idx_rel]]

        # ensure both classes in train
        y_train = y_all[train_idx]


        # filter & then top-K sort
        train_active = np.asarray(X[train_idx, :].sum(axis=0)).ravel() > 0
        candidate_idx = np.where((all_counts[:, t] >= 0) &
                                 train_active)[0]

        K = 20
        if candidate_idx.size > K:
            freqs = all_counts[candidate_idx, t]
            order = np.argsort(freqs)[::-1]
            candidate_idx = candidate_idx[order[:K]]

        # sweep top-K candidates
        best_ce, best_clf, best_i = np.inf, None, None
        for i in candidate_idx:
            Xi = X[train_idx, i].toarray().reshape(-1, 1)
            clf = LogisticRegression(max_iter=200).fit(Xi, y_train)
            Xv = X[val_idx, i].toarray().reshape(-1, 1)
            ce = log_loss(y_all[val_idx], clf.predict_proba(Xv)[:, 1],
                          labels=[0,1])
            if ce < best_ce:
                best_ce, best_clf, best_i = ce, clf, i


        # test eval
        Xt = X[test_idx, best_i].toarray().reshape(-1, 1)
        preds = best_clf.predict_proba(Xt)[:, 1]
        probe_losses.append(log_loss(y_all[test_idx], preds, labels=[0,1]))
    return probe_losses

In [7]:
probe_results = {}
experiments = {"global_cross":analysis_results_global_cross,
                         "global_no_cross": analysis_results_global_no_cross, 
                         "local_cross":analysis_results_local_cross,
                         "local_no_cross": analysis_results_local_no_cross}

for name, analysis_results in experiments.items():
    print(f"Running for exp: {name}")
    exp_results = {}
    for stream in analysis_results.streams:
        X, Y = get_XY(analysis_results, stream)
        probe_losses = compute_probe_loss(X, Y)
        exp_results[stream] = probe_losses
    probe_results[name] = exp_results

Running for exp: global_cross


Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Running for exp: global_no_cross


Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Running for exp: local_cross


Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Running for exp: local_no_cross


Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

Probing tasks:   0%|          | 0/432 [00:00<?, ?task/s]

In [11]:
# Save Results
# os.makedirs('../../final_results/', exist_ok=True)
# with open('../../final_results/open_images_probe.json', 'w') as f:
#     json.dump(probe_results, f)

In [12]:
with open('../../final_results/open_images_probe.json', 'r') as f:
    probe_results = json.load(f)

In [13]:
for name in probe_results:
    for stream in probe_results[name]:
        probe = np.array(probe_results[name][stream])
        mean_probe = probe[probe>=0].mean()
        print(f"Performance for {name} on stream {stream}: {mean_probe}")

Performance for global_cross on stream clip_img: 0.535445801401678
Performance for global_cross on stream clip_txt: 0.5646201720348464
Performance for global_cross on stream dino: 0.5409099481012206
Performance for global_no_cross on stream clip_img: 0.5336463733083144
Performance for global_no_cross on stream clip_txt: 0.4942256545857633
Performance for global_no_cross on stream dino: 0.5193776675324345
Performance for local_cross on stream clip_img: 0.49900418200192076
Performance for local_cross on stream clip_txt: 0.5363410356381009
Performance for local_cross on stream dino: 0.5170213532142913
Performance for local_no_cross on stream clip_img: 0.5238217005884307
Performance for local_no_cross on stream clip_txt: 0.4903718857503919
Performance for local_no_cross on stream dino: 0.5264730380063344
