In [1]:
import os
import numpy as np

from datasets import load_dataset
import evaluate

from transformers import (
    ViTImageProcessor ,
    ViTMSNForImageClassification,
    TrainingArguments,
    Trainer
)
import os
os.environ["WANDB_DISABLED"] = "true"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dir = "train-dataset"
val_dir   = "val-dataset"
dataset = load_dataset("imagefolder", data_dir="train-dataset")

In [3]:
train_ds = dataset["train"]
val_ds   = dataset["validation"]

In [4]:
labels = train_ds.features["label"].names
num_labels = len(labels)

# Create id2label/label2id dicts
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: i for i, label in enumerate(labels)}

print("Number of classes:", num_labels)
print("Label names:", labels)
pretrained_model_name = "facebook/vit-msn-base"

Number of classes: 19
Label names: ['AD-AZ-60', 'AD-FC-30', 'AD-OW-35', 'AD-SS-30', 'AD-UT-55', 'BA-BS-50', 'BE-HS-50', 'CO-CH-40', 'DI-DW-50', 'DR-PL-35', 'RP-XX-30', 'RP-XX-35', 'SK-SP-35', 'SL-SL-50', 'SL-SL-55', 'SO-SG-45', 'SO-XX-60', 'SW-XX-60', 'VA-ST-55']


In [5]:
feature_extractor = ViTImageProcessor.from_pretrained(pretrained_model_name)

In [6]:
def preprocess_images(examples):
    # examples["image"] is a list of PIL.Image objects
    inputs = feature_extractor(examples["image"], return_tensors="pt")
    # Store the label in "labels" for the model
    return inputs

In [7]:
train_ds = train_ds.map(preprocess_images, batched=True)
val_ds   = val_ds.map(preprocess_images, batched=True)

In [8]:
val_ds = val_ds.rename_column("label", "labels")
train_ds = train_ds.rename_column("label", "labels")

In [9]:
train_ds

Dataset({
    features: ['image', 'labels', 'pixel_values'],
    num_rows: 11379
})

In [10]:
train_ds.set_format(type="torch", columns=["pixel_values", "labels"])
val_ds.set_format(type="torch", columns=["pixel_values", "labels"])


In [11]:
model = ViTMSNForImageClassification.from_pretrained(
    pretrained_model_name,
    num_labels=num_labels,
    id2label={i: str(i) for i in range(num_labels)},
    label2id={str(i): i for i in range(num_labels)},
    # ignore_mismatched_sizes=True
)
accuracy_metric = evaluate.load("accuracy")

Some weights of ViTMSNForImageClassification were not initialized from the model checkpoint at facebook/vit-msn-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=preds, references=labels)

In [13]:
training_args = TrainingArguments(
  output_dir="./vit-model",
  per_device_train_batch_size=16,
  eval_strategy="steps",
  num_train_epochs=2,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to=None,
  load_best_model_at_end=True,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    processing_class=feature_extractor,
)

In [15]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy
100,3.2604,3.207176,0.040665
200,2.1036,1.91135,0.45841
300,1.3366,1.674114,0.572089
400,0.9392,0.541656,0.837338
500,0.6524,0.53378,0.838262
600,0.458,0.262601,0.918669
700,0.3733,0.43247,0.875231
800,0.6165,0.298643,0.909427
900,0.0809,0.159183,0.949168
1000,0.0976,0.256344,0.938078


TrainOutput(global_step=1424, training_loss=0.8571953054496579, metrics={'train_runtime': 1400.5462, 'train_samples_per_second': 16.249, 'train_steps_per_second': 1.017, 'total_flos': 1.7638314059780628e+18, 'train_loss': 0.8571953054496579, 'epoch': 2.0})

In [16]:
eval_results = trainer.evaluate()
print("Evaluation results:", eval_results)

Evaluation results: {'eval_loss': 0.01432074699550867, 'eval_accuracy': 0.9963031423290203, 'eval_runtime': 25.2894, 'eval_samples_per_second': 42.785, 'eval_steps_per_second': 5.378, 'epoch': 2.0}


In [17]:
trainer.save_model('trained_model-msn')