In [None]:
import os
import numpy as np

def compute_mean_direction_vector(data_dir):
    """Compute the mean direction vector from all NPZ files in the directory"""
    npz_files = [f for f in os.listdir(data_dir) if f.endswith('_direction_vectors.npz')]
    print(f"Computing mean from {len(npz_files)} files")
    
    # Create a list to store all direction vectors
    all_vectors = []
    
    for npz_file in npz_files:
        npz_path = os.path.join(data_dir, npz_file)
        data = np.load(npz_path)
        dir_vectors = data['direction_vectors']  # Shape: (frames, connections, 3)
        
        # Flatten the first dimension (frames)
        flat_vectors = dir_vectors.reshape(-1, dir_vectors.shape[1], dir_vectors.shape[2])
        all_vectors.append(flat_vectors)
    
    # Concatenate all vectors
    all_vectors = np.concatenate(all_vectors, axis=0)
    
    # Compute mean
    mean_dir_vec = np.mean(all_vectors, axis=0)
    
    # Reshape to match the format expected by the model
    mean_dir_vec_flat = mean_dir_vec.reshape(-1)
    
    print(f"Mean direction vector shape: {mean_dir_vec.shape}")
    print(f"Flattened mean direction vector shape: {mean_dir_vec_flat.shape}")
    
    return mean_dir_vec_flat