# This notebook trains a model with cross-val on the entire test dataset. Uses the true labels as labels (instead of consensus, making this perfect model)
- Gets pred_probs on holdout and saves results as numpy files
- Make sure you run ``preprocess_data`` and ``create_labels_df`` on local and push/pull newest ``cifar10_test_consensus_dataset`` first

In [1]:
# %load_ext autoreload
# %autoreload 2

import sys 

sys.path.insert(0, "../")

from autogluon.vision import ImagePredictor, ImageDataset
import numpy as np
import pandas as pd
import pickle
import datetime
from pathlib import Path
import cleanlab
from utils.cross_validation_autogluon import cross_val_predict_autogluon_image_dataset

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [2]:
# Load consensus test files
data_filepath = './data/benchmark_data/cifar10_test_dataset.csv'
df = pd.read_csv(data_filepath)

# Create mini train dataset for testing
num_from_each_group = 15
mini_df = df.groupby("label").head(num_from_each_group)
mini_df.groupby("label")["image"].count().reset_index()

Unnamed: 0,label,image
0,0,15
1,1,15
2,2,15
3,3,15
4,4,15
5,5,15
6,6,15
7,7,15
8,8,15
9,9,15


**Model and data saving params**

In [3]:
# save/load folders
model_folder = './data/cifar10_truelabels' # + [model_type]

# generate cross-validated predicted probabilities for various models
models = [
    "swin_base_patch4_window7_224"
]

# xvalidation parameters
num_cv_folds = 5 # number K in stratified K-folds cross-validation
verbose = 1 # verbose for dataloading duing crossval to numpy save

# shared model parameters
epochs = 100
holdout_frac = 0.2
time_limit = 21600
random_state = 123

## Run cross validation on `models`

In [4]:
%%time
# run cross-validation for each model
for model in models:
    
    print("----")
    print(f"Running cross-validation for model: {model}")

    MODEL_PARAMS = {
        "model": model,
        "epochs": epochs,
        "holdout_frac": holdout_frac,
    }

    # results of cross-validation will be saved to pickle files for each model/fold
    _ = \
        cross_val_predict_autogluon_image_dataset(
            dataset=df,
            out_folder=f"{model_folder}_{model}/", # save results of cross-validation in pickle files for each fold
            n_splits=num_cv_folds,
            model_params=MODEL_PARAMS,
            time_limit=time_limit,
            random_state=random_state,
        )

modified configs(<old> != <new>): {
root.img_cls.model   resnet101 != swin_base_patch4_window7_224
root.train.early_stop_patience -1 != 10
root.train.early_stop_baseline 0.0 != -inf
root.train.early_stop_max_value 1.0 != inf
root.train.epochs    200 != 100
root.train.batch_size 32 != 16
root.misc.seed       42 != 331
root.misc.num_workers 4 != 64
}
Saved config to /datasets/ulyana/multiannotator_benchmarks/09373de0/.trial_0/config.yaml


----
Running cross-validation for model: swin_base_patch4_window7_224
----
Running Cross-Validation on Split: 0


Model swin_base_patch4_window7_224 created, param count:                                         86753474
AMP not enabled. Training in float32.
Disable EMA as it is not supported for now.
Start training from [Epoch 0]
Epoch[0] Batch [49]	Speed: 70.610260 samples/sec	accuracy=0.145000	lr=0.000100
Epoch[0] Batch [99]	Speed: 103.269302 samples/sec	accuracy=0.180625	lr=0.000100
Epoch[0] Batch [149]	Speed: 103.583590 samples/sec	accuracy=0.235417	lr=0.000100
Epoch[0] Batch [199]	Speed: 103.142438 samples/sec	accuracy=0.301563	lr=0.000100
Epoch[0] Batch [249]	Speed: 102.781737 samples/sec	accuracy=0.361000	lr=0.000100
Epoch[0] Batch [299]	Speed: 102.498664 samples/sec	accuracy=0.410625	lr=0.000100
Epoch[0] Batch [349]	Speed: 102.620982 samples/sec	accuracy=0.455357	lr=0.000100
Epoch[0] Batch [399]	Speed: 102.558473 samples/sec	accuracy=0.490938	lr=0.000100
Epoch[0] Batch [449]	Speed: 102.394425 samples/sec	accuracy=0.524444	lr=0.000100
[Epoch 0] training: accuracy=0.524444
[Epoch 0] speed: 9

