In [1]:
# Cell 1: Import necessary libraries
import torch
from transformers import ViTFeatureExtractor, ViTModel
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import os
from PIL import Image


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
train_data_dir = "E:/MIT AOE/BTech/Deep Learning/Encoder Decoder/splitted_data/train"
test_data_dir = "E:/MIT AOE/BTech/Deep Learning/Encoder Decoder/splitted_data/val"
num_classes = len(os.listdir(train_data_dir))  # Counts the number of subdirectories/classes
batch_size = 32
image_size = (224, 224)  # ViT expects 224x224 resolution


In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),  # Converts the image to a tensor in range [0, 1]
])




In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.dataset = datasets.ImageFolder(root=root_dir, transform=transform)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        # Pass do_rescale=False to avoid double rescaling
        img = feature_extractor(images=img, return_tensors="pt", do_rescale=False)['pixel_values'][0]
        return img, label


# Initialize datasets
train_dataset = CustomImageDataset(train_data_dir, transform=transform)
test_dataset = CustomImageDataset(test_data_dir, transform=transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
class ViTClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ViTClassifier, self).__init__()
        self.base_model = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.base_model.eval()  # Freeze the base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        self.classifier = nn.Linear(self.base_model.config.hidden_size, num_classes)
    
    def forward(self, x):
        # Get hidden states from ViT
        outputs = self.base_model(pixel_values=x)
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token output
        logits = self.classifier(cls_output)
        return logits

# Instantiate the model
model = ViTClassifier(num_classes=num_classes)
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
optimizer = optim.Adam(model.parameters(), lr=3e-5)
criterion = nn.CrossEntropyLoss()


In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")

# Train the model
train_model(model, train_loader, criterion, optimizer, num_epochs=5)


Epoch [1/5], Loss: 1.2884
Epoch [2/5], Loss: 0.9765
Epoch [3/5], Loss: 0.8126
Epoch [4/5], Loss: 0.7056
Epoch [5/5], Loss: 0.6311


In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")

# Evaluate the model
evaluate_model(model, test_loader)


Test Accuracy: 81.10%
