In [1]:
import numpy as np
import torch    
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
import torch.optim as optim
import matplotlib.pyplot as plt

In [35]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import TrainingArguments, Trainer

In [2]:
train_dir = r'./DATASET/TRAIN'
test_dir = r'./DATASET/TEST'

classes = ['O', 'R']

In [30]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

normalize = transforms.Normalize(mean=feature_extractor.image_mean,
                                 std=feature_extractor.image_std)

transform = transforms.Compose([transforms.RandomResizedCrop(224),
                               transforms.RandomHorizontalFlip(),
                               transforms.ToTensor(),
                               normalize])

train_data = datasets.ImageFolder(train_dir, transform=transform)
test_data = datasets.ImageFolder(test_dir, transform=transform)

#now check if they've loaded correctly
print("Number of train images: ", (len(train_data)))
print("Number of test images: ", len(test_data))

Number of train images:  22564
Number of test images:  2513


In [42]:
orig_n = len(train_data)  # total number of examples
n_test = int(0.1 * orig_n)  # take ~10% for val
val_set = torch.utils.data.Subset(train_data, range(n_test))  # take first 10%
train_set = torch.utils.data.Subset(train_data, range(n_test, orig_n))  # take the rest   

In [31]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor): 
        self.feature_extractor = feature_extractor
    def __call__(self, batch):  
        encodings = self.feature_extractor([x[0] for x in batch],
          return_tensors='pt')   
        encodings['labels'] = torch.tensor([x[1] for x in batch],    
          dtype=torch.long)
        return encodings

In [43]:
#prepare data loaders
batch_size = 20
collator = ImageClassificationCollator(feature_extractor)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, collate_fn=collator, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, collate_fn=collator, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, collate_fn=collator, shuffle=True)

In [33]:
id2label = {id:label for id, label in enumerate(classes)}
label2id = {label:id for id,label in id2label.items()}

In [34]:


model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=2,
                                                  id2label=id2label,
                                                  label2id=label2id)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [36]:
metric_name = "accuracy"

args = TrainingArguments(
    f"test-waste",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

In [38]:
from datasets import load_metric
import numpy as np

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

  metric = load_metric("accuracy")


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

In [45]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_set,
    eval_dataset=val_set,
    data_collator=collator,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

In [47]:
trainer.train()

***** Running training *****
  Num examples = 20308
  Num Epochs = 3
  Instantaneous batch size per device = 10
  Total train batch size (w. parallel, distributed & accumulation) = 10
  Gradient Accumulation steps = 1
  Total optimization steps = 6093
  Number of trainable parameters = 85800194


Epoch,Training Loss,Validation Loss,Accuracy
1,0.1912,0.087445,0.969415
2,0.1533,0.112211,0.960993
3,0.1216,0.121194,0.96055


***** Running Evaluation *****
  Num examples = 2256
  Batch size = 4
Saving model checkpoint to test-waste/checkpoint-2031
Configuration saved in test-waste/checkpoint-2031/config.json
Model weights saved in test-waste/checkpoint-2031/pytorch_model.bin
Feature extractor saved in test-waste/checkpoint-2031/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 2256
  Batch size = 4
Saving model checkpoint to test-waste/checkpoint-4062
Configuration saved in test-waste/checkpoint-4062/config.json
Model weights saved in test-waste/checkpoint-4062/pytorch_model.bin
Feature extractor saved in test-waste/checkpoint-4062/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 2256
  Batch size = 4
Saving model checkpoint to test-waste/checkpoint-6093
Configuration saved in test-waste/checkpoint-6093/config.json
Model weights saved in test-waste/checkpoint-6093/pytorch_model.bin
Feature extractor saved in test-waste/checkpoint-6093/preprocessor_config.json



TrainOutput(global_step=6093, training_loss=0.17826861496259677, metrics={'train_runtime': 29273.476, 'train_samples_per_second': 2.081, 'train_steps_per_second': 0.208, 'total_flos': 4.721121975279403e+18, 'train_loss': 0.17826861496259677, 'epoch': 3.0})

In [49]:
outputs = trainer.predict(test_data)
print(outputs.metrics)

***** Running Prediction *****
  Num examples = 2513
  Batch size = 4


{'test_loss': 0.30848002433776855, 'test_accuracy': 0.9052924791086351, 'test_runtime': 188.9854, 'test_samples_per_second': 13.297, 'test_steps_per_second': 3.328}


Reference:
- https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb
- https://medium.com/@kenjiteezhen/image-classification-using-huggingface-vit-261888bfa19f