# Vision Transformer Implementation

[ViT Docs](https://huggingface.co/docs/transformers/en/model_doc/vit)

### Abstract
While the Transformer architecture has become the de-facto standard for natural language processing tasks, its applications to computer vision remain limited. In vision, attention is either applied in conjunction with convolutional networks, or used to replace certain components of convolutional networks while keeping their overall structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks.[[1]](#1)

![ViT Architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vit_architecture.jpg)

## Installation

Run the following commands to install the required libraries and dependencies on the virtual environment.
```shell
conda activate env/cmil
conda install transformers datasets accelerate -c conda-forge
```


## Hugging Face Hub and Credentials

We need the hugging face read/write API token to run this notebook. Please run the cell below and paste your API key for accessing model checkpoint.

In [1]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [55]:
import os
import torch
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

from PIL import Image
from datasets import Dataset, load_metric
from transformers import AutoImageProcessor, DefaultDataCollator, Trainer, TrainingArguments, AutoModelForImageClassification
from torchvision.transforms import RandomResizedCrop, ToTensor, Normalize, Compose

In [3]:
df = pd.read_csv('public/public.csv')
num_classes = ['non_globally_sclerotic_glomeruli', 'globally_sclerotic_glomeruli']
df['path'] = df.apply(lambda x: f"public/{num_classes[x['ground truth']]}/{x['name']}", axis=1)
df.head()

Unnamed: 0,name,ground truth,path
0,S-2006-005094_PAS_1of2_64552732435c92704a3d37c...,0,public/non_globally_sclerotic_glomeruli/S-2006...
1,S-2006-005094_PAS_1of2_64552732435c92704a3d37c...,0,public/non_globally_sclerotic_glomeruli/S-2006...
2,S-2006-005094_PAS_1of2_64552732435c92704a3d37c...,0,public/non_globally_sclerotic_glomeruli/S-2006...
3,S-2006-005094_PAS_1of2_64552732435c92704a3d37c...,0,public/non_globally_sclerotic_glomeruli/S-2006...
4,S-2006-005094_PAS_1of2_64552732435c92704a3d37d...,0,public/non_globally_sclerotic_glomeruli/S-2006...


In [4]:
dataset = Dataset.from_pandas(df)
dataset = dataset.map(lambda x: {'image': Image.open(x['path']).convert('RGB'), 'label': x['ground truth']}, remove_columns=['path', 'name', 'ground truth'])

Map:   0%|          | 0/5758 [00:00<?, ? examples/s]

In [7]:
dataset = dataset.train_test_split(test_size=0.2)

In [8]:
label2id, id2label = {}, {}
for i, label in enumerate(num_classes):
    label2id[label] = i
    id2label[i] = label

In [76]:
checkpoint = "google/vit-base-patch16-224-in21k"
image_size = 224
processor = AutoImageProcessor.from_pretrained(checkpoint, size=image_size)

In [77]:
_transforms = Compose([
    ToTensor(),
    Normalize(mean=processor.image_mean, std=processor.image_std)
])

def transform(examples):
    inputs = processor([image for image in examples['image']], return_tensors="pt")
    inputs['label'] = examples['label']
    return inputs

In [78]:
dataset['train'][0]

{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=788x807>,
 'label': 0}

In [79]:
ds = dataset.with_transform(transform)

In [80]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [82]:
metric = load_metric("accuracy", trust_remote_code=True)

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return metric.compute(predictions=preds, references=p.label_ids)

In [83]:
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(num_classes),
    id2label=id2label,
    label2id=label2id
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [84]:
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    remove_unused_columns=False,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds['train'],
    eval_dataset=ds['test'],
    tokenizer=processor,
    data_collator=collate_fn,
    compute_metrics=compute_metrics
)

In [85]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

  0%|          | 0/864 [00:00<?, ?it/s]

{'loss': 0.4518, 'grad_norm': 0.9761220812797546, 'learning_rate': 4.94212962962963e-05, 'epoch': 0.03}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.3402031660079956, 'eval_accuracy': 0.8012152777777778, 'eval_runtime': 86.232, 'eval_samples_per_second': 13.359, 'eval_steps_per_second': 1.67, 'epoch': 0.03}
{'loss': 0.2895, 'grad_norm': 2.2736270427703857, 'learning_rate': 4.8842592592592595e-05, 'epoch': 0.07}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.2407713234424591, 'eval_accuracy': 0.8871527777777778, 'eval_runtime': 89.7032, 'eval_samples_per_second': 12.842, 'eval_steps_per_second': 1.605, 'epoch': 0.07}
{'loss': 0.1526, 'grad_norm': 0.7643176913261414, 'learning_rate': 4.8263888888888895e-05, 'epoch': 0.1}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.12461209297180176, 'eval_accuracy': 0.9722222222222222, 'eval_runtime': 83.0902, 'eval_samples_per_second': 13.864, 'eval_steps_per_second': 1.733, 'epoch': 0.1}
{'loss': 0.1406, 'grad_norm': 5.452765464782715, 'learning_rate': 4.768518518518519e-05, 'epoch': 0.14}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.1441803276538849, 'eval_accuracy': 0.9444444444444444, 'eval_runtime': 43.03, 'eval_samples_per_second': 26.772, 'eval_steps_per_second': 3.347, 'epoch': 0.14}
{'loss': 0.0901, 'grad_norm': 0.19336548447608948, 'learning_rate': 4.710648148148149e-05, 'epoch': 0.17}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.08282507210969925, 'eval_accuracy': 0.9817708333333334, 'eval_runtime': 44.2907, 'eval_samples_per_second': 26.01, 'eval_steps_per_second': 3.251, 'epoch': 0.17}
{'loss': 0.0958, 'grad_norm': 2.5958755016326904, 'learning_rate': 4.652777777777778e-05, 'epoch': 0.21}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.09175758063793182, 'eval_accuracy': 0.9774305555555556, 'eval_runtime': 43.3705, 'eval_samples_per_second': 26.562, 'eval_steps_per_second': 3.32, 'epoch': 0.21}
{'loss': 0.2788, 'grad_norm': 3.240398406982422, 'learning_rate': 4.594907407407408e-05, 'epoch': 0.24}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.10169175267219543, 'eval_accuracy': 0.9635416666666666, 'eval_runtime': 42.6389, 'eval_samples_per_second': 27.018, 'eval_steps_per_second': 3.377, 'epoch': 0.24}
{'loss': 0.0883, 'grad_norm': 0.16685526072978973, 'learning_rate': 4.5370370370370374e-05, 'epoch': 0.28}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.06528379768133163, 'eval_accuracy': 0.9817708333333334, 'eval_runtime': 41.1051, 'eval_samples_per_second': 28.026, 'eval_steps_per_second': 3.503, 'epoch': 0.28}
{'loss': 0.0991, 'grad_norm': 3.9426655769348145, 'learning_rate': 4.4791666666666673e-05, 'epoch': 0.31}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.08623868972063065, 'eval_accuracy': 0.9730902777777778, 'eval_runtime': 41.8329, 'eval_samples_per_second': 27.538, 'eval_steps_per_second': 3.442, 'epoch': 0.31}
{'loss': 0.1133, 'grad_norm': 0.1730802059173584, 'learning_rate': 4.4212962962962966e-05, 'epoch': 0.35}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.048651501536369324, 'eval_accuracy': 0.9852430555555556, 'eval_runtime': 39.9884, 'eval_samples_per_second': 28.808, 'eval_steps_per_second': 3.601, 'epoch': 0.35}
{'loss': 0.0985, 'grad_norm': 2.0847182273864746, 'learning_rate': 4.3634259259259266e-05, 'epoch': 0.38}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.0744868814945221, 'eval_accuracy': 0.9765625, 'eval_runtime': 43.3684, 'eval_samples_per_second': 26.563, 'eval_steps_per_second': 3.32, 'epoch': 0.38}
{'loss': 0.0651, 'grad_norm': 0.1199178695678711, 'learning_rate': 4.305555555555556e-05, 'epoch': 0.42}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.04724985361099243, 'eval_accuracy': 0.9878472222222222, 'eval_runtime': 42.371, 'eval_samples_per_second': 27.188, 'eval_steps_per_second': 3.399, 'epoch': 0.42}
{'loss': 0.0313, 'grad_norm': 0.192734032869339, 'learning_rate': 4.247685185185186e-05, 'epoch': 0.45}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.04165603220462799, 'eval_accuracy': 0.9887152777777778, 'eval_runtime': 41.0327, 'eval_samples_per_second': 28.075, 'eval_steps_per_second': 3.509, 'epoch': 0.45}
{'loss': 0.0774, 'grad_norm': 9.349303245544434, 'learning_rate': 4.1898148148148145e-05, 'epoch': 0.49}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03480932116508484, 'eval_accuracy': 0.9913194444444444, 'eval_runtime': 42.5074, 'eval_samples_per_second': 27.101, 'eval_steps_per_second': 3.388, 'epoch': 0.49}
{'loss': 0.0374, 'grad_norm': 2.474944829940796, 'learning_rate': 4.1319444444444445e-05, 'epoch': 0.52}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.040861599147319794, 'eval_accuracy': 0.9904513888888888, 'eval_runtime': 42.5826, 'eval_samples_per_second': 27.053, 'eval_steps_per_second': 3.382, 'epoch': 0.52}
{'loss': 0.0534, 'grad_norm': 0.1629900336265564, 'learning_rate': 4.074074074074074e-05, 'epoch': 0.56}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03371104598045349, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 40.9305, 'eval_samples_per_second': 28.145, 'eval_steps_per_second': 3.518, 'epoch': 0.56}
{'loss': 0.0883, 'grad_norm': 1.6173553466796875, 'learning_rate': 4.016203703703704e-05, 'epoch': 0.59}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.0307486392557621, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 41.5897, 'eval_samples_per_second': 27.699, 'eval_steps_per_second': 3.462, 'epoch': 0.59}
{'loss': 0.0168, 'grad_norm': 0.09957311302423477, 'learning_rate': 3.958333333333333e-05, 'epoch': 0.62}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.037585336714982986, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 41.265, 'eval_samples_per_second': 27.917, 'eval_steps_per_second': 3.49, 'epoch': 0.62}
{'loss': 0.0117, 'grad_norm': 0.08811041712760925, 'learning_rate': 3.900462962962963e-05, 'epoch': 0.66}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.050885483622550964, 'eval_accuracy': 0.9852430555555556, 'eval_runtime': 39.9183, 'eval_samples_per_second': 28.859, 'eval_steps_per_second': 3.607, 'epoch': 0.66}
{'loss': 0.011, 'grad_norm': 0.06184639781713486, 'learning_rate': 3.8425925925925924e-05, 'epoch': 0.69}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.039528738707304, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 41.8687, 'eval_samples_per_second': 27.515, 'eval_steps_per_second': 3.439, 'epoch': 0.69}
{'loss': 0.0458, 'grad_norm': 0.09029455482959747, 'learning_rate': 3.7847222222222224e-05, 'epoch': 0.73}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.06940127909183502, 'eval_accuracy': 0.9800347222222222, 'eval_runtime': 41.1807, 'eval_samples_per_second': 27.974, 'eval_steps_per_second': 3.497, 'epoch': 0.73}
{'loss': 0.077, 'grad_norm': 1.0053964853286743, 'learning_rate': 3.726851851851852e-05, 'epoch': 0.76}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.0383620448410511, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 43.4326, 'eval_samples_per_second': 26.524, 'eval_steps_per_second': 3.315, 'epoch': 0.76}
{'loss': 0.1143, 'grad_norm': 0.35228249430656433, 'learning_rate': 3.6689814814814816e-05, 'epoch': 0.8}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.027611752972006798, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 41.4076, 'eval_samples_per_second': 27.821, 'eval_steps_per_second': 3.478, 'epoch': 0.8}
{'loss': 0.0189, 'grad_norm': 7.034470081329346, 'learning_rate': 3.611111111111111e-05, 'epoch': 0.83}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.06573370099067688, 'eval_accuracy': 0.9774305555555556, 'eval_runtime': 42.2918, 'eval_samples_per_second': 27.239, 'eval_steps_per_second': 3.405, 'epoch': 0.83}
{'loss': 0.0752, 'grad_norm': 0.06566165387630463, 'learning_rate': 3.553240740740741e-05, 'epoch': 0.87}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.030190574005246162, 'eval_accuracy': 0.9904513888888888, 'eval_runtime': 41.136, 'eval_samples_per_second': 28.005, 'eval_steps_per_second': 3.501, 'epoch': 0.87}
{'loss': 0.0524, 'grad_norm': 7.775566577911377, 'learning_rate': 3.49537037037037e-05, 'epoch': 0.9}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03139081969857216, 'eval_accuracy': 0.9921875, 'eval_runtime': 41.942, 'eval_samples_per_second': 27.466, 'eval_steps_per_second': 3.433, 'epoch': 0.9}
{'loss': 0.0488, 'grad_norm': 6.464156627655029, 'learning_rate': 3.4375e-05, 'epoch': 0.94}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.026334941387176514, 'eval_accuracy': 0.9956597222222222, 'eval_runtime': 41.5278, 'eval_samples_per_second': 27.74, 'eval_steps_per_second': 3.468, 'epoch': 0.94}
{'loss': 0.0943, 'grad_norm': 0.32471320033073425, 'learning_rate': 3.3796296296296295e-05, 'epoch': 0.97}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.049042537808418274, 'eval_accuracy': 0.9809027777777778, 'eval_runtime': 40.5408, 'eval_samples_per_second': 28.416, 'eval_steps_per_second': 3.552, 'epoch': 0.97}
{'loss': 0.0624, 'grad_norm': 0.10303625464439392, 'learning_rate': 3.3217592592592595e-05, 'epoch': 1.01}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.0438123494386673, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 43.0916, 'eval_samples_per_second': 26.734, 'eval_steps_per_second': 3.342, 'epoch': 1.01}
{'loss': 0.0165, 'grad_norm': 0.37413454055786133, 'learning_rate': 3.263888888888889e-05, 'epoch': 1.04}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03435048088431358, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 41.1764, 'eval_samples_per_second': 27.977, 'eval_steps_per_second': 3.497, 'epoch': 1.04}
{'loss': 0.0177, 'grad_norm': 0.11163243651390076, 'learning_rate': 3.206018518518519e-05, 'epoch': 1.08}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.029102176427841187, 'eval_accuracy': 0.9921875, 'eval_runtime': 42.6162, 'eval_samples_per_second': 27.032, 'eval_steps_per_second': 3.379, 'epoch': 1.08}
{'loss': 0.0239, 'grad_norm': 0.06347991526126862, 'learning_rate': 3.148148148148148e-05, 'epoch': 1.11}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02855781652033329, 'eval_accuracy': 0.9921875, 'eval_runtime': 42.6433, 'eval_samples_per_second': 27.015, 'eval_steps_per_second': 3.377, 'epoch': 1.11}
{'loss': 0.0435, 'grad_norm': 0.05664515122771263, 'learning_rate': 3.090277777777778e-05, 'epoch': 1.15}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03050212375819683, 'eval_accuracy': 0.9913194444444444, 'eval_runtime': 40.1796, 'eval_samples_per_second': 28.671, 'eval_steps_per_second': 3.584, 'epoch': 1.15}
{'loss': 0.0116, 'grad_norm': 0.05008790269494057, 'learning_rate': 3.0324074074074077e-05, 'epoch': 1.18}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03456183150410652, 'eval_accuracy': 0.9904513888888888, 'eval_runtime': 42.8483, 'eval_samples_per_second': 26.886, 'eval_steps_per_second': 3.361, 'epoch': 1.18}
{'loss': 0.0065, 'grad_norm': 0.06178167834877968, 'learning_rate': 2.9745370370370373e-05, 'epoch': 1.22}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.032364923506975174, 'eval_accuracy': 0.9904513888888888, 'eval_runtime': 42.6746, 'eval_samples_per_second': 26.995, 'eval_steps_per_second': 3.374, 'epoch': 1.22}
{'loss': 0.0367, 'grad_norm': 12.993435859680176, 'learning_rate': 2.916666666666667e-05, 'epoch': 1.25}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.05039579048752785, 'eval_accuracy': 0.9835069444444444, 'eval_runtime': 41.209, 'eval_samples_per_second': 27.955, 'eval_steps_per_second': 3.494, 'epoch': 1.25}
{'loss': 0.0067, 'grad_norm': 0.04145192354917526, 'learning_rate': 2.8587962962962966e-05, 'epoch': 1.28}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.035078540444374084, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 41.9353, 'eval_samples_per_second': 27.471, 'eval_steps_per_second': 3.434, 'epoch': 1.28}
{'loss': 0.0301, 'grad_norm': 0.04039226844906807, 'learning_rate': 2.8009259259259263e-05, 'epoch': 1.32}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.04108380526304245, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 41.4154, 'eval_samples_per_second': 27.816, 'eval_steps_per_second': 3.477, 'epoch': 1.32}
{'loss': 0.0245, 'grad_norm': 0.21439437568187714, 'learning_rate': 2.743055555555556e-05, 'epoch': 1.35}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03259219974279404, 'eval_accuracy': 0.9921875, 'eval_runtime': 41.1253, 'eval_samples_per_second': 28.012, 'eval_steps_per_second': 3.501, 'epoch': 1.35}
{'loss': 0.0055, 'grad_norm': 0.03972407802939415, 'learning_rate': 2.6851851851851855e-05, 'epoch': 1.39}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02879941090941429, 'eval_accuracy': 0.9921875, 'eval_runtime': 42.3057, 'eval_samples_per_second': 27.23, 'eval_steps_per_second': 3.404, 'epoch': 1.39}
{'loss': 0.0059, 'grad_norm': 0.16452594101428986, 'learning_rate': 2.627314814814815e-05, 'epoch': 1.42}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02736344188451767, 'eval_accuracy': 0.9921875, 'eval_runtime': 39.4559, 'eval_samples_per_second': 29.197, 'eval_steps_per_second': 3.65, 'epoch': 1.42}
{'loss': 0.0053, 'grad_norm': 0.04301031306385994, 'learning_rate': 2.5694444444444445e-05, 'epoch': 1.46}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.028202347457408905, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 43.327, 'eval_samples_per_second': 26.588, 'eval_steps_per_second': 3.324, 'epoch': 1.46}
{'loss': 0.0161, 'grad_norm': 0.03813305124640465, 'learning_rate': 2.511574074074074e-05, 'epoch': 1.49}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.029860682785511017, 'eval_accuracy': 0.9921875, 'eval_runtime': 41.0297, 'eval_samples_per_second': 28.077, 'eval_steps_per_second': 3.51, 'epoch': 1.49}
{'loss': 0.0065, 'grad_norm': 0.03586605191230774, 'learning_rate': 2.4537037037037038e-05, 'epoch': 1.53}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03111533261835575, 'eval_accuracy': 0.9913194444444444, 'eval_runtime': 42.0455, 'eval_samples_per_second': 27.399, 'eval_steps_per_second': 3.425, 'epoch': 1.53}
{'loss': 0.0209, 'grad_norm': 6.252248764038086, 'learning_rate': 2.3958333333333334e-05, 'epoch': 1.56}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03683473542332649, 'eval_accuracy': 0.9887152777777778, 'eval_runtime': 39.77, 'eval_samples_per_second': 28.967, 'eval_steps_per_second': 3.621, 'epoch': 1.56}
{'loss': 0.0265, 'grad_norm': 0.03513942286372185, 'learning_rate': 2.337962962962963e-05, 'epoch': 1.6}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.034548819065093994, 'eval_accuracy': 0.9887152777777778, 'eval_runtime': 43.4486, 'eval_samples_per_second': 26.514, 'eval_steps_per_second': 3.314, 'epoch': 1.6}
{'loss': 0.0049, 'grad_norm': 0.03586023673415184, 'learning_rate': 2.2800925925925927e-05, 'epoch': 1.63}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03528197109699249, 'eval_accuracy': 0.9904513888888888, 'eval_runtime': 42.5925, 'eval_samples_per_second': 27.047, 'eval_steps_per_second': 3.381, 'epoch': 1.63}
{'loss': 0.0057, 'grad_norm': 0.03579947352409363, 'learning_rate': 2.2222222222222223e-05, 'epoch': 1.67}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.038615599274635315, 'eval_accuracy': 0.9921875, 'eval_runtime': 42.7069, 'eval_samples_per_second': 26.975, 'eval_steps_per_second': 3.372, 'epoch': 1.67}
{'loss': 0.0047, 'grad_norm': 0.03090183064341545, 'learning_rate': 2.164351851851852e-05, 'epoch': 1.7}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.032972920686006546, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 40.9165, 'eval_samples_per_second': 28.155, 'eval_steps_per_second': 3.519, 'epoch': 1.7}
{'loss': 0.0269, 'grad_norm': 0.03390669822692871, 'learning_rate': 2.1064814814814816e-05, 'epoch': 1.74}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.060344431549310684, 'eval_accuracy': 0.9835069444444444, 'eval_runtime': 43.2289, 'eval_samples_per_second': 26.649, 'eval_steps_per_second': 3.331, 'epoch': 1.74}
{'loss': 0.1035, 'grad_norm': 5.133935928344727, 'learning_rate': 2.0486111111111113e-05, 'epoch': 1.77}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.033066000789403915, 'eval_accuracy': 0.9921875, 'eval_runtime': 42.2059, 'eval_samples_per_second': 27.295, 'eval_steps_per_second': 3.412, 'epoch': 1.77}
{'loss': 0.0046, 'grad_norm': 0.042429935187101364, 'learning_rate': 1.990740740740741e-05, 'epoch': 1.81}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02902667038142681, 'eval_accuracy': 0.9913194444444444, 'eval_runtime': 42.297, 'eval_samples_per_second': 27.236, 'eval_steps_per_second': 3.404, 'epoch': 1.81}
{'loss': 0.0201, 'grad_norm': 0.1985769271850586, 'learning_rate': 1.9328703703703705e-05, 'epoch': 1.84}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03182784467935562, 'eval_accuracy': 0.9904513888888888, 'eval_runtime': 42.6538, 'eval_samples_per_second': 27.008, 'eval_steps_per_second': 3.376, 'epoch': 1.84}
{'loss': 0.0564, 'grad_norm': 6.4880852699279785, 'learning_rate': 1.8750000000000002e-05, 'epoch': 1.88}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03001459501683712, 'eval_accuracy': 0.9939236111111112, 'eval_runtime': 42.557, 'eval_samples_per_second': 27.07, 'eval_steps_per_second': 3.384, 'epoch': 1.88}
{'loss': 0.0317, 'grad_norm': 0.03362669795751572, 'learning_rate': 1.8171296296296298e-05, 'epoch': 1.91}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.037695273756980896, 'eval_accuracy': 0.9887152777777778, 'eval_runtime': 43.2952, 'eval_samples_per_second': 26.608, 'eval_steps_per_second': 3.326, 'epoch': 1.91}
{'loss': 0.0482, 'grad_norm': 0.05317641794681549, 'learning_rate': 1.7592592592592595e-05, 'epoch': 1.94}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.04715777933597565, 'eval_accuracy': 0.984375, 'eval_runtime': 43.7517, 'eval_samples_per_second': 26.33, 'eval_steps_per_second': 3.291, 'epoch': 1.94}
{'loss': 0.0083, 'grad_norm': 0.04213929921388626, 'learning_rate': 1.701388888888889e-05, 'epoch': 1.98}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.026546621695160866, 'eval_accuracy': 0.9913194444444444, 'eval_runtime': 42.3128, 'eval_samples_per_second': 27.226, 'eval_steps_per_second': 3.403, 'epoch': 1.98}
{'loss': 0.0382, 'grad_norm': 0.032483622431755066, 'learning_rate': 1.6435185185185187e-05, 'epoch': 2.01}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.024998579174280167, 'eval_accuracy': 0.9921875, 'eval_runtime': 42.3842, 'eval_samples_per_second': 27.18, 'eval_steps_per_second': 3.397, 'epoch': 2.01}
{'loss': 0.0168, 'grad_norm': 0.02994435280561447, 'learning_rate': 1.5856481481481484e-05, 'epoch': 2.05}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02744544856250286, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 42.0652, 'eval_samples_per_second': 27.386, 'eval_steps_per_second': 3.423, 'epoch': 2.05}
{'loss': 0.0046, 'grad_norm': 0.045989394187927246, 'learning_rate': 1.527777777777778e-05, 'epoch': 2.08}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.033793721348047256, 'eval_accuracy': 0.9895833333333334, 'eval_runtime': 41.1175, 'eval_samples_per_second': 28.017, 'eval_steps_per_second': 3.502, 'epoch': 2.08}
{'loss': 0.0053, 'grad_norm': 0.037199102342128754, 'learning_rate': 1.4699074074074073e-05, 'epoch': 2.12}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.04000992327928543, 'eval_accuracy': 0.9878472222222222, 'eval_runtime': 41.9834, 'eval_samples_per_second': 27.439, 'eval_steps_per_second': 3.43, 'epoch': 2.12}
{'loss': 0.0038, 'grad_norm': 0.030377719551324844, 'learning_rate': 1.412037037037037e-05, 'epoch': 2.15}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.042036157101392746, 'eval_accuracy': 0.9878472222222222, 'eval_runtime': 41.7025, 'eval_samples_per_second': 27.624, 'eval_steps_per_second': 3.453, 'epoch': 2.15}
{'loss': 0.0038, 'grad_norm': 0.03173423931002617, 'learning_rate': 1.3541666666666666e-05, 'epoch': 2.19}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.0421002171933651, 'eval_accuracy': 0.9878472222222222, 'eval_runtime': 39.3896, 'eval_samples_per_second': 29.246, 'eval_steps_per_second': 3.656, 'epoch': 2.19}
{'loss': 0.0039, 'grad_norm': 0.03938976302742958, 'learning_rate': 1.2962962962962962e-05, 'epoch': 2.22}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.04204228147864342, 'eval_accuracy': 0.9887152777777778, 'eval_runtime': 41.0367, 'eval_samples_per_second': 28.072, 'eval_steps_per_second': 3.509, 'epoch': 2.22}
{'loss': 0.0053, 'grad_norm': 0.028670627623796463, 'learning_rate': 1.2384259259259259e-05, 'epoch': 2.26}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.0352412648499012, 'eval_accuracy': 0.9921875, 'eval_runtime': 43.196, 'eval_samples_per_second': 26.669, 'eval_steps_per_second': 3.334, 'epoch': 2.26}
{'loss': 0.0035, 'grad_norm': 0.026586705818772316, 'learning_rate': 1.1805555555555555e-05, 'epoch': 2.29}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03286649286746979, 'eval_accuracy': 0.9913194444444444, 'eval_runtime': 40.8179, 'eval_samples_per_second': 28.223, 'eval_steps_per_second': 3.528, 'epoch': 2.29}
{'loss': 0.0035, 'grad_norm': 0.03485596925020218, 'learning_rate': 1.1226851851851852e-05, 'epoch': 2.33}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03221860155463219, 'eval_accuracy': 0.9913194444444444, 'eval_runtime': 43.1922, 'eval_samples_per_second': 26.671, 'eval_steps_per_second': 3.334, 'epoch': 2.33}
{'loss': 0.0048, 'grad_norm': 0.03873318061232567, 'learning_rate': 1.0648148148148148e-05, 'epoch': 2.36}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.03107506036758423, 'eval_accuracy': 0.9921875, 'eval_runtime': 40.9278, 'eval_samples_per_second': 28.147, 'eval_steps_per_second': 3.518, 'epoch': 2.36}
{'loss': 0.0119, 'grad_norm': 0.04348960518836975, 'learning_rate': 1.0069444444444445e-05, 'epoch': 2.4}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.028881318867206573, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 6982.4989, 'eval_samples_per_second': 0.165, 'eval_steps_per_second': 0.021, 'epoch': 2.4}
{'loss': 0.0035, 'grad_norm': 0.027058659121394157, 'learning_rate': 9.490740740740741e-06, 'epoch': 2.43}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.028635522350668907, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 39.1168, 'eval_samples_per_second': 29.45, 'eval_steps_per_second': 3.681, 'epoch': 2.43}
{'loss': 0.004, 'grad_norm': 0.02511575073003769, 'learning_rate': 8.912037037037037e-06, 'epoch': 2.47}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02478478103876114, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 38.7073, 'eval_samples_per_second': 29.762, 'eval_steps_per_second': 3.72, 'epoch': 2.47}
{'loss': 0.0035, 'grad_norm': 0.027820324525237083, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.5}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02211960218846798, 'eval_accuracy': 0.9947916666666666, 'eval_runtime': 38.7188, 'eval_samples_per_second': 29.753, 'eval_steps_per_second': 3.719, 'epoch': 2.5}
{'loss': 0.0032, 'grad_norm': 0.025755584239959717, 'learning_rate': 7.75462962962963e-06, 'epoch': 2.53}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.021948399022221565, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 38.7441, 'eval_samples_per_second': 29.734, 'eval_steps_per_second': 3.717, 'epoch': 2.53}
{'loss': 0.0033, 'grad_norm': 0.031100640073418617, 'learning_rate': 7.1759259259259266e-06, 'epoch': 2.57}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02203122340142727, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 38.8371, 'eval_samples_per_second': 29.662, 'eval_steps_per_second': 3.708, 'epoch': 2.57}
{'loss': 0.0031, 'grad_norm': 0.026160968467593193, 'learning_rate': 6.597222222222223e-06, 'epoch': 2.6}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.022169897332787514, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 38.8317, 'eval_samples_per_second': 29.666, 'eval_steps_per_second': 3.708, 'epoch': 2.6}
{'loss': 0.0032, 'grad_norm': 0.026460448279976845, 'learning_rate': 6.0185185185185185e-06, 'epoch': 2.64}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.022552574053406715, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 40.5097, 'eval_samples_per_second': 28.438, 'eval_steps_per_second': 3.555, 'epoch': 2.64}
{'loss': 0.0032, 'grad_norm': 0.025686468929052353, 'learning_rate': 5.439814814814815e-06, 'epoch': 2.67}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02270754799246788, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 39.0816, 'eval_samples_per_second': 29.477, 'eval_steps_per_second': 3.685, 'epoch': 2.67}
{'loss': 0.0031, 'grad_norm': 0.025646982714533806, 'learning_rate': 4.861111111111111e-06, 'epoch': 2.71}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.022809602320194244, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 39.1597, 'eval_samples_per_second': 29.418, 'eval_steps_per_second': 3.677, 'epoch': 2.71}
{'loss': 0.008, 'grad_norm': 0.02547714300453663, 'learning_rate': 4.282407407407408e-06, 'epoch': 2.74}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02510293386876583, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 39.1622, 'eval_samples_per_second': 29.416, 'eval_steps_per_second': 3.677, 'epoch': 2.74}
{'loss': 0.0031, 'grad_norm': 0.024022487923502922, 'learning_rate': 3.7037037037037037e-06, 'epoch': 2.78}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02620147354900837, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 39.117, 'eval_samples_per_second': 29.45, 'eval_steps_per_second': 3.681, 'epoch': 2.78}
{'loss': 0.0031, 'grad_norm': 0.024733655154705048, 'learning_rate': 3.125e-06, 'epoch': 2.81}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.026592271402478218, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 39.1139, 'eval_samples_per_second': 29.452, 'eval_steps_per_second': 3.682, 'epoch': 2.81}
{'loss': 0.0043, 'grad_norm': 0.03886941820383072, 'learning_rate': 2.546296296296296e-06, 'epoch': 2.85}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.025932209566235542, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 41.6272, 'eval_samples_per_second': 27.674, 'eval_steps_per_second': 3.459, 'epoch': 2.85}
{'loss': 0.0032, 'grad_norm': 0.35465630888938904, 'learning_rate': 1.9675925925925925e-06, 'epoch': 2.88}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.02577088586986065, 'eval_accuracy': 0.9921875, 'eval_runtime': 40.5866, 'eval_samples_per_second': 28.384, 'eval_steps_per_second': 3.548, 'epoch': 2.88}
{'loss': 0.0042, 'grad_norm': 0.024415601044893265, 'learning_rate': 1.388888888888889e-06, 'epoch': 2.92}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.026823166757822037, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 41.656, 'eval_samples_per_second': 27.655, 'eval_steps_per_second': 3.457, 'epoch': 2.92}
{'loss': 0.0032, 'grad_norm': 0.023707451298832893, 'learning_rate': 8.101851851851852e-07, 'epoch': 2.95}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.027684058994054794, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 40.7364, 'eval_samples_per_second': 28.279, 'eval_steps_per_second': 3.535, 'epoch': 2.95}
{'loss': 0.003, 'grad_norm': 0.02588501386344433, 'learning_rate': 2.3148148148148148e-07, 'epoch': 2.99}


  0%|          | 0/144 [00:00<?, ?it/s]

{'eval_loss': 0.027843892574310303, 'eval_accuracy': 0.9930555555555556, 'eval_runtime': 41.3565, 'eval_samples_per_second': 27.855, 'eval_steps_per_second': 3.482, 'epoch': 2.99}
{'train_runtime': 11539.1431, 'train_samples_per_second': 1.197, 'train_steps_per_second': 0.075, 'train_loss': 0.042453370087228166, 'epoch': 3.0}
***** train metrics *****
  epoch                    =         3.0
  total_flos               = 997245602GF
  train_loss               =      0.0425
  train_runtime            =  3:12:19.14
  train_samples_per_second =       1.197
  train_steps_per_second   =       0.075


In [86]:
metrics = trainer.evaluate(ds['test'])
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)

  0%|          | 0/144 [00:00<?, ?it/s]

***** test metrics *****
  epoch                   =        3.0
  eval_accuracy           =     0.9835
  eval_loss               =     0.0603
  eval_runtime            = 0:00:42.28
  eval_samples_per_second =     27.245
  eval_steps_per_second   =      3.406


## References

<a id="1">1. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2020). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ArXiv, abs/2010.11929.</a>
