# ViT FineTuning with CIFAR-10 

## Import Library

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from datasets import load_dataset,load_metric

from transformers import ViTFeatureExtractor,ViTForImageClassification,TrainingArguments, Trainer

## Load Dataset (CIFAR10)

In [2]:
dataset = load_dataset('cifar10',split='train')
dataset

Found cached dataset cifar10 (C:/Users/USER/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


Dataset({
    features: ['img', 'label'],
    num_rows: 50000
})

In [3]:
label2id = {"plane":0, "car":1, "bird":2, "cat":3, "deer":4, "dog":5, "frog":6, "horse":7, "ship":8, "truck":9}
id2label = {0:"plane", 1:"car", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}

In [4]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [5]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['img']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['label'] = example_batch['label']
    return inputs

In [6]:
prepared_ds = dataset.with_transform(transform)
prepared_ds = prepared_ds.train_test_split(test_size=0.1)
prepared_ds

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 45000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 5000
    })
})

## Training

In [7]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [8]:
metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [9]:
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=10,
    id2label=id2label,
    label2id=label2id)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
training_args = TrainingArguments(
  output_dir="./",
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=2,
  fp16=True,
  save_steps=10000,
  eval_steps=10000,
  logging_steps=100,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
  
)

PyTorch: setting up devices


In [20]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=feature_extractor,
)

Using cuda_amp half precision backend


In [12]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** Running training *****
  Num examples = 45000
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 11250
  Number of trainable parameters = 85806346
  1%|          | 100/11250 [03:08<5:46:19,  1.86s/it]

{'loss': 2.191, 'learning_rate': 0.00019873777777777778, 'epoch': 0.02}


  2%|▏         | 200/11250 [06:16<5:45:58,  1.88s/it]

{'loss': 1.7366, 'learning_rate': 0.00019696, 'epoch': 0.04}


  3%|▎         | 300/11250 [09:24<5:42:42,  1.88s/it]

{'loss': 1.3208, 'learning_rate': 0.0001952, 'epoch': 0.05}


  4%|▎         | 400/11250 [12:32<5:39:51,  1.88s/it]

{'loss': 0.9913, 'learning_rate': 0.00019342222222222224, 'epoch': 0.07}


  4%|▍         | 500/11250 [15:40<5:37:00,  1.88s/it]

{'loss': 0.6846, 'learning_rate': 0.00019164444444444445, 'epoch': 0.09}


  5%|▌         | 600/11250 [18:48<5:33:57,  1.88s/it]

{'loss': 0.5575, 'learning_rate': 0.00018986666666666668, 'epoch': 0.11}


  6%|▌         | 700/11250 [21:56<5:30:55,  1.88s/it]

{'loss': 0.4374, 'learning_rate': 0.00018810666666666667, 'epoch': 0.12}


  7%|▋         | 800/11250 [25:05<5:27:55,  1.88s/it]

{'loss': 0.3414, 'learning_rate': 0.0001863288888888889, 'epoch': 0.14}


  8%|▊         | 900/11250 [28:13<5:24:27,  1.88s/it]

{'loss': 0.3199, 'learning_rate': 0.00018455111111111112, 'epoch': 0.16}


  9%|▉         | 1000/11250 [31:21<5:21:19,  1.88s/it]

{'loss': 0.2407, 'learning_rate': 0.00018277333333333335, 'epoch': 0.18}


 10%|▉         | 1100/11250 [34:29<5:18:10,  1.88s/it]

{'loss': 0.2432, 'learning_rate': 0.00018099555555555556, 'epoch': 0.2}


 11%|█         | 1200/11250 [37:37<5:14:53,  1.88s/it]

{'loss': 0.2275, 'learning_rate': 0.0001792177777777778, 'epoch': 0.21}


 12%|█▏        | 1300/11250 [40:45<5:11:51,  1.88s/it]

{'loss': 0.2158, 'learning_rate': 0.00017744, 'epoch': 0.23}


 12%|█▏        | 1400/11250 [43:53<5:08:35,  1.88s/it]

