In [41]:
import numpy as np
import os
import csv
from typing import List, Dict, Tuple
import h5py
import warnings
from collections import defaultdict
import matplotlib.pyplot as plt



In [42]:
pbs_example = '/mnt/hd0/Pain_ML_data/videos/pbs/trimmed_2024-07-15_22-04-39_1-2blue'
formalin_example = '/mnt/hd0/Pain_ML_data/videos/formalin/formalin-trimmed_2024-07-15_22-04-39_1-1blue'

ex = os.path.join(pbs_example,'features.h5')

In [43]:
ex

'/mnt/hd0/Pain_ML_data/videos/pbs/trimmed_2024-07-15_22-04-39_1-2blue/features.h5'

In [44]:
# def traverse_hdf5(name, node, group_data):
#     group_name = name.split('/')[0]  # Get the top-level group name
#     if isinstance(node, h5py.Dataset):
#         if group_name not in group_data:
#             group_data[group_name] = {}
#         dataset_name = name.split('/')[-1]  # Get the dataset name
#         group_data[group_name][dataset_name] = {
#             'shape': node.shape,
#             'dtype': str(node.dtype),
#             # 'sample': node[(0,) * len(node.shape)][:5].tolist()  # Convert to list for easier printing
#         }

# def explore_h5_file(file_path):
#     group_data = defaultdict(dict)
#     with h5py.File(file_path, 'r') as f:
#         f.visititems(lambda name, node: traverse_hdf5(name, node, group_data))
    
#     print(f"Exploring file: {file_path}\n")
#     for group_name, datasets in group_data.items():
#         print(f"Group: {group_name}")
#         print(f"Number of datasets: {len(datasets)}")
#         total_data_points = sum(info['shape'][0] for info in datasets.values())
#         print(f"Total data points: {total_data_points}")
#         print("Datasets:")
#         for dataset_name, info in datasets.items():
#             print(f"  - {dataset_name}")
#             print(f"    Shape: {info['shape']}")
#             print(f"    Type: {info['dtype']}")
#             # print(f"    First few values: {info['sample']}")
#         print()
#     return group_data

# explore_h5_file(ex)

In [45]:
class HDF5Manager:
    def __init__(self, file_path):
        self.file_path = file_path
        self.file = None
        self.group_data = defaultdict(dict)
        self._traverse_file()

    def _traverse_file(self):
        with h5py.File(self.file_path, 'r') as f:
            f.visititems(self._collect_datasets)

    def _collect_datasets(self, name, node):
        if isinstance(node, h5py.Dataset):
            group_name = name.split('/')[0]
            dataset_name = name.split('/')[-1]
            self.group_data[group_name][dataset_name] = name  # Store full path

    def _ensure_file_open(self):
        if self.file is None or not self.file.id.valid:
            self.file = h5py.File(self.file_path, 'r')

    def close(self):
        if self.file is not None:
            self.file.close()
            self.file = None

    def get_dataset(self, group_name, dataset_name) -> np.array:
        # Extracts given dataset from a specific group
        self._ensure_file_open()
        full_path = self.group_data[group_name][dataset_name]
        return self.file[full_path]

    def get_dataset_for_all_groups(self, dataset_name) -> Dict[str,np.array]:
        dataset_dict = {}
        for group_name in self.group_data.keys():
            dataset = self.get_data(group_name, dataset_name)
            dataset_dict[group_name] = dataset
        return dataset_dict

    def get_data(self, group_name, dataset_name, slice=None):
        dataset = self.get_dataset(group_name, dataset_name)
        if slice is None:
            return dataset[:]
        return dataset[slice]

    def get_metadata(self, group_name, dataset_name):
        dataset = self.get_dataset(group_name, dataset_name)
        return {
            'shape': dataset.shape,
            'dtype': str(dataset.dtype),
            'chunks': dataset.chunks,
            'compression': dataset.compression
        }

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def __del__(self):
        self.close()


In [46]:
file_path = ex
with HDF5Manager(file_path) as hdf:
    # List all groups
    print("Groups:", list(hdf.group_data.keys()))
    # Choose a group
    group_name = list(hdf.group_data.keys())[0]
    print(f"\nDatasets in {group_name}:", list(hdf.group_data[group_name].keys()))
    # Get data
    data = hdf.get_data(group_name, 'both_front_paws_lifted')

Groups: ['trimmed_2024-07-15_22-04-39_1-2blue']

Datasets in trimmed_2024-07-15_22-04-39_1-2blue: ['ankle_distance', 'background_luminance', 'both_front_paws_lifted', 'cheek_distance', 'chest_head_angle', 'distance_delta', 'fps', 'frame_count', 'front_left_luminance', 'front_paws_distance', 'front_right_luminance', 'hind_left_luminance', 'hind_left_luminance_scaled', 'hind_paws_distance', 'hind_right_luminance', 'hind_right_luminance_scaled', 'hip_chest_angle', 'hip_sternumtail_distance', 'hip_tailbase_distance', 'hip_tailbase_hlpaw_angle', 'hip_tailbase_hrpaw_angle', 'hip_width', 'luminance_logratio', 'neck_snout_distance', 'paw_guarding', 'shoulder_width', 'sternumhead_neck_distance', 'sternumtail_sternumhead_distance', 'tail_hip_angle', 'tailbase_tailtip_distance']


