<a href="https://colab.research.google.com/github/NoCodeProgram/deepLearning/blob/main/transformer/vitTransfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
import torch.nn as nn
import torch.optim as optim

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# Hyperparameters
num_epochs = 10
batch_size = 64
learning_rate = 0.001

In [4]:
# Data preprocessing
transform = transforms.Compose([
    transforms.Resize(224),  # ViT requires 224x224 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
# Load CIFAR10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                           download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                          download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# Load pre-trained ViT model
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
#you can freeze everything except the head
for param in model.parameters():
    param.requires_grad = False

# Unfreeze only the classification head
for param in model.heads.parameters():
    param.requires_grad = True


In [7]:
# Modify the classifier for CIFAR10 (10 classes)
num_features = model.heads.head.in_features
model.heads.head = nn.Linear(num_features, 10)
model = model.to(device)


In [8]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
print(f'Training on {device}')
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Training phase
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Print progress every 100 batches
        if (batch_idx + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], '
                  f'Loss: {running_loss/100:.4f}, Acc: {100.*correct/total:.2f}%')
            running_loss = 0.0

    # Validation phase
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    print(f'Epoch [{epoch+1}/{num_epochs}] Test Accuracy: {100.*correct/total:.2f}%')

print('Training finished!')

Training on cuda
Epoch [1/10], Step [100/782], Loss: 0.5673, Acc: 86.89%
Epoch [1/10], Step [200/782], Loss: 0.2328, Acc: 89.91%
Epoch [1/10], Step [300/782], Loss: 0.1948, Acc: 91.24%
Epoch [1/10], Step [400/782], Loss: 0.1743, Acc: 92.10%
Epoch [1/10], Step [500/782], Loss: 0.1767, Acc: 92.49%
Epoch [1/10], Step [600/782], Loss: 0.1631, Acc: 92.88%
Epoch [1/10], Step [700/782], Loss: 0.1669, Acc: 93.16%
Epoch [1/10] Test Accuracy: 94.83%
Epoch [2/10], Step [100/782], Loss: 0.1397, Acc: 95.47%
Epoch [2/10], Step [200/782], Loss: 0.1439, Acc: 95.48%
Epoch [2/10], Step [300/782], Loss: 0.1355, Acc: 95.54%
Epoch [2/10], Step [400/782], Loss: 0.1413, Acc: 95.47%
Epoch [2/10], Step [500/782], Loss: 0.1383, Acc: 95.46%
Epoch [2/10], Step [600/782], Loss: 0.1364, Acc: 95.46%
Epoch [2/10], Step [700/782], Loss: 0.1409, Acc: 95.48%
Epoch [2/10] Test Accuracy: 95.17%
Epoch [3/10], Step [100/782], Loss: 0.1225, Acc: 96.06%
Epoch [3/10], Step [200/782], Loss: 0.1192, Acc: 96.12%
Epoch [3/10], Ste