# SUPER IMPORTANT - YOU MUST NORMALIZE HERE

In [None]:
import torch
import numpy as np
from pathlib import Path
from sklearn.decomposition import IncrementalPCA

# --- CONFIG ---
batched_dir = Path("../Results/EmbeddingDataBatched/resnet34")
reduced_dir = Path("../Results/EmbeddingDataReduced")
reduced_dir.mkdir(parents=True, exist_ok=True)

n_components = 50  # change this to however many dims you want
# batch_size = 5000  # known per-batch frame size

# --- COLLECT BATCH FILES ---
batch_files = sorted(batched_dir.glob("batch_*.pt"))
print(f"Found {len(batch_files)} batch files")

# --- INIT INCREMENTAL PCA ---
ipca = IncrementalPCA(n_components=n_components)

# --- PARTIAL FIT ON ALL BATCHES ---
for f in batch_files:
    print(f"Fitting on {f.name} ...")
    batch_tensor = torch.load(f, weights_only=True)
    ipca.partial_fit(batch_tensor.numpy())

# --- SAVE PCA WEIGHTS ---
np.savez(
    reduced_dir / "pca_weights.npz",
    components=ipca.components_,
    mean=ipca.mean_,
    explained_variance=ipca.explained_variance_,
    explained_variance_ratio=ipca.explained_variance_ratio_,
    singular_values=ipca.singular_values_,
    n_components=ipca.n_components,
    n_features=ipca.n_features_in_,
)
print("Saved PCA weights.")

# --- TRANSFORM & SAVE REDUCED BATCHES ---
for f in batch_files:
    print(f"Transforming {f.name} ...")
    batch_tensor = torch.load(f, weights_only=True)
    reduced = ipca.transform(batch_tensor.numpy())
    reduced_tensor = torch.from_numpy(reduced).to(torch.float32)
    
    out_file = reduced_dir / f.name
    torch.save(reduced_tensor, out_file)
    print(f"Saved reduced batch to {out_file}")


In [None]:
import torch
from pathlib import Path

# --- CONFIG ---
reduced_dir = Path("../Results/EmbeddingDataReduced")

# --- CHECK ---
reduced_files = sorted(reduced_dir.glob("batch_*.pt"))
if reduced_files:
    first_file = reduced_files[0]
    print(f"Loading {first_file.name} ...")
    reduced_tensor = torch.load(first_file, weights_only=True)

    print(f"Type: {type(reduced_tensor)}")
    print(f"Shape: {reduced_tensor.shape}")
    print(f"Dtype: {reduced_tensor.dtype}")
    print(f"First 3 rows:\n{reduced_tensor[:3]}")
    print(f"Min/max values: {reduced_tensor.min().item():.6f} / {reduced_tensor.max().item():.6f}")
else:
    print("No reduced batch files found in the directory.")