{'loss': 0.2419, 'learning_rate': 0.00017566222222222224, 'epoch': 0.25}


 13%|█▎        | 1500/11250 [47:01<5:05:06,  1.88s/it]

{'loss': 0.2794, 'learning_rate': 0.00017388444444444445, 'epoch': 0.27}


 14%|█▍        | 1600/11250 [50:09<5:02:10,  1.88s/it]

{'loss': 0.2544, 'learning_rate': 0.0001721066666666667, 'epoch': 0.28}


 15%|█▌        | 1700/11250 [53:17<4:58:46,  1.88s/it]

{'loss': 0.2293, 'learning_rate': 0.0001703288888888889, 'epoch': 0.3}


 16%|█▌        | 1800/11250 [56:24<4:55:46,  1.88s/it]

{'loss': 0.2482, 'learning_rate': 0.00016855111111111114, 'epoch': 0.32}


 17%|█▋        | 1900/11250 [59:32<4:52:33,  1.88s/it]

{'loss': 0.2676, 'learning_rate': 0.00016677333333333334, 'epoch': 0.34}


 18%|█▊        | 2000/11250 [1:02:40<4:49:47,  1.88s/it]

{'loss': 0.2406, 'learning_rate': 0.00016499555555555558, 'epoch': 0.36}


 19%|█▊        | 2100/11250 [1:05:50<4:45:31,  1.87s/it]

{'loss': 0.2633, 'learning_rate': 0.00016323555555555557, 'epoch': 0.37}


 20%|█▉        | 2200/11250 [1:08:57<4:42:06,  1.87s/it]

{'loss': 0.2199, 'learning_rate': 0.00016145777777777778, 'epoch': 0.39}


 20%|██        | 2300/11250 [1:12:04<4:38:48,  1.87s/it]

{'loss': 0.2378, 'learning_rate': 0.00015968000000000002, 'epoch': 0.41}


 21%|██▏       | 2400/11250 [1:15:11<4:35:41,  1.87s/it]

{'loss': 0.2102, 'learning_rate': 0.00015790222222222225, 'epoch': 0.43}


 22%|██▏       | 2500/11250 [1:18:18<4:32:49,  1.87s/it]

{'loss': 0.2108, 'learning_rate': 0.00015612444444444446, 'epoch': 0.44}


 23%|██▎       | 2600/11250 [1:21:25<4:29:34,  1.87s/it]

{'loss': 0.2149, 'learning_rate': 0.00015434666666666667, 'epoch': 0.46}


 24%|██▍       | 2700/11250 [1:24:32<4:26:35,  1.87s/it]

{'loss': 0.2108, 'learning_rate': 0.0001525688888888889, 'epoch': 0.48}


 25%|██▍       | 2800/11250 [1:27:39<4:23:27,  1.87s/it]

{'loss': 0.1954, 'learning_rate': 0.00015079111111111112, 'epoch': 0.5}


 26%|██▌       | 2900/11250 [1:30:46<4:20:19,  1.87s/it]

{'loss': 0.2179, 'learning_rate': 0.00014901333333333333, 'epoch': 0.52}


 27%|██▋       | 3000/11250 [1:33:53<4:17:18,  1.87s/it]

{'loss': 0.1819, 'learning_rate': 0.00014723555555555556, 'epoch': 0.53}


 28%|██▊       | 3100/11250 [1:37:00<4:14:08,  1.87s/it]

{'loss': 0.1748, 'learning_rate': 0.00014545777777777777, 'epoch': 0.55}


 28%|██▊       | 3200/11250 [1:40:07<4:10:56,  1.87s/it]

{'loss': 0.1989, 'learning_rate': 0.00014368, 'epoch': 0.57}


 29%|██▉       | 3300/11250 [1:43:14<4:07:34,  1.87s/it]

{'loss': 0.1889, 'learning_rate': 0.00014190222222222222, 'epoch': 0.59}


 30%|███       | 3400/11250 [1:46:21<4:04:32,  1.87s/it]

