In [1]:
from fastai.vision.all import *
from fastai.callback.all import EarlyStoppingCallback

In [2]:
import warnings

warnings.filterwarnings(
    "ignore",
    message="User provided device_type of 'cuda', but CUDA is not available.*"
)

In [3]:
from datasets import load_dataset
ds = load_dataset("uoft-cs/cifar10")
ds

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
})

In [4]:
from fastai.torch_core import defaults
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
defaults.device = device
device

device(type='cpu')

In [5]:
from sklearn.model_selection import StratifiedKFold

n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

In [6]:
hf_train = ds['train'].with_format("torch")
hf_valid = ds['test'].with_format("torch")

In [7]:
class Ensure3Channels(Transform):
    def encodes(self, x:TensorImage):
        if x.ndim==3 and x.shape[0]==1:
            return x.repeat(3,1,1).contiguous()
        return TensorImage(x)

In [8]:
from torch.utils.data import Dataset

class HFtoFAI(Dataset):
    def __init__(self, hf_ds): self.ds = hf_ds
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        ex = self.ds[i]
        x = TensorImage(ex['img'])
        y = int(ex['label'])
        return x, y


In [9]:
from torch.utils.data import Subset
from fastai.vision.all import *
import torch
from fastai.torch_core import defaults


item_tfms  = [Ensure3Channels(), Resize(32)]
batch_tfms = [IntToFloatTensor(),
              *aug_transforms(size=32, max_rotate=15, max_zoom=1.1, p_affine=0.75, p_lighting=0.8),
             Normalize.from_stats(*imagenet_stats)]

valid_batch_tfms = [IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)]

fold_metrics = []
learners = []

for fold, (train_idx, valid_idx) in enumerate(skf.split(hf_train['img'], hf_train['label'])):
    print(f'==== Fold {fold+1} ====')
    train_ds = Subset(hf_train, train_idx)
    valid_ds = Subset(hf_train, valid_idx)

    train_dl = TfmdDL(
        HFtoFAI(train_ds), bs=128, shuffle=True,
        after_item=item_tfms, after_batch=batch_tfms,
        num_workers=1, device='cpu',
    )

    valid_dl = TfmdDL(
        HFtoFAI(valid_ds), bs=128, shuffle=False,
        after_item=item_tfms, after_batch=valid_batch_tfms,
        num_workers=1, device='cpu',
    )

    dls = DataLoaders(train_dl, valid_dl)
    dls.c = 10
    dls.vocab = list(range(10))


    learn = vision_learner(
        dls, resnet18,
        pretrained=False,
        n_out=10,
        loss_func=CrossEntropyLossFlat(),
        metrics=[accuracy, error_rate],
        cbs=[EarlyStoppingCallback(patience=3)],
    ).to_fp16()

    learn.fit_one_cycle(30)

    
    val_res = learn.validate()
    val_loss, val_acc, val_err = val_res
    print(f'==== End Fold {fold+1} with mean acc={val_acc:.3f} ====')

    fold_metrics.append({
        'fold': fold,
        'val_loss': float(val_loss),
        'acc': float(val_acc),
        'err': float(val_err),
    })
    learners.append(learn)

print('==== Summary ====')
print('Mean accuracy:', np.mean([m['acc'] for m in fold_metrics]))
print('Std accuracy:', np.std([m['acc'] for m in fold_metrics]))


==== Fold 1 ====




epoch,train_loss,valid_loss,accuracy,error_rate,time
0,2.583992,1.907416,0.364,0.636,02:57
1,2.03453,1.619279,0.4413,0.5587,02:56
2,1.611417,1.387483,0.4987,0.5013,03:01
3,1.415727,1.22633,0.5671,0.4329,03:00
4,1.296909,1.266184,0.5655,0.4345,03:01
5,1.229784,1.37655,0.5442,0.4558,02:52
6,1.138074,1.016678,0.6332,0.3668,02:49
7,1.038641,1.050419,0.6509,0.3491,02:50
8,0.954033,0.914216,0.6894,0.3106,02:50
9,0.889096,0.89489,0.6884,0.3116,02:50


No improvement since epoch 20: early stopping