[Epoch 8] speed: 100 samples/sec	time cost: 71.628242
[Epoch 8] validation: top1=0.931250 top5=0.996250
Epoch[9] Batch [49]	Speed: 97.604310 samples/sec	accuracy=0.837500	lr=0.010000
Epoch[9] Batch [99]	Speed: 101.038402 samples/sec	accuracy=0.836250	lr=0.010000
Epoch[9] Batch [149]	Speed: 101.202269 samples/sec	accuracy=0.840000	lr=0.010000
Epoch[9] Batch [199]	Speed: 100.852692 samples/sec	accuracy=0.832812	lr=0.010000
Epoch[9] Batch [249]	Speed: 100.828572 samples/sec	accuracy=0.834500	lr=0.010000
Epoch[9] Batch [299]	Speed: 100.762317 samples/sec	accuracy=0.835833	lr=0.010000
Epoch[9] Batch [349]	Speed: 101.006904 samples/sec	accuracy=0.832321	lr=0.010000
Epoch[9] Batch [399]	Speed: 101.050747 samples/sec	accuracy=0.833750	lr=0.010000
Epoch[9] Batch [449]	Speed: 100.985984 samples/sec	accuracy=0.835417	lr=0.010000
[Epoch 9] training: accuracy=0.835417
[Epoch 9] speed: 100 samples/sec	time cost: 71.592460
[Epoch 9] validation: top1=0.936250 top5=0.996250
Epoch[10] Batch [49]	Speed: 

Folder ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/ already exists!
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_pred_probs_split_0
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_pred_features_split_0
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_labels_split_0
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_image_files_split_0
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_indices_split_0


modified configs(<old> != <new>): {
root.img_cls.model   resnet101 != swin_base_patch4_window7_224
root.train.early_stop_patience -1 != 10
root.train.early_stop_baseline 0.0 != -inf
root.train.early_stop_max_value 1.0 != inf
root.train.epochs    200 != 100
root.train.batch_size 32 != 16
root.misc.seed       42 != 468
root.misc.num_workers 4 != 64
}
Saved config to /datasets/ulyana/multiannotator_benchmarks/e5f44713/.trial_0/config.yaml


----
Running Cross-Validation on Split: 1


Model swin_base_patch4_window7_224 created, param count:                                         86753474
AMP not enabled. Training in float32.
Disable EMA as it is not supported for now.
Start training from [Epoch 0]
Epoch[0] Batch [49]	Speed: 72.825903 samples/sec	accuracy=0.126250	lr=0.000100
Epoch[0] Batch [99]	Speed: 101.462994 samples/sec	accuracy=0.201875	lr=0.000100
Epoch[0] Batch [149]	Speed: 101.430170 samples/sec	accuracy=0.294583	lr=0.000100
Epoch[0] Batch [199]	Speed: 101.270042 samples/sec	accuracy=0.362812	lr=0.000100
Epoch[0] Batch [249]	Speed: 101.161101 samples/sec	accuracy=0.416500	lr=0.000100
Epoch[0] Batch [299]	Speed: 101.064580 samples/sec	accuracy=0.461458	lr=0.000100
Epoch[0] Batch [349]	Speed: 101.052430 samples/sec	accuracy=0.497321	lr=0.000100
Epoch[0] Batch [399]	Speed: 101.050762 samples/sec	accuracy=0.527500	lr=0.000100
Epoch[0] Batch [449]	Speed: 101.013205 samples/sec	accuracy=0.555833	lr=0.000100
[Epoch 0] training: accuracy=0.555833
[Epoch 0] speed: 9

