In [1]:
import numpy as np
import hashlib

In [2]:
def split_by_hash(X, splits):
    
    # normalize the weights, just in case
    weight_sum = np.sum([splits[k] for k in splits.keys()])
    splits = {k:splits[k]/weight_sum for k in splits.keys()}

    # determine intervals in [0,1] that correspond to each split
    # sorted is important here -- find a less fragile way to create bins?
    split_points = np.cumsum([0.0] + [splits[k] for k in sorted(splits.keys())])
    bins = {k:[split_points[i],split_points[i+1]]
            for i,k in enumerate(sorted(splits.keys()))}
    
    # hash the strings deterministically -- could add a salt / set a seed if we want
    hashes = [hashlib.sha512(str(i).encode('utf-8')).hexdigest() for i in X]

    # create some numbers in [0,1] (at five sig figs) from the hashes to bin the list elements
    img_nums = np.array([float("".join([c for c in h if c.isdigit()][:5]))/10**5
                         for h in hashes])
    
    # check where the img_nums fall relative to the bins
    # np.where returns a singleton tuple containing an np array, hence the weird code 
    split_inds = {k:list(*np.where(np.logical_and(img_nums>bins[k][0], img_nums<=bins[k][1])))
                  for k in bins.keys()}
    
    return split_inds

In [5]:
# some list that we want to split, could be directory paths or whatever
somelist = [i for i in range(20)]

# classes into which we want to split, with correspoding weights
somesplits = {'train':0.7,
              'test':0.2,
              'valid':0.1}

In [6]:
split_by_hash(somelist, somesplits)

{'test': [5, 7, 9, 18],
 'train': [0, 1, 2, 3, 4, 6, 8, 10, 11, 12, 13, 14, 16, 17, 19],
 'valid': [15]}