# Validate Results
The aim of this notebook is to test whether the results of the hyperparameter tuning can be generalized to different train-test splits. For this purpose, 5-fold cross-validation is used (Goodfellow et al., 2016, p. 122). 
## 1. Imports

In [1]:
from sklearn.model_selection import train_test_split
import torch

from utils.train import set_global_seed, train_cross_validation

## 2. Load Training Data

In [2]:
# set a global seed
SEED = 42
set_global_seed(SEED)

In [3]:
%%time
# load data and labels
data, labels = torch.load("data/fashion_mnist_dataset.pt", weights_only=False)

CPU times: total: 32 s
Wall time: 35.9 s


In [4]:
%%time
# split the training set to be able to use some of the data for later image ingestion
train_data, later_data, train_labels, later_labels = train_test_split(
    data, labels, test_size=0.05, stratify=labels, random_state=42
)  

print(f"size of the training set: {len(train_data)}")

size of the training set: 57000
CPU times: total: 46.9 ms
Wall time: 48.2 ms


## 3. 5-Fold Cross-Validation

In [5]:
# setup device-agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [6]:
%%time
# best hyperparameter combination
config = {
    "batch_size": 32,
    "dropout": 0.5,
    "epochs": 13,
    "learning_rate": 0.001,
    "freeze": False
}

# perform cross-validation
cv_results = train_cross_validation(config, device, train_labels, train_data)

-----------------------------------
Training using config: {'batch_size': 32, 'dropout': 0.5, 'epochs': 13, 'learning_rate': 0.001, 'freeze': False}
--------------Fold 1----------------


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.5664, Train Acc=0.7944, Train F1=0.7919,Train Precision=0.7911, Train Recall=0.7944
           Val Loss=0.3592, Val Acc=0.8714, Val F1=0.8618, Val Precision=0.8821, Val Recall=0.8714
Epoch 2: Train Loss=0.3309, Train Acc=0.8811, Train F1=0.8804,Train Precision=0.8800, Train Recall=0.8811
           Val Loss=0.2589, Val Acc=0.9045, Val F1=0.9042, Val Precision=0.9052, Val Recall=0.9045
Epoch 3: Train Loss=0.2776, Train Acc=0.9016, Train F1=0.9013,Train Precision=0.9011, Train Recall=0.9016
           Val Loss=0.2348, Val Acc=0.9146, Val F1=0.9143, Val Precision=0.9163, Val Recall=0.9146
Epoch 4: Train Loss=0.2487, Train Acc=0.9098, Train F1=0.9094,Train Precision=0.9092, Train Recall=0.9098
           Val Loss=0.2233, Val Acc=0.9216, Val F1=0.9216, Val Precision=0.9222, Val Recall=0.9216
Epoch 5: Train Loss=0.2224, Train Acc=0.9196, Train F1=0.9194,Train Precision=0.9192, Train Recall=0.9196
           Val Loss=0.2123, Val Acc=0.9275, Val F1=0.9275, Val Precision=0

  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.5972, Train Acc=0.7807, Train F1=0.7785,Train Precision=0.7778, Train Recall=0.7807
           Val Loss=0.4109, Val Acc=0.8528, Val F1=0.8467, Val Precision=0.8682, Val Recall=0.8528
Epoch 2: Train Loss=0.3337, Train Acc=0.8818, Train F1=0.8811,Train Precision=0.8808, Train Recall=0.8818
           Val Loss=0.2781, Val Acc=0.9036, Val F1=0.9029, Val Precision=0.9059, Val Recall=0.9036
Epoch 3: Train Loss=0.2752, Train Acc=0.9018, Train F1=0.9014,Train Precision=0.9012, Train Recall=0.9018
           Val Loss=0.2335, Val Acc=0.9149, Val F1=0.9161, Val Precision=0.9186, Val Recall=0.9149
Epoch 4: Train Loss=0.2395, Train Acc=0.9138, Train F1=0.9136,Train Precision=0.9135, Train Recall=0.9138
           Val Loss=0.2286, Val Acc=0.9232, Val F1=0.9217, Val Precision=0.9245, Val Recall=0.9232
Epoch 5: Train Loss=0.2219, Train Acc=0.9216, Train F1=0.9213,Train Precision=0.9212, Train Recall=0.9216
           Val Loss=0.2101, Val Acc=0.9285, Val F1=0.9285, Val Precision=0

  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.5822, Train Acc=0.7898, Train F1=0.7874,Train Precision=0.7863, Train Recall=0.7898
           Val Loss=0.3437, Val Acc=0.8796, Val F1=0.8775, Val Precision=0.8813, Val Recall=0.8796