[Epoch 8] speed: 100 samples/sec	time cost: 71.636907
[Epoch 8] validation: top1=0.916250 top5=0.991250
Epoch[9] Batch [49]	Speed: 96.963370 samples/sec	accuracy=0.803750	lr=0.010000
Epoch[9] Batch [99]	Speed: 100.956447 samples/sec	accuracy=0.816875	lr=0.010000
Epoch[9] Batch [149]	Speed: 101.111806 samples/sec	accuracy=0.824583	lr=0.010000
Epoch[9] Batch [199]	Speed: 100.886991 samples/sec	accuracy=0.824063	lr=0.010000
Epoch[9] Batch [249]	Speed: 100.929478 samples/sec	accuracy=0.784000	lr=0.010000
Epoch[9] Batch [299]	Speed: 100.893746 samples/sec	accuracy=0.683542	lr=0.010000
Epoch[9] Batch [349]	Speed: 100.887346 samples/sec	accuracy=0.620714	lr=0.010000
Epoch[9] Batch [399]	Speed: 101.032558 samples/sec	accuracy=0.577812	lr=0.010000
Epoch[9] Batch [449]	Speed: 100.957769 samples/sec	accuracy=0.544722	lr=0.010000
[Epoch 9] training: accuracy=0.544722
[Epoch 9] speed: 100 samples/sec	time cost: 71.652579
[Epoch 9] validation: top1=0.413750 top5=0.866250
Epoch[10] Batch [49]	Speed: 

Folder ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/ already exists!
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_pred_probs_split_1
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_pred_features_split_1
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_labels_split_1
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_image_files_split_1
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_indices_split_1


modified configs(<old> != <new>): {
root.img_cls.model   resnet101 != swin_base_patch4_window7_224
root.train.early_stop_patience -1 != 10
root.train.early_stop_baseline 0.0 != -inf
root.train.early_stop_max_value 1.0 != inf
root.train.epochs    200 != 100
root.train.batch_size 32 != 16
root.misc.seed       42 != 525
root.misc.num_workers 4 != 64
}
Saved config to /datasets/ulyana/multiannotator_benchmarks/01ea71e9/.trial_0/config.yaml


----
Running Cross-Validation on Split: 2


Model swin_base_patch4_window7_224 created, param count:                                         86753474
AMP not enabled. Training in float32.
Disable EMA as it is not supported for now.
Start training from [Epoch 0]
Epoch[0] Batch [49]	Speed: 72.772137 samples/sec	accuracy=0.130000	lr=0.000100
Epoch[0] Batch [99]	Speed: 101.578924 samples/sec	accuracy=0.210625	lr=0.000100
Epoch[0] Batch [149]	Speed: 101.384438 samples/sec	accuracy=0.283750	lr=0.000100
Epoch[0] Batch [199]	Speed: 101.152156 samples/sec	accuracy=0.345313	lr=0.000100
Epoch[0] Batch [249]	Speed: 101.005901 samples/sec	accuracy=0.396500	lr=0.000100
Epoch[0] Batch [299]	Speed: 100.903653 samples/sec	accuracy=0.450417	lr=0.000100
Epoch[0] Batch [349]	Speed: 101.004332 samples/sec	accuracy=0.496964	lr=0.000100
Epoch[0] Batch [399]	Speed: 101.157670 samples/sec	accuracy=0.532031	lr=0.000100
Epoch[0] Batch [449]	Speed: 101.045044 samples/sec	accuracy=0.563194	lr=0.000100
[Epoch 0] training: accuracy=0.563194
[Epoch 0] speed: 9