In [47]:
data

array([False, False, False, ..., False, False, False])

In [48]:
np.unique(data, return_counts=True)

(array([False,  True]), array([12902,   421]))

In [49]:
pain_datasets = ['paw_guarding','both_front_paws_lifted', 'luminance_logratio','hip_tailbase_hlpaw_angle', 
                 'hind_paws_distance', 'neck_snout_distance'] 

def extract_datasets_from_h5(filepath: os.PathLike, target_datasets: List[str]) -> Dict[str,np.ndarray]:
    assert filepath.endswith('.h5'), f' Expecting h5 file, but file {os.path.basename(filepath)} passed in'
    with h5py.File(filepath, 'r') as f:
        # Get top level groups
        groups = [key for key in f.keys() if isinstance(f[key],h5py.Group)]

        # Make sure there is only one group
        if len(groups) != 1:
            raise ValueError(f' expecting file to have exactly one group, as is typical in palmreader outputs, but found groups: {groups}')
        group = f[groups[0]]
        
        # Get datasets in the group
        extracted_datasets = {}
        for dataset in target_datasets:
            if dataset not in group:
                raise KeyError(f"Dataset {dataset} not found in the file. Found the following keys: {group.keys()}")
            # Read only the desired dataset into memory and assign to a dict
            extracted_datasets[dataset] = group[dataset][()]
            
        return extracted_datasets
    

In [50]:
extracted_datasets = extract_datasets_from_h5(ex,pain_datasets)
for k,v in extracted_datasets.items():
    print('----',k,'----')
    print('val type:', type(v))
    print('len:', len(v))
    print('first item:', v[0])
    print('first item type',type(v[0]))
    

---- paw_guarding ----
val type: <class 'numpy.ndarray'>
len: 13323
first item: False
first item type <class 'numpy.bool_'>
---- both_front_paws_lifted ----
val type: <class 'numpy.ndarray'>
len: 13323
first item: False
first item type <class 'numpy.bool_'>
---- luminance_logratio ----
val type: <class 'numpy.ndarray'>
len: 13323
first item: -0.34682081129222675
first item type <class 'numpy.float64'>
---- hip_tailbase_hlpaw_angle ----
val type: <class 'numpy.ndarray'>
len: 13500
first item: 36.953868629412376
first item type <class 'numpy.float64'>
---- hind_paws_distance ----
val type: <class 'numpy.ndarray'>
len: 13500
first item: 83.2689600266507
first item type <class 'numpy.float64'>
---- neck_snout_distance ----
val type: <class 'numpy.ndarray'>
len: 13500
first item: 38.7669519779217
first item type <class 'numpy.float64'>


In [55]:
import os
import numpy as np
import h5py
from typing import List, Dict
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm

def process_single_file(args):
    filepath, pain_datasets = args
    return extract_datasets_from_h5(filepath, pain_datasets)

def process_h5_files(file_list: List[str], pain_datasets: List[str]) -> Dict[str, List[np.ndarray]]:
    # Validate input files
    valid_files = [file for file in file_list if os.path.exists(file) and file.endswith('.h5')]
    
    if not valid_files:
        raise ValueError("No valid .h5 files provided.")

    # Initialize dictionary to store data
    all_data = {dataset: [] for dataset in pain_datasets}
    
    # Process files in parallel
    with ProcessPoolExecutor() as executor:
        results = list(tqdm(executor.map(process_single_file, [(f, pain_datasets) for f in valid_files]), total=len(valid_files)))
    
    # Collect results
    for i, result in enumerate(results):
        for dataset in pain_datasets:
            if dataset in result:
                all_data[dataset].append(result[dataset])
            else:
                print(f"Warning: Dataset {dataset} not found in file {valid_files[i]}")
                all_data[dataset].append(np.array([]))  # Append an empty array for missing data
    
    return all_data

def extract_datasets_from_h5(filepath: os.PathLike, target_datasets: List[str]) -> Dict[str, np.ndarray]:
    with h5py.File(filepath, 'r') as f:
        group = list(f.keys())[0]
        return {dataset: f[group][dataset][()] for dataset in target_datasets if dataset in f[group]}

In [68]:
def calculate_bins(arrays: List[np.ndarray[np.float64]], num_bins: int) -> np.ndarray[np.float64]:
    """
    Calculate bin edges based on the combined data from all input arrays.
    
    :param arrays: List of NumPy arrays containing the data
    :param num_bins: Number of bins to create
    :return: NumPy array of bin edges
    """
    # Combine all arrays
    combined = np.concatenate(arrays)
    
    # Calculate bin edges
    _, bin_edges = np.histogram(combined, bins=num_bins)
    
    return bin_edges

