In [2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# Config
REAL_DATA_PATH = "/data/yyang409/bowen/imagenet_feature/swin_base/patch4_window7_224/image_features_w_label_train.npz"
GENERATED_PATH = "/scratch/bowenxi/dit/data_gen/0404_all_data/full_dataset.h5" #"../generated_latents.npz"
SELECTED_CLASSES = [7, 12, 24, 35, 47, 68, 73, 89, 91, 99]
OUTPUT_PATH = "tsne_comparison.png"

def load_and_visualize(output_path=OUTPUT_PATH):
    # Load data with NaN checks
    real_data = np.load(REAL_DATA_PATH)
    generated_data = np.load(GENERATED_PATH)
    
    # ========== New NaN Verification Section ==========
    # Check generated data
    gen_features = generated_data["generated"]
    gen_labels = generated_data["labels"]
    
    print("\n=== NaN Check ===")
    print(f"Generated features NaN count: {np.isnan(gen_features).sum()}")
    print(f"Generated labels NaN count: {np.isnan(gen_labels).sum()}")
    
    if np.isnan(gen_features).any() or np.isnan(gen_labels).any():
        raise ValueError("NaN values detected in generated data! Aborting visualization.")
    
    # Check real data
    real_features = real_data["features"]
    real_labels = real_data["labels"]
    print(f"Real features NaN count: {np.isnan(real_features).sum()}")
    print(f"Real labels NaN count: {np.isnan(real_labels).sum()}")
    print("="*40 + "\n")
    # ========== End of NaN Checks ==========

    # Filter real data for selected classes
    real_latents, real_labels = [], []
    for class_idx in SELECTED_CLASSES:
        mask = (real_data["labels"] == class_idx)
        real_latents.append(real_data["features"][mask][:100].reshape(100, -1))
        real_labels.extend([class_idx] * 100)
    
    # Combine data
    combined = np.concatenate([
        np.concatenate(real_latents),
        generated_data["generated"]
    ])
    labels = np.concatenate([
        np.array(real_labels),
        generated_data["labels"]
    ])
    
    # t-SNE
    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    embeddings = tsne.fit_transform(combined)
    
    # Plotting (remainder unchanged)
    # ... [rest of your plotting code] ...
    # Plotting
    plt.figure(figsize=(15, 10))
    colors = plt.cm.tab10(np.linspace(0, 1, len(SELECTED_CLASSES)))
    
    for idx, class_id in enumerate(SELECTED_CLASSES):
        # Real samples (circles)
        mask = (labels == class_id) & (np.arange(len(labels)) < len(real_labels))
        plt.scatter(embeddings[mask, 0], embeddings[mask, 1],
                    color=colors[idx], marker='o', label=f'Class {class_id} (Real)')
        
        # Generated samples (crosses)
        mask = (labels == class_id) & (np.arange(len(labels)) >= len(real_labels))
        plt.scatter(embeddings[mask, 0], embeddings[mask, 1],
                    color=colors[idx], marker='x', label=f'Class {class_id} (Gen)')
    
    plt.title("Real (○) vs Generated (×) Latent Vectors")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    
    # Save and show
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    print(f"Plot saved to {output_path}")
    plt.show()

if __name__ == "__main__":
    load_and_visualize()

ValueError: Cannot load file containing pickled data when allow_pickle=False