In [3]:
from utils import CLASS_TO_ID
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
import torch
import numpy as np
from datasets import load_metric, load_dataset

In [4]:
train = load_dataset('datasets/augmented/PlantVillage/train')
val = load_dataset('datasets/augmented/PlantVillage/val')
train = train.shuffle(seed=42)

Resolving data files:   0%|          | 0/46541 [00:00<?, ?it/s]

Using custom data configuration train-4a5d2e024927d5d3
Found cached dataset imagefolder (C:/Users/ASROCK/.cache/huggingface/datasets/imagefolder/train-4a5d2e024927d5d3/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)
Using custom data configuration val-c4c340b5bde682f0
Found cached dataset imagefolder (C:/Users/ASROCK/.cache/huggingface/datasets/imagefolder/val-c4c340b5bde682f0/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)
Loading cached shuffled indices for dataset at C:/Users/ASROCK/.cache/huggingface/datasets/imagefolder/train-4a5d2e024927d5d3/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f\cache-9c9db930684d816b.arrow


  0%|          | 0/1 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/11096 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
model_name_or_path = 'WinKawaks/vit-tiny-patch16-224'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [6]:
def transform(example_batch):
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']

    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

prepared_train = train.with_transform(transform)
prepared_val = train.with_transform(transform)

  metric = load_metric("accuracy")


In [7]:
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(CLASS_TO_ID),
    id2label={y: x for x, y in CLASS_TO_ID.items()},
    label2id=CLASS_TO_ID,
    ignore_mismatched_sizes=True
)

training_args = TrainingArguments(
  output_dir="./vit-base-plant-village",
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=5,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_train['train'],
    eval_dataset=prepared_val['train'],
    tokenizer=feature_extractor,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([39, 192]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([39]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using cuda_amp half precision backend


In [8]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()



***** Running training *****
  Num examples = 46541
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 29090
  Number of trainable parameters = 5531943
***** Running Evaluation *****
  Num examples = 46541
  Batch size = 8
Saving model checkpoint to ./vit-base-plant-village\checkpoint-100
Configuration saved in ./vit-base-plant-village\checkpoint-100\config.json
Model weights saved in ./vit-base-plant-village\checkpoint-100\pytorch_model.bin
Image processor saved in ./vit-base-plant-village\checkpoint-100\preprocessor_config.json
***** Running Evaluation *****
  Num examples = 46541
  Batch size = 8
Saving model checkpoint to ./vit-base-plant-village\checkpoint-200
Configuration saved in ./vit-base-plant-village\checkpoint-200\config.json
Model weights saved in ./vit-base-plant-village\checkpoint-200\pytorch_model.bin
Image processor saved in ./v

Step,Training Loss,Validation Loss


KeyboardInterrupt: 