def calculate_frequencies(arrays: List[np.ndarray[np.float64]], bin_edges: np.ndarray[np.float64]) -> np.ndarray[np.float64]:
    """
    Calculate frequencies for each array based on the given bin edges.
    
    :param arrays: List of NumPy arrays containing the data
    :param bin_edges: NumPy array of bin edges
    :return: 2D NumPy array of frequencies for each input array
    """
    frequencies = []
    for arr in arrays:
        # Use numpy's digitize to find which bin each value belongs to
        indices = np.digitize(arr, bin_edges[1:])
        
        # Count occurrences in each bin
        bin_counts = np.bincount(indices, minlength=len(bin_edges)-1)
        
        # Calculate frequencies
        freq = bin_counts / len(arr)
        frequencies.append(freq)
    
    return frequencies

In [69]:
# Get all features.h5 files regardless of experimental group
# Another function will handle this step once bins are decided upon

data_dir = '/mnt/hd0/Pain_ML_data'
video_dir = os.path.join(data_dir, 'videos')

features_files = []
for root, dirs, files in os.walk(video_dir):
    for file in files:
        if file == 'features.h5':
            filepath = os.path.join(root,file)
            features_files.append(filepath)
print('number of files found:',len(features_files))

concatenated_features = process_h5_files(features_files, pain_datasets)
for k,v in concatenated_features.items():
    print('key:',k,'has ',len(v),'items')

number of files found: 38


100%|███████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:00<00:00, 1387.15it/s]

key: paw_guarding has  38 items
key: both_front_paws_lifted has  38 items
key: luminance_logratio has  38 items
key: hip_tailbase_hlpaw_angle has  38 items
key: hind_paws_distance has  38 items
key: neck_snout_distance has  38 items





In [70]:
bin_edges = calculate_bins(concatenated_features['luminance_logratio'],15) # do this separately when interested in experimental groups
calculate_frequencies(concatenated_features['luminance_logratio'], bin_edges)

[array([0.        , 0.13050962, 0.01391903, 0.06465614, 0.27867994,
        0.19965577, 0.10110005, 0.19336975, 0.00718402, 0.00523834,
        0.00209534, 0.00074833, 0.00074833, 0.00209534, 0.        ]),
 array([0.        , 0.0009863 , 0.00087671, 0.0019726 , 0.01490411,
        0.40991781, 0.36      , 0.19391781, 0.00975342, 0.00421918,
        0.00060274, 0.00049315, 0.00065753, 0.00169863, 0.        ]),
 array([0.00268373, 0.01336705, 0.00263212, 0.01594756, 0.17258464,
        0.35156895, 0.34336292, 0.08133774, 0.00717382, 0.00433526,
        0.00129026, 0.00056771, 0.0020128 , 0.00113543, 0.        ]),
 array([0.        , 0.00717796, 0.00099387, 0.00204296, 0.02545414,
        0.21760256, 0.05891447, 0.65545801, 0.01280989, 0.00833747,
        0.00364419, 0.00149081, 0.00281597, 0.00325769, 0.        ]),
 array([0.        , 0.0241842 , 0.00353376, 0.00706753, 0.18938766,
        0.30522887, 0.28358456, 0.15973718, 0.00640495, 0.00563194,
        0.00502457, 0.00320247, 0.005024

In [77]:
def process_one_hot_datasets(arrays: List[np.ndarray[bool]]) -> List[float]:
    int_one_hot = [arr.astype(int) for arr in arrays]
    freqs = [np.sum(arr)/len(arr) for arr in int_one_hot]
    return freqs
    

In [79]:
process_one_hot_datasets(concatenated_features['paw_guarding'])

[0.04609743321110529,
 0.0038356164383561643,
 0.016515276630883566,
 0.016233228424714263,
 0.038098393241676326,
 0.05028335414465469,
 0.012383561643835616,
 0.04778599558159639,
 0.03713036389238865,
 0.031505138795504754,
 0.05124387666890789,
 0.014082191780821918,
 0.02957266721717589,
 0.1031314230443719,
 0.028901734104046242,
 0.03963666391412056,
 0.013904001934469834,
 0.024219178082191782,
 0.08715802063088653,
 0.011498544015530502,
 0.005614612966012876,
 0.01448340856108328,
 0.021211741617604075,
 0.0022568269013766643,
 0.008351079377009478,
 0.007588909280293636,
 0.0920706501315295,
 0.01179741508866847,
 0.003006388575723412,
 0.008729725708050196,
 0.022766727840880907,
 0.029496776106129104,
 0.008781805899572169,
 0.002764212660093983,
 0.01563657922039425,
 0.024455313472654512,
 0.006000889020595644,
 0.006297695784248352]