In [1]:
model_checkpoint = "google/vit-base-patch16-224-in21k"  # pre-trained model from which to fine-tune

In [2]:
from datasets import load_dataset

In [3]:
data = load_dataset('imagefolder', data_dir="new")

Resolving data files:   0%|          | 0/12719 [00:00<?, ?it/s]

In [4]:
data

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 12719
    })
})

In [5]:
data = data['train'].train_test_split(test_size = 0.1, shuffle = True)
train_ds = data['train']
val_ds = data['test']

In [6]:
train_ds

Dataset({
    features: ['image', 'label'],
    num_rows: 11447
})

In [7]:
val_ds

Dataset({
    features: ['image', 'label'],
    num_rows: 1272
})

In [8]:
test_ds = load_dataset('imagefolder', data_dir='Benchmark/')['train']

Resolving data files:   0%|          | 0/600 [00:00<?, ?it/s]

In [9]:
train_ds[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=L size=182x182>,
 'label': 1}

In [10]:
labels = data['train'].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

label2id

{'Mild_Demented': '0',
 'Moderate_Demented': '1',
 'Non_Demented': '2',
 'Very_Mild_Demented': '3'}

In [11]:
from transformers import AutoImageProcessor

checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [12]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [13]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [14]:
train_ds = train_ds.with_transform(transforms)
val_ds = val_ds.with_transform(transforms)
test_ds = test_ds.with_transform(transforms)

In [15]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [16]:
import evaluate

accuracy = evaluate.load("accuracy")

In [17]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [18]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
training_args = TrainingArguments(output_dir="alzheimer-image-classification-google-vit-base-patch16",
                                  remove_unused_columns=False,
                                  evaluation_strategy="epoch",
                                  overwrite_output_dir=True,
                                  save_strategy="epoch",
                                  learning_rate=5e-5,
                                  per_device_train_batch_size=4,
                                  gradient_accumulation_steps=4,
                                  per_device_eval_batch_size=4,
                                  num_train_epochs=10,
                                  warmup_ratio=0.1,
                                  logging_steps=10,
                                  load_best_model_at_end=True,
                                  metric_for_best_model="accuracy",
                                  report_to='tensorboard',
                                  push_to_hub=True)

trainer = Trainer(model=model,
                  args=training_args,
                  data_collator=data_collator,
                  train_dataset=train_ds,
                  eval_dataset=val_ds,
                  tokenizer=image_processor,
                  compute_metrics=compute_metrics)

trainer.train()

Cloning https://huggingface.co/AhmadHakami/alzheimer-image-classification-google-vit-base-patch16 into local empty directory.


  0%|          | 0/7150 [00:00<?, ?it/s]