{'loss': 0.1996, 'learning_rate': 0.00014012444444444445, 'epoch': 0.6}


 31%|███       | 3500/11250 [1:49:28<4:01:29,  1.87s/it]

{'loss': 0.1734, 'learning_rate': 0.00013834666666666666, 'epoch': 0.62}


 32%|███▏      | 3600/11250 [1:52:35<3:58:20,  1.87s/it]

{'loss': 0.1494, 'learning_rate': 0.0001365688888888889, 'epoch': 0.64}


 33%|███▎      | 3700/11250 [1:55:42<3:55:15,  1.87s/it]

{'loss': 0.1817, 'learning_rate': 0.0001347911111111111, 'epoch': 0.66}


 34%|███▍      | 3800/11250 [1:58:49<3:52:11,  1.87s/it]

{'loss': 0.1472, 'learning_rate': 0.00013301333333333334, 'epoch': 0.68}


 35%|███▍      | 3900/11250 [2:01:56<3:48:55,  1.87s/it]

{'loss': 0.1353, 'learning_rate': 0.00013123555555555555, 'epoch': 0.69}


 36%|███▌      | 4000/11250 [2:05:03<3:45:56,  1.87s/it]

{'loss': 0.18, 'learning_rate': 0.0001294577777777778, 'epoch': 0.71}


 36%|███▋      | 4100/11250 [2:08:04<3:35:38,  1.81s/it]

{'loss': 0.23, 'learning_rate': 0.00012769777777777778, 'epoch': 0.73}


 37%|███▋      | 4200/11250 [2:11:05<3:32:33,  1.81s/it]

{'loss': 0.1658, 'learning_rate': 0.00012592000000000001, 'epoch': 0.75}


 38%|███▊      | 4300/11250 [2:14:06<3:29:37,  1.81s/it]

{'loss': 0.1675, 'learning_rate': 0.00012414222222222222, 'epoch': 0.76}


 39%|███▉      | 4400/11250 [2:17:07<3:26:29,  1.81s/it]

{'loss': 0.1729, 'learning_rate': 0.00012236444444444446, 'epoch': 0.78}


 40%|████      | 4500/11250 [2:20:08<3:23:33,  1.81s/it]

{'loss': 0.1738, 'learning_rate': 0.00012058666666666667, 'epoch': 0.8}


 41%|████      | 4600/11250 [2:23:09<3:20:27,  1.81s/it]

{'loss': 0.2052, 'learning_rate': 0.0001188088888888889, 'epoch': 0.82}


 42%|████▏     | 4700/11250 [2:26:10<3:17:28,  1.81s/it]

{'loss': 0.1892, 'learning_rate': 0.00011703111111111111, 'epoch': 0.84}


 43%|████▎     | 4800/11250 [2:29:10<3:14:31,  1.81s/it]

{'loss': 0.1597, 'learning_rate': 0.00011525333333333335, 'epoch': 0.85}


 44%|████▎     | 4900/11250 [2:32:11<3:11:26,  1.81s/it]

{'loss': 0.1797, 'learning_rate': 0.00011347555555555556, 'epoch': 0.87}


 44%|████▍     | 5000/11250 [2:35:12<3:08:25,  1.81s/it]

{'loss': 0.1738, 'learning_rate': 0.0001116977777777778, 'epoch': 0.89}


 45%|████▌     | 5100/11250 [2:38:13<3:05:31,  1.81s/it]

{'loss': 0.1624, 'learning_rate': 0.00010993777777777779, 'epoch': 0.91}


 46%|████▌     | 5200/11250 [2:41:14<3:02:25,  1.81s/it]

{'loss': 0.1803, 'learning_rate': 0.00010816, 'epoch': 0.92}


 47%|████▋     | 5300/11250 [2:44:15<2:59:18,  1.81s/it]

{'loss': 0.1679, 'learning_rate': 0.00010638222222222223, 'epoch': 0.94}


 48%|████▊     | 5400/11250 [2:47:16<2:56:18,  1.81s/it]

