## Vision Transformer (ViT) Fine-Tuning for Image Classification

The Vision Transformer (ViT) model was proposed in An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. It’s the first paper that successfully trains a Transformer encoder on ImageNet, attaining very good results compared to familiar convolutional architectures.

- Refer for more details: https://huggingface.co/docs/transformers/en/model_doc/vit
- Original Paper: https://arxiv.org/abs/2010.11929

About this Notebook
```
Developer: Chintan Patel
Date: January 2025
Description: This notebook demonstrates the fine-tuning of a Vision Transformer (ViT) model for image classification using the Hugging Face Transformers library. The model is trained on a custom dataset of images, and the training process includes evaluation and saving the best model.
```

In [15]:
# Import necessary libraries
import os
import torch
from PIL import Image
from torch.utils.data import Dataset
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from transformers import ViTForImageClassification, ViTFeatureExtractor, Trainer, TrainingArguments

import warnings
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

In [16]:
# Hyperparameters for training
learning_rate = 0.0002
train_batch_size = 32
eval_batch_size = 16
seed = 42
num_epochs = 20
num_classes = 9
root_folder =  r"D:\Chintan\AI_capstone\Data\Images_300"

# Set random seed for reproducibility
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Enable mixed precision if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
# Load the ViT feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

# Custom Dataset class for handling image data
class CustomImageDataset(Dataset):
    """
    A custom dataset class for loading and preprocessing images.
    
    Args:
        image_paths (list): List of paths to the images.
        labels (list): List of labels corresponding to the images.
        feature_extractor (ViTFeatureExtractor): Feature extractor for preprocessing images.
    """
    def __init__(self, image_paths, labels, feature_extractor):
        self.image_paths = image_paths
        self.labels = labels
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]

        # Use the feature extractor to preprocess the image
        encoding = self.feature_extractor(images=image, return_tensors="pt")
        
        # Return the processed data as a dictionary
        return {
            'pixel_values': encoding['pixel_values'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Get image paths and labels
image_paths = []
labels = []

# Loop over all disease folders in the root folder
disease_folders = os.listdir(root_folder)
class_map = {disease: idx for idx, disease in enumerate(disease_folders)}

for disease in disease_folders:
    disease_folder_path = os.path.join(root_folder, disease)
    if os.path.isdir(disease_folder_path):
        for img_name in os.listdir(disease_folder_path):
            img_path = os.path.join(disease_folder_path, img_name)
            image_paths.append(img_path)
            labels.append(class_map[disease])

# Split into train and validation sets (80/20 split)
train_paths, val_paths, train_labels, val_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=seed)



In [18]:
# Create Datasets
train_dataset = CustomImageDataset(train_paths, train_labels, feature_extractor)
val_dataset = CustomImageDataset(val_paths, val_labels, feature_extractor)

# Load pre-trained ViT model
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=num_classes)

# Move model to GPU if available
model.to(device)

# Check the types of datasets and model
print(f"Train Dataset Type: {type(train_dataset)}")
print(f"Eval Dataset Type: {type(val_dataset)}")
print(f"Model Type: {type(model)}")

# Update model's config with custom label mappings (id2label, label2id)
model.config.id2label = {v: k for k, v in class_map.items()}
model.config.label2id = class_map

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Train Dataset Type: <class '__main__.CustomImageDataset'>
Eval Dataset Type: <class '__main__.CustomImageDataset'>
Model Type: <class 'transformers.models.vit.modeling_vit.ViTForImageClassification'>


In [19]:
# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=num_epochs,     # number of training epochs
    per_device_train_batch_size=train_batch_size,  # batch size for training
    per_device_eval_batch_size=eval_batch_size,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,                # log every 10 steps
    evaluation_strategy="epoch",     # evaluate every epoch
    save_strategy="epoch",           # save checkpoint every epoch
    load_best_model_at_end=True,     # load the best model when finished training
    metric_for_best_model="accuracy",  # use accuracy to determine the best model
    greater_is_better=True,
    seed=seed,
    fp16=True,
    lr_scheduler_type="linear",
    learning_rate=learning_rate,
)

# Define the metric computation function
def compute_metrics(p):
    """
    Compute accuracy for evaluation.
    
    Args:
        p (EvalPrediction): Contains predictions and labels.
    
    Returns:
        dict: Dictionary with accuracy score.
    """
    preds = p.predictions.argmax(axis=1)
    labels = p.label_ids
    accuracy = accuracy_score(labels, preds)
    return {"accuracy": accuracy}



In [20]:
# Check the types of the arguments being passed to the Trainer
print(f"Training Arguments Type: {type(training_args)}")

TrainingArguments Type: <class 'transformers.training_args.TrainingArguments'>


In [22]:
# Define the Trainer
trainer = Trainer(
    model=model,                         # the model to train
    args=training_args,                  # training arguments
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,            # evaluation dataset
    compute_metrics=compute_metrics,     # function to compute metrics
)

# Start training
trainer.train()

# Save the trained model
# Save the best model manually after training, if needed:
trainer.save_model('vit_finetuned_best_model')

# To get the name of the best checkpoint:
best_checkpoint = trainer.state.best_model_checkpoint
print(f"Best Model Checkpoint: {best_checkpoint}")

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss,Accuracy
1,2.0271,1.912062,0.542593
2,1.2698,1.21289,0.67963
3,0.9036,0.96793,0.690741
4,0.616,0.873984,0.705556
5,0.4861,0.946642,0.688889
6,0.4021,0.908524,0.705556
7,0.4218,1.042436,0.666667
8,0.3312,0.988834,0.692593
9,0.17,1.116554,0.690741
10,0.1902,1.184392,0.696296




Best Model Checkpoint: ./results\checkpoint-1156


In [23]:
# Save it to your local directory
feature_extractor.save_pretrained(r'D:\Chintan\AI_capstone')

['D:\\Chintan\\AI_capstone\\preprocessor_config.json']

In [24]:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)



***** eval metrics *****
  epoch                   =       20.0
  eval_accuracy           =     0.7444
  eval_loss               =     1.3149
  eval_runtime            = 0:00:07.78
  eval_samples_per_second =     69.366
  eval_steps_per_second   =      4.368


In [26]:
# Load model and feature extractor
model_path = './vit_finetuned_best_model'  # Path to the fine-tuned model
model = ViTForImageClassification.from_pretrained(model_path)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

# Ensure that model is in evaluation mode
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [36]:
# Load your image
image_path = "psoriasis1.jpg"  # Test image of psoriasis
image = Image.open(image_path).convert("RGB")  # Ensure RGB format

# Preprocess the image using the feature extractor
inputs = feature_extractor(images=image, return_tensors="pt")

# Run inference
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits

# Get probabilities and predictions
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
top5_probs, top5_indices = torch.topk(probabilities, 5)

# Get the class label mapping from the model config
id2label = model.config.id2label  # This contains the mapping from index to class label

# Print top 5 predictions with labels
print("Top 5 predictions:")
for i, (prob, idx) in enumerate(zip(top5_probs, top5_indices)):
    label = id2label[idx.item()]  # Directly access the label using the integer index
    print(f"{i+1}. {label}: {prob.item()*100:.2f}%")


Top 5 predictions:
1. psoriasis: 98.87%
2. f_infection: 0.51%
3. eczema: 0.15%
4. alopecia: 0.09%
5. skincancer: 0.09%