{'loss': 1.403, 'learning_rate': 6.993006993006994e-07, 'epoch': 0.01}
{'loss': 1.3965, 'learning_rate': 1.3986013986013987e-06, 'epoch': 0.03}
{'loss': 1.396, 'learning_rate': 2.0979020979020983e-06, 'epoch': 0.04}
{'loss': 1.3901, 'learning_rate': 2.7972027972027974e-06, 'epoch': 0.06}
{'loss': 1.3696, 'learning_rate': 3.496503496503497e-06, 'epoch': 0.07}
{'loss': 1.3617, 'learning_rate': 4.195804195804197e-06, 'epoch': 0.08}
{'loss': 1.3571, 'learning_rate': 4.895104895104895e-06, 'epoch': 0.1}
{'loss': 1.3373, 'learning_rate': 5.594405594405595e-06, 'epoch': 0.11}
{'loss': 1.3147, 'learning_rate': 6.2937062937062944e-06, 'epoch': 0.13}
{'loss': 1.2924, 'learning_rate': 6.993006993006994e-06, 'epoch': 0.14}
{'loss': 1.2683, 'learning_rate': 7.692307692307694e-06, 'epoch': 0.15}
{'loss': 1.2507, 'learning_rate': 8.391608391608393e-06, 'epoch': 0.17}
{'loss': 1.2024, 'learning_rate': 9.090909090909091e-06, 'epoch': 0.18}
{'loss': 1.1831, 'learning_rate': 9.79020979020979e-06, 'epoch'

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.7519570589065552, 'eval_accuracy': 0.64937106918239, 'eval_runtime': 11.4049, 'eval_samples_per_second': 111.531, 'eval_steps_per_second': 27.883, 'epoch': 1.0}
{'loss': 0.8059, 'learning_rate': 4.996114996114996e-05, 'epoch': 1.01}
{'loss': 0.6811, 'learning_rate': 4.988344988344989e-05, 'epoch': 1.02}
{'loss': 0.6353, 'learning_rate': 4.9805749805749805e-05, 'epoch': 1.03}
{'loss': 0.6931, 'learning_rate': 4.972804972804973e-05, 'epoch': 1.05}
{'loss': 0.711, 'learning_rate': 4.9650349650349656e-05, 'epoch': 1.06}
{'loss': 0.6704, 'learning_rate': 4.9572649572649575e-05, 'epoch': 1.08}
{'loss': 0.6915, 'learning_rate': 4.94949494949495e-05, 'epoch': 1.09}
{'loss': 0.6958, 'learning_rate': 4.941724941724942e-05, 'epoch': 1.1}
{'loss': 0.682, 'learning_rate': 4.9339549339549344e-05, 'epoch': 1.12}
{'loss': 0.6294, 'learning_rate': 4.926184926184926e-05, 'epoch': 1.13}
{'loss': 0.6122, 'learning_rate': 4.918414918414919e-05, 'epoch': 1.15}
{'loss': 0.6943, 'learning_rate

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.646718442440033, 'eval_accuracy': 0.7091194968553459, 'eval_runtime': 11.968, 'eval_samples_per_second': 106.284, 'eval_steps_per_second': 26.571, 'epoch': 2.0}
{'loss': 0.5131, 'learning_rate': 4.4366744366744365e-05, 'epoch': 2.01}
{'loss': 0.5388, 'learning_rate': 4.428904428904429e-05, 'epoch': 2.03}
{'loss': 0.5453, 'learning_rate': 4.4211344211344216e-05, 'epoch': 2.04}
{'loss': 0.4609, 'learning_rate': 4.4133644133644134e-05, 'epoch': 2.05}
{'loss': 0.6504, 'learning_rate': 4.405594405594406e-05, 'epoch': 2.07}
{'loss': 0.6436, 'learning_rate': 4.3978243978243985e-05, 'epoch': 2.08}
{'loss': 0.5886, 'learning_rate': 4.39005439005439e-05, 'epoch': 2.1}
{'loss': 0.5337, 'learning_rate': 4.382284382284382e-05, 'epoch': 2.11}
{'loss': 0.5122, 'learning_rate': 4.374514374514375e-05, 'epoch': 2.12}
{'loss': 0.4732, 'learning_rate': 4.366744366744367e-05, 'epoch': 2.14}
{'loss': 0.6697, 'learning_rate': 4.358974358974359e-05, 'epoch': 2.15}
{'loss': 0.6293, 'learning_ra

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.543040931224823, 'eval_accuracy': 0.7594339622641509, 'eval_runtime': 9.1489, 'eval_samples_per_second': 139.033, 'eval_steps_per_second': 34.758, 'epoch': 3.0}
{'loss': 0.5556, 'learning_rate': 3.885003885003885e-05, 'epoch': 3.0}
{'loss': 0.5045, 'learning_rate': 3.8772338772338775e-05, 'epoch': 3.02}
{'loss': 0.4953, 'learning_rate': 3.8694638694638694e-05, 'epoch': 3.03}
{'loss': 0.4843, 'learning_rate': 3.861693861693862e-05, 'epoch': 3.05}
{'loss': 0.4979, 'learning_rate': 3.8539238539238544e-05, 'epoch': 3.06}
{'loss': 0.4425, 'learning_rate': 3.846153846153846e-05, 'epoch': 3.07}
{'loss': 0.4815, 'learning_rate': 3.838383838383838e-05, 'epoch': 3.09}
{'loss': 0.4499, 'learning_rate': 3.830613830613831e-05, 'epoch': 3.1}
{'loss': 0.3848, 'learning_rate': 3.822843822843823e-05, 'epoch': 3.12}
{'loss': 0.4735, 'learning_rate': 3.815073815073815e-05, 'epoch': 3.13}
{'loss': 0.5089, 'learning_rate': 3.8073038073038076e-05, 'epoch': 3.14}
{'loss': 0.4805, 'learning_ra

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.43717989325523376, 'eval_accuracy': 0.8144654088050315, 'eval_runtime': 9.7602, 'eval_samples_per_second': 130.325, 'eval_steps_per_second': 32.581, 'epoch': 4.0}
{'loss': 0.4023, 'learning_rate': 3.3255633255633253e-05, 'epoch': 4.01}
{'loss': 0.4009, 'learning_rate': 3.317793317793318e-05, 'epoch': 4.03}
{'loss': 0.4081, 'learning_rate': 3.3100233100233104e-05, 'epoch': 4.04}
{'loss': 0.4303, 'learning_rate': 3.302253302253302e-05, 'epoch': 4.05}
{'loss': 0.399, 'learning_rate': 3.294483294483295e-05, 'epoch': 4.07}
{'loss': 0.3252, 'learning_rate': 3.2867132867132866e-05, 'epoch': 4.08}
{'loss': 0.4219, 'learning_rate': 3.278943278943279e-05, 'epoch': 4.1}
{'loss': 0.4289, 'learning_rate': 3.271173271173271e-05, 'epoch': 4.11}
{'loss': 0.3918, 'learning_rate': 3.2634032634032635e-05, 'epoch': 4.12}
{'loss': 0.3622, 'learning_rate': 3.255633255633256e-05, 'epoch': 4.14}
{'loss': 0.3696, 'learning_rate': 3.247863247863248e-05, 'epoch': 4.15}
{'loss': 0.302, 'learning_r

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.36814603209495544, 'eval_accuracy': 0.8427672955974843, 'eval_runtime': 9.7072, 'eval_samples_per_second': 131.036, 'eval_steps_per_second': 32.759, 'epoch': 5.0}


