### Chronological split generation.

The following is code used to generate the chronological splits based on the presence of positive and negative samples. This is more of an issue for the speech/sentence tasks, but the same approach is also used for the volume and optical flow tasks.

In [1]:
%load_ext autoreload
%autoreload 2


In [None]:
from barista.data.metadata import Metadata
from collections import Counter, defaultdict
import numpy as np
import os
from pathlib import Path

In [3]:
def load_metadata(metadata_path):
    return Metadata(load_path=metadata_path)

In [4]:
def generate_folds(subject_rows_indices, per_label_subject_rows_indices,
                   bucket_size=0.05, step_size=1, base_step_size=1,
                   window=4, base_window=1, **folds_kwargss):
    assert window % 4 == 0, "Window should be divisible by 4"

    bucket_len = int(bucket_size * len(subject_rows_indices)) # bucket size in samples
    buckets = np.arange(subject_rows_indices[0], subject_rows_indices[-1], bucket_len)
    print(f"Buckets: {buckets}")

    ## Magic number 2 everywhere corresponds to the 0/1 (negative/positive) labels.
    ## First, sum the unique label counts per bucket according to the specifications provided.
    bucket_counts = [{} for i in range(len(buckets)-1)]
    for bucket_ind in range(0, len(bucket_counts), base_step_size):
        bucket_start = buckets[bucket_ind]
        bucket_end = bucket_start + base_window * bucket_len
        for i in range(2):
            bucket_counts[bucket_ind][i] = np.sum(np.logical_and(
                per_label_subject_rows_indices[i] >= bucket_start,
                per_label_subject_rows_indices[i] < bucket_end
            ))

    ## Count the residual samples in the last bucket.
    for i in range(2):
        bucket_counts[-1][i] += np.sum(
            per_label_subject_rows_indices[i] >= bucket_end
        )
    print(f"bucket_counts: {bucket_counts}")

    return _find_folds(bucket_counts, step_size, window, bucket_size, **folds_kwargss)


def _find_folds(bucket_counts, step_size, window, bucket_size, num_folds=5):
    """Logic to find all legitimate folds such that train and test are separated with valid, e.g.,
    
    [train, valid, test]
    [test, valid, train]
    [train, valid (0.05), test, valid(0.05), train]
    """
    all_folds, all_folds_splits = [], []
    head, tail = 0, len(bucket_counts) - window
    use_tail, quad_window = 0, int(window / 4)
    while len(all_folds) < num_folds:
        curr_ind = tail if use_tail else head
        found = False
        while not found and curr_ind >= 0 and curr_ind <= len(bucket_counts) - window:
            ## Check that any of the validation buckets has both sets of labels.
            val_found = False
            for check_i in range(quad_window):
                val_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0
            for check_i in range(window - quad_window, window):
                val_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0

            ## Check that any of the test buckets for test data has both labels.
            test_found = False
            for check_i in range(quad_window, 3*quad_window):
                test_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0

            found = val_found & test_found
            if found:
                found_ind = curr_ind
            curr_ind += -step_size if use_tail else step_size

        val_test_interval = np.array([found_ind, found_ind + window]) * bucket_size
    
        this_fold = [bucket_size, (window-2)*bucket_size, bucket_size]
        this_fold_splits = ["val", "test", "val"]
        if 1.0 - val_test_interval[-1] > 0:
            this_fold.append(1.0 - val_test_interval[-1])
            this_fold_splits.append('train')
        if val_test_interval[0] > 0:
            this_fold = [val_test_interval[0]] + this_fold
            this_fold_splits = ['train'] + this_fold_splits

        assert np.sum(this_fold) == 1.0
        all_folds.append(this_fold)
        all_folds_splits.append(this_fold_splits)

        if use_tail:
            tail = curr_ind - 1 * step_size
        else:
            head = curr_ind + 1 * step_size
        use_tail = 1 - use_tail

    return all_folds, all_folds_splits


In [None]:
## Specify all subjects to compute the chronological folds for.
## By default we have the held out sessions (val/test) listed here.
ALL_SUBJECTS = [
    "HOLDSUBJ_1_HS1_1",
    "HOLDSUBJ_2_HS2_6",
    "HOLDSUBJ_3_HS3_0",
    "HOLDSUBJ_4_HS4_0",
    "HOLDSUBJ_6_HS6_4",
    "HOLDSUBJ_7_HS7_0",
    "HOLDSUBJ_10_HS10_0",

    # "SUBJ_2_S2_5",
    # "SUBJ_4_S4_2",
]

