In [None]:
import torch
from pathlib import Path

# --- CONFIGURATION ---
input_folder = Path("../Results/EmbeddingData")
output_folder = Path("../Results/EmbeddingDataBatched/resnet34")
output_folder.mkdir(parents=True, exist_ok=True)

extractor = "resnet34"
batch_size = 5000  # number of frames per new .pt file

# --- INITIALIZE ---
current_batch = []
batch_counter = 0

# Iterate over all .pt files
for pt_file in sorted(input_folder.glob("*.pt")):
    data = torch.load(pt_file, weights_only=True)
    embeddings = data[extractor]  # shape: [num_frames, embedding_dim]

    for frame_emb in embeddings:
        current_batch.append(frame_emb.unsqueeze(0))  # keep as 2D [1, dim]
        
        # When batch_size reached, save batch
        if len(current_batch) >= batch_size:
            batch_tensor = torch.cat(current_batch, dim=0)  # [batch_size, dim]
            out_file = output_folder / f"batch_{batch_counter:05d}.pt"
            torch.save(batch_tensor, out_file)
            print(f"Saved {out_file} ({batch_tensor.shape[0]} frames)")
            batch_counter += 1
            current_batch = []

# Save any leftover frames
if current_batch:
    batch_tensor = torch.cat(current_batch, dim=0)
    out_file = output_folder / f"batch_{batch_counter:05d}.pt"
    torch.save(batch_tensor, out_file)
    print(f"Saved {out_file} ({batch_tensor.shape[0]} frames)")


In [None]:
# --- CONFIG ---
file_path = Path("../Results/EmbeddingDataBatched/resnet34/batch_00000.pt")

# --- LOAD ---
batch_tensor = torch.load(file_path, weights_only=True)

# --- VERIFY ---
print(f"Type: {type(batch_tensor)}")
print(f"Shape: {batch_tensor.shape}")
print(f"Dtype: {batch_tensor.dtype}")
print(f"First 3 embeddings:\n{batch_tensor[:3]}")  # print first 3 frames
print(f"Min/max values: {batch_tensor.min().item():.6f} / {batch_tensor.max().item():.6f}")