Epoch[8] Batch [449]	Speed: 101.004730 samples/sec	accuracy=0.839028	lr=0.010000
[Epoch 8] training: accuracy=0.839028
[Epoch 8] speed: 100 samples/sec	time cost: 71.503580
[Epoch 8] validation: top1=0.941250 top5=0.998750
Epoch[9] Batch [49]	Speed: 97.271429 samples/sec	accuracy=0.857500	lr=0.010000
Epoch[9] Batch [99]	Speed: 101.028041 samples/sec	accuracy=0.846250	lr=0.010000
Epoch[9] Batch [149]	Speed: 101.192332 samples/sec	accuracy=0.844167	lr=0.010000
Epoch[9] Batch [199]	Speed: 100.834581 samples/sec	accuracy=0.844375	lr=0.010000
Epoch[9] Batch [249]	Speed: 100.919667 samples/sec	accuracy=0.845250	lr=0.010000
Epoch[9] Batch [299]	Speed: 100.881652 samples/sec	accuracy=0.846875	lr=0.010000
Epoch[9] Batch [349]	Speed: 100.925243 samples/sec	accuracy=0.845536	lr=0.010000
Epoch[9] Batch [399]	Speed: 100.997817 samples/sec	accuracy=0.844531	lr=0.010000
Epoch[9] Batch [449]	Speed: 101.003180 samples/sec	accuracy=0.844861	lr=0.010000
[Epoch 9] training: accuracy=0.844861
[Epoch 9] spe

[Epoch 17] validation: top1=0.951250 top5=0.997500
Epoch[18] Batch [49]	Speed: 98.043480 samples/sec	accuracy=0.873750	lr=0.010000
Epoch[18] Batch [99]	Speed: 101.015540 samples/sec	accuracy=0.881250	lr=0.010000
Epoch[18] Batch [149]	Speed: 101.196455 samples/sec	accuracy=0.888750	lr=0.010000
Epoch[18] Batch [199]	Speed: 100.913166 samples/sec	accuracy=0.887500	lr=0.010000
Epoch[18] Batch [249]	Speed: 100.960578 samples/sec	accuracy=0.887000	lr=0.010000
Epoch[18] Batch [299]	Speed: 100.914908 samples/sec	accuracy=0.886042	lr=0.010000
Epoch[18] Batch [349]	Speed: 101.073953 samples/sec	accuracy=0.884286	lr=0.010000
Epoch[18] Batch [399]	Speed: 101.035068 samples/sec	accuracy=0.884844	lr=0.010000
Epoch[18] Batch [449]	Speed: 101.046611 samples/sec	accuracy=0.885556	lr=0.010000
[Epoch 18] training: accuracy=0.885556
[Epoch 18] speed: 100 samples/sec	time cost: 71.521473
[Epoch 18] validation: top1=0.938750 top5=1.000000
Epoch[19] Batch [49]	Speed: 97.738561 samples/sec	accuracy=0.885000	l

Folder ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_2/ already exists!
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_2/_test_pred_probs_split_2
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_2/_test_pred_features_split_2
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_2/_test_labels_split_2
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_2/_test_image_files_split_2
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_2/_test_indices_split_2


modified configs(<old> != <new>): {
root.img_cls.model   resnet101 != swin_base_patch4_window7_224
root.train.early_stop_patience -1 != 10
root.train.early_stop_baseline 0.0 != -inf
root.train.early_stop_max_value 1.0 != inf
root.train.epochs    200 != 100
root.train.batch_size 32 != 16
root.misc.seed       42 != 554
root.misc.num_workers 4 != 64
}
Saved config to /datasets/ulyana/multiannotator_benchmarks/dfc70db1/.trial_0/config.yaml


----
Running Cross-Validation on Split: 3


Model swin_base_patch4_window7_224 created, param count:                                         86753474
AMP not enabled. Training in float32.
Disable EMA as it is not supported for now.
Start training from [Epoch 0]
Epoch[0] Batch [49]	Speed: 72.362471 samples/sec	accuracy=0.123750	lr=0.000100
Epoch[0] Batch [99]	Speed: 101.640667 samples/sec	accuracy=0.214375	lr=0.000100
Epoch[0] Batch [149]	Speed: 101.345929 samples/sec	accuracy=0.300000	lr=0.000100
Epoch[0] Batch [199]	Speed: 101.024439 samples/sec	accuracy=0.375937	lr=0.000100
Epoch[0] Batch [249]	Speed: 100.925829 samples/sec	accuracy=0.438000	lr=0.000100
Epoch[0] Batch [299]	Speed: 100.782697 samples/sec	accuracy=0.486667	lr=0.000100
Epoch[0] Batch [349]	Speed: 100.882432 samples/sec	accuracy=0.523214	lr=0.000100
Epoch[0] Batch [399]	Speed: 101.006898 samples/sec	accuracy=0.560000	lr=0.000100
Epoch[0] Batch [449]	Speed: 100.917552 samples/sec	accuracy=0.589028	lr=0.000100
[Epoch 0] training: accuracy=0.589028
[Epoch 0] speed: 9