## List all the metadata files that correspond to the segments to preprocess. Can optionally use
## keyword identifiers for each of the metadata files that need to be processed.
_METADATA_FNAMES = {
    'default_metadata': 'metadata_ee8e0.csv',
}

## List all experiments for which the folds should be computed.
# _ALL_EXPERIMENTS = ["sentence_onset_time", "speech_vs_nonspeech_time", "volume", "optical_flow"]
_ALL_EXPERIMENTS = ["sentence_onset_time", "speech_vs_nonspeech_time"]

_SEGMENT_DIR = 'braintreebank_data_segments/{0}'

## These are the recommended default settings for computing the folds.
bucket_size = 0.05 # Each bucket is 5% duration in samples
base_step_size = 1 # We take increments of base_step_size * 5% in samples when constructing buckets.
base_window = 1 # Count number of samples per base_window * 5% interval per bucket. Should match base_step_size ideally.
step_size = 2 # We take increments of step_size * bucket_size (5%) when looking for buckets.
window = 4 # Targeting 20% of data for val and test (i.e., 4 buckets combined for val and test).
num_folds = 5 # Number of folds to generate.

subject_folds = {}
for metadata_setting in _METADATA_FNAMES.keys():
    metadata_setting_folds = defaultdict(dict)

    for subject_session in ALL_SUBJECTS:
        for experiment in _ALL_EXPERIMENTS:

            fpath = _SEGMENT_DIR.format(experiment)
            metadata_fname = _METADATA_FNAMES[metadata_setting]
            metadata = load_metadata(os.path.join(fpath, metadata_fname))

            subject_rows_indices = metadata.get_indices_matching_cols_values(
                ["subject_session", "experiment"], [subject_session, experiment]
            )

            per_label_subject_rows_indices = [0, 0]
            for i in range(2): # 2 = negative/positive labels.
                per_label_subject_rows_indices[i] = (
                    metadata.get_indices_matching_cols_values(
                        ["subject_session", "experiment", "label"],
                        [subject_session, experiment, i],
                    )
                )

            all_folds, all_folds_splits = generate_folds(
                subject_rows_indices,
                per_label_subject_rows_indices,
                bucket_size,
                step_size,
                base_step_size,
                window,
                base_window,
                num_folds=num_folds
            )

            metadata_setting_folds[subject_session][experiment] = (all_folds, all_folds_splits)

    subject_folds[metadata_setting] = metadata_setting_folds

Buckets: [   0  154  308  462  616  770  924 1078 1232 1386 1540 1694 1848 2002
 2156 2310 2464 2618 2772 2926 3080]
bucket_counts: [{0: 68, 1: 86}, {0: 92, 1: 62}, {0: 123, 1: 31}, {0: 42, 1: 112}, {0: 25, 1: 129}, {0: 76, 1: 78}, {0: 65, 1: 89}, {0: 81, 1: 73}, {0: 65, 1: 89}, {0: 33, 1: 121}, {0: 23, 1: 131}, {0: 65, 1: 89}, {0: 75, 1: 79}, {0: 106, 1: 48}, {0: 51, 1: 103}, {0: 103, 1: 51}, {0: 74, 1: 80}, {0: 62, 1: 92}, {0: 154, 1: 0}, {0: 160, 1: 0}]
Buckets: [   0  165  330  495  660  825  990 1155 1320 1485 1650 1815 1980 2145
 2310 2475 2640 2805 2970 3135]