Several commits (2) will be pushed upstream.


{'loss': 0.3785, 'learning_rate': 2.773892773892774e-05, 'epoch': 5.0}
{'loss': 0.3433, 'learning_rate': 2.7661227661227664e-05, 'epoch': 5.02}
{'loss': 0.3235, 'learning_rate': 2.7583527583527586e-05, 'epoch': 5.03}
{'loss': 0.2571, 'learning_rate': 2.7505827505827507e-05, 'epoch': 5.05}
{'loss': 0.2699, 'learning_rate': 2.7428127428127433e-05, 'epoch': 5.06}
{'loss': 0.3551, 'learning_rate': 2.7350427350427355e-05, 'epoch': 5.07}
{'loss': 0.2332, 'learning_rate': 2.7272727272727273e-05, 'epoch': 5.09}
{'loss': 0.2341, 'learning_rate': 2.7195027195027195e-05, 'epoch': 5.1}
{'loss': 0.2365, 'learning_rate': 2.7117327117327117e-05, 'epoch': 5.12}
{'loss': 0.3338, 'learning_rate': 2.7039627039627042e-05, 'epoch': 5.13}
{'loss': 0.3511, 'learning_rate': 2.6961926961926964e-05, 'epoch': 5.14}
{'loss': 0.322, 'learning_rate': 2.6884226884226886e-05, 'epoch': 5.16}
{'loss': 0.2362, 'learning_rate': 2.680652680652681e-05, 'epoch': 5.17}
{'loss': 0.278, 'learning_rate': 2.6728826728826726e-05,

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.37463584542274475, 'eval_accuracy': 0.8514150943396226, 'eval_runtime': 4.9828, 'eval_samples_per_second': 255.278, 'eval_steps_per_second': 63.819, 'epoch': 6.0}
{'loss': 0.2376, 'learning_rate': 2.2144522144522145e-05, 'epoch': 6.01}
{'loss': 0.226, 'learning_rate': 2.2066822066822067e-05, 'epoch': 6.02}
{'loss': 0.2983, 'learning_rate': 2.1989121989121992e-05, 'epoch': 6.04}
{'loss': 0.2962, 'learning_rate': 2.191142191142191e-05, 'epoch': 6.05}
{'loss': 0.2501, 'learning_rate': 2.1833721833721836e-05, 'epoch': 6.07}
{'loss': 0.3864, 'learning_rate': 2.1756021756021758e-05, 'epoch': 6.08}
{'loss': 0.1608, 'learning_rate': 2.1678321678321677e-05, 'epoch': 6.09}
{'loss': 0.2555, 'learning_rate': 2.1600621600621602e-05, 'epoch': 6.11}
{'loss': 0.1928, 'learning_rate': 2.1522921522921524e-05, 'epoch': 6.12}
{'loss': 0.1793, 'learning_rate': 2.1445221445221446e-05, 'epoch': 6.14}
{'loss': 0.2341, 'learning_rate': 2.1367521367521368e-05, 'epoch': 6.15}
{'loss': 0.2691, 'le

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.2835550010204315, 'eval_accuracy': 0.8907232704402516, 'eval_runtime': 9.4354, 'eval_samples_per_second': 134.812, 'eval_steps_per_second': 33.703, 'epoch': 7.0}
{'loss': 0.146, 'learning_rate': 1.6627816627816627e-05, 'epoch': 7.0}
{'loss': 0.1414, 'learning_rate': 1.6550116550116552e-05, 'epoch': 7.02}
{'loss': 0.271, 'learning_rate': 1.6472416472416474e-05, 'epoch': 7.03}
{'loss': 0.1984, 'learning_rate': 1.6394716394716396e-05, 'epoch': 7.04}
{'loss': 0.2371, 'learning_rate': 1.6317016317016318e-05, 'epoch': 7.06}
{'loss': 0.1775, 'learning_rate': 1.623931623931624e-05, 'epoch': 7.07}
{'loss': 0.2255, 'learning_rate': 1.6161616161616165e-05, 'epoch': 7.09}
{'loss': 0.1779, 'learning_rate': 1.6083916083916083e-05, 'epoch': 7.1}
{'loss': 0.1834, 'learning_rate': 1.600621600621601e-05, 'epoch': 7.11}
{'loss': 0.1953, 'learning_rate': 1.592851592851593e-05, 'epoch': 7.13}
{'loss': 0.2078, 'learning_rate': 1.585081585081585e-05, 'epoch': 7.14}
{'loss': 0.1432, 'learning_

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.27977997064590454, 'eval_accuracy': 0.8954402515723271, 'eval_runtime': 5.0118, 'eval_samples_per_second': 253.801, 'eval_steps_per_second': 63.45, 'epoch': 8.0}
{'loss': 0.1821, 'learning_rate': 1.1033411033411034e-05, 'epoch': 8.01}
{'loss': 0.178, 'learning_rate': 1.0955710955710955e-05, 'epoch': 8.02}
{'loss': 0.1503, 'learning_rate': 1.0878010878010879e-05, 'epoch': 8.04}
{'loss': 0.1729, 'learning_rate': 1.0800310800310801e-05, 'epoch': 8.05}
{'loss': 0.2206, 'learning_rate': 1.0722610722610723e-05, 'epoch': 8.06}
{'loss': 0.1739, 'learning_rate': 1.0644910644910645e-05, 'epoch': 8.08}
{'loss': 0.1034, 'learning_rate': 1.0567210567210568e-05, 'epoch': 8.09}
{'loss': 0.1535, 'learning_rate': 1.048951048951049e-05, 'epoch': 8.11}
{'loss': 0.2187, 'learning_rate': 1.0411810411810412e-05, 'epoch': 8.12}
{'loss': 0.2155, 'learning_rate': 1.0334110334110336e-05, 'epoch': 8.13}
{'loss': 0.0786, 'learning_rate': 1.0256410256410256e-05, 'epoch': 8.15}
{'loss': 0.1243, 'lea

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.23014533519744873, 'eval_accuracy': 0.9158805031446541, 'eval_runtime': 9.0361, 'eval_samples_per_second': 140.768, 'eval_steps_per_second': 35.192, 'epoch': 9.0}
{'loss': 0.1255, 'learning_rate': 5.516705516705517e-06, 'epoch': 9.0}
{'loss': 0.1976, 'learning_rate': 5.4390054390054395e-06, 'epoch': 9.01}
{'loss': 0.115, 'learning_rate': 5.361305361305361e-06, 'epoch': 9.03}
{'loss': 0.1257, 'learning_rate': 5.283605283605284e-06, 'epoch': 9.04}
{'loss': 0.1573, 'learning_rate': 5.205905205905206e-06, 'epoch': 9.06}
{'loss': 0.1423, 'learning_rate': 5.128205128205128e-06, 'epoch': 9.07}
{'loss': 0.0864, 'learning_rate': 5.050505050505051e-06, 'epoch': 9.08}
{'loss': 0.0992, 'learning_rate': 4.972804972804973e-06, 'epoch': 9.1}
{'loss': 0.1676, 'learning_rate': 4.895104895104895e-06, 'epoch': 9.11}
{'loss': 0.1576, 'learning_rate': 4.817404817404818e-06, 'epoch': 9.13}
{'loss': 0.241, 'learning_rate': 4.73970473970474e-06, 'epoch': 9.14}
{'loss': 0.162, 'learning_rate': 

  0%|          | 0/318 [00:00<?, ?it/s]

{'eval_loss': 0.21270409226417542, 'eval_accuracy': 0.9261006289308176, 'eval_runtime': 5.5086, 'eval_samples_per_second': 230.91, 'eval_steps_per_second': 57.728, 'epoch': 9.99}
{'train_runtime': 2384.637, 'train_samples_per_second': 48.003, 'train_steps_per_second': 2.998, 'train_loss': 0.394019101883148, 'epoch': 9.99}


TrainOutput(global_step=7150, training_loss=0.394019101883148, metrics={'train_runtime': 2384.637, 'train_samples_per_second': 48.003, 'train_steps_per_second': 2.998, 'train_loss': 0.394019101883148, 'epoch': 9.99})

In [20]:
trainer.push_to_hub()

Several commits (2) will be pushed upstream.
The progress bars may be unreliable.


Upload file pytorch_model.bin:   0%|          | 1.00/327M [00:00<?, ?B/s]

Upload file runs/Aug14_17-07-38_AhmadHakami/events.out.tfevents.1692022063.AhmadHakami.5056.0:   0%|          …

To https://huggingface.co/AhmadHakami/alzheimer-image-classification-google-vit-base-patch16
   1ce4208..b51bc55  main -> main

To https://huggingface.co/AhmadHakami/alzheimer-image-classification-google-vit-base-patch16
   b51bc55..fac7c42  main -> main



'https://huggingface.co/AhmadHakami/alzheimer-image-classification-google-vit-base-patch16/commit/b51bc553c28d79e0a286f4ce5938893000591000'