# Cross-validation

In [12]:
import ipytest
import random
ipytest.autoconfig()

## Task

Given a set of instances (by their IDs), divide them into k folds to perform cross-validation.

Each fold should enumerate the instances for the train and test splits.

In [37]:
def create_folds(instances, k=5):
    n_train_samples = int(len(instances) * 0.8)
    """Given a set of instances, it returns k splits of train and test."""
    folds = []
    for i in range(k):
        train = random.sample(instances,k=n_train_samples)
        test = [instance for instance in instances if instance not in train]
        folds.append({
            'train': train, 
            'test': test
        })
    return folds

### Tests

One simple test is provided, which merely checks if the required number of folds is generated and that each contains the correct number of train and test instances.

Part of the exercise is to create some more advanced tests. 

  - One test should test converage, that is, check that all instances are part of exactly one test fold and k-1 train folds.
  - Another test should checks that the folds are sufficiently random, i.e., that you're not always returning the exact same partitioning of instances.

In [57]:
%%run_pytest[clean]

def test_fold_size():
    instances = list(range(100))
    folds = create_folds(instances, k=5)
    assert len(folds) == 5
    for fold in folds:
        assert len(fold['train']) == 80
        assert len(fold['test']) == 20

def test_coverage():
    instances_set = set(instances)
    for fold in folds:
        fold_set = set(fold['train'] + fold['test'])
        assert fold_set == instances_set
    
def test_randomization():
    fold_test= None
    fold_train = None
    for fold in folds:
        test = set(fold['test'])
        train = set(fold['train'])
        if  fold_test == None:
            fold_test = test
            fold_train = train
        else:
            assert fold_test != test and fold_train != train

...                                                                      [100%]
3 passed in 0.02s
