In [None]:
import torch
from pathlib import Path

# --- CONFIGURATION ---
input_folder = Path("../Results/EmbeddingData")
output_folder = Path("../Results/EmbeddingDataBatched/resnet34")
output_folder.mkdir(parents=True, exist_ok=True)

extractor = "resnet34"
batch_size = 5000  # number of frames per new .pt file

# --- INITIALIZE ---
current_batch = []
current_meta = []   # store filenames/frame info
batch_counter = 0

# Iterate over all .pt files
for pt_file in sorted(input_folder.glob("*.pt")):
    data = torch.load(pt_file, weights_only=True)
    embeddings = data[extractor]  # shape: [num_frames, embedding_dim]

    for frame_idx, frame_emb in enumerate(embeddings):
        current_batch.append(frame_emb.unsqueeze(0))  # keep as [1, dim]
        current_meta.append((pt_file.name, frame_idx))  # filename + frame index

        # When batch_size reached, save batch
        if len(current_batch) >= batch_size:
            batch_tensor = torch.cat(current_batch, dim=0)  # [batch_size, dim]
            out_file = output_folder / f"batch_{batch_counter:05d}.pt"

            # Save as a dict with embeddings + metadata
            torch.save(
                {"embeddings": batch_tensor, "meta": current_meta},
                out_file
            )
            print(f"Saved {out_file} ({batch_tensor.shape[0]} frames)")
            batch_counter += 1
            current_batch = []
            current_meta = []

# Save any leftover frames
if current_batch:
    batch_tensor = torch.cat(current_batch, dim=0)
    out_file = output_folder / f"batch_{batch_counter:05d}.pt"
    torch.save(
        {"embeddings": batch_tensor, "meta": current_meta},
        out_file
    )
    print(f"Saved {out_file} ({batch_tensor.shape[0]} frames)")


In [None]:
import torch
from pathlib import Path

# --- CONFIGURATION ---
batch_folder = Path("../Results/EmbeddingDataBatched/resnet34")

# Pick a batch file (e.g., the first one)
batch_file = sorted(batch_folder.glob("batch_*.pt"))[0]

# Load the saved dictionary
data = torch.load(batch_file)

embeddings = data["embeddings"]  # tensor [N, dim]
meta = data["meta"]              # list of (filename, frame_idx)

print(f"Loaded batch: {batch_file}")
print(f"Embeddings shape: {embeddings.shape}")
print(f"Number of metadata entries: {len(meta)}")

# Print a few sample entries
for i in range(5):  # first 5 examples
    print(f"Entry {i}:")
    print(f"  Source file: {meta[i][0]}")
    print(f"  Frame index: {meta[i][1]}")
    print(f"  Embedding (first 5 values): {embeddings[i, :5].tolist()}")
