### Dataset Split

**packages**

In [22]:
import json
import h5py
import math
import numpy as np

import os.path as osp

#### **))Util functions**

In [4]:
def mkdir_if_missing(directory):
    if not osp.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise
                
def write_json(obj, fpath):
    mkdir_if_missing(osp.dirname(fpath))
    with open(fpath, 'w') as f:
        json.dump(obj, f, indent=4, separators=(',', ': '))

def write_yaml(obj, fpath):
    mkdir_if_missing(osp.dirname(fpath))
    with open(fpath, 'w') as f:
        yaml.dump(obj, f)
        

#### **)) Train Test Split modules**

In [14]:
def train_test_split(keys, num_videos, num_train):
    """Random split"""
    train_keys, test_keys = [], []
    rnd_idxs = np.random.choice(range(num_videos), size=num_train, replace=False)
    for key_idx, key in enumerate(keys):
        if key_idx in rnd_idxs:
            train_keys.append(key)
        else:
            test_keys.append(key)

    assert len(set(train_keys) & set(test_keys)) == 0, "Error: train_keys and test_keys overlap"

    return train_keys, test_keys

In [25]:
def create_splits(train_percent=0.8, num_splits=1, dataset_path='extracted_features/normal/TVSum.h5', save_name='tvsum_splits'):
    with h5py.File(dataset_path, 'r') as dataset:
        keys = dataset.keys()
        num_videos = len(keys)
        num_train = int(math.ceil(num_videos * train_percent))
        num_test = num_videos - num_train
        
        print("Split breakdown: # total videos {}. # train videos {}. # test videos {}".format(num_videos, num_train, num_test))
        splits = []
        
        for split_idx in range(num_splits):
            
            train_keys, test_keys = train_test_split(keys, num_videos, num_train)
            splits.append({
                'train_keys': train_keys,
                'test_keys': test_keys,
                })
            print('Split-', split_idx+1, ' completed')

        savetojson = osp.join('splits', save_name + '.json')
        # savetoyaml = osp.join(save_dir, save_name + '.yaml')
        write_json(splits, savetojson)
        # write_yaml(splits, savetoyaml)
        print("Splits saved to splits")


#### **)) TVSum Split**

In [26]:
create(dataset_path='extracted_features/Prebuilt/eccv16_dataset_tvsum_google_pool5.h5')

Split breakdown: # total videos 50. # train videos 40. # test videos 10
Split- 0  completed
Splits saved to splits
