<a href="https://colab.research.google.com/github/Maximilianwte/Image-tutorial/blob/main/Image_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --quiet transformers datasets optuna

#### Classification

In [None]:
from datasets import load_dataset, Dataset, Image

dataset = load_dataset("imagefolder", data_dir="Path/to/Dataset")

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 149
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 40
    })
})

In [None]:
from transformers import ViTFeatureExtractor
import torch

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [None]:
def transform_without_augmentation(example_batch):
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

prepared_ds = dataset.with_transform(transform_without_augmentation)

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [None]:
# For classification
from transformers import ViTForImageClassification
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
).to('cuda')


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from transformers import EarlyStoppingCallback

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3, # Stop training if the validation loss doesn't improve for 3 epochs
    early_stopping_threshold=0.01, # Stop training if the validation loss hasn't decreased by at least 0.01
)

In [None]:
# OPTUNA
import optuna
from transformers import TrainingArguments
from transformers import Trainer

def objective(trial):
    lr = trial.suggest_float('lr', 1e-5, 1e-3)
    batch_size = trial.suggest_categorical('batch_size', [16,32,64])

    training_args = TrainingArguments(
    output_dir="./vit-model",
    per_device_train_batch_size=batch_size,
    evaluation_strategy="steps",
    num_train_epochs=12,
    fp16=True,
    save_steps=30,
    eval_steps=30,
    logging_steps=10,
    learning_rate=lr,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
    )
      

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        train_dataset=prepared_ds['train'],
        eval_dataset=prepared_ds['test'],
        tokenizer=feature_extractor,
        callbacks=[early_stopping_callback]
    )

    train_results = trainer.train()
    trainer.save_model()
    trainer.log_metrics("train", train_results.metrics)
    trainer.save_metrics("train", train_results.metrics)
    trainer.save_state()

    metrics = trainer.evaluate(prepared_ds['test'])
    return metrics['eval_accuracy']

In [None]:
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=30)

study.best_params

In [None]:
# Finally train a model with the best hyperparameter settings outputted by optuna
training_args = TrainingArguments(
output_dir="./vit-model",
per_device_train_batch_size=64,
evaluation_strategy="steps",
num_train_epochs=20,
fp16=True,
save_steps=5,
eval_steps=5,
logging_steps=10,
learning_rate=0.0008027780105461245,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
  

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds['train'],
    eval_dataset=prepared_ds['test'],
    tokenizer=feature_extractor,
    callbacks=[early_stopping_callback]
)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


In [None]:
trainer.save_model("/content/drive/MyDrive/Research/22 Brand Fit Project/Classifier/GoodImage/model-04052023")

### Inference

In [None]:
from PIL import Image
from transformers import AutoImageProcessor
from transformers import ViTForImageClassification

model_name_or_path = 'google/vit-base-patch16-224-in21k'
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)



model = ViTForImageClassification.from_pretrained(
    "/content/drive/MyDrive/Research/22 Brand Fit Project/Classifier/GoodImage/model-28042023",
)



In [None]:
images_test = glob.glob("Path/To/TestData/*.jpg")

In [None]:
i = 4
Image.open(images_test[i])

In [None]:
image = Image.open(images_test[i])
inputs = image_processor(image, return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits


predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])