In [4]:
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

### Metrics

In [24]:
def classification_metrics(y_true, y_pred):
    """
    computes conf matrix + acc, prec, rec, and f1
    
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    # conf matrix
    tp = np.sum((y_true==1) & (y_pred==1))
    tn = np.sum((y_true==0) & (y_pred==0))
    fp = np.sum((y_true==0) & (y_pred==1))
    fn = np.sum((y_true==1) & (y_pred==0))

    acc = (tp + tn)/(tp + tn + fp + fn)
    prec = tp/(tp + fp)
    rec = tp/(tp + fn)
    f1 = 2*prec*rec/(prec+rec)

    metrics = {
        "tp":tp, "tn": tn, "fp":fp, "fn":fn, "acc":acc, "prec":prec, "rec": rec, "f1":f1
    }

### Test, Train, Validation

Test, Training, and Validation Sets
- A completely randomly sampled split
 							
- A stratified split
 							
- A split that is chosen in a non-random way, so that your test and/or validation sets can be considered to more accurately represent the data that will be seen when the system is deployed


In [25]:
# 1. randomly sampled split
def train_test_index_split_random(n, k, seed=42):
    rng = np.random.default_rng(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    folds = np.array_split(idx, k)
    return [
        (np.concatenate(folds[:i] + folds[i+1:]), folds[i])
        for i in range(k)
    ]

In [26]:
#2. stratified split
def train_test_index_split_stratified(y, k, seed=42):
    rng = np.random.default_rng(seed)
    y = np.asarray(y)
    folds = [[] for _ in range(k)]
    # distribute classes evenly across folds
    for cls in np.unique(y):
        indexes = rng.permutation(np.where(y == cls)[0])
        split = np.array_split(indexes, k)
        for i in range(k):
            folds[i].extend(split[i])
    splits = []
    for i in range(k):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(k) if j != i])
        splits.append((train_idx, test_idx))
    return splits

In [27]:
import numpy as np

def cross_validate(X, y, k=5, split="stratified", seed=42, model_fn=None):
    X = np.asarray(X)
    y = np.asarray(y)

    # split method
    if split == "stratified":
        splits = train_test_index_split_stratified(y, k, seed)
    elif split == "random":
        splits = train_test_index_split_random(len(y), k, seed)
    else:
        raise ValueError("split type not found")

    results = []
    for fold, (train_idx, test_idx) in enumerate(splits, 1):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        if model_fn is None:
            results.append({
                "fold": fold,
                "train_idx": train_idx, "test_idx": test_idx,
                "X_train": X_train, "y_train": y_train,
                "X_test": X_test, "y_test": y_test,
            })
        else:
            # TODO: create model_fn for the three methods
            metrics = model_fn(X_train, y_train, X_test, y_test)  
            metrics["fold"] = fold
            results.append(metrics)
    return results

sanity check

In [29]:
data = load_breast_cancer()
X, y = data.data, data.target
folds = cross_validate(X, y, k=5, split="stratified", seed=42)
folds

[{'fold': 1,
  'train_idx': array([323,   4, 237,  36,   3,  32, 446, 514, 171,   8, 168, 517, 117,
         218, 197, 182,  30, 184, 141,  82, 138, 205,   2,  35, 417, 321,
         193,  78, 105,  73,  31, 366,  12, 244, 351,   0, 282, 265,  94,
         261, 146, 186,  43, 360, 482, 511, 101, 338, 388, 188, 270,  19,
          92, 267, 442, 289, 463, 398, 555, 347, 395, 556, 332, 316, 155,
          71, 428, 224, 151, 456, 539, 243, 269, 307, 506, 491, 142, 469,
         248, 419, 383, 518, 303,  97, 149, 354, 421, 553, 558, 143, 322,
         513, 165, 295, 440, 505, 462, 409, 550, 324,  96, 120, 145, 281,
         381, 528, 110, 544, 333, 320, 341, 386, 311, 179,  46, 435, 259,
         408, 509,  13, 373, 207, 233, 180,   5, 162,  86,  64, 449, 254,
         236, 260, 302, 330, 492, 214, 230, 433, 317, 156,  56, 199, 172,
         262,  18, 219, 300, 134, 167, 257, 487, 122, 535, 213, 177, 196,
         126, 327, 548, 227, 375, 204, 112, 175, 530, 166, 240, 445, 115,
         346