{'loss': 0.1374, 'learning_rate': 0.00010460444444444447, 'epoch': 0.96}


 49%|████▉     | 5500/11250 [2:50:17<2:53:19,  1.81s/it]

{'loss': 0.1559, 'learning_rate': 0.00010282666666666668, 'epoch': 0.98}


 50%|████▉     | 5600/11250 [2:53:18<2:50:21,  1.81s/it]

{'loss': 0.1805, 'learning_rate': 0.0001010488888888889, 'epoch': 1.0}


 51%|█████     | 5700/11250 [2:56:19<2:47:14,  1.81s/it]

{'loss': 0.1612, 'learning_rate': 9.927111111111112e-05, 'epoch': 1.01}


 52%|█████▏    | 5800/11250 [2:59:19<2:44:18,  1.81s/it]

{'loss': 0.1792, 'learning_rate': 9.749333333333333e-05, 'epoch': 1.03}


 52%|█████▏    | 5900/11250 [3:02:20<2:41:14,  1.81s/it]

{'loss': 0.1727, 'learning_rate': 9.571555555555555e-05, 'epoch': 1.05}


 53%|█████▎    | 6000/11250 [3:05:21<2:38:13,  1.81s/it]

{'loss': 0.163, 'learning_rate': 9.393777777777778e-05, 'epoch': 1.07}


 54%|█████▍    | 6100/11250 [3:08:22<2:35:12,  1.81s/it]

{'loss': 0.1727, 'learning_rate': 9.216e-05, 'epoch': 1.08}


 55%|█████▌    | 6200/11250 [3:11:23<2:32:17,  1.81s/it]

{'loss': 0.1708, 'learning_rate': 9.038222222222222e-05, 'epoch': 1.1}


 56%|█████▌    | 6300/11250 [3:14:24<2:29:12,  1.81s/it]

{'loss': 0.1469, 'learning_rate': 8.860444444444444e-05, 'epoch': 1.12}


 57%|█████▋    | 6400/11250 [3:17:25<2:26:09,  1.81s/it]

{'loss': 0.186, 'learning_rate': 8.682666666666667e-05, 'epoch': 1.14}


 58%|█████▊    | 6500/11250 [3:20:26<2:23:12,  1.81s/it]

{'loss': 0.1515, 'learning_rate': 8.504888888888889e-05, 'epoch': 1.16}


 59%|█████▊    | 6600/11250 [3:23:27<2:20:14,  1.81s/it]

{'loss': 0.1674, 'learning_rate': 8.327111111111111e-05, 'epoch': 1.17}


 60%|█████▉    | 6700/11250 [3:26:28<2:17:15,  1.81s/it]

{'loss': 0.182, 'learning_rate': 8.149333333333333e-05, 'epoch': 1.19}


 60%|██████    | 6800/11250 [3:29:28<2:14:10,  1.81s/it]

{'loss': 0.1947, 'learning_rate': 7.971555555555556e-05, 'epoch': 1.21}


 61%|██████▏   | 6900/11250 [3:32:29<2:11:06,  1.81s/it]

{'loss': 0.1817, 'learning_rate': 7.793777777777778e-05, 'epoch': 1.23}


 62%|██████▏   | 7000/11250 [3:35:30<2:08:08,  1.81s/it]

{'loss': 0.1969, 'learning_rate': 7.616e-05, 'epoch': 1.24}


 63%|██████▎   | 7100/11250 [3:38:31<2:05:08,  1.81s/it]

{'loss': 0.1851, 'learning_rate': 7.438222222222223e-05, 'epoch': 1.26}


 64%|██████▍   | 7200/11250 [3:41:32<2:02:06,  1.81s/it]

{'loss': 0.1922, 'learning_rate': 7.260444444444445e-05, 'epoch': 1.28}


 65%|██████▍   | 7300/11250 [3:44:33<1:59:08,  1.81s/it]

{'loss': 0.1542, 'learning_rate': 7.084444444444445e-05, 'epoch': 1.3}


 66%|██████▌   | 7400/11250 [3:47:34<1:56:03,  1.81s/it]