Epoch[9] Batch [99]	Speed: 100.998887 samples/sec	accuracy=0.858125	lr=0.010000
Epoch[9] Batch [149]	Speed: 101.104981 samples/sec	accuracy=0.860000	lr=0.010000
Epoch[9] Batch [199]	Speed: 100.798752 samples/sec	accuracy=0.852187	lr=0.010000
Epoch[9] Batch [249]	Speed: 100.864661 samples/sec	accuracy=0.850750	lr=0.010000
Epoch[9] Batch [299]	Speed: 100.842651 samples/sec	accuracy=0.852292	lr=0.010000
Epoch[9] Batch [349]	Speed: 100.860929 samples/sec	accuracy=0.851786	lr=0.010000
Epoch[9] Batch [399]	Speed: 101.029163 samples/sec	accuracy=0.847812	lr=0.010000
Epoch[9] Batch [449]	Speed: 100.970394 samples/sec	accuracy=0.841667	lr=0.010000
[Epoch 9] training: accuracy=0.841667
[Epoch 9] speed: 100 samples/sec	time cost: 71.623932
[Epoch 9] validation: top1=0.942500 top5=1.000000
Epoch[10] Batch [49]	Speed: 97.135191 samples/sec	accuracy=0.870000	lr=0.010000
Epoch[10] Batch [99]	Speed: 101.077373 samples/sec	accuracy=0.865000	lr=0.010000
Epoch[10] Batch [149]	Speed: 101.200969 samples/se

Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_3/_test_pred_probs_split_3
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_3/_test_pred_features_split_3
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_3/_test_labels_split_3
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_3/_test_image_files_split_3
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_3/_test_indices_split_3


modified configs(<old> != <new>): {
root.img_cls.model   resnet101 != swin_base_patch4_window7_224
root.train.early_stop_patience -1 != 10
root.train.early_stop_baseline 0.0 != -inf
root.train.early_stop_max_value 1.0 != inf
root.train.epochs    200 != 100
root.train.batch_size 32 != 16
root.misc.seed       42 != 479
root.misc.num_workers 4 != 64
}
Saved config to /datasets/ulyana/multiannotator_benchmarks/c2314db8/.trial_0/config.yaml


----
Running Cross-Validation on Split: 4


Model swin_base_patch4_window7_224 created, param count:                                         86753474
AMP not enabled. Training in float32.
Disable EMA as it is not supported for now.
Start training from [Epoch 0]
Epoch[0] Batch [49]	Speed: 72.559596 samples/sec	accuracy=0.130000	lr=0.000100
Epoch[0] Batch [99]	Speed: 101.671449 samples/sec	accuracy=0.193750	lr=0.000100
Epoch[0] Batch [149]	Speed: 101.485629 samples/sec	accuracy=0.285000	lr=0.000100
Epoch[0] Batch [199]	Speed: 101.226962 samples/sec	accuracy=0.366563	lr=0.000100
Epoch[0] Batch [249]	Speed: 101.073573 samples/sec	accuracy=0.426750	lr=0.000100
Epoch[0] Batch [299]	Speed: 100.993601 samples/sec	accuracy=0.475000	lr=0.000100
Epoch[0] Batch [349]	Speed: 101.069676 samples/sec	accuracy=0.517857	lr=0.000100
Epoch[0] Batch [399]	Speed: 101.076721 samples/sec	accuracy=0.545625	lr=0.000100
Epoch[0] Batch [449]	Speed: 101.086264 samples/sec	accuracy=0.572639	lr=0.000100
[Epoch 0] training: accuracy=0.572639
[Epoch 0] speed: 9