Epoch 2: Train Loss=0.3242, Train Acc=0.8837, Train F1=0.8831,Train Precision=0.8828, Train Recall=0.8837
           Val Loss=0.2523, Val Acc=0.9087, Val F1=0.9083, Val Precision=0.9091, Val Recall=0.9087
Epoch 3: Train Loss=0.2738, Train Acc=0.9039, Train F1=0.9035,Train Precision=0.9034, Train Recall=0.9039
           Val Loss=0.2383, Val Acc=0.9148, Val F1=0.9149, Val Precision=0.9177, Val Recall=0.9148
Epoch 4: Train Loss=0.2384, Train Acc=0.9162, Train F1=0.9159,Train Precision=0.9157, Train Recall=0.9162
           Val Loss=0.2156, Val Acc=0.9216, Val F1=0.9215, Val Precision=0.9215, Val Recall=0.9216
Epoch 5: Train Loss=0.2186, Train Acc=0.9217, Train F1=0.9215,Train Precision=0.9214, Train Recall=0.9217
           Val Loss=0.2150, Val Acc=0.9254, Val F1=0.9253, Val Precision=0

  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.5961, Train Acc=0.7834, Train F1=0.7800,Train Precision=0.7787, Train Recall=0.7834
           Val Loss=0.3194, Val Acc=0.8860, Val F1=0.8845, Val Precision=0.8852, Val Recall=0.8860
Epoch 2: Train Loss=0.3287, Train Acc=0.8835, Train F1=0.8829,Train Precision=0.8825, Train Recall=0.8835
           Val Loss=0.2922, Val Acc=0.8939, Val F1=0.8935, Val Precision=0.8951, Val Recall=0.8939
Epoch 3: Train Loss=0.2715, Train Acc=0.9038, Train F1=0.9034,Train Precision=0.9033, Train Recall=0.9038
           Val Loss=0.2630, Val Acc=0.9031, Val F1=0.8996, Val Precision=0.9042, Val Recall=0.9031
Epoch 4: Train Loss=0.2408, Train Acc=0.9148, Train F1=0.9144,Train Precision=0.9143, Train Recall=0.9148
           Val Loss=0.2363, Val Acc=0.9180, Val F1=0.9180, Val Precision=0.9195, Val Recall=0.9180
Epoch 5: Train Loss=0.2156, Train Acc=0.9230, Train F1=0.9227,Train Precision=0.9226, Train Recall=0.9230
           Val Loss=0.2068, Val Acc=0.9253, Val F1=0.9249, Val Precision=0

  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.5816, Train Acc=0.7866, Train F1=0.7841,Train Precision=0.7833, Train Recall=0.7866
           Val Loss=0.3531, Val Acc=0.8654, Val F1=0.8669, Val Precision=0.8742, Val Recall=0.8654
Epoch 2: Train Loss=0.3220, Train Acc=0.8850, Train F1=0.8844,Train Precision=0.8840, Train Recall=0.8850
           Val Loss=0.2825, Val Acc=0.8977, Val F1=0.8978, Val Precision=0.9009, Val Recall=0.8977
Epoch 3: Train Loss=0.2727, Train Acc=0.9027, Train F1=0.9023,Train Precision=0.9021, Train Recall=0.9027
           Val Loss=0.2499, Val Acc=0.9123, Val F1=0.9121, Val Precision=0.9144, Val Recall=0.9123
Epoch 4: Train Loss=0.2361, Train Acc=0.9158, Train F1=0.9155,Train Precision=0.9153, Train Recall=0.9158
           Val Loss=0.2263, Val Acc=0.9200, Val F1=0.9212, Val Precision=0.9247, Val Recall=0.9200
Epoch 5: Train Loss=0.2162, Train Acc=0.9226, Train F1=0.9224,Train Precision=0.9223, Train Recall=0.9226
           Val Loss=0.1987, Val Acc=0.9275, Val F1=0.9283, Val Precision=0

In [7]:
# display cross validation results
cv_results

Unnamed: 0,Metric,Mean Value
0,accuracy,0.933947
1,f1_score,0.934092
2,precision,0.935643
3,recall,0.933947


## 6. References

Goodfellow, I., Bengio, Y., & Courville, A. (2016). _Deep learning. Adaptive Computation and Machine Learning_. The MIT Press. https://lccn.loc.gov/2016022992 