{'loss': 0.1757, 'learning_rate': 6.906666666666667e-05, 'epoch': 1.32}


 67%|██████▋   | 7500/11250 [3:50:35<1:53:01,  1.81s/it]

{'loss': 0.1475, 'learning_rate': 6.72888888888889e-05, 'epoch': 1.33}


 68%|██████▊   | 7600/11250 [3:53:36<1:50:02,  1.81s/it]

{'loss': 0.1983, 'learning_rate': 6.55111111111111e-05, 'epoch': 1.35}


 68%|██████▊   | 7700/11250 [3:56:37<1:47:02,  1.81s/it]

{'loss': 0.163, 'learning_rate': 6.373333333333333e-05, 'epoch': 1.37}


 69%|██████▉   | 7800/11250 [3:59:37<1:43:57,  1.81s/it]

{'loss': 0.181, 'learning_rate': 6.195555555555555e-05, 'epoch': 1.39}


 70%|███████   | 7900/11250 [4:02:38<1:41:02,  1.81s/it]

{'loss': 0.1835, 'learning_rate': 6.017777777777778e-05, 'epoch': 1.4}


 71%|███████   | 8000/11250 [4:05:39<1:37:59,  1.81s/it]

{'loss': 0.1677, 'learning_rate': 5.8399999999999997e-05, 'epoch': 1.42}


 72%|███████▏  | 8100/11250 [4:08:40<1:35:00,  1.81s/it]

{'loss': 0.1862, 'learning_rate': 5.662222222222222e-05, 'epoch': 1.44}


 73%|███████▎  | 8200/11250 [4:11:41<1:31:59,  1.81s/it]

{'loss': 0.1896, 'learning_rate': 5.484444444444444e-05, 'epoch': 1.46}


 74%|███████▍  | 8300/11250 [4:14:42<1:28:58,  1.81s/it]

{'loss': 0.1758, 'learning_rate': 5.3066666666666665e-05, 'epoch': 1.48}


 75%|███████▍  | 8400/11250 [4:17:43<1:25:55,  1.81s/it]

{'loss': 0.1639, 'learning_rate': 5.128888888888889e-05, 'epoch': 1.49}


 76%|███████▌  | 8500/11250 [4:20:44<1:22:54,  1.81s/it]

{'loss': 0.1991, 'learning_rate': 4.951111111111112e-05, 'epoch': 1.51}


 76%|███████▋  | 8600/11250 [4:23:45<1:19:54,  1.81s/it]

{'loss': 0.1806, 'learning_rate': 4.773333333333333e-05, 'epoch': 1.53}


 77%|███████▋  | 8700/11250 [4:26:46<1:16:53,  1.81s/it]

{'loss': 0.1889, 'learning_rate': 4.5955555555555555e-05, 'epoch': 1.55}


 78%|███████▊  | 8800/11250 [4:29:47<1:13:53,  1.81s/it]

{'loss': 0.1971, 'learning_rate': 4.417777777777778e-05, 'epoch': 1.56}


 79%|███████▉  | 8900/11250 [4:32:48<1:10:50,  1.81s/it]

{'loss': 0.2007, 'learning_rate': 4.24e-05, 'epoch': 1.58}


 80%|████████  | 9000/11250 [4:35:49<1:07:49,  1.81s/it]

{'loss': 0.1837, 'learning_rate': 4.062222222222222e-05, 'epoch': 1.6}


 81%|████████  | 9100/11250 [4:38:49<1:04:47,  1.81s/it]

{'loss': 0.1968, 'learning_rate': 3.8844444444444446e-05, 'epoch': 1.62}


 82%|████████▏ | 9200/11250 [4:41:50<1:01:47,  1.81s/it]

{'loss': 0.1665, 'learning_rate': 3.706666666666667e-05, 'epoch': 1.64}


 83%|████████▎ | 9300/11250 [4:44:51<58:50,  1.81s/it]  

