In [1]:
import numpy as np
import hashlib

In [2]:
def hashsplit(X, splits, salt=1, N=5):
    
    # normalize the weights, just in case
    splits = {k:v/sum(splits.values()) for k,v in splits.items()}

    # determine bins in [0,1] that correspond to each split
    bounds = np.cumsum([0.0] + [v for k,v in sorted(splits.items())])
    bins = {k:[bounds[i],bounds[i+1]] for i,(k,v) in enumerate(sorted(splits.items()))}
    
    # hash the strings deterministically
    hashes = [hashlib.sha512((str(i)+str(salt)).encode('utf-8')).hexdigest() for i in X]

    # create some numbers in [0,1] (at N sig figs) from the hashes
    nums = np.array([float("".join([c for c in h if c.isdigit()][:N]))/10**N for h in hashes])
    
    # check where the nums fall in [0,1] relative to the bins left and right boundaries
    inds = {k:np.where( (nums>l) & (nums<=r) ) for k,(l,r) in bins.items()}

    # np.where returns a singleton tuple containing an np array, so convert to list 
    return {k:list(*v) for k,v in inds.items()}

In [3]:
# some list that we want to split, could be directory paths or whatever
somelist = [i for i in range(20)]

# classes into which we want to split, with correspoding weights
somesplits = {'train':0.7,
              'test':0.2,
              'valid':0.1}

In [4]:
hashsplit(somelist, somesplits, salt=1, N=5)

{'test': [2, 4, 6, 11, 16, 17, 18],
 'train': [0, 1, 3, 5, 7, 8, 10, 12, 13, 14, 15],
 'valid': [9, 19]}