## Vision Transformer (ViT) Training with PyTorch

In this notebook, we will train a Vision Transformer (ViT) model for real-time sign language detection using PyTorch. We will use Albumentations for data augmentation and PyTorch's modules for model building and training.

The training will involve the following steps:
1. Data augmentation and preprocessing
2. Loading the dataset
3. Defining and compiling the Vision Transformer model
4. Training the model
5. Saving the trained model


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
from albumentations import Compose, Normalize, Resize, RandomCrop, HorizontalFlip, RandomBrightnessContrast
from albumentations.pytorch import ToTensorV2
import os

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.14 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.


### Data Augmentation

We will use Albumentations to perform various data augmentation techniques, including resizing, cropping, flipping, and normalization. These augmentations will help improve the generalization of the model. We also create a custom `AlbumentationsDataset` class to apply these transformations to the dataset.


In [2]:
# Define data augmentation pipeline
def get_augmentations():
    return Compose([
        Resize(224, 224),  # Resize images to the input size of ViT
        RandomCrop(224, 224),
        HorizontalFlip(),
        RandomBrightnessContrast(),
        Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
        ToTensorV2()
    ])

class AlbumentationsDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = datasets.ImageFolder(image_folder)
        self.transform = transform

    def __len__(self):
        return len(self.image_folder)

    def __getitem__(self, idx):
        image, label = self.image_folder[idx]
        image = np.array(image)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label


### Data Loading

The `load_data` function prepares the dataset by applying the defined augmentations. We use the `DataLoader` from PyTorch to load the images from directories and apply preprocessing on-the-fly.


In [3]:
# Load and preprocess the dataset
def load_data(train_dir, val_dir, batch_size):
    augmentations = get_augmentations()

    train_dataset = AlbumentationsDataset(train_dir, transform=augmentations)
    val_dataset = AlbumentationsDataset(val_dir, transform=augmentations)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader


### Model Definition

We define the Vision Transformer model using PyTorch. The model includes a pre-trained ViT with a custom classification head. The original classification head is removed and replaced with a fully connected layer tailored for our specific number of classes.


In [4]:
# Define the Vision Transformer model
class ViTModel(nn.Module):
    def __init__(self, num_classes):
        super(ViTModel, self).__init__()
        self.base_model = models.vit_b_16(pretrained=True)
        self.base_model.heads = nn.Identity()  # Remove the original classification head
        self.classifier = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        x = x[:, 0]  # Extract the CLS token
        x = self.classifier(x)
        return x


### Model Training

The `train_model` function sets up the data loaders, compiles the Vision Transformer model, and trains it using the training data. The model is evaluated on the validation set after each epoch. The best model is saved as `best_vit_model.pth`.

**Note**: Update the paths for `train_dir` and `val_dir` with the actual locations of your dataset.


In [5]:
# Main function to train the model
def train_model(train_loader, val_loader, num_classes, epochs=10, learning_rate=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ViTModel(num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1, min_lr=1e-6)

    best_accuracy = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = 100 * correct / total
        scheduler.step(epoch_loss)

        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {epoch_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Save the best model
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_vit_model.pth')

    print(f'Training complete. Best validation accuracy: {best_accuracy:.2f}%')