{'loss': 0.2175, 'learning_rate': 3.528888888888889e-05, 'epoch': 1.65}


 84%|████████▎ | 9400/11250 [4:47:52<55:47,  1.81s/it]

{'loss': 0.2102, 'learning_rate': 3.3528888888888895e-05, 'epoch': 1.67}


 84%|████████▍ | 9500/11250 [4:50:53<52:45,  1.81s/it]

{'loss': 0.1689, 'learning_rate': 3.175111111111112e-05, 'epoch': 1.69}


 85%|████████▌ | 9600/11250 [4:53:54<49:46,  1.81s/it]

{'loss': 0.2015, 'learning_rate': 2.9973333333333337e-05, 'epoch': 1.71}


 86%|████████▌ | 9700/11250 [4:56:55<46:44,  1.81s/it]

{'loss': 0.1989, 'learning_rate': 2.819555555555556e-05, 'epoch': 1.72}


 87%|████████▋ | 9800/11250 [4:59:56<43:44,  1.81s/it]

{'loss': 0.1946, 'learning_rate': 2.641777777777778e-05, 'epoch': 1.74}


 88%|████████▊ | 9900/11250 [5:02:57<40:43,  1.81s/it]

{'loss': 0.1902, 'learning_rate': 2.464e-05, 'epoch': 1.76}


 89%|████████▉ | 10000/11250 [5:05:58<37:41,  1.81s/it]***** Running Evaluation *****
  Num examples = 5000
  Batch size = 8


{'loss': 0.1925, 'learning_rate': 2.2862222222222224e-05, 'epoch': 1.78}


                                                       
 89%|████████▉ | 10000/11250 [5:12:38<37:41,  1.81s/it]Saving model checkpoint to ./checkpoint-10000
Configuration saved in ./checkpoint-10000\config.json


{'eval_loss': 0.1873820275068283, 'eval_accuracy': 0.9448, 'eval_runtime': 400.5671, 'eval_samples_per_second': 12.482, 'eval_steps_per_second': 1.56, 'epoch': 1.78}


Model weights saved in ./checkpoint-10000\pytorch_model.bin
Image processor saved in ./checkpoint-10000\preprocessor_config.json
Deleting older checkpoint [checkpoint-300] due to args.save_total_limit
 90%|████████▉ | 10100/11250 [5:15:42<34:40,  1.81s/it]    

{'loss': 0.1996, 'learning_rate': 2.1084444444444447e-05, 'epoch': 1.8}


 91%|█████████ | 10200/11250 [5:18:43<31:39,  1.81s/it]

{'loss': 0.215, 'learning_rate': 1.9306666666666666e-05, 'epoch': 1.81}


 92%|█████████▏| 10300/11250 [5:21:44<28:38,  1.81s/it]

{'loss': 0.1783, 'learning_rate': 1.752888888888889e-05, 'epoch': 1.83}


 92%|█████████▏| 10400/11250 [5:24:45<25:37,  1.81s/it]

{'loss': 0.1935, 'learning_rate': 1.575111111111111e-05, 'epoch': 1.85}


 93%|█████████▎| 10500/11250 [5:27:46<22:37,  1.81s/it]

{'loss': 0.1977, 'learning_rate': 1.3973333333333332e-05, 'epoch': 1.87}


 94%|█████████▍| 10600/11250 [5:30:47<19:35,  1.81s/it]

{'loss': 0.1936, 'learning_rate': 1.2195555555555557e-05, 'epoch': 1.88}


 95%|█████████▌| 10700/11250 [5:33:48<16:34,  1.81s/it]

{'loss': 0.1675, 'learning_rate': 1.0417777777777778e-05, 'epoch': 1.9}


 96%|█████████▌| 10800/11250 [5:36:49<13:33,  1.81s/it]

{'loss': 0.2239, 'learning_rate': 8.64e-06, 'epoch': 1.92}


 97%|█████████▋| 10900/11250 [5:39:50<10:33,  1.81s/it]