[Epoch 8] speed: 100 samples/sec	time cost: 71.589633
[Epoch 8] validation: top1=0.942500 top5=0.997500
Epoch[9] Batch [49]	Speed: 97.291750 samples/sec	accuracy=0.856250	lr=0.010000
Epoch[9] Batch [99]	Speed: 101.109530 samples/sec	accuracy=0.860625	lr=0.010000
Epoch[9] Batch [149]	Speed: 101.198673 samples/sec	accuracy=0.857500	lr=0.010000
Epoch[9] Batch [199]	Speed: 100.934807 samples/sec	accuracy=0.853750	lr=0.010000
Epoch[9] Batch [249]	Speed: 100.877409 samples/sec	accuracy=0.852250	lr=0.010000
Epoch[9] Batch [299]	Speed: 100.906532 samples/sec	accuracy=0.851667	lr=0.010000
Epoch[9] Batch [349]	Speed: 100.975444 samples/sec	accuracy=0.854286	lr=0.010000
Epoch[9] Batch [399]	Speed: 101.091609 samples/sec	accuracy=0.851562	lr=0.010000
Epoch[9] Batch [449]	Speed: 100.972099 samples/sec	accuracy=0.850694	lr=0.010000
[Epoch 9] training: accuracy=0.850694
[Epoch 9] speed: 100 samples/sec	time cost: 71.593703
[Epoch 9] validation: top1=0.952500 top5=1.000000
[Epoch 9] Current best top-1

Epoch[18] Batch [99]	Speed: 100.988920 samples/sec	accuracy=0.875000	lr=0.010000
Epoch[18] Batch [149]	Speed: 101.216391 samples/sec	accuracy=0.876667	lr=0.010000
Epoch[18] Batch [199]	Speed: 100.942862 samples/sec	accuracy=0.877188	lr=0.010000
Epoch[18] Batch [249]	Speed: 100.933574 samples/sec	accuracy=0.877500	lr=0.010000
Epoch[18] Batch [299]	Speed: 100.902697 samples/sec	accuracy=0.878542	lr=0.010000
Epoch[18] Batch [349]	Speed: 101.039528 samples/sec	accuracy=0.878750	lr=0.010000
Epoch[18] Batch [399]	Speed: 101.072568 samples/sec	accuracy=0.879531	lr=0.010000
Epoch[18] Batch [449]	Speed: 101.029449 samples/sec	accuracy=0.879306	lr=0.010000
[Epoch 18] training: accuracy=0.879306
[Epoch 18] speed: 100 samples/sec	time cost: 71.562829
[Epoch 18] validation: top1=0.932500 top5=0.996250
Epoch[19] Batch [49]	Speed: 97.809194 samples/sec	accuracy=0.900000	lr=0.010000
Epoch[19] Batch [99]	Speed: 101.056858 samples/sec	accuracy=0.891875	lr=0.010000
Epoch[19] Batch [149]	Speed: 101.328349

Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_4/_test_pred_probs_split_4
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_4/_test_pred_features_split_4
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_4/_test_labels_split_4
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_4/_test_image_files_split_4
Saving ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_4/_test_indices_split_4
CPU times: user 1h 21min 1s, sys: 19min 38s, total: 1h 40min 39s
Wall time: 1h 39min 7s


## Read per-fold pickle files from xvalidation and save data as numpy arrays

In [5]:
# load pickle file util
def load_pickle(pickle_file_name, verbose=1):
    """Load pickle file"""

    if verbose:
        print(f"Loading {pickle_file_name}")

    with open(pickle_file_name, 'rb') as handle:
        out = pickle.load(handle)
        
    return out

# get the original label from file path (aka "true labels" y)
get_orig_label_idx_from_file_path = np.vectorize(lambda f: label_name_to_idx_map[Path(f).parts[-2]])

