# Train a First Model
The aim of this notebook is to train a first model to predict the classes in the FashionMNIST dataset. Based on the results of the cross-validation, it will be decided whether hyperparameter tuning of the model parameters will be performed. The model is further refined with hyperparameter tuning if the average accuracy in the cross-validation is above 0.75. 
## 1. Imports

In [1]:
from sklearn.model_selection import train_test_split

import torch
from torch import optim
from torchinfo import summary

from utils.model import get_model
from utils.train import train_cross_validation

## 2. Load Training Data

In [2]:
%%time
# load data and labels
data, labels = torch.load("data/fashion_mnist_dataset.pt", weights_only=False)

CPU times: total: 12.3 s
Wall time: 16.2 s


In [3]:
%%time
# perform a train-test split to use 20% of the training data for initial training to get some first results (ensuring a balanced label distribution)
train_data, later_data, train_labels, later_labels = train_test_split(
    data, labels, test_size=0.8, stratify=labels, random_state=42
)

print(f"Size of the training set before performing cross validation: {len(train_data)}")

Size of the training set before performing cross validation: 12000
CPU times: total: 31.2 ms
Wall time: 55.3 ms


## 3. Model Training

In [4]:
# setup device-agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
# get model
model = get_model(device, dropout_rate=0, freeze=False)
summary(model=model,
        input_size=(64, 1, 224, 224),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
EfficientNet (EfficientNet)                                  [64, 1, 224, 224]    [64, 10]             --                   True
├─Sequential (features)                                      [64, 1, 224, 224]    [64, 1280, 7, 7]     --                   True
│    └─Conv2dNormActivation (0)                              [64, 1, 224, 224]    [64, 32, 112, 112]   --                   True
│    │    └─Conv2d (0)                                       [64, 1, 224, 224]    [64, 32, 112, 112]   288                  True
│    │    └─BatchNorm2d (1)                                  [64, 32, 112, 112]   [64, 32, 112, 112]   64                   True
│    │    └─SiLU (2)                                         [64, 32, 112, 112]   [64, 32, 112, 112]   --                   --
│    └─Sequential (1)                                        [64, 32, 112, 112]   [64, 16, 112

In [6]:
%%time
# define hyperparameter combination for a first training run
config = {
    "batch_size": 64,
    "dropout": 0,
    "epochs": 5,
    "learning_rate": 0.001,
    "freeze": False
}

# perform cross-validation
cv_results = train_cross_validation(config, device, train_labels, train_data)

-----------------------------------
Training using config: {'batch_size': 64, 'dropout': 0, 'epochs': 5, 'learning_rate': 0.001, 'freeze': False}
--------------Fold 1----------------


  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.9443, Train Acc=0.6465, Train F1=0.6394,Train Precision=0.6371, Train Recall=0.6465
           Val Loss=0.7341, Val Acc=0.7467, Val F1=0.7431, Val Precision=0.7532, Val Recall=0.7467
Epoch 2: Train Loss=0.5463, Train Acc=0.7984, Train F1=0.7941,Train Precision=0.7929, Train Recall=0.7984
           Val Loss=0.4594, Val Acc=0.8271, Val F1=0.8268, Val Precision=0.8359, Val Recall=0.8271
Epoch 3: Train Loss=0.4174, Train Acc=0.8484, Train F1=0.8462,Train Precision=0.8456, Train Recall=0.8484
           Val Loss=0.3636, Val Acc=0.8608, Val F1=0.8592, Val Precision=0.8647, Val Recall=0.8608
Epoch 4: Train Loss=0.3579, Train Acc=0.8690, Train F1=0.8676,Train Precision=0.8671, Train Recall=0.8690
           Val Loss=0.3983, Val Acc=0.8617, Val F1=0.8604, Val Precision=0.8681, Val Recall=0.8617
Epoch 5: Train Loss=0.3054, Train Acc=0.8882, Train F1=0.8874,Train Precision=0.8869, Train Recall=0.8882
           Val Loss=0.3958, Val Acc=0.8546, Val F1=0.8575, Val Precision=0

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.9879, Train Acc=0.6312, Train F1=0.6244,Train Precision=0.6207, Train Recall=0.6312
           Val Loss=0.7336, Val Acc=0.7296, Val F1=0.6964, Val Precision=0.7546, Val Recall=0.7296
Epoch 2: Train Loss=0.5497, Train Acc=0.7925, Train F1=0.7875,Train Precision=0.7876, Train Recall=0.7925
           Val Loss=0.6467, Val Acc=0.7754, Val F1=0.7656, Val Precision=0.7994, Val Recall=0.7754
Epoch 3: Train Loss=0.4318, Train Acc=0.8445, Train F1=0.8422,Train Precision=0.8418, Train Recall=0.8445
           Val Loss=0.3946, Val Acc=0.8667, Val F1=0.8649, Val Precision=0.8760, Val Recall=0.8667
Epoch 4: Train Loss=0.3647, Train Acc=0.8670, Train F1=0.8656,Train Precision=0.8651, Train Recall=0.8670
           Val Loss=0.8044, Val Acc=0.7458, Val F1=0.7429, Val Precision=0.8099, Val Recall=0.7458
Epoch 5: Train Loss=0.3255, Train Acc=0.8800, Train F1=0.8793,Train Precision=0.8788, Train Recall=0.8800
           Val Loss=0.8254, Val Acc=0.7425, Val F1=0.7347, Val Precision=0

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.9631, Train Acc=0.6297, Train F1=0.6251,Train Precision=0.6227, Train Recall=0.6297
           Val Loss=0.6190, Val Acc=0.7833, Val F1=0.7765, Val Precision=0.7859, Val Recall=0.7833
Epoch 2: Train Loss=0.5296, Train Acc=0.8077, Train F1=0.8051,Train Precision=0.8043, Train Recall=0.8077
           Val Loss=0.5112, Val Acc=0.7783, Val F1=0.7538, Val Precision=0.8071, Val Recall=0.7783
Epoch 3: Train Loss=0.4457, Train Acc=0.8396, Train F1=0.8383,Train Precision=0.8377, Train Recall=0.8396
           Val Loss=0.4156, Val Acc=0.8483, Val F1=0.8447, Val Precision=0.8535, Val Recall=0.8483
Epoch 4: Train Loss=0.3700, Train Acc=0.8640, Train F1=0.8630,Train Precision=0.8627, Train Recall=0.8640
           Val Loss=1.1064, Val Acc=0.5942, Val F1=0.5983, Val Precision=0.7036, Val Recall=0.5942
Epoch 5: Train Loss=0.3230, Train Acc=0.8835, Train F1=0.8829,Train Precision=0.8826, Train Recall=0.8835
           Val Loss=0.3099, Val Acc=0.8858, Val F1=0.8871, Val Precision=0

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.9684, Train Acc=0.6385, Train F1=0.6316,Train Precision=0.6295, Train Recall=0.6385
           Val Loss=0.8690, Val Acc=0.7096, Val F1=0.7115, Val Precision=0.7565, Val Recall=0.7096
Epoch 2: Train Loss=0.5444, Train Acc=0.8007, Train F1=0.7961,Train Precision=0.7959, Train Recall=0.8007
           Val Loss=0.4448, Val Acc=0.8246, Val F1=0.8090, Val Precision=0.8345, Val Recall=0.8246
Epoch 3: Train Loss=0.4289, Train Acc=0.8431, Train F1=0.8406,Train Precision=0.8403, Train Recall=0.8431
           Val Loss=0.3818, Val Acc=0.8529, Val F1=0.8514, Val Precision=0.8621, Val Recall=0.8529
Epoch 4: Train Loss=0.3598, Train Acc=0.8684, Train F1=0.8671,Train Precision=0.8667, Train Recall=0.8684
           Val Loss=0.3436, Val Acc=0.8746, Val F1=0.8743, Val Precision=0.8762, Val Recall=0.8746
Epoch 5: Train Loss=0.3194, Train Acc=0.8857, Train F1=0.8850,Train Precision=0.8847, Train Recall=0.8857
           Val Loss=0.3455, Val Acc=0.8683, Val F1=0.8640, Val Precision=0

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.9714, Train Acc=0.6388, Train F1=0.6318,Train Precision=0.6299, Train Recall=0.6388
           Val Loss=0.6093, Val Acc=0.7712, Val F1=0.7612, Val Precision=0.7837, Val Recall=0.7712
Epoch 2: Train Loss=0.5206, Train Acc=0.8093, Train F1=0.8055,Train Precision=0.8056, Train Recall=0.8093
           Val Loss=0.4282, Val Acc=0.8342, Val F1=0.8331, Val Precision=0.8522, Val Recall=0.8342
Epoch 3: Train Loss=0.4022, Train Acc=0.8554, Train F1=0.8540,Train Precision=0.8534, Train Recall=0.8554
           Val Loss=0.4139, Val Acc=0.8496, Val F1=0.8449, Val Precision=0.8597, Val Recall=0.8496
Epoch 4: Train Loss=0.3381, Train Acc=0.8790, Train F1=0.8781,Train Precision=0.8779, Train Recall=0.8790
           Val Loss=0.3676, Val Acc=0.8675, Val F1=0.8644, Val Precision=0.8726, Val Recall=0.8675
Epoch 5: Train Loss=0.3016, Train Acc=0.8897, Train F1=0.8893,Train Precision=0.8891, Train Recall=0.8897
           Val Loss=0.3522, Val Acc=0.8771, Val F1=0.8780, Val Precision=0

In [7]:
# display cross validation results
cv_results

Unnamed: 0,Metric,Mean Value
0,accuracy,0.845667
1,f1_score,0.844272
2,precision,0.869639
3,recall,0.845667


As the results look promising, the next step is to carry out hyperparameter tuning for the EfficientNet model. 