{'loss': 0.1786, 'learning_rate': 6.862222222222223e-06, 'epoch': 1.94}


 98%|█████████▊| 11000/11250 [5:42:50<07:32,  1.81s/it]

{'loss': 0.2025, 'learning_rate': 5.084444444444445e-06, 'epoch': 1.96}


 99%|█████████▊| 11100/11250 [5:45:51<04:31,  1.81s/it]

{'loss': 0.196, 'learning_rate': 3.306666666666667e-06, 'epoch': 1.97}


100%|█████████▉| 11200/11250 [5:49:00<01:34,  1.88s/it]

{'loss': 0.1743, 'learning_rate': 1.5288888888888889e-06, 'epoch': 1.99}


100%|██████████| 11250/11250 [5:50:37<00:00,  2.06s/it]

Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from ./checkpoint-10000 (score: 0.1873820275068283).
100%|██████████| 11250/11250 [5:50:37<00:00,  1.87s/it]
Saving model checkpoint to ./
Configuration saved in ./config.json


{'train_runtime': 21037.8953, 'train_samples_per_second': 4.278, 'train_steps_per_second': 0.535, 'train_loss': 0.25218826421101886, 'epoch': 2.0}


Model weights saved in ./pytorch_model.bin
Image processor saved in ./preprocessor_config.json


***** train metrics *****
  epoch                    =        2.0
  train_loss               =     0.2522
  train_runtime            = 5:50:37.89
  train_samples_per_second =      4.278
  train_steps_per_second   =      0.535


In [13]:
dataset_test = load_dataset('cifar10',split='test')
dataset_test

Found cached dataset cifar10 (C:/Users/USER/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


Dataset({
    features: ['img', 'label'],
    num_rows: 10000
})

## Testing

In [21]:
model = ViTForImageClassification.from_pretrained(
    './checkpoint-10000/',
    num_labels=10,
    id2label=id2label,
    label2id=label2id)

loading configuration file ./checkpoint-10000/config.json
Model config ViTConfig {
  "_name_or_path": "google/vit-base-patch16-224-in21k",
  "architectures": [
    "ViTForImageClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "plane",
    "1": "car",
    "2": "bird",
    "3": "cat",
    "4": "deer",
    "5": "dog",
    "6": "frog",
    "7": "horse",
    "8": "ship",
    "9": "truck"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "bird": 2,
    "car": 1,
    "cat": 3,
    "deer": 4,
    "dog": 5,
    "frog": 6,
    "horse": 7,
    "plane": 0,
    "ship": 8,
    "truck": 9
  },
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "problem_type": "single_label_classification",
  "qkv_bias": true,
  "torch_

In [23]:
prepared_ds_test = dataset_test.with_transform(transform)

In [24]:
# Predictions
y_test_predict = trainer.predict(prepared_ds_test)

# Take a look at the predictions
y_test_predict

***** Running Prediction *****
  Num examples = 10000
  Batch size = 8
100%|██████████| 1250/1250 [14:25<00:00,  1.44it/s]


PredictionOutput(predictions=array([[ 2.65  ,  1.558 ,  1.853 , ...,  2.688 ,  1.09  ,  1.026 ],
       [ 3.559 ,  4.49  ,  1.596 , ...,  1.493 , 12.81  ,  2.729 ],
       [ 4.184 ,  3.684 ,  1.752 , ...,  1.954 , 10.734 ,  2.432 ],
       ...,
       [ 1.27  , -0.089 ,  1.393 , ...,  2.295 ,  0.6973,  1.25  ],
       [ 6.598 ,  7.43  ,  2.408 , ...,  2.807 ,  3.254 ,  3.709 ],
       [ 1.143 ,  1.122 ,  2.422 , ..., 11.66  ,  1.349 ,  1.61  ]],
      dtype=float16), label_ids=array([3, 8, 8, ..., 5, 1, 7], dtype=int64), metrics={'test_loss': 0.20327411592006683, 'test_accuracy': 0.9399, 'test_runtime': 866.4282, 'test_samples_per_second': 11.542, 'test_steps_per_second': 1.443})