## Import packages

In [None]:
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm



## Pre-trained Vision Transformer (ViT) Model Setup for Image Classification

In [29]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained Vision Transformer model
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224', 
    num_labels=45,  
    ignore_mismatched_sizes=True  # Ignore size mismatches in the classifier
)

model.to(device)

# Load the feature extractor (for resizing and normalizing the input images)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([45]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([45, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Dataset Preparation and DataLoader Setup for ViT Image Classification

In [52]:

# Define image transforms for train and validation sets
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the images to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)  # Use ViT's mean and std
])

# Load the training and validation datasets
train_data_dir = 'NWPU-RESISC45-classification/train'
test_data_dir = 'NWPU-RESISC45-classification/test'

train_dataset = datasets.ImageFolder(train_data_dir, transform=transform)
test_dataset = datasets.ImageFolder(test_data_dir, transform=transform)

# Create DataLoaders for batching the data
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


## Optimizer and Loss Function Setup for ViT Training


In [53]:
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=2e-5)  # Use Adam with a small learning rate
criterion = nn.CrossEntropyLoss()  # Cross-entropy loss for multi-class classification


## Training Loop for Vision Transformer Model

In [54]:
def train_model(model, train_loader, val_loader, num_epochs=5):
    for epoch in range(num_epochs):
        print(f"Starting Epoch {epoch+1}/{num_epochs}")
        model.train()  # Set the model to training mode
        running_loss = 0.0
        correct = 0
        total = 0

        # Use tqdm for progress bar visualization
        for batch_idx, (inputs, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            loss.backward()  # Backpropagation
            optimizer.step()  # Update the weights

            # Track loss and accuracy
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        # Calculate accuracy and loss for this epoch
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100. * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f} Accuracy: {epoch_acc:.2f}%")

        # Validate the model at the end of each epoch
        validate_model(model, val_loader)


In [None]:
# Train the model for 5 epochs 
train_model(model, train_loader, val_loader, num_epochs=5)


Starting Epoch 1/5


Epoch 1/5: 100%|████████████████████████████| 788/788 [4:41:25<00:00, 21.43s/it]


Epoch [1/5] Loss: 0.8180 Accuracy: 84.06%
Validation Accuracy: 93.51%

Starting Epoch 2/5


Epoch 2/5:  90%|█████████████████████████   | 707/788 [4:08:00<24:02, 17.81s/it]

In [48]:
# Define the path where the model will be saved
save_path = 'vit_model.pth'

# Save the trained model
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

# Load the model back (if needed)
model.load_state_dict(torch.load(save_path, weights_only=True))

model.to(device)  # Ensure the model is on the appropriate device


Model saved to vit_model.pth


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

## Optimize the model

In [None]:
import torch.quantization as quantization

def quantize_model(model):
    model.eval()  # Ensure model is in evaluation mode
    model.qconfig = quantization.default_qconfig  # Define quantization configuration
    print("Preparing the model for quantization...")
    quantization.prepare(model, inplace=True)

    # Dummy input data to calibrate the quantized model
    input_tensor = torch.randn(1, 3, 224, 224)
    model(input_tensor)

    print("Converting model to quantized version...")
    quantization.convert(model, inplace=True)

    return model

# Quantize the trained model
quantized_model = quantize_model(model)

# Save the quantized model
torch.save(quantized_model.state_dict(), 'quantized_vit_model.pth')
print("Quantized model saved to quantized_vit_model.pth")


## Model Pruning:

In [None]:
import torch.nn.utils.prune as prune

def prune_model(model, amount=0.2):
    # Define pruning on specific layers (for example, the ViT classifier layer)
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)  # Apply L1-norm unstructured pruning

    return model

# Prune 20% of the weights
pruned_model = prune_model(model, amount=0.2)

# Save the pruned model
torch.save(pruned_model.state_dict(), 'pruned_vit_model.pth')
print("Pruned model saved to pruned_vit_model.pth")
