In [None]:
import os
import glob
import torch

# Directory containing saved batches
save_dir = "../Results/EmbeddingData"

# Find the first .pt file (sorted for determinism)
pt_files = sorted(glob.glob(os.path.join(save_dir, "*.pt")))
if not pt_files:
    raise FileNotFoundError(f"No .pt files found in {save_dir}")

path = pt_files[0]  # first .pt file
print(f"Loading: {path}")

# Load the dictionary
batch = torch.load(path, weights_only=True)

print("Top-level keys and shapes:")
for key, value in batch.items():
    if torch.is_tensor(value):
        print(f"{key}: {tuple(value.shape)}")
    elif isinstance(value, dict):
        print(f"{key}: dict with keys {list(value.keys())}")
    elif isinstance(value, list):
        print(f"{key}: list of length {len(value)}")
    else:
        print(f"{key}: {type(value)}")


In [None]:
import os

def checkFileExists(save_path: str) -> bool:
    """
    Check if a .pt file exists at the given path and is non-empty.
    
    Args:
        save_path (str): Path to the file.
    
    Returns:
        bool: True if the file exists and has nonzero size, False otherwise.
    """
    return os.path.isfile(save_path) and os.path.getsize(save_path) > 0

checkFileExists(path)
