In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

from wildlife_datasets import datasets, loader, splits

# datasets.MacaqueFaces.download.get_data('data/MacaqueFaces')
dataset = datasets.MacaqueFaces('data/MacaqueFaces')
df = dataset.df.copy()

# For testing purposes only
df = df.iloc[4:]
df['date'].iloc[:1000] = '2016-01-01'

seed = 100

In [None]:
def analyze_split(df, idx_train, idx_test):
    df_train = df.loc[idx_train]
    df_test = df.loc[idx_test]
    
    ids_train = set(df_train['identity'])
    ids_test = set(df_test['identity'])
    ids_train_only = ids_train - ids_test
    ids_test_only = ids_test - ids_train
    
    n = len(idx_train)+len(idx_test)
    n_train = len(idx_train)
    n_test_only = sum([sum(df_test['identity'] == ids) for ids in ids_test_only])    
    
    ratio_train = n_train / n    
    ratio_test_only = n_test_only / n   
    print('Dataset size      = %d' % len(df))
    print('Train size        = %d' % len(df_train))
    print('Test size         = %d' % len(df_test))
    print('Unassigned        = %d' % (len(df)-len(df_train)-len(df_test)))
    print('')    
    print('Total individuals = %d' % len(ids_train.union(ids_test)))
    print('Joint individuals = %d' % len(ids_train.intersection(ids_test)))
    print('Only in train     = %d' % len(ids_train - ids_test))
    print('Only in test      = %d' % len(ids_test - ids_train))
    print('')    
    print('Fraction of train set = %1.2f%%' % (100*ratio_train))
    print('Fraction of test set only = %1.2f%%' % (100*ratio_test_only))

# Closed-set split

In [None]:
splitter = splits.ClosedSetSplit(df, seed)
idx_train, idx_test = splitter.split(0.5)
analyze_split(df, idx_train, idx_test)

# Open-set split

In [None]:
splitter = splits.OpenSetSplit(df, seed)
idx_train, idx_test = splitter.split(0.5, 0.1)
analyze_split(df, idx_train, idx_test)

In [None]:
splitter = splits.OpenSetSplit(df, seed)
idx_train, idx_test = splitter.split(0.5, n_class_test=5)
analyze_split(df, idx_train, idx_test)

# Disjoint split

In [None]:
splitter = splits.DisjointSetSplit(df, seed)
idx_train, idx_test = splitter.split(0.5)
analyze_split(df, idx_train, idx_test)

In [None]:
splitter = splits.DisjointSetSplit(df, seed)
idx_train, idx_test = splitter.split(n_class_test=10)
analyze_split(df, idx_train, idx_test)

# Time-proportion splits

In [None]:
splitter = splits.TimeProportionSplit(df, seed)
idx_train, idx_test = splitter.split()
analyze_split(df, idx_train, idx_test)

In [None]:
idx_train, idx_test = splitter.resplit_random(idx_train, idx_test)
analyze_split(df, idx_train, idx_test)

# Time-cutoff split

In [None]:
splitter = splits.TimeCutoffSplit(df, seed)
idx_train, idx_test = splitter.split(2015)
analyze_split(df, idx_train, idx_test)

In [None]:
idx_train, idx_test = splitter.resplit_random(idx_train, idx_test)
analyze_split(df, idx_train, idx_test)

In [None]:
for (idx_train, idx_test) in splitter.splits_all()[0]:
    analyze_split(df, idx_train, idx_test)