==== End Fold 1 with mean acc=0.823 ====
==== Fold 2 ====


epoch,train_loss,valid_loss,accuracy,error_rate,time
0,2.623492,1.974229,0.3552,0.6448,02:49
1,2.046145,1.661489,0.4329,0.5671,02:50
2,1.648183,1.352481,0.5124,0.4876,02:48
3,1.449096,1.309228,0.5286,0.4714,02:50
4,1.330417,1.376283,0.5273,0.4727,02:50
5,1.176519,1.245238,0.581,0.419,02:49
6,1.082883,1.012022,0.6566,0.3434,02:47
7,1.017276,0.902915,0.6891,0.3109,02:47
8,0.927123,0.896815,0.6969,0.3031,02:47
9,0.86478,0.809049,0.7205,0.2795,02:47


==== End Fold 2 with mean acc=0.830 ====
==== Fold 3 ====


epoch,train_loss,valid_loss,accuracy,error_rate,time
0,2.61613,1.998519,0.3459,0.6541,03:11
1,2.097606,1.79901,0.4181,0.5819,03:26
2,1.655926,1.713225,0.4538,0.5462,03:20
3,1.438125,1.354157,0.5421,0.4579,03:13
4,1.336398,1.273668,0.5474,0.4526,03:17
5,1.214622,1.210189,0.5789,0.4211,03:19
6,1.103137,1.008435,0.6467,0.3533,03:16
7,0.988787,0.94884,0.6701,0.3299,03:20
8,0.892649,0.938599,0.6699,0.3301,03:26
9,0.853736,0.816944,0.7095,0.2905,03:22


No improvement since epoch 24: early stopping


==== End Fold 3 with mean acc=0.821 ====
==== Fold 4 ====


epoch,train_loss,valid_loss,accuracy,error_rate,time
0,2.642653,2.036526,0.3518,0.6482,03:24
1,2.051587,1.609221,0.4446,0.5554,03:09
2,1.634056,1.508921,0.4964,0.5036,02:48
3,1.438296,1.496986,0.5224,0.4776,02:48
4,1.294888,4.002858,0.3971,0.6029,02:49
5,1.151338,1.255735,0.5749,0.4251,02:51
6,1.088853,0.977914,0.6642,0.3358,02:54
7,0.988658,0.987338,0.6607,0.3393,02:55
8,0.893804,0.883236,0.6879,0.3121,02:54
9,0.830674,0.853159,0.7131,0.2869,02:55


No improvement since epoch 25: early stopping


==== End Fold 4 with mean acc=0.821 ====
==== Fold 5 ====


epoch,train_loss,valid_loss,accuracy,error_rate,time
0,2.554208,1.941119,0.3694,0.6306,03:00
1,2.046623,1.643749,0.4391,0.5609,02:51
2,1.65373,1.496313,0.4889,0.5111,02:49
3,1.450456,1.419786,0.5196,0.4804,02:47
4,1.297464,1.188523,0.5913,0.4087,02:47
5,1.169388,1.200063,0.5978,0.4022,02:46
6,1.080171,1.011576,0.6469,0.3531,02:46
7,1.004274,0.969997,0.6693,0.3307,02:47
8,0.884769,0.930501,0.6846,0.3154,02:50
9,0.856195,0.805574,0.7212,0.2788,02:49


No improvement since epoch 23: early stopping


==== End Fold 5 with mean acc=0.828 ====
==== Summary ====
Mean accuracy: 0.8246399998664856
Std accuracy: 0.0037382230264735366


In [10]:
learn = learners[np.argmax([ _['acc'] for _ in fold_metrics]).item()]
learn

<fastai.learner.Learner at 0x7f09269d47d0>

In [11]:
valid_ds = HFtoFAI(hf_valid)
valid_dl = TfmdDL(valid_ds, bs=1, shuffle=False,
                  after_item=item_tfms, after_batch=valid_batch_tfms,
                  num_workers=1, device=device,)
val_los, accuracy, error_rate = learn.validate(dl=valid_dl)
print(f"Accuracy: {accuracy:.3f}")

Accuracy: 0.821