# get original label name to idx mapping
label_name_to_idx_map = {'airplane': 0,
                         'automobile': 1,
                         'bird': 2,
                         'cat': 3,
                         'deer': 4,
                         'dog': 5,
                         'frog': 6,
                         'horse': 7,
                         'ship': 8,
                         'truck': 9}

#### Save pickle files per fold as single files per model

In [6]:
results_list = []

for model in models:

    pred_probs = []
    labels = []
    images = []
    
    for split_num in range(num_cv_folds):

        out_subfolder = f"{model_folder}_{model}/split_{split_num}/"
        
        # pickle file name to read
        get_pickle_file_name = (
            lambda object_name: f"{out_subfolder}_{object_name}_split_{split_num}"
        )

        # NOTE: the "test_" prefix in the pickle name correspond to the "test" split during cross-validation.
        pred_probs_split = load_pickle(get_pickle_file_name("test_pred_probs"), verbose=verbose)
        labels_split = load_pickle(get_pickle_file_name("test_labels"), verbose=verbose)
        images_split = load_pickle(get_pickle_file_name("test_image_files"), verbose=verbose)
        indices_split = load_pickle(get_pickle_file_name("test_indices"), verbose=verbose)

        # append to list so we can combine data from all the splits
        pred_probs.append(pred_probs_split)
        labels.append(labels_split)
        images.append(images_split)    

    # convert list to array
    pred_probs = np.vstack(pred_probs)
    labels = np.hstack(labels) # remember that this is the noisy labels (s)
    images = np.hstack(images)
    
    # get the true labels (y) from the original file path
    true_labels = get_orig_label_idx_from_file_path(images)
    
    # save to Numpy files
    numpy_out_folder = f"{model_folder}_{model}/"
    
    print(f"Saving to numpy files in this folder: {numpy_out_folder}")
    
    np.save(numpy_out_folder + "pred_probs", pred_probs)
    np.save(numpy_out_folder + "labels", labels)
    np.save(numpy_out_folder + "images", images)
    np.save(numpy_out_folder + "true_labels", true_labels)

    # check the accuracy
    acc_labels = (pred_probs.argmax(axis=1) == labels).mean() # noisy labels (s)
    acc_true_labels = (pred_probs.argmax(axis=1) == true_labels).mean() # true labels (y)    
    acc_noisy_vs_true_labels = (labels == true_labels).mean()
    
    print(f"Model: {model}")
    print(f"  Accuracy (argmax pred vs labels): {acc_labels}")
    print(f"  Accuracy (argmax pred vs true labels) : {acc_true_labels}")
    print(f"  Accuracy (labels vs true labels)       : {acc_noisy_vs_true_labels}")
    
    results = {
        "model": model,
        "Accuracy (argmax pred vs noisy labels)": acc_labels,
        "Accuracy (argmax pred vs true labels)": acc_true_labels,
        "Accuracy (noisy vs true labels)": acc_noisy_vs_true_labels
    }
    
    results_list.append(results)

Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_pred_probs_split_0
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_labels_split_0
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_image_files_split_0
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_0/_test_indices_split_0
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_pred_probs_split_1
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_labels_split_1
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_image_files_split_1
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_1/_test_indices_split_1
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_2/_test_pred_probs_split_2
Loading ./data/cifar10_truelabels_swin_base_patch4_window7_224/split_2/_test_labels_split_2
Loading ./data/cifar10_truelabels_swin_base_patch4_windo

In [7]:
for model in models:
    numpy_out_folder = f"{model_folder}_{model}/"

    pred_probs = np.load(numpy_out_folder + 'pred_probs.npy')
    labels = np.load(numpy_out_folder + 'labels.npy')
    true_labels =  np.load(numpy_out_folder + 'true_labels.npy')
    print(f'{model}\n pred_probs[{pred_probs.shape}],labels[{labels.shape}], true_labels[{true_labels.shape}]\n')

swin_base_patch4_window7_224
 pred_probs[(10000, 10)],labels[(10000,)], true_labels[(10000,)]