bucket_counts: [{0: 75, 1: 90}, {0: 117, 1: 48}, {0: 116, 1: 49}, {0: 35, 1: 130}, {0: 19, 1: 146}, {0: 40, 1: 125}, {0: 86, 1: 79}, {0: 48, 1: 117}, {0: 115, 1: 50}, {0: 50, 1: 115}, {0: 28, 1: 137}, {0: 26, 1: 139}, {0: 121, 1: 44}, {0: 95, 1: 70}, {0: 73, 1: 92}, {0: 83, 1: 82}, {0: 105, 1: 60}, {0: 88, 1: 77}, {0: 330, 1: 0}]
Buckets: [3086 3187 3288 3389 3490 3591 3692 3793 3894 3995 4096 4197 4298 4399
 4500 4601 4702

In [6]:
## Following code will compute the statistics associated with each fold.
all_output_dicts = {}
for metadata_setting in _METADATA_FNAMES.keys():
    output_dict = {} # {experiment_name: {subject_session: [(ratio1, split1), (ratio2, split2), ...]}}
    for experiment in _ALL_EXPERIMENTS:
        output_dict[experiment] = {}

        fpath = _SEGMENT_DIR.format(experiment)
        metadata_fname = _METADATA_FNAMES[metadata_setting]
        metadata = load_metadata(os.path.join(fpath, metadata_fname))

        for subject_session in ALL_SUBJECTS:
            print(
                f'metadata_setting:{metadata_setting}, '
                f'subject_session:{subject_session}, '
                f'experiment:{experiment}\n'
            )

            subject_rows_indices = metadata.get_indices_matching_cols_values(
                ["subject_session", "experiment"], [subject_session, experiment]
            )
            n_segments = len(subject_rows_indices)

            folds, splits = subject_folds[metadata_setting][subject_session][experiment]
            out_tuples = []
            for run_ratio, run_splits in zip(folds, splits):
                counts = (np.array(run_ratio) * n_segments).astype(int)
                counts[-1] = n_segments - sum(counts[:-1])

                print(f"Run_ratio: {run_ratio}")

                agg_split_counts = {'train': Counter(), 'val': Counter(), 'test': Counter()}
                sum_now = 0
                for c, split in zip(counts, run_splits):
                    label_split_indices = subject_rows_indices[sum_now : sum_now + c]
                    sum_now += c
                    agg_split_counts[split].update(
                        metadata._df.iloc[label_split_indices].label.to_numpy()
                    )

                print(f'Split statistics: {agg_split_counts}')
                out_tuples.append((run_ratio, run_splits))
            print('\n')

            output_dict[experiment][subject_session] = out_tuples

    all_output_dicts[metadata_setting] = output_dict

metadata_setting:default_metadata, subject_session:HOLDSUBJ_1_HS1_1, experiment:sentence_onset_time

Run_ratio: [0.05, 0.1, 0.05, 0.8]
Split statistics: {'train': Counter({1: 1252, 0: 1218}), 'val': Counter({1: 198, 0: 110}), 'test': Counter({0: 215, 1: 93})}
Run_ratio: [0.8, 0.05, 0.1, 0.05]
Split statistics: {'train': Counter({1: 1371, 0: 1097}), 'val': Counter({0: 228, 1: 82}), 'test': Counter({0: 218, 1: 90})}
Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]
Split statistics: {'train': Counter({0: 1296, 1: 1174}), 'val': Counter({1: 202, 0: 106}), 'test': Counter({1: 167, 0: 141})}
Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]
Split statistics: {'train': Counter({1: 1262, 0: 1208}), 'val': Counter({0: 178, 1: 130}), 'test': Counter({0: 157, 1: 151})}
Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]
Split statistics: {'train': Counter({0: 1355, 1: 1115}), 'val': Counter({1: 175, 0: 133}), 'test': Counter({1: 253, 0: 55})}


metadata_setting:default_metadata, subje

In [7]:
## Save out the data in the format expected in braintreebank_dataset.py.
import pickle

for fb_setting, fb_setting_output in all_output_dicts.items():
    out_fname = f"{Path(_METADATA_FNAMES[fb_setting]).stem}_folds.pkl"

    for experiment, experiment_output in fb_setting_output.items():
        out_path = _SEGMENT_DIR.format(experiment)
        print(out_path)
        with open(os.path.join(out_path, out_fname), 'wb') as file:
            pickle.dump(experiment_output, file)

/data/seyedesa/njepa/public_release_test/data_nov30_15_00/sentence_onset_time
/data/seyedesa/njepa/public_release_test/data_nov30_15_00/speech_vs_nonspeech_time


In [8]:
## Checking output was correct.
for fb_setting, fb_setting_output in all_output_dicts.items():
    out_fname = f"{Path(_METADATA_FNAMES[fb_setting]).stem}_folds.pkl"

    for experiment, experiment_output in fb_setting_output.items():
        out_path = _SEGMENT_DIR.format(experiment)
        with open(os.path.join(out_path, out_fname), 'rb') as file:
            datatmp = pickle.load(file)
        print(experiment)
        print(datatmp)
        print('\n')

sentence_onset_time
{'HOLDSUBJ_1_HS1_1': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_2_HS2_6': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_3_HS3_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], 